Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipelining] Add back support for multi-use parameters/buffers #126626

Closed
kwen2501 opened this issue May 18, 2024 · 0 comments
Closed

[pipelining] Add back support for multi-use parameters/buffers #126626

kwen2501 opened this issue May 18, 2024 · 0 comments
Assignees
Labels
module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kwen2501
Copy link
Contributor

kwen2501 commented May 18, 2024

🚀 The feature, motivation and pitch

When running tracer mode with torchtitan, the follow NotImplementedError was raised:

Parameter freqs_cis used in multiple stages:
{submod_0: None, submod_1: None}.
Currently, we do not support multi-use parameters.

The source code that causes the multi-use is in Transformer's forward function:

for layer in self.layers.values():
    h = layer(h, self.freqs_cis)

The support was temporarily dropped when we refactor the tracer (_IR.py) to use unflattener. We should add it back.

Alternatives

No response

Additional context

No response

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

@kwen2501 kwen2501 self-assigned this May 18, 2024
kwen2501 added a commit that referenced this issue May 19, 2024
kwen2501 added a commit that referenced this issue May 19, 2024
Resolves #126626

ghstack-source-id: 35a3783f260d57972079289291f4ce827584d037
Pull Request resolved: #126653
@mikaylagawarecki mikaylagawarecki added oncall: distributed Add this issue/PR to distributed oncall triage queue module: pipelining Pipeline Parallelism labels May 20, 2024
@yf225 yf225 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
Aidyn-A pushed a commit to tinglvv/pytorch that referenced this issue May 30, 2024
…ch#126653)

## Motivation
Resolves pytorch#126626 to support TorchTitan.

With this PR, we add back support for cases where a parameter or buffer is used in multiple stages. An example of such usage is in LLaMA (torchtitan), code snippet:
```
for layer in self.layers.values():
    h = layer(h, self.freqs_cis)
```

## Solution
Step 1:
Remove the previous guards of `if len(node.users) == 1`.
Step 2:
Call `move_param_to_callee` multiple times, one for each stage ("callee").
Step 3:
Delay deletion of the `get_attr` node (for getting the param) from root till this param has been sunk into each stage that uses it.

The PR also cleans up the old code around this (dropping the TRANSMIT mode and supporting REPLICATE mode only).

## Test
Changed the `ExampleCode` model to use `mm_param1` in multiple stages.

Pull Request resolved: pytorch#126653
Approved by: https://github.com/pianpwk
bigfootjon pushed a commit that referenced this issue Jun 5, 2024
## Motivation
Resolves #126626 to support TorchTitan.

With this PR, we add back support for cases where a parameter or buffer is used in multiple stages. An example of such usage is in LLaMA (torchtitan), code snippet:
```
for layer in self.layers.values():
    h = layer(h, self.freqs_cis)
```

## Solution
Step 1:
Remove the previous guards of `if len(node.users) == 1`.
Step 2:
Call `move_param_to_callee` multiple times, one for each stage ("callee").
Step 3:
Delay deletion of the `get_attr` node (for getting the param) from root till this param has been sunk into each stage that uses it.

The PR also cleans up the old code around this (dropping the TRANSMIT mode and supporting REPLICATE mode only).

## Test
Changed the `ExampleCode` model to use `mm_param1` in multiple stages.

Pull Request resolved: #126653
Approved by: https://github.com/pianpwk

(cherry picked from commit 8090145)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants