-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.compiler.allow_in_graph
does not create a call_module
op in fx.Graph in torch 2.3.0
#126566
Comments
torch.compiler.allow_in_graph
does not create a call_module
call in torch 2.3.0torch.compiler.allow_in_graph
does not create a call_module
op in fx.Graph in torch 2.3.0
Yes, this seems expected. In this case, it looks like we inlined Btw, you no longer need to add @torch.compiler.allow_in_graph for dynamo trace this snippet in 2.3 |
No, this doesn't look expected to me. I assume the reason they're allowing the module in graph is because they want it to show up as is so, e.g., a custom compiler pass can pattern match on it or something. |
Yes, to be more precise, we need a mechanism to hide ops in certain modules for our custom backend. I added some more context here: |
We discussed this at triage review. If you just want to hide an op, instead of a full-blown nn.Module, please create a custom op (see #125244 (comment)) If you want to instead hide an nn.Module -- we're not sure this was intended to work in the first place and we don't support that well today. |
Thanks for looking into it 馃檹 I always interpreted Here https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html it's described as
which IMO it behaved accordingly to in 2.2, but no longer in 2.3. What is the intended usage of
I see how this can be used, but makes it harder when working with code from a third party library, where it's for our backend also sometimes necessary to ignore some functions/modules. (to give more context, we sometimes need to ignore those functions as they are not relevant for our backend, but they still cause issues for the backend, then we need this escape hatch. For example deepspeed, which introduces several graph breaks). For our use case ideal would be to able to easily scope all ops within a function or module with a decorator/context manager, and that information is then visible in the meta data of the op. @torch._dynamo.mark("custom_attribtute", True):
class CustomModule(torch.nn.Moddule):
""""""
# -> node.meta["custom_attribute"] One could use the stack trace to handle that without any changes to dynamo, but here the issue is:
|
btw I traced it back the change in behaviour back to this PR #116312. TBH I don't quite get the motivation behind allowing torch classes to still stay as |
allow_in_graph is really just intended for PyTorch developers, not third-party developers. It's a low-level tool we use to control what goes into the graph. As the documentation says ("Note that AOT Autograd will trace through it, so the allow_in_graph is only a Dynamo-level concept."), it doesn't completely black-box a Python function throughout the entire torch.compile stack. I do want to understand your use case more though. Is the problem that you have calls to multiple third-party APIs that induce graph breaks and you wish to avoid all of those? |
We are using it to replace some operations in the graph with our own custom calls to run on our hardware. I am aware it is possible to implement the logic to detect those automatically, but it's nice for the user to be able to indicate that manually instead of relying the backend to correctly detect it. For that reason, we have been using I guess we could move forward by improving the skipping logic to automatically detect this in the backend itself, but having the escape hatch would definitely be valueable. Hence the proposal for being able to annotate only the meta data of the node which seems to me like it could be less intrusive. I would assume that this isn't really the main use case for torch.compiler, it's also not what we are planning to use for the longterm but more for an experimental framework for quicker itertations. Generally, the main requirement has been to get the fx graph, before we used |
Ah. You don't want That being said, something like constant propagation at compile time is pretty easy for the compiler to do, and you should make your compiler do it. Remember the graphs are all straight line and functionalized (if you're using AOTAutograd), so the trivial dependency analysis will work. |
@ezyang is the proposal to annotate operations that they don't want to run on their hardware with a HOP? |
Yes, for that particular use case I believe a HOP is most appropriate |
Related: someone else wanted a generic way to annotate groups of ops so a backend can do something with them (#126393). I don't want to allow users to write their own HOPs yet (because that involves breaking open Dynamo internals), but providing a generic annotation HOP seems like it would be sufficient |
馃悰 Describe the bug
I have been testing 2.3 but noticed some diverging behaviour with
allow_in_graph
Output for torch2.2
Output for torch2.3
Is this expected?
Versions
For torch2.3.
For torch2.2
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng
The text was updated successfully, but these errors were encountered: