Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch-XPU] NotImplementedError: No registered fallback function for aten::view #631

Closed
uniartisan opened this issue May 19, 2024 · 7 comments
Assignees
Labels
ARC ARC GPU Crash Execution crashes XPU/GPU XPU/GPU specific issues

Comments

@uniartisan
Copy link

Describe the bug

https://github.com/uniartisan/RWKV_Pytorch/blob/dev/train/train-test.py
I have trained using the above code, and the above code operates normally on both CPU and CUDA.

To reproduce the problem, you can use the following steps:

  1. Initialize an empty small model using the test/test_rwkv_v6_init_params.py code
  2. use my dataset
    demo.zip
  3. run the train code.(dont forget to change the dataset and model path)
/home/lzy/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch/autograd/__init__.py:251: UserWarning: Error detected in NativeLayerNormBackward0. Traceback of forward call that caused the error:
  File "/home/lzy/Data/workspace/RWKV_Pytorch/train/train-test.py", line 107, in <module>
    token_out, state_new = model.forward_parallel(x_i, state)
  File "/home/lzy/Data/workspace/RWKV_Pytorch/src/model.py", line 519, in forward_parallel
    x = self.ln0(x)
  File "/home/lzy/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lzy/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lzy/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 196, in forward
    return F.layer_norm(
  File "/home/lzy/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/functional.py", line 2543, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
 (Triggered internally at /build/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  0%|                                                                                                     | 0/182 [00:09<?, ?it/s]
Traceback (most recent call last):
  File "/home/lzy/Data/workspace/RWKV_Pytorch/train/train-test.py", line 112, in <module>
    loss_weight.backward()
  File "/home/lzy/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/home/lzy/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
NotImplementedError: There were no tensor arguments to this function (e.g., you passed an empty list of Tensors), but no fallback function is registered for schema aten::view.  This usually means that this function requires a non-empty list of Tensors, or that you (the operator writer) forgot to register a fallback function.  Available functions are [CPU, XPU, Meta, QuantizedCPU, QuantizedXPU, MkldnnCPU, NestedTensorCPU, NestedTensorXPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

CPU: registered at /build/pytorch/build/aten/src/ATen/RegisterCPU.cpp:31188 [kernel]
XPU: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterXPU.cpp:10616 [kernel]
Meta: registered at /build/pytorch/build/aten/src/ATen/RegisterMeta.cpp:26829 [kernel]
QuantizedCPU: registered at /build/pytorch/build/aten/src/ATen/RegisterQuantizedCPU.cpp:951 [kernel]
QuantizedXPU: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterQuantizedXPU.cpp:518 [kernel]
MkldnnCPU: registered at /build/pytorch/build/aten/src/ATen/RegisterMkldnnCPU.cpp:515 [kernel]
NestedTensorCPU: registered at /build/pytorch/build/aten/src/ATen/RegisterNestedTensorCPU.cpp:719 [kernel]
NestedTensorXPU: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterNestedTensorXPU.cpp:843 [kernel]
BackendSelect: fallthrough registered at /build/pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /build/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /build/pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at /build/pytorch/build/aten/src/ATen/RegisterFunctionalization_3.cpp:24446 [kernel]
Named: registered at /build/pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: fallthrough registered at /build/pytorch/aten/src/ATen/ConjugateFallback.cpp:21 [kernel]
Negative: fallthrough registered at /build/pytorch/aten/src/ATen/native/NegateFallback.cpp:23 [kernel]
ZeroTensor: registered at /build/pytorch/build/aten/src/ATen/RegisterZeroTensor.cpp:161 [kernel]
ADInplaceOrView: registered at /build/pytorch/torch/csrc/autograd/generated/ADInplaceOrViewType_1.cpp:5074 [kernel]
AutogradOther: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradCPU: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradCUDA: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradHIP: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradXLA: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradMPS: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradIPU: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradXPU: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradHPU: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradVE: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradLazy: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradMTIA: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradPrivateUse1: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradPrivateUse2: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradPrivateUse3: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradMeta: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18179 [autograd kernel]
AutogradNestedTensor: registered at /build/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp:18158 [kernel]
Tracer: registered at /build/pytorch/torch/csrc/autograd/generated/TraceType_3.cpp:14610 [kernel]
AutocastCPU: fallthrough registered at /build/pytorch/aten/src/ATen/autocast_mode.cpp:382 [backend fallback]
AutocastXPU: fallthrough registered at /build/intel-pytorch-extension/csrc/gpu/aten/amp/autocast_mode.cpp:45 [backend fallback]
AutocastCUDA: fallthrough registered at /build/pytorch/aten/src/ATen/autocast_mode.cpp:249 [backend fallback]
FuncTorchBatched: registered at /build/pytorch/aten/src/ATen/functorch/BatchRulesViews.cpp:565 [kernel]
FuncTorchVmapMode: fallthrough registered at /build/pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at /build/pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1079 [kernel]
VmapMode: fallthrough registered at /build/pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /build/pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at /build/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /build/pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at /build/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at /build/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]

Versions

Collecting environment information...
PyTorch version: 2.1.0.post2+cxx11.abi
PyTorch CXX11 ABI: Yes
IPEX version: 2.1.30+xpu
IPEX commit: 474a6b3cb
Build type: Release

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: N/A
IGC version: 2024.1.0 (2024.1.0.20240308)
CMake version: N/A
Libc version: glibc-2.35

Python version: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.35
Is XPU available: True
DPCPP runtime version: 2024.1
MKL version: 2024.1
GPU models and configuration: 
[0] _DeviceProperties(name='Intel(R) Arc(TM) A770 Graphics', platform_name='Intel(R) Level-Zero', dev_type='gpu', driver_version='1.3.28202', has_fp64=0, total_memory=15473MB, max_compute_units=512, gpu_eu_count=512)
Intel OpenCL ICD version: 23.52.28202.52-821~22.04
Level Zero version: 1.3.28202.52-821~22.04

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      39 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             24
On-line CPU(s) list:                0-23
Vendor ID:                          GenuineIntel
Model name:                         13th Gen Intel(R) Core(TM) i7-13700KF
CPU family:                         6
Model:                              183
Thread(s) per core:                 2
Core(s) per socket:                 16
Socket(s):                          1
Stepping:                           1
CPU max MHz:                        5400.0000
CPU min MHz:                        800.0000
BogoMIPS:                           6835.20
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          640 KiB (16 instances)
L1i cache:                          768 KiB (16 instances)
L2 cache:                           24 MiB (10 instances)
L3 cache:                           30 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-23
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.1.30+xpu
[pip3] intel-extension-for-pytorch-deepspeed==2.1.30
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.2.4
[pip3] torch==2.1.0.post2+cxx11.abi
[pip3] torchaudio==2.1.0.post2+cxx11.abi
[pip3] torchmetrics==1.4.0.post0
[pip3] torchvision==0.16.0.post2+cxx11.abi
[conda] intel-extension-for-pytorch 2.1.30+xpu               pypi_0    pypi
[conda] intel-extension-for-pytorch-deepspeed 2.1.30                   pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] pytorch-lightning         2.2.4                    pypi_0    pypi
[conda] torch                     2.1.0.post2+cxx11.abi          pypi_0    pypi
[conda] torchaudio                2.1.0.post2+cxx11.abi          pypi_0    pypi
[conda] torchmetrics              1.4.0.post0              pypi_0    pypi
[conda] torchvision               0.16.0.post2+cxx11.abi          pypi_0    pypi
@uniartisan
Copy link
Author

https://github.com/uniartisan/RWKV_Pytorch/blob/dev/train/train-test.py#L51

Supplement some information: In the code, I implemented automatic detection of the XPU device, so when reproducing, you can manually change it or leave the 'cpu' device as is.

If you change the opset here to 16, it won't trigger the issue mentioned above. (Yes, I implemented different specific implementations for the same model).

This should help better identify which specific operation is causing the problem.

@wangkl2
Copy link
Member

wangkl2 commented May 20, 2024

Thanks for reporting this. I will try to reproduce the issue and get back to you later.

@wangkl2 wangkl2 self-assigned this May 20, 2024
@wangkl2 wangkl2 added ARC ARC GPU Windows and removed Windows labels May 20, 2024
@wangkl2
Copy link
Member

wangkl2 commented May 22, 2024

@uniartisan We have reproduced the error you met. Will get back to you after root causing. Thanks.

@wangkl2 wangkl2 added XPU/GPU XPU/GPU specific issues Crash Execution crashes labels May 30, 2024
@wangkl2
Copy link
Member

wangkl2 commented May 30, 2024

@uniartisan This NotImplementedError has been root caused and we have fixed the bug in LayerNorm layer. It works for both opset=16 and opset=18 in training RWKV now. We will have a release including this fix soon.

@uniartisan
Copy link
Author

@uniartisan This NotImplementedError has been root caused and we have fixed the bug in LayerNorm layer. It works for both opset=16 and opset=18 in training RWKV now. We will have a release including this fix soon.

Thank you for your efforts. I will try the new release as soon as it is available.

@wangkl2
Copy link
Member

wangkl2 commented Jun 3, 2024

Hi @uniartisan, this LN issue has been fixed with this commit: 97b37e2 in branch release/xpu/2.1.30/ as a patch release. You can reinstall the whls in your env via python -m pip install torch==2.1.0.post2 torchvision==0.16.0.post2 torchaudio==2.1.0.post2 intel-extension-for-pytorch==2.1.30.post0 oneccl_bind_pt==2.1.300+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/. Please try it out!

@uniartisan
Copy link
Author

It has been fixed! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ARC ARC GPU Crash Execution crashes XPU/GPU XPU/GPU specific issues
Projects
None yet
Development

No branches or pull requests

2 participants