Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compiler.allow_in_graph does not create a call_module op in fx.Graph in torch 2.3.0 #126566

Open
kilianyp opened this issue May 17, 2024 · 12 comments
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kilianyp
Copy link

kilianyp commented May 17, 2024

馃悰 Describe the bug

I have been testing 2.3 but noticed some diverging behaviour with allow_in_graph

@torch.compiler.allow_in_graph                      
class AllowInGraphLayer(torch.nn.Module):                                                               
    def forward(self, x):                                                                             
        return x + x                                                                                  
                                                                                                      
def test_allow_in_graph_dynamo():                                                                     
    class Model(torch.nn.Module):                  
        def __init__(self):                        
            super().__init__()                                                                        
            self.add = AllowInGraphLayer()                                                              
                                                                                                      
        def forward(self, x):                      
            return self.add(x)                     
                                                   
                                                                                                                                                                                   
    def backend(gm, _):                                                                                                                                                                                     
        gm.graph.print_tabular()                                                                      
        return gm.forward                                                                             
                                                   
    model = Model()                                                                                   
    model = torch.compile(model, backend=backend)                                                     
    model(torch.rand(10)) 

Output for torch2.2

opcode       name           target         args                 kwargs
-----------  -------------  -------------  -------------------  --------
placeholder  l_x_           L_x_           ()                   {}
call_module  l__self___add  L__self___add  (l_x_,)              {}
output       output         output         ((l__self___add,),)  {}

Output for torch2.3

opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    l_x_    L_x_                     ()            {}
call_function  add     <built-in function add>  (l_x_, l_x_)  {}
output         output  output                   ((add,),)     {}

Is this expected?

Versions

For torch2.3.

>>> import torch
>>> torch.__version__
'2.3.0+cu118'

For torch2.2

>>> torch.__version__
'2.2.1+cu118'
Collecting environment information...        
PyTorch version: 2.3.0+cu118                                                                          
Is debug build: False                                                                                 
CUDA used to build PyTorch: 11.8                
ROCM used to build PyTorch: N/A                                                                                                                                                                             
                                                   
OS: Ubuntu 22.04.2 LTS (x86_64)                                                                       
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0                                                    
Clang version: Could not collect                                                                                                                                                                            
CMake version: Could not collect             
Libc version: glibc-2.35                                                                              

Python version: 3.10.6 (main, May 29 2023, 11:10:38) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.19.0-1030-gcp-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 550.54.14            
cuDNN version: Could not collect
HIP runtime version: N/A      
MIOpen runtime version: N/A     
Is XNNPACK available: True
                                                   
CPU:                                                                                                  
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          8
On-line CPU(s) list:             0-7
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family:                      6
Model:                           79
Thread(s) per core:              2
Core(s) per socket:              4
Socket(s):                       1
Stepping:                        0
BogoMIPS:                        4399.99
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology n
onstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp f
sgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities

Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       128 KiB (4 instances)
L1i cache:                       128 KiB (4 instances)
L2 cache:                        1 MiB (4 instances)
L3 cache:                        55 MiB (1 instance)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-7
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Mitigation; PTE Inversion
Vulnerability Mds:               Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Mmio stale data:   Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:          Mitigation; IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; IBRS, IBPB conditional, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Mitigation; Clear CPU buffers; SMT Host state unknown

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.3
[pip3] onnx==1.14.1
[pip3] onnxruntime==1.15.1
[pip3] torch==2.3.0+cu118
[pip3] torchaudio==2.3.0+cu118
[pip3] torchvision==0.18.0+cu118
[pip3] triton==2.3.0
[conda] Could not collect

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@kilianyp kilianyp changed the title torch.compiler.allow_in_graph does not create a call_module call in torch 2.3.0 torch.compiler.allow_in_graph does not create a call_module op in fx.Graph in torch 2.3.0 May 17, 2024
@xmfan
Copy link
Member

xmfan commented May 20, 2024

Yes, this seems expected. In this case, it looks like we inlined x + x directly into the graph as a call_function. I believe this results in faster code since it foregoes the call_module overhead. Is there a concern with this behavior?

Btw, you no longer need to add @torch.compiler.allow_in_graph for dynamo trace this snippet in 2.3

@xmfan xmfan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
@ezyang
Copy link
Contributor

ezyang commented May 20, 2024

No, this doesn't look expected to me. I assume the reason they're allowing the module in graph is because they want it to show up as is so, e.g., a custom compiler pass can pattern match on it or something.

@kilianyp
Copy link
Author

e.g., a custom compiler pass can pattern match on it or something.

Yes, to be more precise, we need a mechanism to hide ops in certain modules for our custom backend. allow_in_graph only works somewhat well for that, so I'd be happy to talk about alternatives.

I added some more context here:
#125244 (comment)

@zou3519
Copy link
Contributor

zou3519 commented May 21, 2024

We discussed this at triage review. If you just want to hide an op, instead of a full-blown nn.Module, please create a custom op (see #125244 (comment))

If you want to instead hide an nn.Module -- we're not sure this was intended to work in the first place and we don't support that well today.

@kilianyp
Copy link
Author

kilianyp commented May 21, 2024

Thanks for looking into it 馃檹

I always interpreted allow_in_graph as that the annotated function or module should be part of the graph as a call_function or call_module.

Here https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html it's described as

The annotated callable goes as is in the TorchDynamo graph. For example, a black-box for TorchDynamo Dynamo.nn. Note that AOT Autograd will trace through it, so the allow_in_graph is only a Dynamo-level concept.

which IMO it behaved accordingly to in 2.2, but no longer in 2.3.

What is the intended usage of allow_in_graph then? Or will it be deprecated?

If you just want to hide an op, instead of a full-blown nn.Module, please create a custom op

I see how this can be used, but makes it harder when working with code from a third party library, where it's for our backend also sometimes necessary to ignore some functions/modules. (to give more context, we sometimes need to ignore those functions as they are not relevant for our backend, but they still cause issues for the backend, then we need this escape hatch. For example deepspeed, which introduces several graph breaks).

For our use case ideal would be to able to easily scope all ops within a function or module with a decorator/context manager, and that information is then visible in the meta data of the op.
for example

@torch._dynamo.mark("custom_attribtute", True):
class CustomModule(torch.nn.Moddule):
	""""""
# -> node.meta["custom_attribute"]

One could use the stack trace to handle that without any changes to dynamo, but here the issue is:

  • the meta['stack_trace'] isn't always complete (I think that happens because of graph breaks). Then it's not clear if the function is scoped.
  • meta['source_fn_stack'] only contains the last function call (at least in 2.2)

@kilianyp
Copy link
Author

btw I traced it back the change in behaviour back to this PR #116312.

TBH I don't quite get the motivation behind allowing torch classes to still stay as call_module ops in the graph but 3rd libraries are not given that option.

@zou3519
Copy link
Contributor

zou3519 commented May 22, 2024

What is the intended usage of allow_in_graph then?

allow_in_graph is really just intended for PyTorch developers, not third-party developers. It's a low-level tool we use to control what goes into the graph. As the documentation says ("Note that AOT Autograd will trace through it, so the allow_in_graph is only a Dynamo-level concept."), it doesn't completely black-box a Python function throughout the entire torch.compile stack.

I do want to understand your use case more though. Is the problem that you have calls to multiple third-party APIs that induce graph breaks and you wish to avoid all of those?

@kilianyp
Copy link
Author

I do want to understand your use case more though. Is the problem that you have calls to multiple third-party APIs that induce graph breaks and you wish to avoid all of those?

We are using it to replace some operations in the graph with our own custom calls to run on our hardware.
However, some operations we don't want to run on our hardware (for example ops that can be calculated during compile time).

I am aware it is possible to implement the logic to detect those automatically, but it's nice for the user to be able to indicate that manually instead of relying the backend to correctly detect it.

For that reason, we have been using torch.compiler.disable and torch.compiler.allow_in_graph to hide that from the backend, such that it doesn't have to skip those operations.

I guess we could move forward by improving the skipping logic to automatically detect this in the backend itself, but having the escape hatch would definitely be valueable. Hence the proposal for being able to annotate only the meta data of the node which seems to me like it could be less intrusive.

I would assume that this isn't really the main use case for torch.compiler, it's also not what we are planning to use for the longterm but more for an experimental framework for quicker itertations.

Generally, the main requirement has been to get the fx graph, before we used torch.fx.symbolic_trace. I have seen make_fx https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py#L1395, and it some tests wasn't as robust torch.compile which is why went that route.
Looking into the code, my understanding is that torch.export still uses this experimental make_fx function. Could you comment what's the long term plan here? My understanding is that torch.export also uses dynamo, but it seems to reimplement some part of the logic compared torch.compile.

@ezyang
Copy link
Contributor

ezyang commented May 29, 2024

We are using it to replace some operations in the graph with our own custom calls to run on our hardware.
However, some operations we don't want to run on our hardware (for example ops that can be calculated during compile time).

Ah. You don't want allow_in_graph for this. You want a HOP. cc @Chillee The compiler would see the HOP and run all the operators in the subgraph in it as traditional eager mode.

That being said, something like constant propagation at compile time is pretty easy for the compiler to do, and you should make your compiler do it. Remember the graphs are all straight line and functionalized (if you're using AOTAutograd), so the trivial dependency analysis will work.

@zou3519
Copy link
Contributor

zou3519 commented May 29, 2024

@ezyang is the proposal to annotate operations that they don't want to run on their hardware with a HOP?

@ezyang
Copy link
Contributor

ezyang commented May 29, 2024

Yes, for that particular use case I believe a HOP is most appropriate

@zou3519
Copy link
Contributor

zou3519 commented May 29, 2024

Related: someone else wanted a generic way to annotate groups of ops so a backend can do something with them (#126393). I don't want to allow users to write their own HOPs yet (because that involves breaking open Dynamo internals), but providing a generic annotation HOP seems like it would be sufficient

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants