Skip to content

Commit

Permalink
Generalize aten.view pattern in scalarize shapes (#3856)
Browse files Browse the repository at this point in the history
Extends the existing pattern to allow finding matching dims from the
back as well as the front.
  • Loading branch information
zjgarvey authored Nov 7, 2024
1 parent 7058f45 commit 8519ecc
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 73 deletions.
161 changes: 88 additions & 73 deletions lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1099,97 +1099,112 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
int64_t outRank = resultTy.getSizes().size();

SmallVector<int64_t> sizes(selfTy.getSizes());
int64_t endMatchingDim = -1;
// input sizes vs. provided view sizes comparison loop
for (int64_t i = 0; i < std::min(outRank, inRank); i++) {
int64_t leftMatchEnd = 0;
// compare input sizes with provided dims from left
for (; leftMatchEnd < std::min(outRank, inRank); leftMatchEnd++) {
int64_t providedSize;
bool providedStatic =
matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize));
// if sizes[i] is static, it must match a constant in viewSizes[i]
if (sizes[i] != Torch::kUnknownSize) {
if (!providedStatic)
return rewriter.notifyMatchFailure(
op, "unsupported: found static input dim, but unable to match "
"provided view size on a constant. See position : " +
std::to_string(i));
if (providedSize != sizes[i]) {
endMatchingDim = i;
bool providedStatic = matchPattern(viewSizes[leftMatchEnd],
m_TorchConstantInt(&providedSize));
// static dim case
if (sizes[leftMatchEnd] != Torch::kUnknownSize) {
// if can't infer equality of dims, set end index and break
if (!providedStatic || providedSize != sizes[leftMatchEnd])
break;
}
continue;
}
// the remaining assumes sizes[i] is dynamic
// if provided dim is static, we can't verify it is a flatten/unflatten
// unless -1
if (i == outRank - 1 && providedStatic && providedSize == -1) {
endMatchingDim = i;
// the remaining assumes sizes[leftMatchEnd] is dynamic
// if provided dim is static, we can't match.
if (providedStatic)
break;
auto sizeIntOp = viewSizes[leftMatchEnd].getDefiningOp<AtenSizeIntOp>();
// if we don't have a size int op on self, break
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
break;
int64_t dim;
// if the dim of the size int op doesn't match, fail
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
dim != leftMatchEnd)
break;
}

int64_t rightMatchEnd = 0;
// compare input sizes with provided dims from right
for (; rightMatchEnd < std::min(outRank, inRank) - leftMatchEnd;
rightMatchEnd++) {
int64_t providedSize;
bool providedStatic = matchPattern(viewSizes[outRank - 1 - rightMatchEnd],
m_TorchConstantInt(&providedSize));
// static dim case
if (sizes[inRank - 1 - rightMatchEnd] != Torch::kUnknownSize) {
// if can't infer equality of dims, set end index and break
if (!providedStatic ||
providedSize != sizes[inRank - 1 - rightMatchEnd])
break;
continue;
}
// the remaining assumes sizes[inRank - 1 - rightMatchEnd] is dynamic
// if provided dim is static, we can't match.
if (providedStatic)
return rewriter.notifyMatchFailure(
op, "unexpected static view dim corresponding to dynamic input dim "
"at position : " +
std::to_string(i));
auto sizeIntOp = viewSizes[i].getDefiningOp<AtenSizeIntOp>();
// if we don't have a size int op on self, fail
break;
auto sizeIntOp =
viewSizes[outRank - 1 - rightMatchEnd].getDefiningOp<AtenSizeIntOp>();
// if we don't have a size int op on self, break
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
return rewriter.notifyMatchFailure(
op, "expected dynamic view dim to come from a corresponding "
"size.int op. See position : " +
std::to_string(i));
break;
int64_t dim;
// if the dim of the size int op doesn't match, fail
// if the dim of the size int op doesn't match, break
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
dim != i)
return rewriter.notifyMatchFailure(
op,
"size int op dim cannot be matched to current dim at position : " +
std::to_string(i));
// passing the previous checks means viewSizes[i] = aten.size.int(self,
// i), so continue
dim != inRank - 1 - rightMatchEnd)
break;
}
// if all dims match and the ranks are equal, fold
if (endMatchingDim == -1 && inRank == outRank) {
rewriter.replaceOp(op, op.getSelf());
// the unmatched input dims start at leftMatchEnd, and end before inRank -
// rightMatchEnd
int64_t inputUnmatched = (inRank - rightMatchEnd) - leftMatchEnd;
int64_t outputUnmatched = (outRank - rightMatchEnd) - leftMatchEnd;
// if too many dims are unmatched in input/output, cannot canonicalize.
if (inputUnmatched > 1 && outputUnmatched > 1)
return rewriter.notifyMatchFailure(
op,
"View op is not simple enough to canonicalize.\n# Unmatched Input "
"dims = " +
std::to_string(inputUnmatched) +
"\n# Unmatched Output Dims = " + std::to_string(outputUnmatched) +
"\nStarting unmatched index = " + std::to_string(leftMatchEnd));

// if all dims match, return self.
if (inputUnmatched == outputUnmatched &&
(inputUnmatched == 1 || inputUnmatched == 0)) {
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(
op, op.getType(), op.getSelf());
return success();
}
if (endMatchingDim > -1 && inRank > outRank) {
// only support flattening last dim
if (endMatchingDim != outRank - 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: output has more than back dim mismatching");
// flatten
Value start =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
Value end =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), inRank - 1);
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
op, resultTy, op.getSelf(), start, end);
// if input has 1 unmatched dim, and output has multiple, unflatten
if (inputUnmatched == 1 && outputUnmatched > 1) {
Value dimVal =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
ArrayRef<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
viewSizes.end() - rightMatchEnd);
Value unflattenList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), op.getSize().getType(), unflattenSizes);
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
op, op.getType(), op.getSelf(), dimVal, unflattenList);
return success();
}
if (endMatchingDim > -1 && inRank < outRank) {
// only support unflattening last dim
if (endMatchingDim != inRank - 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: input has more than back dim mismatching");
// unflatten
Value dim =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
Value primList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), op.getSize().getType(),
ArrayRef<Value>(viewSizes.begin() + endMatchingDim, viewSizes.end()));
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
op, resultTy, op.getSelf(), dim, primList);
// if multiple unmatched input dims map to one output dim, flatten
if (inputUnmatched > 1 && outputUnmatched == 1) {
Value startDim =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
// note: flatten end is inclusive for some reason.
int64_t endInt = inRank - rightMatchEnd - 1;
Value endDim = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endInt);
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
op, op.getType(), op.getSelf(), startDim, endDim);
return success();
}
// examples that might reach this:
// input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants)
// input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes)
// input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes)
// the remaining cases involve maximal matching dims, but mismatched ranks.
// This could only occur if squeezing or unsqueezing.
return rewriter.notifyMatchFailure(
op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) +
", inRank=" + std::to_string(inRank) +
", outRank=" + std::to_string(outRank));
op, "unhandled view op canonicalization to squeeze/unsqueeze.");
}
};
} // namespace
Expand Down
19 changes: 19 additions & 0 deletions test/Dialect/Torch/scalarize-shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,25 @@ func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !tor
return %3 : !torch.vtensor<[?,?,?],f32>
}

// -----

// CHECK-LABEL: @view_as_flatten_mid
func.func @view_as_flatten_mid(%arg0: !torch.vtensor<[?,?,?,?,2,4],f32>) -> !torch.vtensor<[?,?,?,4],f32> {
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
// CHECK-DAG: %[[FOUR:.*]] = torch.constant.int 4
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[FOUR]] : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,4],f32>
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?,4],f32>
%int-1 = torch.constant.int -1
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int4 = torch.constant.int 4
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %0, %1, %int-1, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,4],f32>
return %3 : !torch.vtensor<[?,?,?,4],f32>
}


// -----

Expand Down

0 comments on commit 8519ecc

Please sign in to comment.