diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 13f555c146b4..12d8683bc9d1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3688,6 +3688,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( patterns.onOp( "NonMaxSuppression", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); Torch::ValueTensorType resultType; SmallVector operands; int64_t centerPointBox; @@ -3702,34 +3703,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, "unimplemented: expected center_point_box " "attribute value to be 0"); - // TODO: Add support for optional arguments to be absent. - if (operands.size() < 4) - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: expected at least 4 arguments"); - + // TODO: Support multiple batches and classes // Squeeze the boxes and scores tensor. // In Onnx, the shape of boxes is [BxNx4] while the // torchvision expects it to be of shape [Nx4]. Similarly, for // the scores tensor shape in Onnx is [BxCxN] while the // torchvision expects it to be of shape [N]. Value boxes = operands[0], scores = operands[1]; - FailureOr squeezedBoxes = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, boxes); + FailureOr squeezedBoxes = + Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes); if (failed(squeezedBoxes)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze boxes tensor"); - - FailureOr squeezedScores = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, scores); + FailureOr squeezedScores = + Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores); if (failed(squeezedScores)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze scores tensor"); - squeezedScores = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value()); + squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0, + squeezedScores.value()); if (failed(squeezedScores)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze scores tensor"); - boxes = squeezedBoxes.value(); scores = squeezedScores.value(); @@ -3737,61 +3732,103 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // Filter out the boxes if the score < score_threshold if (operands.size() == 5) { Value scoreThreshold = rewriter.create( - binder.getLoc(), rewriter.getType(), - operands[4]); + loc, rewriter.getType(), operands[4]); Value minScores = rewriter.create( - binder.getLoc(), + loc, Torch::ValueTensorType::get(binder.op->getContext(), SmallVector{}, rewriter.getF32Type()), scores); minScores = rewriter.create( - binder.getLoc(), rewriter.getType(), minScores); + loc, rewriter.getType(), minScores); Value scoresCond = rewriter.create( - binder.getLoc(), minScores, scoreThreshold); + loc, minScores, scoreThreshold); rewriter.create( - binder.getLoc(), scoresCond, + loc, scoresCond, rewriter.getStringAttr( "unimplemented: score_threshold should be <= min(scores)")); } - // TODO: Support default iou_threshold - Value iouThreshold = rewriter.create( - binder.getLoc(), rewriter.getType(), operands[3]); + // Get max_output_boxes_per_class and iou_threshold + Value cst0 = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value cst1 = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value maxOutputBoxesPerClass = cst0; + Value iouThreshold = rewriter.create( + loc, rewriter.getF64FloatAttr(0.0)); + if (operands.size() > 3 && + !isa(operands[3].getType())) { + iouThreshold = rewriter.create( + loc, rewriter.getType(), operands[3]); + } + if (operands.size() > 2 && + !isa(operands[2].getType())) { + maxOutputBoxesPerClass = rewriter.create( + loc, rewriter.getType(), operands[2]); + } + auto nmsTy = Torch::ValueTensorType::get( + binder.op->getContext(), SmallVector{-1}, + rewriter.getIntegerType(64, /*signed=*/true)); + Value result = rewriter.create( + loc, nmsTy, boxes, scores, iouThreshold); + + // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class + Value numOutputBoxes = + rewriter.create(loc, result, cst0); + Value boxesCond = rewriter.create( + loc, numOutputBoxes, maxOutputBoxesPerClass); + + auto nmsResultTy = Torch::ValueTensorType::get( binder.op->getContext(), SmallVector{resultType.getSizes()[0]}, rewriter.getIntegerType(64, /*signed=*/true)); - Value result = rewriter.create( - binder.getLoc(), nmsTy, boxes, scores, iouThreshold); + auto ifSlice = rewriter.create( + loc, TypeRange({nmsResultTy}), boxesCond); + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifSlice.getThenRegion(), + ifSlice.getThenRegion().begin()); + + Value curResult = rewriter.create( + loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0, + /*end=*/maxOutputBoxesPerClass, /*step=*/cst1); + rewriter.create(loc, curResult); + } + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifSlice.getElseRegion(), + ifSlice.getElseRegion().begin()); + + Value curResult = rewriter.create( + loc, nmsResultTy, result); + rewriter.create(loc, curResult); + } + result = ifSlice.getResult(0); // The result generated by torchvision.nms op is of shape [n], while the // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor // and make it of shape [n, 1] and then concatenate it with a zero // tensor of shape [n, 2] to make it of shape [n, 3]. - Value dim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); FailureOr unsqueezedResult = - Torch::unsqueezeTensor(rewriter, binder.op, result, dim); + Torch::unsqueezeTensor(rewriter, binder.op, result, cst1); if (failed(unsqueezedResult)) return rewriter.notifyMatchFailure( binder.op, "failed to unsqueeze result tensor"); result = unsqueezedResult.value(); - Value numOutputBoxes = rewriter.create( - binder.getLoc(), result, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0))); + numOutputBoxes = + rewriter.create(loc, result, cst0); SmallVector zerosShapeValues{numOutputBoxes}; zerosShapeValues.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2))); + loc, rewriter.getI64IntegerAttr(2))); Value zerosShapeList = rewriter.create( - binder.getLoc(), + loc, rewriter.getType( rewriter.getType()), zerosShapeValues); - std::optional> resultShape = cast(result.getType()).getOptionalSizes(); if (!resultShape.has_value()) @@ -3800,10 +3837,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( llvm::SmallVector zerosShape = {resultShape->front(), 2}; auto zerosTy = Torch::ValueTensorType::get( resultType.getContext(), zerosShape, resultType.getOptionalDtype()); - Value cstNone = rewriter.create(binder.getLoc()); + Value cstNone = rewriter.create(loc); Value zeros = rewriter.create( - binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone, - cstNone); + loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone); Type listElemType = cast(resultType) @@ -3811,22 +3847,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); Value tensorList = rewriter.create( - binder.getLoc(), listType, SmallVector{zeros, result}); - - // TODO: Support max_output_boxes_per_class input - // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class - Value maxOutputBoxesPerClass = rewriter.create( - binder.getLoc(), rewriter.getType(), operands[2]); - Value boxesCond = rewriter.create( - binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass); - rewriter.create( - binder.getLoc(), boxesCond, - rewriter.getStringAttr( - "unimplemented: number of output boxes per class should be " - "<= max_output_boxes_per_class")); - + loc, listType, SmallVector{zeros, result}); rewriter.replaceOpWithNewOp(binder.op, resultType, - tensorList, dim); + tensorList, cst1); return success(); }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 7f1e63d83ccd..30b85e63ab0f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -2057,22 +2057,30 @@ func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4] // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" - // CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[1],si64> - // CHECK: %[[VAL_26:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - // CHECK: %[[VAL_28:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_30:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_32:.*]] = torch.constant.none - // CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list - // CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class" - // CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> - // CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64> + // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[?],si64> + // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> + // CHECK: } else { + // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> + // CHECK: } + // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_37:.*]] = torch.constant.none + // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list + // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,10,4],f32>, !torch.vtensor<[1,1,10],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> return %0 : !torch.vtensor<[1,3],si64> } @@ -2109,23 +2117,30 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" - // CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],si64> - // CHECK: %[[VAL_26:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - // CHECK: %[[VAL_28:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_30:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_32:.*]] = torch.constant.none - // CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list - // CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class" - // CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> - // CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64> - // CHECK: } + // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64> + // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> + // CHECK: } else { + // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> + // CHECK: } + // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_37:.*]] = torch.constant.none + // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list + // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> return %0 : !torch.vtensor<[1,3],si64> }