Skip to content

Commit

Permalink
[FxImporter] Synchronize the collection of symbolic torch ops (#3236)
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy authored Apr 29, 2024
1 parent 5684dc0 commit 9f64748
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 14 deletions.
16 changes: 4 additions & 12 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,26 +236,16 @@
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP

if _IS_TORCH_2_1_OR_EARLIER:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size,
torch.ops.aten.sym_stride,
torch.ops.aten.sym_numel,
}

SYMBOLIC_OP_TO_TORCH_OP = {
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
(torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default,
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
}
else:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size.int,
torch.ops.aten.sym_stride.int,
torch.ops.aten.sym_numel.default,
}

SYMBOLIC_TORCH_OPS = {key[0] for key in SYMBOLIC_OP_TO_TORCH_OP}
else:
SYMBOLIC_OP_TO_TORCH_OP = {
torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
Expand All @@ -264,6 +254,8 @@
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
}

SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP}


@dataclass(frozen=True)
class SparsityMeta:
Expand Down
4 changes: 2 additions & 2 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from typing import Optional, Union, Dict, Tuple, Any
from typing import Optional, Union, Dict, Tuple, Any, Callable

import warnings

Expand All @@ -25,7 +25,7 @@ def export_and_import(
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False,
hooks: Optional[FxImporterHooks] = None,
decomposition_table: Optional[list] = None,
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
func_name: str = "main",
enable_graph_printing: bool = False,
**kwargs,
Expand Down

0 comments on commit 9f64748

Please sign in to comment.