[Pipelining] Support for uneven microbatch sizes #126600
Labels
module: pipelining
Pipeline Parallelism
oncall: distributed
Add this issue/PR to distributed oncall triage queue
Discovered in pytorch/torchtitan#345, initially running with pp degree 3, and batch_size 8, it appears that some layers of the stack are correctly microbatch-chunking (see stage 0's output sizes of (3,..), (3,..), (2,..) for 3 microbatches, while stage 1 and 2 have output sizes of 3,3,3 which causes a mismatch at the loss layer.
It may be that we are improperly using the larger initial size buffer for all following iterations.
But we need to decide, either we raise an error on uneven microbatch sizes, or we correctly handle the buffers.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @yf225 @chauhang @d4l3k
The text was updated successfully, but these errors were encountered: