-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[export] allow complex guards as runtime asserts #126627
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126627
Note: Links to docs will display an error until the docs builds have been completed. ❌ 16 New Failures, 2 Unrelated FailuresAs of commit 308beff with merge base fed536d (): NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pianpwk has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@pianpwk has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
7423bf7
to
8f7c1fe
Compare
Summary: With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that. For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of <symbol> in the specified range satisfy <guard>", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](#124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs. In this PR, heavily relying on [hybrid backed-unbacked symints](#121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](https://github.com/pytorch/pytorch/blob/d7de4c9d809697b36ae0fd9e16815f6e3b4d985b/torch/fx/experimental/symbolic_shapes.py#L1824), we add a flag `allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph. Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes: ``` # reshape def forward(self, x, y): # x: [s0, s1], y: [s2] return x.reshape([-1]) + y # guard s0 * s1 = s2 This leads to the following exported program class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1]", y: "f32[s2]"): sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0) mul: "Sym(-s2)" = -1 * sym_size_int; sym_size_int = None sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1) mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2; sym_size_int_1 = sym_size_int_2 = None add: "Sym(s0*s1 - s2)" = mul + mul_1; mul = mul_1 = None eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0; add = None _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'"); eq = None view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]); x = None add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y); view = y = None return (add_1,) ``` Another case is symbol divisibility: ``` def forward(self, x): # x: [s0, s1] return x.reshape([-1, x.shape[0] - 1]) # Eq(Mod(s0 * s1, s0 - 1), 0) ``` Applying deferred runtime asserts also helps dynamic compilation for explicit complex guards and guards around unbacked symints that may cause problems for export. For example, here we are able to compile a program containing a not-equals guard around an unbacked symint, and later guarantee correctness for a convoluted guard as part of the graph: ``` # not equal + unbacked symint def forward(self, x, y): # x: [s0], y: sizeless tensor n = y.item() if x.shape[0] != n * 2: return x * 2 + n The compiled exported program then contains the following code, without requiring the user to add in any additional torch._check() calls: def forward(self, x: "f32[s0]", y: "i64[]"): ... _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(y); y = None sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) mul: "Sym(-2*u0)" = -2 * _local_scalar_dense add: "Sym(s0 - 2*u0)" = sym_size_int + mul; sym_size_int = mul = None eq: "Sym(Eq(s0 - 2*u0, 0))" = add == 0; add = None sym_not: "Sym(Ne(s0 - 2*u0, 0))" = torch.sym_not(eq); eq = None _assert_scalar = torch.ops.aten._assert_scalar.default(sym_not, "Runtime assertion failed for expression Ne(s0 - 2*u0, 0) on node 'sym_not'"); sym_not = None ... ``` Convoluted guard: ``` def forward(self, x, y): if x.shape[0] ** 2 == y.shape[0] * 3: return x * 2.0, y * 3.0 ``` One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given. As shown above, the runtime asserts appear as math ops in the graph resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related. This PR additionally removes the recently added [_disable_forced_specializations](#124949) flag, as it is subsumed. Another note: the decision to defer explicit & unbacked symint-related guards as runtime asserts is separate from deferring "implicit" (i.e. hybrid backed-unbacked symint guards), and can be removed. Hybrid backed-unbacked symints IIUC apply only to guards emitted by torch ops, and this extends to all boolean guards, which seems like a good thing, but might not be. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang Differential Revision: D57530401 Pulled By: pianpwk
This pull request was exported from Phabricator. Differential Revision: D57530401 |
8f7c1fe
to
f4cd946
Compare
Summary: With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that. For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of <symbol> in the specified range satisfy <guard>", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](#124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs. In this PR, heavily relying on [hybrid backed-unbacked symints](#121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](https://github.com/pytorch/pytorch/blob/d7de4c9d809697b36ae0fd9e16815f6e3b4d985b/torch/fx/experimental/symbolic_shapes.py#L1824), we add a flag `allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph. Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes: ``` # reshape def forward(self, x, y): # x: [s0, s1], y: [s2] return x.reshape([-1]) + y # guard s0 * s1 = s2 This leads to the following exported program class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1]", y: "f32[s2]"): sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0) mul: "Sym(-s2)" = -1 * sym_size_int; sym_size_int = None sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1) mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2; sym_size_int_1 = sym_size_int_2 = None add: "Sym(s0*s1 - s2)" = mul + mul_1; mul = mul_1 = None eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0; add = None _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'"); eq = None view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]); x = None add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y); view = y = None return (add_1,) ``` Another case is symbol divisibility: ``` def forward(self, x): # x: [s0, s1] return x.reshape([-1, x.shape[0] - 1]) # Eq(Mod(s0 * s1, s0 - 1), 0) ``` Applying deferred runtime asserts also helps dynamic compilation for explicit complex guards and guards around unbacked symints that may cause problems for export. For example, here we are able to compile a program containing a not-equals guard around an unbacked symint, and later guarantee correctness for a convoluted guard as part of the graph: ``` # not equal + unbacked symint def forward(self, x, y): # x: [s0], y: sizeless tensor n = y.item() if x.shape[0] != n * 2: return x * 2 + n The compiled exported program then contains the following code, without requiring the user to add in any additional torch._check() calls: def forward(self, x: "f32[s0]", y: "i64[]"): ... _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(y); y = None sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) mul: "Sym(-2*u0)" = -2 * _local_scalar_dense add: "Sym(s0 - 2*u0)" = sym_size_int + mul; sym_size_int = mul = None eq: "Sym(Eq(s0 - 2*u0, 0))" = add == 0; add = None sym_not: "Sym(Ne(s0 - 2*u0, 0))" = torch.sym_not(eq); eq = None _assert_scalar = torch.ops.aten._assert_scalar.default(sym_not, "Runtime assertion failed for expression Ne(s0 - 2*u0, 0) on node 'sym_not'"); sym_not = None ... ``` Convoluted guard: ``` def forward(self, x, y): if x.shape[0] ** 2 == y.shape[0] * 3: return x * 2.0, y * 3.0 ``` One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given. As shown above, the runtime asserts appear as math ops in the graph resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related. This PR additionally removes the recently added [_disable_forced_specializations](#124949) flag, as it is subsumed. Another note: the decision to defer explicit & unbacked symint-related guards as runtime asserts is separate from deferring "implicit" (i.e. hybrid backed-unbacked symint guards), and can be removed. Hybrid backed-unbacked symints IIUC apply only to guards emitted by torch ops, and this extends to all boolean guards, which seems like a good thing, but might not be. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang Differential Revision: D57530401 Pulled By: pianpwk
This pull request was exported from Phabricator. Differential Revision: D57530401 |
Summary: With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that. For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of <symbol> in the specified range satisfy <guard>", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](#124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs. In this PR, heavily relying on [hybrid backed-unbacked symints](#121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](https://github.com/pytorch/pytorch/blob/d7de4c9d809697b36ae0fd9e16815f6e3b4d985b/torch/fx/experimental/symbolic_shapes.py#L1824), we add a flag `allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph. Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes: ``` # reshape def forward(self, x, y): # x: [s0, s1], y: [s2] return x.reshape([-1]) + y # guard s0 * s1 = s2 This leads to the following exported program class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1]", y: "f32[s2]"): sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0) mul: "Sym(-s2)" = -1 * sym_size_int; sym_size_int = None sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1) mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2; sym_size_int_1 = sym_size_int_2 = None add: "Sym(s0*s1 - s2)" = mul + mul_1; mul = mul_1 = None eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0; add = None _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'"); eq = None view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]); x = None add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y); view = y = None return (add_1,) ``` Another case is symbol divisibility: ``` def forward(self, x): # x: [s0, s1] return x.reshape([-1, x.shape[0] - 1]) # Eq(Mod(s0 * s1, s0 - 1), 0) ``` Applying deferred runtime asserts also helps dynamic compilation for "explicit" complex guards that typically cause problems for export. For example we can generate runtime asserts for not-equal guards, and complex asserts like the following: ``` class Foo(torch.nn.Module): def forward(self, x, y): # check that negation of first guard also shows up as runtime assertion if x.shape[0] == y.shape[0]: # False return x + y elif x.shape[0] == y.shape[0] ** 3: # False return x + 2, y + 3 elif x.shape[0] ** 2 == y.shape[0] * 3: # True return x * 2.0, y * 3.0 ``` For the above graph we will generate 3 runtime assertions: the negation of the first 2, and the 3rd condition as a guard. One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given. As shown above, the runtime asserts appear as math ops in the graph resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related. This PR additionally removes the recently added [_disable_forced_specializations](#124949) flag, as it is subsumed. This PR doesn't change any behavior around data-dependent errors or unbacked symints yet - we can also defer those to runtime asserts, but that requires a lot more work. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang Differential Revision: D57530401 Pulled By: pianpwk
f4cd946
to
308beff
Compare
This pull request was exported from Phabricator. Differential Revision: D57530401 |
Summary: This [PR](#126627) was testing while this [PR](#127132) got merged in, and some undesired behavior was not caught by testing: runtime asserts for complex guards got added twice, and the command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 didn't omit runtime asserts anymore. Adding back flags `should_insert_runtime_assertion` (for general runtime asserts), and `should_insert_unassociated_runtime_assertions` so guard-related asserts aren't added twice. Test Plan: Added test case which counts # of runtime asserts, and checks if it should be silenced Differential Revision: D57978699
…ded (#127554) Summary: This [PR](#126627) was testing while this [PR](#127132) got merged in, and some undesired behavior was not caught by testing: runtime asserts for complex guards got added twice, and the command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 didn't omit runtime asserts anymore. Adding back flags `should_insert_runtime_assertion` (for general runtime asserts), and `should_insert_unassociated_runtime_assertions` so guard-related asserts aren't added twice. Test Plan: Added test case which counts # of runtime asserts, and checks if it should be silenced Differential Revision: D57978699 Pulled By: pianpwk
…ded (#127554) Summary: This [PR](#126627) was testing while this [PR](#127132) got merged in, and some undesired behavior was not caught by testing: runtime asserts for complex guards got added twice, and the command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 didn't omit runtime asserts anymore. Adding back flags `should_insert_runtime_assertion` (for general runtime asserts), and `should_insert_unassociated_runtime_assertions` so guard-related asserts aren't added twice. Test Plan: Added test case which counts # of runtime asserts, and checks if it should be silenced Differential Revision: D57978699 Pulled By: pianpwk
…ded (#127554) Summary: This [PR](#126627) was testing while this [PR](#127132) got merged in, and some undesired behavior was not caught by testing: runtime asserts for complex guards got added twice, and the command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 didn't omit runtime asserts anymore. Adding back flags `should_insert_runtime_assertion` (for general runtime asserts), and `should_insert_unassociated_runtime_assertions` so guard-related asserts aren't added twice. Test Plan: Added test case which counts # of runtime asserts, and checks if it should be silenced Differential Revision: D57978699 Pulled By: pianpwk
With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that.
For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of in the specified range satisfy ", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous work went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs.
In this PR, relying on hybrid backed-unbacked symints, deferred runtime asserts, and the function _is_supported_equivalence(), we add a flag
allow_complex_guards_as_runtime_asserts
which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph.Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes:
Another case is symbol divisibility:
Applying deferred runtime asserts also helps dynamic compilation for "explicit" complex guards that typically cause problems for export. For example we can generate runtime asserts for not-equal guards, and complex asserts like the following:
For the above graph we will generate 3 runtime assertions: the negation of the first 2, and the 3rd condition as a guard.
One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given.
As shown above, the runtime asserts appear as math ops in the graph, generated by the sympy interpreter, resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting
TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1
. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related.Ideally this PR would subsume and remove the recently added _disable_forced_specializations flag, but that flag still handles one additional case of specialization: single-variable equalities where the symbol is solvable for a concrete value: see this PR
This PR doesn't change any behavior around data-dependent errors/unbacked symints yet, that could be further work.
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang