-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DCP sees 1/2 of the expected size of each tensor in 3D parallel #126595
Labels
module: distributed_checkpoint
oncall: distributed
Add this issue/PR to distributed oncall triage queue
Comments
mikaylagawarecki
added
the
oncall: distributed
Add this issue/PR to distributed oncall triage queue
label
May 20, 2024
What are the |
|
#127071 will fix the issue. |
BoyuanFeng
pushed a commit
to BoyuanFeng/pytorch
that referenced
this issue
May 31, 2024
…lattening when loading (pytorch#127071) Fixes pytorch#126595 **What does this PR do?** This PR unflattens the optimizer state_dict, similar to what TorchRec does. The current `get_optimizer_state_dict()` converts the parameter IDs to FQNs in order to avoid any conflict with different optimizers on different ranks. The current returned optimizer state_dict looks like the following one: ``` { "state": { "layer1.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor}, "layer2.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor}, }, "param_group": [ {"lr": 0.0, "betas": (0.9, 0.95), ..., "params": ["layer1.weight", "layer2.weight"]} ] } ``` While this can avoid the conflict and can support merging multiple optimizers use case (e.g., optimizer in backward), the current optimizer state_dict still cannot support MPMD (e.g., pipeline parallelism). The root cause is `param_group`. `param_group` cannot generate unique keys during saving -- DCP will flatten the dict but for `param_group`, DCP will get the keys like, `param_group.lr` or `param_group.params`. These keys will conflict when using pipeline parallelism. This PR flatten the optimizer state_dict to the one as the following one: ``` { "state.layer1.weight.step": 10, "state.layer2.weight.step": 10, "state.layer1.weight.exp_avg": SomeTensor, "state.layer2.weight.exp_avg": SomeTensor, "state.layer1.weight.exp_avg_sq": SomeTensor, "state.layer2.weight.exp_avg_sq": SomeTensor, "param_group.layer1.weight.lr" : 0.1, "param_group.layer2.weight.lr" : 0.1, "param_group.layer1.weight.betas" : (0.9, 0.95), "param_group.layer2.weight.betas" : (0.9, 0.95), } ``` This allows distributed state_dict (DSD) to support MPMD (e.g., pipeline parallelism). **Pros and Cons** *Pros* 1. Can support optimizer resharding (e.g., changing the parallelisms from 3D to 2D or changing the number of workers). 2. User don't need to manually add prefix to different optimizer. 3. Allow users to merge the optimizer states easily. One use case is loop-based pipeline parallelism. *Cons* 1. The implementation has a strong assumption of the structure of `param_groups` and its value. If the assumption changes or some customized optimizers do not meet the assumption, the implementations will be broken. 2. There will be extra values saved in the checkpoints. The assumption here is `param_group` generally contains scalars which are cheap to save. Pull Request resolved: pytorch#127071 Approved by: https://github.com/wconstab, https://github.com/wz337 ghstack dependencies: pytorch#127070
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
module: distributed_checkpoint
oncall: distributed
Add this issue/PR to distributed oncall triage queue
Quite possibly this is a bug in the 3d parallel implementation itself, but i'm trying to debug why i see this warning, and subsequently fail with
ValueError: Failed to validate global plan
:torch/distributed/checkpoint/default_planner.py:495] key:model.layers.0.attention.wq.weight invalid fill tensor-volume: 65536 chunks-volume: 32768
The repro is on the 8gpu CI for pytorch/torchtitan#344. (log link)
The same warning issues for every weight in the model. For the remainder i'll focus on just one,
model.layers.0.attention.wq.weight
DCP sees the shape of wq.weight as [256, 256], which is the correct full shape of the wq.weight per the model code.
some debugging..
It looks like DCP only sees 2 chunks of size 64. I'm wondering if sharding for both fsdp and TP are happening on the same dim and one of those shardings is being ignored here?
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @yf225 @chauhang @d4l3k @LucasLLC
The text was updated successfully, but these errors were encountered: