-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Fix edge case in JIT vs. AOT fusion after finalizing MultiTemplateBuffer #126622
Conversation
…TemplateBuffer [ghstack-poisoned]
…TemplateBuffer ghstack-source-id: 20fd84f9d079593008434f19d6819ce0f24a4ea1 Pull Request resolved: #126622
torch/_inductor/scheduler.py
Outdated
# V.graph.buffers.remove(new_node) | ||
# idx = V.graph.buffers.index(orig_node) | ||
# V.graph.buffers[idx] = new_node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
other option is to swap the old buffer with new buffer.
but I think sorting buffers by name at the start will be more deterministic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm going with this option
…izing MultiTemplateBuffer" # Context Here's a peripheral scenario causing the JIT-pass and AOT-pass to pick different fusions. ```py # JIT -- buf3 is a MultiTemplateBuffer V.graph.buffers = [buf0, buf1, buf2, buf3, buf4] ^ ^ # JIT pass calls finalize_multi_template_buffers() V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*] # AOT, note proximity_score(buf2, buf4) is "better" for fusion than JIT V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*] ^ ^ ``` It happens like this: * JIT starts with the original set nodes using V.graph.buffers * In JIT, finalize_multi_template_buffers() is called which can change the order of the buffers. * This makes the order of buffers/scheduler nodes different. * Now, each node's min/max-order is different than before. * As a result, the proximity between two nodes is different. https://github.com/pytorch/pytorch/blob/ad67553c5c1672d65b810acd7a6a01e11695098b/torch/_inductor/scheduler.py#L2316-L2335 # Error ``` $ TORCH_LOGS="+fusion" python test/inductor/test_max_autotune.py -k test_jit_fusion_matches_aot_fusion ====================================================================== FAIL: test_jit_fusion_matches_aot_fusion (__main__.TestMaxAutotune) ---------------------------------------------------------------------- Traceback (most recent call last): ... File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1718, in compile_to_fn code, linemap = self.codegen_with_cpp_wrapper() File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1618, in codegen_with_cpp_wrapper return self.codegen() File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1636, in codegen self.scheduler.codegen() File "/data/users/colinpeppler/pytorch/torch/_dynamo/utils.py", line 210, in time_wrapper r = func(*args, **kwargs) File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2602, in codegen self.get_backend(device).codegen_node(node) # type: ignore[possibly-undefined] File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 66, in codegen_node return self._triton_scheduling.codegen_node(node) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3377, in codegen_node return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3602, in codegen_node_schedule final_kernel.call_kernel(final_kernel.kernel_name) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3055, in call_kernel grid = wrapper.generate_default_grid(name, grid) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cpp_wrapper_cuda.py", line 174, in generate_default_grid params is not None AssertionError: cuda kernel parameters for triton_poi_fused_add_0 should already exist at this moment, only found dict_keys(['Placeholder.DESCRIPTIVE_NAME', 'triton_poi_fused_add_mul_0', 'triton_poi_fused_pow_1']) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang [ghstack-poisoned]
…TemplateBuffer ghstack-source-id: 0a0e85ca5c882591a3591bdb72a4e744340d0968 Pull Request resolved: #126622
torch/_inductor/scheduler.py
Outdated
sort_by_name = lambda n: n.get_name() | ||
nodes = sorted(nodes, key=sort_by_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tried sorting nodes by name frist but ran into test errors.
inductor/test_torchinductor.py::CpuTests::test_buffer_batch_norm_cpu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the lint issue. Otherwise, LGTM. Thanks!
…izing MultiTemplateBuffer" # Context Here's a peripheral scenario causing the JIT-pass and AOT-pass to pick different fusions. ```py # JIT -- buf3 is a MultiTemplateBuffer V.graph.buffers = [buf0, buf1, buf2, buf3, buf4] ^ ^ # JIT pass calls finalize_multi_template_buffers() V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*] # AOT, note proximity_score(buf2, buf4) is "better" for fusion than JIT V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*] ^ ^ ``` It happens like this: * JIT starts with the original set nodes using V.graph.buffers * In JIT, finalize_multi_template_buffers() is called which can change the order of the buffers. * This makes the order of buffers/scheduler nodes different. * Now, each node's min/max-order is different than before. * As a result, the proximity between two nodes is different. https://github.com/pytorch/pytorch/blob/ad67553c5c1672d65b810acd7a6a01e11695098b/torch/_inductor/scheduler.py#L2316-L2335 # Error ``` $ TORCH_LOGS="+fusion" python test/inductor/test_max_autotune.py -k test_jit_fusion_matches_aot_fusion ====================================================================== FAIL: test_jit_fusion_matches_aot_fusion (__main__.TestMaxAutotune) ---------------------------------------------------------------------- Traceback (most recent call last): ... File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1718, in compile_to_fn code, linemap = self.codegen_with_cpp_wrapper() File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1618, in codegen_with_cpp_wrapper return self.codegen() File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1636, in codegen self.scheduler.codegen() File "/data/users/colinpeppler/pytorch/torch/_dynamo/utils.py", line 210, in time_wrapper r = func(*args, **kwargs) File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2602, in codegen self.get_backend(device).codegen_node(node) # type: ignore[possibly-undefined] File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 66, in codegen_node return self._triton_scheduling.codegen_node(node) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3377, in codegen_node return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3602, in codegen_node_schedule final_kernel.call_kernel(final_kernel.kernel_name) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3055, in call_kernel grid = wrapper.generate_default_grid(name, grid) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cpp_wrapper_cuda.py", line 174, in generate_default_grid params is not None AssertionError: cuda kernel parameters for triton_poi_fused_add_0 should already exist at this moment, only found dict_keys(['Placeholder.DESCRIPTIVE_NAME', 'triton_poi_fused_add_mul_0', 'triton_poi_fused_pow_1']) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang [ghstack-poisoned]
…TemplateBuffer ghstack-source-id: 8e399722a3802747a92f78344b387160dd6d6161 Pull Request resolved: #126622
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Looks good - where did this come up ? |
The issue failed some internal models |
An internal model in production. Interesting enough the error came up after #125772. What I saw in the logs.
|
Context
Here's a peripheral scenario causing the JIT-pass and AOT-pass to pick different fusions.
It happens like this:
pytorch/torch/_inductor/scheduler.py
Lines 2316 to 2335 in ad67553
Error
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @desertfire @chauhang