diff options
author | Sean Silva <silvasean@google.com> | 2020-10-26 17:29:18 -0700 |
---|---|---|
committer | Sean Silva <silvasean@google.com> | 2020-10-27 11:46:33 -0700 |
commit | 83154c541806468802d687a8b3c8f1a65e92199c (patch) | |
tree | 4c642eb7a6cd211ff321e90675a12419e8a00d5d | |
parent | 4b1ea84a1d04820248a4247cb1efa54c18765a6a (diff) | |
download | llvm-83154c541806468802d687a8b3c8f1a65e92199c.zip llvm-83154c541806468802d687a8b3c8f1a65e92199c.tar.gz llvm-83154c541806468802d687a8b3c8f1a65e92199c.tar.bz2 |
[mlir] Add bufferization for std.select op.
Differential Revision: https://reviews.llvm.org/D90204
-rw-r--r-- | mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp | 38 | ||||
-rw-r--r-- | mlir/test/Dialect/Standard/bufferize.mlir | 14 |
2 files changed, 48 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp index a1b1f0a..9056fbc 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -89,6 +89,24 @@ public: } // namespace namespace { +class BufferizeSelectOp : public OpConversionPattern<SelectOp> { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(SelectOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + if (!op.condition().getType().isa<IntegerType>()) + return rewriter.notifyMatchFailure(op, "requires scalar condition"); + + SelectOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp<SelectOp>( + op, adaptor.condition(), adaptor.true_value(), adaptor.false_value()); + return success(); + } +}; +} // namespace + +namespace { class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> { public: using OpConversionPattern::OpConversionPattern; @@ -128,10 +146,15 @@ public: void mlir::populateStdBufferizePatterns(MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns - .insert<BufferizeDynamicTensorFromElementsOp, BufferizeExtractElementOp, - BufferizeTensorCastOp, BufferizeTensorFromElementsOp>( - typeConverter, context); + patterns.insert< + // clang-format off + BufferizeDynamicTensorFromElementsOp, + BufferizeExtractElementOp, + BufferizeSelectOp, + BufferizeTensorCastOp, + BufferizeTensorFromElementsOp + // clang-format on + >(typeConverter, context); } namespace { @@ -148,6 +171,13 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> { populateStdBufferizePatterns(context, typeConverter, patterns); target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp, TensorCastOp, TensorFromElementsOp>(); + // We only bufferize the case of tensor selected type and scalar condition, + // as that boils down to a select over memref descriptors (don't need to + // touch the data). + target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) { + return typeConverter.isLegal(op.getType()) || + !op.condition().getType().isa<IntegerType>(); + }); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir index 6125998..b2cefe3 100644 --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -61,6 +61,20 @@ func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 { return %0 : f32 } +// CHECK-LABEL: func @select( +// CHECK-SAME: %[[PRED:.*]]: i1, +// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>, +// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> { +// CHECK: %[[TRUE_VAL_MEMREF:.*]] = tensor_to_memref %[[TRUE_VAL]] : memref<f32> +// CHECK: %[[FALSE_VAL_MEMREF:.*]] = tensor_to_memref %[[FALSE_VAL]] : memref<f32> +// CHECK: %[[RET_MEMREF:.*]] = select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32> +// CHECK: %[[RET:.*]] = tensor_load %[[RET_MEMREF]] : memref<f32> +// CHECK: return %[[RET]] : tensor<f32> +func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> { + %0 = select %arg0, %arg1, %arg2 : tensor<f32> + return %0 : tensor<f32> +} + // CHECK-LABEL: func @tensor_cast( // CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> { // CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] |