Distributed Checkpoint doesn't verify shapes are correct #126604
Labels
module: distributed_checkpoint
oncall: distributed checkpointing
Oncall label should be attached to any issues related to distributed checkpointing.
oncall: distributed
Add this issue/PR to distributed oncall triage queue
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
I blindly
torch.cat
the 8 Llama3-70B checkpoints, save with DCP, then dcp.load it into a correctly-sharded model (2D DTensor + FSDP).torch.cat
always cats on first dimension, so we get shapes like these:DCP doesn't complain about this when loading, even though the shapes are completely wrong (example: w2 should be
8192, 28672
). Especially the norm weights, which have 8x as many weights as they should.Versions
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC
The text was updated successfully, but these errors were encountered: