Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pipelining] Support for uneven microbatch sizes #126600

Closed
wconstab opened this issue May 18, 2024 · 1 comment
Closed

[Pipelining] Support for uneven microbatch sizes #126600

wconstab opened this issue May 18, 2024 · 1 comment
Labels
module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@wconstab
Copy link
Contributor

wconstab commented May 18, 2024

Discovered in pytorch/torchtitan#345, initially running with pp degree 3, and batch_size 8, it appears that some layers of the stack are correctly microbatch-chunking (see stage 0's output sizes of (3,..), (3,..), (2,..) for 3 microbatches, while stage 1 and 2 have output sizes of 3,3,3 which causes a mismatch at the loss layer.

It may be that we are improperly using the larger initial size buffer for all following iterations.

But we need to decide, either we raise an error on uneven microbatch sizes, or we correctly handle the buffers.

[rank0]:[rank0]:V0517 17:32:21.484000 140492652921984 torch/distributed/pipelining/_PipelineStage.py:498] [Stage 0] Forwarded chunk 0, outputs: Tensor(torch.Size([3, 2048, 256]), grad=True, dtype=torch.float32)
[rank0]:[rank0]:V0517 17:32:21.485000 140492652921984 torch/distributed/pipelining/_PipelineStage.py:288] [Stage 0] Sending tensor to Stage 1: torch.Size([3, 2048, 256])
[rank0]:[rank0]:V0517 17:32:21.488000 140492652921984 torch/distributed/pipelining/_PipelineStage.py:498] [Stage 0] Forwarded chunk 1, outputs: Tensor(torch.Size([3, 2048, 256]), grad=True, dtype=torch.float32)
[rank0]:[rank0]:V0517 17:32:21.488000 140492652921984 torch/distributed/pipelining/_PipelineStage.py:288] [Stage 0] Sending tensor to Stage 1: torch.Size([3, 2048, 256])
[rank0]:[rank0]:V0517 17:32:21.490000 140492652921984 torch/distributed/pipelining/_PipelineStage.py:498] [Stage 0] Forwarded chunk 2, outputs: Tensor(torch.Size([2, 2048, 256]), grad=True, dtype=torch.float32)
[rank0]:[rank0]:V0517 17:32:21.490000 140492652921984 torch/distributed/pipelining/_PipelineStage.py:288] [Stage 0] Sending tensor to Stage 1: torch.Size([2, 2048, 256])
[rank1]:[rank1]:V0517 17:32:21.605000 139671607760000 torch/distributed/pipelining/_PipelineStage.py:498] [Stage 1] Forwarded chunk 0, outputs: Tensor(torch.Size([3, 2048, 256]), grad=True, dtype=torch.float32)
[rank1]:[rank1]:V0517 17:32:21.605000 139671607760000 torch/distributed/pipelining/_PipelineStage.py:288] [Stage 1] Sending tensor to Stage 2: torch.Size([3, 2048, 256])
[rank1]:[rank1]:V0517 17:32:21.642000 139671607760000 torch/distributed/pipelining/_PipelineStage.py:498] [Stage 1] Forwarded chunk 1, outputs: Tensor(torch.Size([3, 2048, 256]), grad=True, dtype=torch.float32)
[rank1]:[rank1]:V0517 17:32:21.643000 139671607760000 torch/distributed/pipelining/_PipelineStage.py:288] [Stage 1] Sending tensor to Stage 2: torch.Size([3, 2048, 256])
[rank1]:[rank1]:V0517 17:32:21.652000 139671607760000 torch/distributed/pipelining/_PipelineStage.py:498] [Stage 1] Forwarded chunk 2, outputs: Tensor(torch.Size([3, 2048, 256]), grad=True, dtype=torch.float32)
[rank1]:[rank1]:V0517 17:32:21.652000 139671607760000 torch/distributed/pipelining/_PipelineStage.py:288] [Stage 1] Sending tensor to Stage 2: torch.Size([3, 2048, 256])
[rank2]:[rank2]:V0517 17:32:21.822000 139947356648576 torch/distributed/pipelining/_PipelineStage.py:498] [Stage 2] Forwarded chunk 0, outputs: Tensor(torch.Size([3, 2048, 2256]), grad=True, dtype=torch.float32)
[rank2]:[rank2]:V0517 17:32:21.853000 139947356648576 torch/distributed/pipelining/PipelineSchedule.py:43] [2] Loss of microbatch 0: 8.228721618652344
[rank2]:[rank2]:V0517 17:32:21.896000 139947356648576 torch/distributed/pipelining/_PipelineStage.py:541] [Stage 2] Backwarded chunk 0
[rank2]:[rank2]:V0517 17:32:21.897000 139947356648576 torch/distributed/pipelining/_PipelineStage.py:201] [Stage 2] Grad send info: [1]
[rank2]:[rank2]:V0517 17:32:21.897000 139947356648576 torch/distributed/pipelining/_PipelineStage.py:320] [Stage 2] Sending gradient to Stage 1: torch.Size([3, 2048, 256])
[rank2]:[rank2]:V0517 17:32:21.903000 139947356648576 torch/distributed/pipelining/_PipelineStage.py:498] [Stage 2] Forwarded chunk 1, outputs: Tensor(torch.Size([3, 2048, 2256]), grad=True, dtype=torch.float32)
[rank2]:[rank2]:V0517 17:32:21.906000 139947356648576 torch/distributed/pipelining/PipelineSchedule.py:43] [2] Loss of microbatch 1: 8.194499969482422
[rank2]:[rank2]:V0517 17:32:21.909000 139947356648576 torch/distributed/pipelining/_PipelineStage.py:541] [Stage 2] Backwarded chunk 1
[rank2]:[rank2]:V0517 17:32:21.909000 139947356648576 torch/distributed/pipelining/_PipelineStage.py:320] [Stage 2] Sending gradient to Stage 1: torch.Size([3, 2048, 256])
[rank2]:[rank2]:V0517 17:32:21.911000 139947356648576 torch/distributed/pipelining/_PipelineStage.py:498] [Stage 2] Forwarded chunk 2, outputs: Tensor(torch.Size([3, 2048, 2256]), grad=True, dtype=torch.float32)
[rank2]:[rank2]: Traceback (most recent call last):
[rank2]:[rank2]:   File "/data/users/whc/torchtitan/train.py", line 468, in <module>
[rank2]:[rank2]:     main(config)
[rank2]:[rank2]:   File "/data/users/whc/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
[rank2]:[rank2]:     return f(*args, **kwargs)
[rank2]:[rank2]:   File "/data/users/whc/torchtitan/train.py", line 342, in main
[rank2]:[rank2]:     pp_schedule.step(target=labels, losses=losses)
[rank2]:[rank2]:   File "/data/users/whc/pytorch/torch/distributed/pipelining/PipelineSchedule.py", line 277, in step
[rank2]:[rank2]:     self._step_microbatches(args_split, kwargs_split, targets_split, losses)
[rank2]:[rank2]:   File "/data/users/whc/torchtitan/torchtitan/parallelisms/pipelining_utils.py", line 165, in _step_microbatches
[rank2]:[rank2]:     self._maybe_compute_loss(self._stage, output, target_mbs, i)
[rank2]:[rank2]:   File "/data/users/whc/pytorch/torch/distributed/pipelining/PipelineSchedule.py", line 41, in _maybe_compute_loss
[rank2]:[rank2]:     loss = self._compute_loss(output, target_mbs[mb_index])  # type: ignore[index]
[rank2]:[rank2]:   File "/data/users/whc/pytorch/torch/distributed/pipelining/PipelineSchedule.py", line 153, in _compute_loss
[rank2]:[rank2]:     return self._loss_fn(output, target)  # type: ignore[misc]
[rank2]:[rank2]:   File "/data/users/whc/torchtitan/train.py", line 174, in loss_fn
[rank2]:[rank2]:     return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
[rank2]:[rank2]:   File "/data/users/whc/pytorch/torch/nn/functional.py", line 3103, in cross_entropy
[rank2]:[rank2]:     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
[rank2]:[rank2]: ValueError: Expected input batch_size (6144) to match target batch_size (4096).

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @yf225 @chauhang @d4l3k

@wconstab wconstab added oncall: distributed Add this issue/PR to distributed oncall triage queue module: pipelining Pipeline Parallelism labels May 18, 2024
@wconstab
Copy link
Contributor Author

discussed offline. we do not support this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

1 participant