Skip to content

Commit

Permalink
[MLIR][TORCH] Add support for 1-d group convolution (#3904)
Browse files Browse the repository at this point in the history
This commit adds the support for 1-d group convolution by transforming
it into a 2-d group convolution which is already supported.

This commit also refactors the unsqueeze and squeeze tensor utility.

---------

Signed-off-by: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored Dec 13, 2024
1 parent 2c72a82 commit 8e0eafd
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 93 deletions.
9 changes: 9 additions & 0 deletions include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Value defaultValue, Value dimSize);

// Helper function to unsqueeze the input tensor at given dim.
// Returns the unsqueezed tensor or failure.
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value input, int64_t dim);

// Helper function to squeeze the input tensor at given dim.
// Returns the squeezed tensor or failure.
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
Value input, int64_t dim);
} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
98 changes: 13 additions & 85 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1642,69 +1642,18 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Value input = adaptor.getSelf();
auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();

if (inputRank == 0) {
return rewriter.notifyMatchFailure(
op, "zero input rank should have been handled by the folder");
}

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

// assert dynamic squeeze dim size == 1
if (inputType.isDynamicDim(dim)) {
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
Value cmp = rewriter.create<arith::CmpIOp>(
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
rewriter.create<cf::AssertOp>(
op.getLoc(), cmp,
rewriter.getStringAttr(
"Expected dynamic squeeze dim size to be statically 1"));
}

const TypeConverter *typeConverter = getTypeConverter();
auto resultType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
int64_t resultRank = resultType.getRank();

// If the dim(th) dimension of operand tensor type is not statically unit,
// `aten.squeeze` will behave as an identity operation.
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success();
auto squeezeTensorInfo =
squeezeTensor(rewriter, op, adaptor.getSelf(), dim);
if (failed(squeezeTensorInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}

SmallVector<ReassociationIndices> reassociationMap(resultRank);
bool alreadyCrossedSqueezedDim = false;
for (int i = 0; i != resultRank; i++) {
if (alreadyCrossedSqueezedDim) {
reassociationMap[i].push_back(i + 1);
} else {
reassociationMap[i].push_back(i);
if (dim != 0 && i != dim - 1)
continue;

alreadyCrossedSqueezedDim = true;
if (dim == 0)
reassociationMap[0].push_back(1);
if (i == dim - 1)
reassociationMap[i].push_back(dim);
}
}
// Note: In case the operand tensor type is of unit rank and is statically
// shaped with unit dimension, the `reassociationMap` will be empty and the
// input will be collapsed to a 0-D tensor.
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(op, resultType, input,
reassociationMap);
rewriter.replaceOp(op, squeezeTensorInfo.value());
return success();
}
};
Expand All @@ -1722,36 +1671,15 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
auto inputRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
dim = toPositiveDim(dim, inputRank + 1);
if (!isValidDim(dim, inputRank + 1))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

SmallVector<ReassociationIndices> reassociationMap(inputRank);
// From the perspective of the reassociation map, the situation of
// unsqueezing before or after the last dimension is symmetrical.
// Normalize it to the "before" case.
// The 0 case is special here, since there is no last dimension to insert
// before -- we simply rely on the loop below iterating 0 times.
if (dim == inputRank && inputRank != 0)
dim = inputRank - 1;
bool alreadyCrossedExpandedDim = false;
for (int i = 0; i != inputRank; i++) {
if (alreadyCrossedExpandedDim) {
reassociationMap[i].push_back(i + 1);
} else {
reassociationMap[i].push_back(i);
if (i == dim) {
reassociationMap[i].push_back(i + 1);
alreadyCrossedExpandedDim = true;
}
}
auto unsqueezeTensorInfo =
unsqueezeTensor(rewriter, op, adaptor.getSelf(), dim);
if (failed(unsqueezeTensorInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultType, adaptor.getSelf(), reassociationMap);

rewriter.replaceOp(op, unsqueezeTensorInfo.value());
return success();
}
};
Expand Down
72 changes: 64 additions & 8 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,48 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");

// Checks for valid group size
int64_t numGroups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
return rewriter.notifyMatchFailure(op,
"only constant group size supported.");
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());

// Adding support for 1d group convolution by converting the 1d-conv to
// 2d-conv.
// TODO: Replace this logic with the appropriate linalg op for 1-d group
// convolution once that support is added.
bool is1DGroupConv = (numSpatialDims == 1 && numGroups != 1);
if (is1DGroupConv) {
// Unsqueezing the last dim of input and weight. Also extending the
// dilation, stride, padding, and output padding lists.
auto unsqueezeInputInfo =
unsqueezeTensor(rewriter, op, input, /*dim=*/-1);
if (failed(unsqueezeInputInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
input = unsqueezeInputInfo.value();

auto unsqueezeWeightInfo =
unsqueezeTensor(rewriter, op, weight, /*dim=*/-1);
if (failed(unsqueezeWeightInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
weight = unsqueezeWeightInfo.value();

Value cstZero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(0));
paddingIntValues.push_back(cstZero);
outputPaddingIntValues.push_back(cstZero);
strideInts.push_back(1);
dilationInts.push_back(1);

inRank++;
numSpatialDims++;
}

Value inBatch = getDimOp(rewriter, loc, input, 0);
Value inChannels = getDimOp(rewriter, loc, input, 1);
SmallVector<Value> inDims;
Expand All @@ -861,13 +903,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
for (size_t i = 2; i < inRank; i++)
weightDims.push_back(getDimOp(rewriter, loc, weight, i));

// Checks for valid group size
int64_t numGroups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
return rewriter.notifyMatchFailure(op,
"only constant group size supported.");
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());

auto validate = [&](Value toValidate, std::string err) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Expand Down Expand Up @@ -1280,13 +1315,24 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}

if (is1DGroupConv) {
// Squeezing the last dim of the result of conv.
auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1);
if (failed(squeezeOutputInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate squeeze tensor");
}
conv = squeezeOutputInfo.value();
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");
op, "unimplemented: only 1D and 2D grouped convolution supported");

// Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) {
Expand Down Expand Up @@ -1371,6 +1417,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}

if (is1DGroupConv) {
// Squeezing the last dim of the result of conv.
auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1);
if (failed(squeezeOutputInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate squeeze tensor");
}
conv = squeezeOutputInfo.value();
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down
113 changes: 113 additions & 0 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,119 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
return castIntToIndex(rewriter, loc, boundedByDimSize);
}

// Helper function to unsqueeze the input tensor at given dim.
// Returns the unsqueezed tensor or failure.
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value input, int64_t dim) {
auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();
ArrayRef<int64_t> inputShape = inputType.getShape();

// `input` has a reduced rank. Hence add 1.
int64_t unsqueezedRank = inputShape.size() + 1;
dim = toPositiveDim(dim, unsqueezedRank);
if (!isValidDim(dim, unsqueezedRank)) {
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
}

SmallVector<int64_t> unsqueezedShape{inputShape};
unsqueezedShape.insert(unsqueezedShape.begin() + dim, 1);
Type unsqueezedType =
RankedTensorType::get(unsqueezedShape, inputType.getElementType());

SmallVector<ReassociationIndices> reassociationMap(inputRank);
// From the perspective of the reassociation map, the situation of
// unsqueezing before or after the last dimension is symmetrical.
// Normalize it to the "before" case.
// The 0 case is special here, since there is no last dimension to insert
// before -- we simply rely on the loop below iterating 0 times.
if (dim == inputRank && inputRank != 0)
dim = inputRank - 1;
bool alreadyCrossedExpandedDim = false;
for (int i = 0; i != inputRank; i++) {
if (alreadyCrossedExpandedDim) {
reassociationMap[i].push_back(i + 1);
} else {
reassociationMap[i].push_back(i);
if (i == dim) {
reassociationMap[i].push_back(i + 1);
alreadyCrossedExpandedDim = true;
}
}
}
Value unsqueezed = rewriter.create<tensor::ExpandShapeOp>(
op->getLoc(), unsqueezedType, input, reassociationMap);
return unsqueezed;
}

// Helper function to squeeze the input tensor at given dim.
// Returns the squeezed tensor or failure.
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
Value input, int64_t dim) {
Location loc = op->getLoc();
auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();

// No scope for squeezing the input.
if (inputRank == 0)
return input;

dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

// assert dynamic squeeze dim size == 1
if (inputType.isDynamicDim(dim)) {
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
Value dimVal = rewriter.create<tensor::DimOp>(loc, input, cstDim);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
dimVal, cstOne);
rewriter.create<cf::AssertOp>(
loc, cmp,
rewriter.getStringAttr(
"Expected dynamic squeeze dim size to be statically 1"));
}

ArrayRef<int64_t> inputShape = inputType.getShape();
SmallVector<int64_t> squeezedShape;
squeezedShape.append(inputShape.begin(), inputShape.begin() + dim);
squeezedShape.append(inputShape.begin() + dim + 1, inputShape.end());
int64_t squeezedRank = inputRank - 1;
Type squeezedType =
RankedTensorType::get(squeezedShape, inputType.getElementType());

// If the dim(th) dimension of operand tensor type is not statically unit,
// squeeze will behave as an identity operation.
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
return input;
}

SmallVector<ReassociationIndices> reassociationMap(squeezedRank);
bool alreadyCrossedSqueezedDim = false;
for (int i = 0; i != squeezedRank; i++) {
if (alreadyCrossedSqueezedDim) {
reassociationMap[i].push_back(i + 1);
} else {
reassociationMap[i].push_back(i);
if (dim != 0 && i != dim - 1)
continue;

alreadyCrossedSqueezedDim = true;
if (dim == 0)
reassociationMap[0].push_back(1);
if (i == dim - 1)
reassociationMap[i].push_back(dim);
}
}
// Note: In case the operand tensor type is of unit rank and is statically
// shaped with unit dimension, the `reassociationMap` will be empty and the
// input will be collapsed to a 0-D tensor.
Value squeezed = rewriter.create<tensor::CollapseShapeOp>(
op->getLoc(), squeezedType, input, reassociationMap);
return squeezed;
}

} // namespace Torch
} // namespace torch
} // namespace mlir
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2731,6 +2731,7 @@
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"Conv1dGroupModule_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
Expand Down Expand Up @@ -2886,6 +2887,7 @@
"Conv1dModule_basic",
"Conv1dWithSamePaddingModule_basic",
"Conv1dWithValidPaddingModule_basic",
"Conv1dGroupModule_basic",
"Conv2dBiasNoPaddingModule_basic",
"Conv2dModule_basic",
"Conv2dNoPaddingModule_basic",
Expand Down Expand Up @@ -3593,6 +3595,7 @@
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv1dWithSamePaddingModule_basic",
"Conv1dWithValidPaddingModule_basic",
"Conv1dGroupModule_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
Expand Down Expand Up @@ -4186,6 +4189,7 @@
"Conv1dWithSamePaddingModule_basic",
"Conv1dWithValidPaddingModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv1dGroupModule_basic",
"Conv2dBiasNoPaddingModule_basic",
"Conv2dModule_basic",
"Conv2dNoPaddingModule_basic",
Expand Down
Loading

0 comments on commit 8e0eafd

Please sign in to comment.