Skip to content

Commit

Permalink
[TorchToTosa] Refactoring to separate construction of legal/illegal o…
Browse files Browse the repository at this point in the history
…ps and conversion patterns. (#3759)

This PR refactors TorchToTosa to separate the construction of
legal/illegal ops and conversion patterns in their own functions:

1. populateTorchToTosaConversionLegalOps -- populate any ops that are
legal after the conversion pass
2. populateTorchToTosaConversionIllegalOps -- populate any ops that are
illegal after the conversion pass
3. populateTorchToTosaConversionPatterns -- populate the ops conversion
patterns

Currently the (il)legality of the ops that are (il)legal after the
conversion pass runs is embedded within the conversion pattern. Our end
goal is to write a new pass pipeline that converts `torch` ops to a mix
of `tosa`, `linalg`, `tensor`, etc dialect ops. The reason we want to
also emit `tosa` ops (instead of using the existing `TorchToLinalg` to
emit `linalg`+`tensor`+...) is because some operations like `conv2d`
encodes the padding behavior in the op in `tosa` unlike the `linalg`
version -- this helps in lowering the `tosa.conv2d` to a custom
implementation that does padding on the fly.

To implement this new pipeline we need to be able to separate out the
illegal `tosa` ops from the conversion pattern itself. Otherwise we will
hit an issue for ops like `AtenMaxDimOp` which can be lowered to both
`tosa` and `linalg + others` dialects. Not all `AtenMaxDimOp` can be
lowered successfully to `tosa` as the implementation uses `tosa.reshape`
which cannot handle multiple dynamic dimensions but the `TorchToLinalg`
lowering can handle it. In the current behavior the pipeline will stop
as soon as the existing `TorchToTosa` conversion runs as `AtenMaxDimOp`
will be marked as an illegal op.

Essentially we want to be able to control what the legality of the ops
should be independent of the conversion pattern. This is also inline
with the conversion patterns in the llvm-mlir repo such as
https://github.com/llvm/llvm-project/blob/000e790be35b77a01872851646d54432a203542c/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp#L718


"THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY."
  • Loading branch information
sahas3 authored Dec 12, 2024
1 parent 5a5cc6b commit f03a576
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 222 deletions.
15 changes: 14 additions & 1 deletion include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,25 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include <memory>

namespace mlir {
namespace torch {

/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
/// dialect.
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);

/// Collect a set of patterns to convert Torch operations to Tosa dialect +
/// return the set of illegalOps
std::set<StringRef>
populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
RewritePatternSet &patterns);

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
}
} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
Loading

0 comments on commit f03a576

Please sign in to comment.