Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ncclCommWatchdog always terminates the process and prevents error handling if CUDA context is corrupted #126544

Open
szmigacz opened this issue May 17, 2024 · 1 comment
Assignees
Labels
module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@szmigacz
Copy link
Contributor

szmigacz commented May 17, 2024

🐛 Describe the bug

ncclCommWatchdog uses abort to terminate python interpreter process if CUDA context becomes corrupted while NCCL collective was being executed. It doesn't respect settings of TORCH_NCCL_ASYNC_ERROR_HANDLING=0 (NoHandling) or TORCH_NCCL_ASYNC_ERROR_HANDLING=2 (CleanUpOnly) or TORCH_NCCL_ENABLE_MONITORING=0.

Watchdog always terminates the process and prevents any possible error handling like (e.g. perform cleanup, log failure or notify other ranks that error happened).

Repro:

import os
import time

import torch


rank = int(os.getenv('LOCAL_RANK'))

torch.cuda.set_device(rank)
device = torch.device(f'cuda:{rank}')

torch.distributed.init_process_group(backend='nccl')

tensor = torch.ones(1024 * 1024, dtype=torch.int64, device=device)
a = torch.ones(1, dtype=torch.int64, device=device)
b = torch.ones(1, dtype=torch.int64, device=device)

torch.cuda.synchronize()

try:
    # schedule a bunch of all-reduces to fill the queue
    for _ in range(100):
        torch.distributed.all_reduce(tensor)
    # perform invalid memory access to trigger device assertion and crash
    # CUDA context while all_reduce is running
    a[b] = 0
    # sync to discover that CUDA crashed
    torch.cuda.synchronize()
except Exception as ex:
    print(f'Exception: {ex}')
    
    # my expectation is that all ranks enter this loop and keep running 
    while True:
        print(f'{rank} is running')
        time.sleep(0.05)

Run on a machine with at least 2 GPUs:

python3 -m torch.distributed.run --nproc-per-node 2 watchdog.py

All possible combinations of TORCH_NCCL_ASYNC_ERROR_HANDLING={0,1,2,3} x TORCH_NCCL_ENABLE_MONITORING={0,1} also trigger the same failure

Traceback:

WARNING:__main__:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [0,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Exception: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

0 is running
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [0,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Exception: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

1 is running
[rank0]:[E ProcessGroupNCCL.cpp:1414] [PG 0 Rank 0] Process group watchdog thread terminated with exception: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fc966b7a897 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fc966b2ab25 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fc966f2a718 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7fc91a84ae36 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7fc91a84ef38 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x77c (0x7fc91a8545ac in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7fc91a85531c in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7fc9662b0253 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7fc967c32ac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126850 (0x7fc967cc4850 in /lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG 0 Rank 0] Process group watchdog thread terminated with exception: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fc966b7a897 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fc966b2ab25 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fc966f2a718 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7fc91a84ae36 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7fc91a84ef38 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x77c (0x7fc91a8545ac in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7fc91a85531c in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7fc9662b0253 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7fc967c32ac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126850 (0x7fc967cc4850 in /lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1418 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fc966b7a897 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe32e33 (0x7fc91a4d7e33 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xdc253 (0x7fc9662b0253 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: <unknown function> + 0x94ac3 (0x7fc967c32ac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126850 (0x7fc967cc4850 in /lib/x86_64-linux-gnu/libc.so.6)

1 is running
[rank1]:[E ProcessGroupNCCL.cpp:1414] [PG 0 Rank 1] Process group watchdog thread terminated with exception: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fefa74cf897 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fefa747fb25 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fefa75a7718 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7fef5b24ae36 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7fef5b24ef38 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x77c (0x7fef5b2545ac in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7fef5b25531c in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7fefa6cb0253 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7fefa859fac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126850 (0x7fefa8631850 in /lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG 0 Rank 1] Process group watchdog thread terminated with exception: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fefa74cf897 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fefa747fb25 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fefa75a7718 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7fef5b24ae36 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7fef5b24ef38 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x77c (0x7fef5b2545ac in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7fef5b25531c in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7fefa6cb0253 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7fefa859fac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126850 (0x7fefa8631850 in /lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1418 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fefa74cf897 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe32e33 (0x7fef5aed7e33 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xdc253 (0x7fefa6cb0253 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: <unknown function> + 0x94ac3 (0x7fefa859fac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126850 (0x7fefa8631850 in /lib/x86_64-linux-gnu/libc.so.6)

E0517 08:24:03.805000 140617718325696 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: -6) local_rank: 0 (pid: 970611) of binary: /tmp/venv/bin/python3
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/tmp/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 883, in <module>
    main()
  File "/tmp/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/tmp/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 879, in main
    run(args)
  File "/tmp/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 870, in run
    elastic_launch(
  File "/tmp/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/tmp/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
=======================================================
watchdog.py FAILED
-------------------------------------------------------
Failures:
[1]:
  time      : 2024-05-17_08:24:03
  host      : ubuntu.cfxlab-colfax.private
  rank      : 1 (local_rank: 1)
  exitcode  : -6 (pid: 970612)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 970612
-------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-05-17_08:24:03
  host      : ubuntu.cfxlab-colfax.private
  rank      : 0 (local_rank: 0)
  exitcode  : -6 (pid: 970611)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 970611
=======================================================

Versions

PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-97-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: Quadro RTX 8000
GPU 1: Quadro RTX 8000

Nvidia driver version: 550.54.15
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             20
On-line CPU(s) list:                0-19
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Core(TM) i9-9820X CPU @ 3.30GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 2
Core(s) per socket:                 10
Socket(s):                          1
Stepping:                           4
CPU max MHz:                        4200.0000
CPU min MHz:                        1200.0000
BogoMIPS:                           6599.98
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti ssbd mba ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req md_clear flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          320 KiB (10 instances)
L1i cache:                          320 KiB (10 instances)
L2 cache:                           10 MiB (10 instances)
L3 cache:                           16.5 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-19
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        KVM: Mitigation: VMX disabled
Vulnerability L1tf:                 Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds:                  Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; IBRS, IBPB conditional, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; Clear CPU buffers; SMT vulnerable

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.3.0
[pip3] triton==2.3.0
[conda] Could not collect

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

@mikaylagawarecki mikaylagawarecki added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 20, 2024
@yf225 yf225 added the module: c10d Issues/PRs related to collective communications and process groups label May 20, 2024
@wconstab wconstab added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
@kwen2501
Copy link
Contributor

Looking closer to the trace stack:

frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fc966f2a718 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7fc91a84ae36 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7fc91a84ef38 in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x77c (0x7fc91a8545ac in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7fc91a85531c in /tmp/venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)

The trigger is in finishedGPUExecutionInternal. What does it do?

bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const {
  // Checking the work's corresponding CUDA event's status
  if (!ncclEndEvent_->query()) {
    return false;
  }
  return true;
}

ncclEndEvent_->query() would call cudaEventQuery(). Due to the sticky nature of CUDA errors, the error in the main thread is hit by the watchdog, though in an unexpected way.

From watchdog's point of view, it is a bit "innocent" -- it cannot distinguish whether the CUDA error is from compute kernels launched by the main thread or NCCL kernels. Should it be from the NCCL kernel, shall the watchdog not report it? It is not ideal that way either.

@eqy proposed an env to control this in #126587. Maybe that's the way to go for the moment? I mean, if we cannot decide what to do, maybe it is better to give it to the user?

Cc: @shuqiangzhang @wconstab

pytorchmergebot pushed a commit that referenced this issue May 28, 2024
…#126587)

Doesn't affect current behavior by default, for #126544
I'm not sure what the exact mechanism is here but CUDA errors appear to already be thrown in the main process, meaning that the watchdog is separately throwing CUDA errors again. However this rethrown error causes the process to be terminated as it cannot be handled from user code (which doesn't have visibility of the watchdog thread).

Pull Request resolved: #126587
Approved by: https://github.com/kwen2501
Aidyn-A pushed a commit to tinglvv/pytorch that referenced this issue May 30, 2024
…pytorch#126587)

Doesn't affect current behavior by default, for pytorch#126544
I'm not sure what the exact mechanism is here but CUDA errors appear to already be thrown in the main process, meaning that the watchdog is separately throwing CUDA errors again. However this rethrown error causes the process to be terminated as it cannot be handled from user code (which doesn't have visibility of the watchdog thread).

Pull Request resolved: pytorch#126587
Approved by: https://github.com/kwen2501
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants