Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

inductor compile failling for batched tensor on fuse #126617

Closed
johnnv1 opened this issue May 18, 2024 · 2 comments
Closed

inductor compile failling for batched tensor on fuse #126617

johnnv1 opened this issue May 18, 2024 · 2 comments
Labels
oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2

Comments

@johnnv1
Copy link

johnnv1 commented May 18, 2024

馃悰 Describe the bug

I'm not sure if the title of this issue is right, so please change this to whatever you think is most appropriate... in short, I couldn't decipher the cause of the problem :smile

When running torch 2.3.0 against kornia test suite we got some new errors around dynamo (CI in kornia/kornia#2912). For the kornia.geometry.transform.homography_warp operator it's failing to run within batch size != 1.

It's working fine in the previous torch version, and now it's working on CUDA, but not on CPU. Eager mode is ok, and it's failing under the inductor backend.

The falling code snippet

import logging

import torch

import kornia
from kornia.utils import eye_like

torch._logging.set_logs(dynamo=logging.DEBUG)
torch._dynamo.config.verbose = True

align_corners = True
normalized_coordinates = True
device = torch.device("cpu")
dtype = torch.float32
batch_size = 3

# generate input data
height, width = 128, 64
eye_size = 3  # identity 3x3

patch_src = torch.rand(batch_size, 1, height, width, device=device, dtype=dtype)

# create base homography
dst_homo_src = eye_like(eye_size, patch_src)

# generate homography noise
homo_delta = torch.rand_like(dst_homo_src) * 0.3

dst_homo_src_i = dst_homo_src + homo_delta


op_optimized = torch.compile(kornia.geometry.transform.homography_warp, backend="inductor")

patch_dst_optimized = op_optimized(
    patch_src,
    dst_homo_src_i,
    (height, width),
    align_corners=align_corners,
    normalized_coordinates=normalized_coordinates,
)

Error logs

Traceback:

0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] TRACED GRAPH TENSOR SIZES
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] ===== __compiled_fn_0 =====
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] l_src_homo_dst_: (3, 3, 3)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] l_patch_src_: (3, 1, 128, 64)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] xs: (64,)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] ys: (128,)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] truediv: (64,)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] sub: (64,)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] xs_1: (64,)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] truediv_1: (128,)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] sub_1: (128,)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] ys_1: (128,)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] getitem: (64, 128)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] getitem_1: (64, 128)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] base_grid: (64, 128, 2)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] permute: (128, 64, 2)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] grid: (1, 128, 64, 2)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] grid_1: (3, 128, 64, 2)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] src_homo_dst: (3, 1, 3, 3)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] to: (3, 128, 64, 2)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] points_1: (384, 64, 2)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] trans_1: (3, 3, 3)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] trans_2: (384, 3, 3)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] points_1_h: (384, 64, 3)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] permute_1: (384, 3, 3)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] points_0_h: (384, 64, 3)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] points_0_h_1: (384, 64, 3)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] z_vec: (384, 64, 1)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] abs_1: (384, 64, 1)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] mask: (384, 64, 1)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] add: (384, 64, 1)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] truediv_2: (384, 64, 1)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] ones_like: (384, 64, 1)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] scale: (384, 64, 1)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] getitem_3: (384, 64, 2)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] points_0: (384, 64, 2)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] flow: (3, 128, 64, 2)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] warped_grid: (3, 128, 64, 2)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] grid_sample: (3, 1, 128, 64)
V0518 11:04:26.829000 124833143125056 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] 
I0518 11:04:26.829000 124833143125056 torch/_dynamo/logging.py:55] [0/0] Step 2: calling compiler function inductor
V0518 11:04:27.377000 124833143125056 torch/fx/experimental/symbolic_shapes.py:4119] [0/0] eval True == True [statically known]
V0518 11:04:27.399000 124833143125056 torch/fx/experimental/symbolic_shapes.py:4119] [0/0] eval False == False [statically known]
W0518 11:04:28.477000 124833143125056 torch/_dynamo/repro/after_dynamo.py:98] [0/0] Compiled Fx GraphModule failed. Creating script to minify the error.
W0518 11:04:28.480000 124833143125056 torch/_dynamo/debug_utils.py:276] [0/0] Writing minified repro to:
W0518 11:04:28.480000 124833143125056 torch/_dynamo/debug_utils.py:276] [0/0] /tmp/kornia/torch_compile_debug/run_2024_05_18_11_04_28_479708-pid_21799/minifier/minifier_launcher.py

Traceback (most recent call last):
  File "/tmp/kornia/t.py", line 34, in <module>
    patch_dst_optimized = op_optimized(
                          ^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
    result = inner_convert(
             ^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
        ^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2268, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 971, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/tmp/kornia/venv/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1168, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1241, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1222, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/__init__.py", line 1729, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1330, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 58, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 903, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 628, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 443, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 648, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 119, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1257, in fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/debug.py", line 304, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 438, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
                     ^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 714, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1307, in compile_to_fn
    return self.compile_to_module().call
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1250, in compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
                                                             ^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1205, in codegen
    self.scheduler = Scheduler(self.buffers)
                     ^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py", line 1312, in __init__
    self.fuse_nodes()
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py", line 1686, in fuse_nodes
    self.fuse_nodes_once()
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py", line 1830, in fuse_nodes_once
    node3 = self.get_backend(device).fuse(node1, node2)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py", line 3532, in fuse
    assert vars1 == vars2, (vars1, vars2)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: ((3, 8192), (24576,))

Minified repro

# $TORCHDYNAMO_REPRO_AFTER="dynamo" python t.py
from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config
torch._dynamo.config.verbose = True

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()



    def forward(self, L_src_homo_dst_ : torch.Tensor, L_patch_src_ : torch.Tensor):
        l_src_homo_dst_ = L_src_homo_dst_
        l_patch_src_ = L_patch_src_
        xs = torch.linspace(0, 63, 64, device = device(type='cpu'), dtype = torch.float32)
        ys = torch.linspace(0, 127, 128, device = device(type='cpu'), dtype = torch.float32)
        truediv = xs / 63;  xs = None
        sub = truediv - 0.5;  truediv = None
        xs_1 = sub * 2;  sub = None
        truediv_1 = ys / 127;  ys = None
        sub_1 = truediv_1 - 0.5;  truediv_1 = None
        ys_1 = sub_1 * 2;  sub_1 = None
        meshgrid = torch.functional.meshgrid([xs_1, ys_1], indexing = 'ij');  xs_1 = ys_1 = None
        getitem = meshgrid[0]
        getitem_1 = meshgrid[1];  meshgrid = None
        base_grid = torch.stack((getitem, getitem_1), dim = -1);  getitem = getitem_1 = None
        permute = base_grid.permute(1, 0, 2);  base_grid = None
        grid = permute.unsqueeze(0);  permute = None
        grid_1 = grid.expand(3, -1, -1, -1);  grid = None
        src_homo_dst = l_src_homo_dst_.view(3, 1, 3, 3);  l_src_homo_dst_ = None
        to = grid_1.to(src_homo_dst);  grid_1 = None
        points_1 = to.reshape(-1, 64, 2);  to = None
        trans_1 = src_homo_dst.reshape(-1, 3, 3);  src_homo_dst = None
        trans_2 = torch.repeat_interleave(trans_1, repeats = 128, dim = 0);  trans_1 = None
        points_1_h = torch._C._nn.pad(points_1, [0, 1], 'constant', 1.0);  points_1 = None
        permute_1 = trans_2.permute(0, 2, 1);  trans_2 = None
        points_0_h = torch.bmm(points_1_h, permute_1);  points_1_h = permute_1 = None
        points_0_h_1 = torch.squeeze(points_0_h, dim = -1);  points_0_h = None
        z_vec = points_0_h_1[(Ellipsis, slice(-1, None, None))]
        abs_1 = torch.abs(z_vec)
        mask = abs_1 > 1e-08;  abs_1 = None
        add = z_vec + 1e-08
        truediv_2 = 1.0 / add;  add = None
        ones_like = torch.ones_like(z_vec);  z_vec = None
        scale = torch.where(mask, truediv_2, ones_like);  mask = truediv_2 = ones_like = None
        getitem_3 = points_0_h_1[(Ellipsis, slice(None, -1, None))];  points_0_h_1 = None
        points_0 = scale * getitem_3;  scale = getitem_3 = None
        flow = points_0.reshape([3, 128, 64, 2]);  points_0 = None
        warped_grid = flow.view(3, 128, 64, 2);  flow = None
        grid_sample = torch.nn.functional.grid_sample(l_patch_src_, warped_grid, mode = 'bilinear', padding_mode = 'zeros', align_corners = True);  l_patch_src_ = warped_grid = None
        return (grid_sample,)


mod = Repro()

def load_args(reader):
    buf0 = reader.storage('18c6cbf7f976b2076947af5f4dcb0618d1439075', 108)
    reader.tensor(buf0, (3, 3, 3), is_leaf=True)  # L_src_homo_dst_
    buf1 = reader.storage('f6ebd3c18a355ea3ca38dee4894ff94ed45abcdd', 98304)
    reader.tensor(buf1, (3, 1, 128, 64), is_leaf=True)  # L_patch_src_
load_args._version = 0

if __name__ == '__main__':
    from torch._dynamo.repro.after_dynamo import run_repro
    run_repro(mod, load_args, accuracy=False, command='minify',
        save_dir='/tmp/kornia/torch_compile_debug/run_2024_05_18_10_51_25_841050-pid_19688/minifier/checkpoints', autocast=False, backend='inductor')
# $TORCHDYNAMO_REPRO_AFTER="aot" python t.py
import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
import torch._inductor.inductor_prims

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config
torch._dynamo.config.verbose = True





isolate_fails_code_str = None



# torch version: 2.3.0
# torch cuda version: 12.1
# torch git version: 97ff6cfd9c86c5c09d7ce775ab64ec5c99230f5d


# CUDA Info: 
# nvcc not found
# GPU Hardware Info: 
# NVIDIA GeForce RTX 3060 Ti : 1 


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    
    
    def forward(self, arg0_1, arg1_1):
        iota = torch.ops.prims.iota.default(64, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        lt = torch.ops.aten.lt.Scalar(iota, 32.0)
        convert_element_type = torch.ops.prims.convert_element_type.default(iota, torch.float32)
        mul = torch.ops.aten.mul.Tensor(convert_element_type, 1.0);  convert_element_type = None
        add = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        sub = torch.ops.aten.sub.Tensor(63, iota);  iota = None
        convert_element_type_1 = torch.ops.prims.convert_element_type.default(sub, torch.float32);  sub = None
        mul_1 = torch.ops.aten.mul.Tensor(convert_element_type_1, 1.0);  convert_element_type_1 = None
        sub_1 = torch.ops.aten.sub.Tensor(63, mul_1);  mul_1 = None
        where = torch.ops.aten.where.self(lt, add, sub_1);  lt = add = sub_1 = None
        iota_1 = torch.ops.prims.iota.default(128, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        lt_1 = torch.ops.aten.lt.Scalar(iota_1, 64.0)
        convert_element_type_2 = torch.ops.prims.convert_element_type.default(iota_1, torch.float32)
        mul_2 = torch.ops.aten.mul.Tensor(convert_element_type_2, 1.0);  convert_element_type_2 = None
        add_1 = torch.ops.aten.add.Tensor(mul_2, 0);  mul_2 = None
        sub_2 = torch.ops.aten.sub.Tensor(127, iota_1);  iota_1 = None
        convert_element_type_3 = torch.ops.prims.convert_element_type.default(sub_2, torch.float32);  sub_2 = None
        mul_3 = torch.ops.aten.mul.Tensor(convert_element_type_3, 1.0);  convert_element_type_3 = None
        sub_3 = torch.ops.aten.sub.Tensor(127, mul_3);  mul_3 = None
        where_1 = torch.ops.aten.where.self(lt_1, add_1, sub_3);  lt_1 = add_1 = sub_3 = None
        div = torch.ops.aten.div.Tensor(where, 63);  where = None
        sub_4 = torch.ops.aten.sub.Tensor(div, 0.5);  div = None
        mul_4 = torch.ops.aten.mul.Tensor(sub_4, 2);  sub_4 = None
        div_1 = torch.ops.aten.div.Tensor(where_1, 127);  where_1 = None
        sub_5 = torch.ops.aten.sub.Tensor(div_1, 0.5);  div_1 = None
        mul_5 = torch.ops.aten.mul.Tensor(sub_5, 2);  sub_5 = None
        view = torch.ops.aten.view.default(mul_4, [-1, 1]);  mul_4 = None
        expand = torch.ops.aten.expand.default(view, [64, 128]);  view = None
        view_1 = torch.ops.aten.view.default(mul_5, [1, -1]);  mul_5 = None
        expand_1 = torch.ops.aten.expand.default(view_1, [64, 128]);  view_1 = None
        unsqueeze = torch.ops.aten.unsqueeze.default(expand, 2);  expand = None
        unsqueeze_1 = torch.ops.aten.unsqueeze.default(expand_1, 2);  expand_1 = None
        cat = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 2);  unsqueeze = unsqueeze_1 = None
        permute = torch.ops.aten.permute.default(cat, [1, 0, 2]);  cat = None
        unsqueeze_2 = torch.ops.aten.unsqueeze.default(permute, 0);  permute = None
        expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [3, -1, -1, -1]);  unsqueeze_2 = None
        view_2 = torch.ops.aten.view.default(arg0_1, [3, 1, 3, 3]);  arg0_1 = None
        clone = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format);  expand_2 = None
        view_3 = torch.ops.aten.view.default(clone, [384, 64, 2]);  clone = None
        view_4 = torch.ops.aten.view.default(view_2, [-1, 3, 3]);  view_2 = None
        unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_4, 1);  view_4 = None
        expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [3, 128, 3, 3]);  unsqueeze_3 = None
        clone_1 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format);  expand_3 = None
        view_5 = torch.ops.aten.view.default(clone_1, [384, 3, 3]);  clone_1 = None
        constant_pad_nd = torch.ops.aten.constant_pad_nd.default(view_3, [0, 1], 1.0);  view_3 = None
        permute_1 = torch.ops.aten.permute.default(view_5, [0, 2, 1]);  view_5 = None
        bmm = torch.ops.aten.bmm.default(constant_pad_nd, permute_1);  constant_pad_nd = permute_1 = None
        squeeze = torch.ops.aten.squeeze.dim(bmm, -1);  bmm = None
        slice_1 = torch.ops.aten.slice.Tensor(squeeze, 2, -1, 9223372036854775807)
        abs_1 = torch.ops.aten.abs.default(slice_1)
        gt = torch.ops.aten.gt.Scalar(abs_1, 1e-08);  abs_1 = None
        add_2 = torch.ops.aten.add.Tensor(slice_1, 1e-08);  slice_1 = None
        reciprocal = torch.ops.aten.reciprocal.default(add_2);  add_2 = None
        mul_6 = torch.ops.aten.mul.Tensor(reciprocal, 1.0);  reciprocal = None
        full_default = torch.ops.aten.full.default([384, 64, 1], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_2 = torch.ops.aten.where.self(gt, mul_6, full_default);  gt = mul_6 = full_default = None
        slice_2 = torch.ops.aten.slice.Tensor(squeeze, 2, 0, -1);  squeeze = None
        mul_7 = torch.ops.aten.mul.Tensor(where_2, slice_2);  where_2 = slice_2 = None
        view_6 = torch.ops.aten.view.default(mul_7, [3, 128, 64, 2]);  mul_7 = None
        iota_2 = torch.ops.prims.iota.default(3, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        view_8 = torch.ops.aten.view.default(iota_2, [3, 1, 1, 1]);  iota_2 = None
        iota_3 = torch.ops.prims.iota.default(1, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        full_default_1 = torch.ops.aten.full.default([1, 1, 1, 1], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        select = torch.ops.aten.select.int(view_6, 3, 0)
        select_1 = torch.ops.aten.select.int(view_6, 3, 1);  view_6 = None
        mul_8 = torch.ops.aten.mul.Tensor(select, 31.5);  select = None
        add_3 = torch.ops.aten.add.Tensor(mul_8, 31.5);  mul_8 = None
        mul_9 = torch.ops.aten.mul.Tensor(select_1, 63.5);  select_1 = None
        add_4 = torch.ops.aten.add.Tensor(mul_9, 63.5);  mul_9 = None
        floor = torch.ops.aten.floor.default(add_3)
        floor_1 = torch.ops.aten.floor.default(add_4)
        add_5 = torch.ops.aten.add.Tensor(floor, 1)
        add_6 = torch.ops.aten.add.Tensor(floor_1, 1)
        sub_6 = torch.ops.aten.sub.Tensor(add_5, add_3)
        sub_7 = torch.ops.aten.sub.Tensor(add_6, add_4)
        mul_10 = torch.ops.aten.mul.Tensor(sub_6, sub_7);  sub_6 = sub_7 = None
        sub_8 = torch.ops.aten.sub.Tensor(add_3, floor)
        sub_9 = torch.ops.aten.sub.Tensor(add_6, add_4)
        mul_11 = torch.ops.aten.mul.Tensor(sub_8, sub_9);  sub_8 = sub_9 = None
        sub_10 = torch.ops.aten.sub.Tensor(add_5, add_3)
        sub_11 = torch.ops.aten.sub.Tensor(add_4, floor_1)
        mul_12 = torch.ops.aten.mul.Tensor(sub_10, sub_11);  sub_10 = sub_11 = None
        sub_12 = torch.ops.aten.sub.Tensor(add_3, floor);  add_3 = None
        sub_13 = torch.ops.aten.sub.Tensor(add_4, floor_1);  add_4 = None
        mul_13 = torch.ops.aten.mul.Tensor(sub_12, sub_13);  sub_12 = sub_13 = None
        ge = torch.ops.aten.ge.Scalar(floor, 0)
        lt_2 = torch.ops.aten.lt.Scalar(floor, 64)
        ge_1 = torch.ops.aten.ge.Scalar(floor_1, 0)
        lt_3 = torch.ops.aten.lt.Scalar(floor_1, 128)
        logical_and = torch.ops.aten.logical_and.default(ge_1, lt_3);  ge_1 = lt_3 = None
        logical_and_1 = torch.ops.aten.logical_and.default(lt_2, logical_and);  lt_2 = logical_and = None
        logical_and_2 = torch.ops.aten.logical_and.default(ge, logical_and_1);  ge = logical_and_1 = None
        convert_element_type_4 = torch.ops.prims.convert_element_type.default(floor, torch.int64)
        convert_element_type_5 = torch.ops.prims.convert_element_type.default(floor_1, torch.int64)
        full_default_2 = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_3 = torch.ops.aten.where.self(logical_and_2, convert_element_type_4, full_default_2);  convert_element_type_4 = full_default_2 = None
        view_10 = torch.ops.aten.view.default(where_3, [3, 1, 128, 64]);  where_3 = None
        full_default_3 = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_4 = torch.ops.aten.where.self(logical_and_2, convert_element_type_5, full_default_3);  convert_element_type_5 = full_default_3 = None
        view_11 = torch.ops.aten.view.default(where_4, [3, 1, 128, 64]);  where_4 = None
        full_default_4 = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_5 = torch.ops.aten.where.self(logical_and_2, mul_10, full_default_4);  logical_and_2 = mul_10 = full_default_4 = None
        view_12 = torch.ops.aten.view.default(where_5, [3, 1, 128, 64]);  where_5 = None
        index = torch.ops.aten.index.Tensor(arg1_1, [view_8, full_default_1, view_11, view_10]);  view_11 = view_10 = None
        mul_14 = torch.ops.aten.mul.Tensor(index, view_12);  index = view_12 = None
        ge_2 = torch.ops.aten.ge.Scalar(add_5, 0)
        lt_4 = torch.ops.aten.lt.Scalar(add_5, 64)
        ge_3 = torch.ops.aten.ge.Scalar(floor_1, 0)
        lt_5 = torch.ops.aten.lt.Scalar(floor_1, 128)
        logical_and_3 = torch.ops.aten.logical_and.default(ge_3, lt_5);  ge_3 = lt_5 = None
        logical_and_4 = torch.ops.aten.logical_and.default(lt_4, logical_and_3);  lt_4 = logical_and_3 = None
        logical_and_5 = torch.ops.aten.logical_and.default(ge_2, logical_and_4);  ge_2 = logical_and_4 = None
        convert_element_type_6 = torch.ops.prims.convert_element_type.default(add_5, torch.int64)
        convert_element_type_7 = torch.ops.prims.convert_element_type.default(floor_1, torch.int64);  floor_1 = None
        full_default_5 = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_6 = torch.ops.aten.where.self(logical_and_5, convert_element_type_6, full_default_5);  convert_element_type_6 = full_default_5 = None
        view_13 = torch.ops.aten.view.default(where_6, [3, 1, 128, 64]);  where_6 = None
        full_default_6 = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_7 = torch.ops.aten.where.self(logical_and_5, convert_element_type_7, full_default_6);  convert_element_type_7 = full_default_6 = None
        view_14 = torch.ops.aten.view.default(where_7, [3, 1, 128, 64]);  where_7 = None
        full_default_7 = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_8 = torch.ops.aten.where.self(logical_and_5, mul_11, full_default_7);  logical_and_5 = mul_11 = full_default_7 = None
        view_15 = torch.ops.aten.view.default(where_8, [3, 1, 128, 64]);  where_8 = None
        index_1 = torch.ops.aten.index.Tensor(arg1_1, [view_8, full_default_1, view_14, view_13]);  view_14 = view_13 = None
        mul_15 = torch.ops.aten.mul.Tensor(index_1, view_15);  index_1 = view_15 = None
        add_7 = torch.ops.aten.add.Tensor(mul_14, mul_15);  mul_14 = mul_15 = None
        ge_4 = torch.ops.aten.ge.Scalar(floor, 0)
        lt_6 = torch.ops.aten.lt.Scalar(floor, 64)
        ge_5 = torch.ops.aten.ge.Scalar(add_6, 0)
        lt_7 = torch.ops.aten.lt.Scalar(add_6, 128)
        logical_and_6 = torch.ops.aten.logical_and.default(ge_5, lt_7);  ge_5 = lt_7 = None
        logical_and_7 = torch.ops.aten.logical_and.default(lt_6, logical_and_6);  lt_6 = logical_and_6 = None
        logical_and_8 = torch.ops.aten.logical_and.default(ge_4, logical_and_7);  ge_4 = logical_and_7 = None
        convert_element_type_8 = torch.ops.prims.convert_element_type.default(floor, torch.int64);  floor = None
        convert_element_type_9 = torch.ops.prims.convert_element_type.default(add_6, torch.int64)
        full_default_8 = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_9 = torch.ops.aten.where.self(logical_and_8, convert_element_type_8, full_default_8);  convert_element_type_8 = full_default_8 = None
        view_16 = torch.ops.aten.view.default(where_9, [3, 1, 128, 64]);  where_9 = None
        full_default_9 = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_10 = torch.ops.aten.where.self(logical_and_8, convert_element_type_9, full_default_9);  convert_element_type_9 = full_default_9 = None
        view_17 = torch.ops.aten.view.default(where_10, [3, 1, 128, 64]);  where_10 = None
        full_default_10 = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_11 = torch.ops.aten.where.self(logical_and_8, mul_12, full_default_10);  logical_and_8 = mul_12 = full_default_10 = None
        view_18 = torch.ops.aten.view.default(where_11, [3, 1, 128, 64]);  where_11 = None
        index_2 = torch.ops.aten.index.Tensor(arg1_1, [view_8, full_default_1, view_17, view_16]);  view_17 = view_16 = None
        mul_16 = torch.ops.aten.mul.Tensor(index_2, view_18);  index_2 = view_18 = None
        add_8 = torch.ops.aten.add.Tensor(add_7, mul_16);  add_7 = mul_16 = None
        ge_6 = torch.ops.aten.ge.Scalar(add_5, 0)
        lt_8 = torch.ops.aten.lt.Scalar(add_5, 64)
        ge_7 = torch.ops.aten.ge.Scalar(add_6, 0)
        lt_9 = torch.ops.aten.lt.Scalar(add_6, 128)
        logical_and_9 = torch.ops.aten.logical_and.default(ge_7, lt_9);  ge_7 = lt_9 = None
        logical_and_10 = torch.ops.aten.logical_and.default(lt_8, logical_and_9);  lt_8 = logical_and_9 = None
        logical_and_11 = torch.ops.aten.logical_and.default(ge_6, logical_and_10);  ge_6 = logical_and_10 = None
        convert_element_type_10 = torch.ops.prims.convert_element_type.default(add_5, torch.int64);  add_5 = None
        convert_element_type_11 = torch.ops.prims.convert_element_type.default(add_6, torch.int64);  add_6 = None
        full_default_11 = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_12 = torch.ops.aten.where.self(logical_and_11, convert_element_type_10, full_default_11);  convert_element_type_10 = full_default_11 = None
        view_19 = torch.ops.aten.view.default(where_12, [3, 1, 128, 64]);  where_12 = None
        full_default_12 = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_13 = torch.ops.aten.where.self(logical_and_11, convert_element_type_11, full_default_12);  convert_element_type_11 = full_default_12 = None
        view_20 = torch.ops.aten.view.default(where_13, [3, 1, 128, 64]);  where_13 = None
        full_default_13 = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        where_14 = torch.ops.aten.where.self(logical_and_11, mul_13, full_default_13);  logical_and_11 = mul_13 = full_default_13 = None
        view_21 = torch.ops.aten.view.default(where_14, [3, 1, 128, 64]);  where_14 = None
        index_3 = torch.ops.aten.index.Tensor(arg1_1, [view_8, full_default_1, view_20, view_19]);  arg1_1 = view_8 = full_default_1 = view_20 = view_19 = None
        mul_17 = torch.ops.aten.mul.Tensor(index_3, view_21);  index_3 = view_21 = None
        add_9 = torch.ops.aten.add.Tensor(add_8, mul_17);  add_8 = mul_17 = None
        return (add_9,)
        
def load_args(reader):
    buf0 = reader.storage(None, 108)
    reader.tensor(buf0, (3, 3, 3), is_leaf=True)  # arg0_1
    buf1 = reader.storage(None, 98304)
    reader.tensor(buf1, (3, 1, 128, 64), is_leaf=True)  # arg1_1
load_args._version = 0
mod = Repro()
if __name__ == '__main__':
    from torch._dynamo.repro.after_aot import run_repro
    with torch.no_grad():
        run_repro(mod, load_args, accuracy=False, command='minify', save_dir='/tmp/kornia/torch_compile_debug/run_2024_05_18_10_57_22_980811-pid_20828/minifier/checkpoints', tracing_mode='real', check_str=None)
    

Versions

PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060 Ti
Nvidia driver version: 535.171.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             12
On-line CPU(s) list:                0-11
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen 5 5600X 6-Core Processor
CPU family:                         25
Model:                              33
Thread(s) per core:                 2
Core(s) per socket:                 6
Socket(s):                          1
Stepping:                           0
Frequency boost:                    enabled
CPU max MHz:                        4650,2920
CPU min MHz:                        2200,0000
BogoMIPS:                           7399.51
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                     AMD-V
L1d cache:                          192 KiB (6 instances)
L1i cache:                          192 KiB (6 instances)
L2 cache:                           3 MiB (6 instances)
L3 cache:                           32 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] onnx==1.16.0
[pip3] torch==2.3.0
[pip3] torchaudio==2.3.0
[pip3] torchvision==0.18.0
[pip3] triton==2.3.0
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mkl-service               2.4.0           py311h5eee18b_1  
[conda] mkl_fft                   1.3.8           py311h5eee18b_0  
[conda] mkl_random                1.2.4           py311hdb19cb5_0  
[conda] numpy                     1.26.4          py311h08b1b3b_0  
[conda] numpy-base                1.26.4          py311hf175353_0  
[conda] pytorch                   2.3.0           py3.11_cuda12.1_cudnn8.9.2_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.3.0               py311_cu121    pytorch
[conda] torchtriton               2.3.0                     py311    pytorch
[conda] torchvision               0.18.0              py311_cu121    pytorch

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

johnnv1 added a commit to johnnv1/kornia that referenced this issue May 18, 2024
@ezyang ezyang added the oncall: cpu inductor CPU Inductor issues for Intel team to triage label May 19, 2024
@jgong5
Copy link
Collaborator

jgong5 commented May 20, 2024

I cannot repro the problem with the latest pytorch mainline: d9c3485. Can you check if the pytorch trunk has your problem fixed? Thanks. @johnnv1

@johnnv1
Copy link
Author

johnnv1 commented May 20, 2024

I cannot repro the problem with the latest pytorch mainline: d9c3485. Can you check if the pytorch trunk has your problem fixed? Thanks. @johnnv1

Yeah, it's fine in the nightly (torch-2.4.0.dev20240515+cu121) already, sorry for raising the issue :)

@johnnv1 johnnv1 closed this as completed May 20, 2024
edgarriba pushed a commit to kornia/kornia that referenced this issue May 21, 2024
* chore (CI): ensure support to pytorch 2.3.0

* chore: skip specific dynamo tests for torch 2.3.0

- Report in pytorch/pytorch#126617

* chore: skip specific dynamo tests for torch 2.3.0

- Report in pytorch/pytorch#126619
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2
Projects
None yet
Development

No branches or pull requests

3 participants