Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] allow complex guards as runtime asserts #126627

Closed

Conversation

pianpwk
Copy link
Contributor

@pianpwk pianpwk commented May 18, 2024

With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that.

For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of in the specified range satisfy ", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous work went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs.

In this PR, relying on hybrid backed-unbacked symints, deferred runtime asserts, and the function _is_supported_equivalence(), we add a flag allow_complex_guards_as_runtime_asserts which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph.

Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes:

# reshape
def forward(self, x, y):  # x: [s0, s1], y: [s2]
    return x.reshape([-1]) + y  # guard s0 * s1 = s2

This leads to the following exported program

class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[s0, s1]", y: "f32[s2]"):
        sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0)
        mul: "Sym(-s2)" = -1 * sym_size_int;  sym_size_int = None
        sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
        sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
        mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2;  sym_size_int_1 = sym_size_int_2 = None
        add: "Sym(s0*s1 - s2)" = mul + mul_1;  mul = mul_1 = None
        eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0;  add = None
        _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'");  eq = None

        view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]);  x = None
        add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y);  view = y = None
        return (add_1,)

Another case is symbol divisibility:

def forward(self, x):  # x: [s0, s1]
    return x.reshape([-1, x.shape[0] - 1])  # Eq(Mod(s0 * s1, s0 - 1), 0)

Applying deferred runtime asserts also helps dynamic compilation for "explicit" complex guards that typically cause problems for export. For example we can generate runtime asserts for not-equal guards, and complex asserts like the following:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        # check that negation of first guard also shows up as runtime assertion
        if x.shape[0] == y.shape[0]:  # False
            return x + y
        elif x.shape[0] == y.shape[0] ** 3:  # False
            return x + 2, y + 3
        elif x.shape[0] ** 2 == y.shape[0] * 3:  # True
            return x * 2.0, y * 3.0

For the above graph we will generate 3 runtime assertions: the negation of the first 2, and the 3rd condition as a guard.

One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given.

As shown above, the runtime asserts appear as math ops in the graph, generated by the sympy interpreter, resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related.

Ideally this PR would subsume and remove the recently added _disable_forced_specializations flag, but that flag still handles one additional case of specialization: single-variable equalities where the symbol is solvable for a concrete value: see this PR

This PR doesn't change any behavior around data-dependent errors/unbacked symints yet, that could be further work.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang

Copy link

pytorch-bot bot commented May 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126627

Note: Links to docs will display an error until the docs builds have been completed.

❌ 16 New Failures, 2 Unrelated Failures

As of commit 308beff with merge base fed536d (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo release notes: fx release notes category labels May 18, 2024
@facebook-github-bot
Copy link
Contributor

@pianpwk has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@pianpwk has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot facebook-github-bot force-pushed the pianpwk/allow_complex_guards_as_runtime_asserts branch from 7423bf7 to 8f7c1fe Compare May 20, 2024 23:21
facebook-github-bot pushed a commit that referenced this pull request May 20, 2024
Summary:
With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that.

For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of <symbol> in the specified range satisfy <guard>", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](#124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs.

In this PR, heavily relying on [hybrid backed-unbacked symints](#121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](https://github.com/pytorch/pytorch/blob/d7de4c9d809697b36ae0fd9e16815f6e3b4d985b/torch/fx/experimental/symbolic_shapes.py#L1824), we add a flag `allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph.

Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes:
```
# reshape
def forward(self, x, y):  # x: [s0, s1], y: [s2]
    return x.reshape([-1]) + y  # guard s0 * s1 = s2

This leads to the following exported program

class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[s0, s1]", y: "f32[s2]"):
        sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0)
        mul: "Sym(-s2)" = -1 * sym_size_int;  sym_size_int = None
        sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
        sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
        mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2;  sym_size_int_1 = sym_size_int_2 = None
        add: "Sym(s0*s1 - s2)" = mul + mul_1;  mul = mul_1 = None
        eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0;  add = None
        _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'");  eq = None
        
        view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]);  x = None
        add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y);  view = y = None
        return (add_1,)
```
Another case is symbol divisibility:
```
def forward(self, x):  # x: [s0, s1]
    return x.reshape([-1, x.shape[0] - 1])  # Eq(Mod(s0 * s1, s0 - 1), 0)
```

Applying deferred runtime asserts also helps dynamic compilation for explicit complex guards and guards around unbacked symints that may cause problems for export. For example, here we are able to compile a program containing a not-equals guard around an unbacked symint, and later guarantee correctness for a convoluted guard as part of the graph:
```
# not equal + unbacked symint
def forward(self, x, y):  # x: [s0], y: sizeless tensor
    n = y.item()
    if x.shape[0] != n * 2:
        return x * 2 + n

The compiled exported program then contains the following code, without requiring the user to add in any additional torch._check() calls:
    
def forward(self, x: "f32[s0]", y: "i64[]"):
    ...
    _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(y);  y = None
    sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
    mul: "Sym(-2*u0)" = -2 * _local_scalar_dense
    add: "Sym(s0 - 2*u0)" = sym_size_int + mul;  sym_size_int = mul = None
    eq: "Sym(Eq(s0 - 2*u0, 0))" = add == 0;  add = None
    sym_not: "Sym(Ne(s0 - 2*u0, 0))" = torch.sym_not(eq);  eq = None
    _assert_scalar = torch.ops.aten._assert_scalar.default(sym_not, "Runtime assertion failed for expression Ne(s0 - 2*u0, 0) on node 'sym_not'");  sym_not = None
    ...
```
Convoluted guard:
```
def forward(self, x, y):
    if x.shape[0] ** 2 == y.shape[0] * 3:
        return x * 2.0, y * 3.0
```

One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given.

As shown above, the runtime asserts appear as math ops in the graph resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related.

This PR additionally removes the recently added [_disable_forced_specializations](#124949) flag, as it is subsumed. 

Another note: the decision to defer explicit & unbacked symint-related guards as runtime asserts is separate from deferring "implicit" (i.e. hybrid backed-unbacked symint guards), and can be removed. Hybrid backed-unbacked symints IIUC apply only to guards emitted by torch ops, and this extends to all boolean guards, which seems like a good thing, but might not be.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang


Differential Revision: D57530401

Pulled By: pianpwk
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D57530401

@facebook-github-bot facebook-github-bot force-pushed the pianpwk/allow_complex_guards_as_runtime_asserts branch from 8f7c1fe to f4cd946 Compare May 20, 2024 23:23
facebook-github-bot pushed a commit that referenced this pull request May 20, 2024
Summary:
With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that.

For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of <symbol> in the specified range satisfy <guard>", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](#124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs.

In this PR, heavily relying on [hybrid backed-unbacked symints](#121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](https://github.com/pytorch/pytorch/blob/d7de4c9d809697b36ae0fd9e16815f6e3b4d985b/torch/fx/experimental/symbolic_shapes.py#L1824), we add a flag `allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph.

Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes:
```
# reshape
def forward(self, x, y):  # x: [s0, s1], y: [s2]
    return x.reshape([-1]) + y  # guard s0 * s1 = s2

This leads to the following exported program

class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[s0, s1]", y: "f32[s2]"):
        sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0)
        mul: "Sym(-s2)" = -1 * sym_size_int;  sym_size_int = None
        sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
        sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
        mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2;  sym_size_int_1 = sym_size_int_2 = None
        add: "Sym(s0*s1 - s2)" = mul + mul_1;  mul = mul_1 = None
        eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0;  add = None
        _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'");  eq = None
        
        view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]);  x = None
        add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y);  view = y = None
        return (add_1,)
```
Another case is symbol divisibility:
```
def forward(self, x):  # x: [s0, s1]
    return x.reshape([-1, x.shape[0] - 1])  # Eq(Mod(s0 * s1, s0 - 1), 0)
```

Applying deferred runtime asserts also helps dynamic compilation for explicit complex guards and guards around unbacked symints that may cause problems for export. For example, here we are able to compile a program containing a not-equals guard around an unbacked symint, and later guarantee correctness for a convoluted guard as part of the graph:
```
# not equal + unbacked symint
def forward(self, x, y):  # x: [s0], y: sizeless tensor
    n = y.item()
    if x.shape[0] != n * 2:
        return x * 2 + n

The compiled exported program then contains the following code, without requiring the user to add in any additional torch._check() calls:
    
def forward(self, x: "f32[s0]", y: "i64[]"):
    ...
    _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(y);  y = None
    sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
    mul: "Sym(-2*u0)" = -2 * _local_scalar_dense
    add: "Sym(s0 - 2*u0)" = sym_size_int + mul;  sym_size_int = mul = None
    eq: "Sym(Eq(s0 - 2*u0, 0))" = add == 0;  add = None
    sym_not: "Sym(Ne(s0 - 2*u0, 0))" = torch.sym_not(eq);  eq = None
    _assert_scalar = torch.ops.aten._assert_scalar.default(sym_not, "Runtime assertion failed for expression Ne(s0 - 2*u0, 0) on node 'sym_not'");  sym_not = None
    ...
```
Convoluted guard:
```
def forward(self, x, y):
    if x.shape[0] ** 2 == y.shape[0] * 3:
        return x * 2.0, y * 3.0
```

One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given.

As shown above, the runtime asserts appear as math ops in the graph resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related.

This PR additionally removes the recently added [_disable_forced_specializations](#124949) flag, as it is subsumed. 

Another note: the decision to defer explicit & unbacked symint-related guards as runtime asserts is separate from deferring "implicit" (i.e. hybrid backed-unbacked symint guards), and can be removed. Hybrid backed-unbacked symints IIUC apply only to guards emitted by torch ops, and this extends to all boolean guards, which seems like a good thing, but might not be.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang


Differential Revision: D57530401

Pulled By: pianpwk
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D57530401

Summary:
With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that.

For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of <symbol> in the specified range satisfy <guard>", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](#124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs.

In this PR, heavily relying on [hybrid backed-unbacked symints](#121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](https://github.com/pytorch/pytorch/blob/d7de4c9d809697b36ae0fd9e16815f6e3b4d985b/torch/fx/experimental/symbolic_shapes.py#L1824), we add a flag `allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph.

Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes:
```
# reshape
def forward(self, x, y):  # x: [s0, s1], y: [s2]
    return x.reshape([-1]) + y  # guard s0 * s1 = s2

This leads to the following exported program

class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[s0, s1]", y: "f32[s2]"):
        sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0)
        mul: "Sym(-s2)" = -1 * sym_size_int;  sym_size_int = None
        sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
        sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
        mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2;  sym_size_int_1 = sym_size_int_2 = None
        add: "Sym(s0*s1 - s2)" = mul + mul_1;  mul = mul_1 = None
        eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0;  add = None
        _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'");  eq = None

        view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]);  x = None
        add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y);  view = y = None
        return (add_1,)
```
Another case is symbol divisibility:
```
def forward(self, x):  # x: [s0, s1]
    return x.reshape([-1, x.shape[0] - 1])  # Eq(Mod(s0 * s1, s0 - 1), 0)
```

Applying deferred runtime asserts also helps dynamic compilation for "explicit" complex guards that typically cause problems for export. For example we can generate runtime asserts for not-equal guards, and complex asserts like the following:
```
class Foo(torch.nn.Module):
    def forward(self, x, y):
        # check that negation of first guard also shows up as runtime assertion
        if x.shape[0] == y.shape[0]:  # False
            return x + y
        elif x.shape[0] == y.shape[0] ** 3:  # False
            return x + 2, y + 3
        elif x.shape[0] ** 2 == y.shape[0] * 3:  # True
            return x * 2.0, y * 3.0
```
For the above graph we will generate 3 runtime assertions: the negation of the first 2, and the 3rd condition as a guard.

One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given.

As shown above, the runtime asserts appear as math ops in the graph resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related.

This PR additionally removes the recently added [_disable_forced_specializations](#124949) flag, as it is subsumed.

This PR doesn't change any behavior around data-dependent errors or unbacked symints yet - we can also defer those to runtime asserts, but that requires a lot more work.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang


Differential Revision: D57530401

Pulled By: pianpwk
@facebook-github-bot facebook-github-bot force-pushed the pianpwk/allow_complex_guards_as_runtime_asserts branch from f4cd946 to 308beff Compare May 21, 2024 21:47
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D57530401

@pianpwk pianpwk closed this May 24, 2024
facebook-github-bot pushed a commit that referenced this pull request May 30, 2024
Summary:
This [PR](#126627) was testing while this [PR](#127132) got merged in, and some undesired behavior was not caught by testing: runtime asserts for complex guards got added twice, and the command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 didn't omit runtime asserts anymore.

Adding back flags `should_insert_runtime_assertion` (for general runtime asserts), and `should_insert_unassociated_runtime_assertions` so guard-related asserts aren't added twice.

Test Plan: Added test case which counts # of runtime asserts, and checks if it should be silenced

Differential Revision: D57978699
pianpwk added a commit that referenced this pull request May 31, 2024
…ded (#127554)

Summary:
This [PR](#126627) was testing while this [PR](#127132) got merged in, and some undesired behavior was not caught by testing: runtime asserts for complex guards got added twice, and the command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 didn't omit runtime asserts anymore.

Adding back flags `should_insert_runtime_assertion` (for general runtime asserts), and `should_insert_unassociated_runtime_assertions` so guard-related asserts aren't added twice.


Test Plan: Added test case which counts # of runtime asserts, and checks if it should be silenced

Differential Revision: D57978699

Pulled By: pianpwk
facebook-github-bot pushed a commit that referenced this pull request May 31, 2024
…ded (#127554)

Summary:
This [PR](#126627) was testing while this [PR](#127132) got merged in, and some undesired behavior was not caught by testing: runtime asserts for complex guards got added twice, and the command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 didn't omit runtime asserts anymore.

Adding back flags `should_insert_runtime_assertion` (for general runtime asserts), and `should_insert_unassociated_runtime_assertions` so guard-related asserts aren't added twice.


Test Plan: Added test case which counts # of runtime asserts, and checks if it should be silenced

Differential Revision: D57978699

Pulled By: pianpwk
pianpwk added a commit that referenced this pull request May 31, 2024
…ded (#127554)

Summary:
This [PR](#126627) was testing while this [PR](#127132) got merged in, and some undesired behavior was not caught by testing: runtime asserts for complex guards got added twice, and the command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 didn't omit runtime asserts anymore.

Adding back flags `should_insert_runtime_assertion` (for general runtime asserts), and `should_insert_unassociated_runtime_assertions` so guard-related asserts aren't added twice.


Test Plan: Added test case which counts # of runtime asserts, and checks if it should be silenced

Differential Revision: D57978699

Pulled By: pianpwk
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor fb-exported module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo release notes: fx release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants