aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean Silva <silvasean@google.com>2020-10-26 17:29:18 -0700
committerSean Silva <silvasean@google.com>2020-10-27 11:46:33 -0700
commit83154c541806468802d687a8b3c8f1a65e92199c (patch)
tree4c642eb7a6cd211ff321e90675a12419e8a00d5d
parent4b1ea84a1d04820248a4247cb1efa54c18765a6a (diff)
downloadllvm-83154c541806468802d687a8b3c8f1a65e92199c.zip
llvm-83154c541806468802d687a8b3c8f1a65e92199c.tar.gz
llvm-83154c541806468802d687a8b3c8f1a65e92199c.tar.bz2
[mlir] Add bufferization for std.select op.
Differential Revision: https://reviews.llvm.org/D90204
-rw-r--r--mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp38
-rw-r--r--mlir/test/Dialect/Standard/bufferize.mlir14
2 files changed, 48 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index a1b1f0a..9056fbc 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -89,6 +89,24 @@ public:
} // namespace
namespace {
+class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!op.condition().getType().isa<IntegerType>())
+ return rewriter.notifyMatchFailure(op, "requires scalar condition");
+
+ SelectOp::Adaptor adaptor(operands);
+ rewriter.replaceOpWithNewOp<SelectOp>(
+ op, adaptor.condition(), adaptor.true_value(), adaptor.false_value());
+ return success();
+ }
+};
+} // namespace
+
+namespace {
class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -128,10 +146,15 @@ public:
void mlir::populateStdBufferizePatterns(MLIRContext *context,
BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns
- .insert<BufferizeDynamicTensorFromElementsOp, BufferizeExtractElementOp,
- BufferizeTensorCastOp, BufferizeTensorFromElementsOp>(
- typeConverter, context);
+ patterns.insert<
+ // clang-format off
+ BufferizeDynamicTensorFromElementsOp,
+ BufferizeExtractElementOp,
+ BufferizeSelectOp,
+ BufferizeTensorCastOp,
+ BufferizeTensorFromElementsOp
+ // clang-format on
+ >(typeConverter, context);
}
namespace {
@@ -148,6 +171,13 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
populateStdBufferizePatterns(context, typeConverter, patterns);
target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
TensorCastOp, TensorFromElementsOp>();
+ // We only bufferize the case of tensor selected type and scalar condition,
+ // as that boils down to a select over memref descriptors (don't need to
+ // touch the data).
+ target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
+ return typeConverter.isLegal(op.getType()) ||
+ !op.condition().getType().isa<IntegerType>();
+ });
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 6125998..b2cefe3 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -61,6 +61,20 @@ func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
return %0 : f32
}
+// CHECK-LABEL: func @select(
+// CHECK-SAME: %[[PRED:.*]]: i1,
+// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
+// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK: %[[TRUE_VAL_MEMREF:.*]] = tensor_to_memref %[[TRUE_VAL]] : memref<f32>
+// CHECK: %[[FALSE_VAL_MEMREF:.*]] = tensor_to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK: %[[RET_MEMREF:.*]] = select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
+// CHECK: %[[RET:.*]] = tensor_load %[[RET_MEMREF]] : memref<f32>
+// CHECK: return %[[RET]] : tensor<f32>
+func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
+ %0 = select %arg0, %arg1, %arg2 : tensor<f32>
+ return %0 : tensor<f32>
+}
+
// CHECK-LABEL: func @tensor_cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]