Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] aten.index_put_ runtime shape mismatch on H100 but not on A100 #126614

Open
bhack opened this issue May 18, 2024 · 9 comments
Open

[inductor] aten.index_put_ runtime shape mismatch on H100 but not on A100 #126614

bhack opened this issue May 18, 2024 · 9 comments
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@bhack
Copy link
Contributor

bhack commented May 18, 2024

馃悰 Describe the bug

Recently compiling #121504 (already analyzed there by @williamwen42) is working on A100/L4 but on H100, with the same function compiled, I got this error.

Error logs

  aten.index_put_(buf6, [reinterpret_tensor(buf0, (1, 1, 4624, 82, 82), (0, 0, 6784, 82, 1), 0)], reinterpret_tensor(buf7, (1040400, ), (1, ), 0), False)
  File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 1060, in __call__
    return self_._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape mismatch: value tensor of shape [1040400] cannot be broadcast to indexing result of shape [1033097]

Minified repro

No response

Versions

nightly

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

@bhack
Copy link
Contributor Author

bhack commented May 18, 2024

Also, on A100/L4 where it is working without runtime errors we have a different runtime error with other input sizes:

aten.index_put_(buf4, [reinterpret_tensor(buf0, (1, 1, s2*s3, 14 + s2, 14 + s3), (0, 0, 196 + (14*s2) + (14*s3) + (s2*s3), 14 + s3, 1), 0)], buf6, False)

RuntimeError: nonzero is not supported for tensors with more than INT_MAX elements,    See https://github.com/pytorch/pytorch/issues/51871

@xmfan
Copy link
Member

xmfan commented May 20, 2024

Could you share a repro or a full error trace?

@bhack
Copy link
Contributor Author

bhack commented May 20, 2024

As I mentioned the repro is at #121504

What is a full error trace to debug this?

@xmfan
Copy link
Member

xmfan commented May 20, 2024

The full stack trace that came with the RuntimeError: shape mismatch:

@bhack
Copy link
Contributor Author

bhack commented May 21, 2024

On the H100 with pytorch nightly

Traceback (most recent call last):
  File "/workspace/tools/eval.py", line 135, in <module>
    main()
  File "/workspace/tools/eval.py", line 130, in main
    main_worker(0, cfg, enable_amp=args.amp)
  File "/workspace/tools/eval.py", line 30, in main_worker
    evaluator.evaluating()
  File "/workspace/networks/managers/evaluator.py", line 442, in evaluating
    engine.add_reference_frame(current_img,
  File "/workspace/networks/engines/aotv3_engine.py", line 648, in add_reference_frame
    aot_engine.add_reference_frame(img,
  File "/workspace/networks/engines/aotv3_engine.py", line 239, in add_reference_frame
    self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs,
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/networks/models/aotv3.py", line 189, in LSTT_forward
    lstt_embs, lstt_memories = self.MSLSTT(curr_embs, long_term_memories,
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 414, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/networks/layers/transformer.py", line 581, in forward
    output, memories = layer(output,
                       ^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/networks/layers/transformer.py", line 753, in forward
    def forward(self,
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 414, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/networks/layers/attention.py", line 310, in forward
    @torch.compile
  File "/workspace/networks/layers/attention.py", line 344, in torch_dynamo_resume_in_forward_at_344
    qk = self.correlation_sampler(q, k).view(
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 548, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 998, in forward
    return compiled_fn(full_args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 203, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 118, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
                            ^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 434, in wrapper
    return compiled_fn(runtime_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1078, in __call__
    return self.current_callable(inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 927, in run
    return model(new_inputs)
           ^^^^^^^^^^^^^^^^^
  File "/tmp/torchinductor_root/ut/cutmbnzthsr64p23ilpnn2ym54twqj4lwpqj5v3shylgqucshcur.py", line 660, in call
    aten.index_put_(buf6, [reinterpret_tensor(buf0, (1, 1, 5244, 90, 83), (0, 0, 7552, 83, 1), 0)], reinterpret_tensor(buf7, (1179900, ), (1, ), 0), False)
  File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 1060, in __call__
    return self_._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape mismatch: value tensor of shape [1179900] cannot be broadcast to indexing result of shape [1165408]

@xmfan xmfan changed the title aten.index_put_ runtime shape mismatch [inductor] aten.index_put_ runtime shape mismatch on H100 but not on A100 May 21, 2024
@bhack
Copy link
Contributor Author

bhack commented May 21, 2024

@ezyang For the problem instead on the A100/L4 do you know where index_put is going to require nonzero op?
Is it derived from:

// The index is a TensorList containing kLong, kBool or kByte tensors or nulls. Byte
// tensors (boolean masks) are expanded to long tensors via nonzero(). Null

@mlazos mlazos added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 23, 2024
@ezyang
Copy link
Contributor

ezyang commented May 29, 2024

When buf0 is a boolean mask, this results in data-dependent compute (nonzero call) because we must determine all the True entries in the boolean mask to determine which entries we write to. If buf0 is integer indices this is not needed.

@bhack
Copy link
Contributor Author

bhack commented May 30, 2024

Ok so in this case we are going to hit again NVIDIA/cccl#1422 for some specific inputs.

But do you know instead what it is happening on H100 ?

@bhack
Copy link
Contributor Author

bhack commented May 30, 2024

Tested again on 20240530 nightly the error #126614 (comment) on H100 is still there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants