-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] aten.index_put_
runtime shape mismatch on H100 but not on A100
#126614
Comments
Also, on A100/L4 where it is working without runtime errors we have a different runtime error with other input sizes: aten.index_put_(buf4, [reinterpret_tensor(buf0, (1, 1, s2*s3, 14 + s2, 14 + s3), (0, 0, 196 + (14*s2) + (14*s3) + (s2*s3), 14 + s3, 1), 0)], buf6, False)
RuntimeError: nonzero is not supported for tensors with more than INT_MAX elements, See https://github.com/pytorch/pytorch/issues/51871 |
Could you share a repro or a full error trace? |
As I mentioned the repro is at #121504 What is a full error trace to debug this? |
The full stack trace that came with the |
On the H100 with pytorch nightly Traceback (most recent call last):
File "/workspace/tools/eval.py", line 135, in <module>
main()
File "/workspace/tools/eval.py", line 130, in main
main_worker(0, cfg, enable_amp=args.amp)
File "/workspace/tools/eval.py", line 30, in main_worker
evaluator.evaluating()
File "/workspace/networks/managers/evaluator.py", line 442, in evaluating
engine.add_reference_frame(current_img,
File "/workspace/networks/engines/aotv3_engine.py", line 648, in add_reference_frame
aot_engine.add_reference_frame(img,
File "/workspace/networks/engines/aotv3_engine.py", line 239, in add_reference_frame
self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/networks/models/aotv3.py", line 189, in LSTT_forward
lstt_embs, lstt_memories = self.MSLSTT(curr_embs, long_term_memories,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 414, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/workspace/networks/layers/transformer.py", line 581, in forward
output, memories = layer(output,
^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/networks/layers/transformer.py", line 753, in forward
def forward(self,
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 414, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/workspace/networks/layers/attention.py", line 310, in forward
@torch.compile
File "/workspace/networks/layers/attention.py", line 344, in torch_dynamo_resume_in_forward_at_344
qk = self.correlation_sampler(q, k).view(
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 548, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 998, in forward
return compiled_fn(full_args)
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 203, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 118, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 434, in wrapper
return compiled_fn(runtime_args)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1078, in __call__
return self.current_callable(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 927, in run
return model(new_inputs)
^^^^^^^^^^^^^^^^^
File "/tmp/torchinductor_root/ut/cutmbnzthsr64p23ilpnn2ym54twqj4lwpqj5v3shylgqucshcur.py", line 660, in call
aten.index_put_(buf6, [reinterpret_tensor(buf0, (1, 1, 5244, 90, 83), (0, 0, 7552, 83, 1), 0)], reinterpret_tensor(buf7, (1179900, ), (1, ), 0), False)
File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 1060, in __call__
return self_._op(*args, **(kwargs or {}))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape mismatch: value tensor of shape [1179900] cannot be broadcast to indexing result of shape [1165408] |
aten.index_put_
runtime shape mismatchaten.index_put_
runtime shape mismatch on H100 but not on A100
@ezyang For the problem instead on the A100/L4 do you know where pytorch/aten/src/ATen/native/TensorAdvancedIndexing.cpp Lines 8 to 9 in 0756f9f
|
When buf0 is a boolean mask, this results in data-dependent compute (nonzero call) because we must determine all the True entries in the boolean mask to determine which entries we write to. If buf0 is integer indices this is not needed. |
Ok so in this case we are going to hit again NVIDIA/cccl#1422 for some specific inputs. But do you know instead what it is happening on H100 ? |
Tested again on |
馃悰 Describe the bug
Recently compiling #121504 (already analyzed there by @williamwen42) is working on A100/L4 but on H100, with the same function compiled, I got this error.
Error logs
Minified repro
No response
Versions
nightly
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
The text was updated successfully, but these errors were encountered: