Skip to content

Commit

Permalink
Add attributes support for onnx.nms (#3920)
Browse files Browse the repository at this point in the history
- Set default attribute values
- Support `max_output_boxes_per_class` attribute
- e2e test `test_nonmaxsuppression_limit_output_size` passed
  • Loading branch information
jinchen62 authored Dec 19, 2024
1 parent 71cb942 commit e68560d
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 85 deletions.
127 changes: 75 additions & 52 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3688,6 +3688,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
patterns.onOp(
"NonMaxSuppression", 10,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
SmallVector<Value> operands;
int64_t centerPointBox;
Expand All @@ -3702,96 +3703,132 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, "unimplemented: expected center_point_box "
"attribute value to be 0");

// TODO: Add support for optional arguments to be absent.
if (operands.size() < 4)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected at least 4 arguments");

// TODO: Support multiple batches and classes
// Squeeze the boxes and scores tensor.
// In Onnx, the shape of boxes is [BxNx4] while the
// torchvision expects it to be of shape [Nx4]. Similarly, for
// the scores tensor shape in Onnx is [BxCxN] while the
// torchvision expects it to be of shape [N].
Value boxes = operands[0], scores = operands[1];
FailureOr<Value> squeezedBoxes = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, boxes);
FailureOr<Value> squeezedBoxes =
Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes);
if (failed(squeezedBoxes))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze boxes tensor");

FailureOr<Value> squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, scores);
FailureOr<Value> squeezedScores =
Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores);
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value());
squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0,
squeezedScores.value());
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");

boxes = squeezedBoxes.value();
scores = squeezedScores.value();

// TODO: Support score_threshold input
// Filter out the boxes if the score < score_threshold
if (operands.size() == 5) {
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
operands[4]);
loc, rewriter.getType<Torch::FloatType>(), operands[4]);
Value minScores = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
loc,
Torch::ValueTensorType::get(binder.op->getContext(),
SmallVector<int64_t>{},
rewriter.getF32Type()),
scores);
minScores = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);
loc, rewriter.getType<Torch::FloatType>(), minScores);

Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
binder.getLoc(), minScores, scoreThreshold);
loc, minScores, scoreThreshold);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), scoresCond,
loc, scoresCond,
rewriter.getStringAttr(
"unimplemented: score_threshold should be <= min(scores)"));
}

// TODO: Support default iou_threshold
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
// Get max_output_boxes_per_class and iou_threshold
Value cst0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value maxOutputBoxesPerClass = cst0;
Value iouThreshold = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0));
if (operands.size() > 3 &&
!isa<Torch::NoneType>(operands[3].getType())) {
iouThreshold = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::FloatType>(), operands[3]);
}
if (operands.size() > 2 &&
!isa<Torch::NoneType>(operands[2].getType())) {
maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), operands[2]);
}

auto nmsTy = Torch::ValueTensorType::get(
binder.op->getContext(), SmallVector<int64_t>{-1},
rewriter.getIntegerType(64, /*signed=*/true));
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
loc, nmsTy, boxes, scores, iouThreshold);

// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
loc, numOutputBoxes, maxOutputBoxesPerClass);

auto nmsResultTy = Torch::ValueTensorType::get(
binder.op->getContext(),
SmallVector<int64_t>{resultType.getSizes()[0]},
rewriter.getIntegerType(64, /*signed=*/true));
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
binder.getLoc(), nmsTy, boxes, scores, iouThreshold);
auto ifSlice = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({nmsResultTy}), boxesCond);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getThenRegion(),
ifSlice.getThenRegion().begin());

Value curResult = rewriter.create<Torch::AtenSliceTensorOp>(
loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0,
/*end=*/maxOutputBoxesPerClass, /*step=*/cst1);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getElseRegion(),
ifSlice.getElseRegion().begin());

Value curResult = rewriter.create<Torch::TensorStaticInfoCastOp>(
loc, nmsResultTy, result);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
}
result = ifSlice.getResult(0);

// The result generated by torchvision.nms op is of shape [n], while the
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
// and make it of shape [n, 1] and then concatenate it with a zero
// tensor of shape [n, 2] to make it of shape [n, 3].
Value dim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
FailureOr<Value> unsqueezedResult =
Torch::unsqueezeTensor(rewriter, binder.op, result, dim);
Torch::unsqueezeTensor(rewriter, binder.op, result, cst1);
if (failed(unsqueezedResult))
return rewriter.notifyMatchFailure(
binder.op, "failed to unsqueeze result tensor");
result = unsqueezedResult.value();

Value numOutputBoxes = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), result,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
SmallVector<Value> zerosShapeValues{numOutputBoxes};
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
loc, rewriter.getI64IntegerAttr(2)));
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
zerosShapeValues);

std::optional<ArrayRef<int64_t>> resultShape =
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
if (!resultShape.has_value())
Expand All @@ -3800,33 +3837,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
auto zerosTy = Torch::ValueTensorType::get(
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value zeros = rewriter.create<Torch::AtenZerosOp>(
binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone,
cstNone);
loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);

Type listElemType =
cast<Torch::BaseTensorType>(resultType)
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), listType, SmallVector<Value>{zeros, result});

// TODO: Support max_output_boxes_per_class input
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), boxesCond,
rewriter.getStringAttr(
"unimplemented: number of output boxes per class should be "
"<= max_output_boxes_per_class"));

loc, listType, SmallVector<Value>{zeros, result});
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
tensorList, dim);
tensorList, cst1);
return success();
});
}
Loading

0 comments on commit e68560d

Please sign in to comment.