diff options
author | aartbik <ajcbik@google.com> | 2020-02-11 11:09:14 -0800 |
---|---|---|
committer | aartbik <ajcbik@google.com> | 2020-02-11 11:31:59 -0800 |
commit | e83b7b99da2e0385c567cd3883cad66fb5ce271c (patch) | |
tree | 62cb962cb9fbe9ea75654dec5320cd3120b0b192 | |
parent | 0cecafd647ccd9d0acc5968d4d6e80c1cbdee275 (diff) | |
download | llvm-e83b7b99da2e0385c567cd3883cad66fb5ce271c.zip llvm-e83b7b99da2e0385c567cd3883cad66fb5ce271c.tar.gz llvm-e83b7b99da2e0385c567cd3883cad66fb5ce271c.tar.bz2 |
[mlir] [VectorOps] Implement vector.reduce operation
Summary:
This new operation operates on 1-D vectors and
forms the bridge between vector.contract and
llvm intrinsics for vector reductions.
Reviewers: nicolasvasilache, andydavis1, ftynse
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D74370
-rw-r--r-- | mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 33 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 80 | ||||
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 27 | ||||
-rw-r--r-- | mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 44 | ||||
-rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 28 | ||||
-rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 34 |
6 files changed, 237 insertions, 9 deletions
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index de7007e..074a6d0 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -183,6 +183,39 @@ def Vector_ContractionOp : }]; } +def Vector_ReductionOp : + Vector_Op<"reduction", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins StrAttr:$kind, AnyVector:$vector)>, + Results<(outs AnyType:$dest)> { + let summary = "reduction operation"; + let description = [{ + Reduces an 1-D vector "horizontally" into a scalar using the given + operation (add/mul/min/max for int/fp and and/or/xor for int only). + Note that these operations are restricted to 1-D vectors to remain + close to the corresponding LLVM intrinsics: + + http://llvm.org/docs/LangRef.html#experimental-vector-reduction-intrinsics + + Examples: + ``` + %1 = vector.reduction "add", %0 : vector<16xf32> into f32 + + %3 = vector.reduction "xor", %2 : vector<4xi32> into i32 + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let assemblyFormat = [{ + $kind `,` $vector attr-dict `:` type($vector) `into` type($dest) + }]; + let extraClassDeclaration = [{ + VectorType getVectorType() { + return vector().getType().cast<VectorType>(); + } + }]; +} + def Vector_BroadcastOp : Vector_Op<"broadcast", [NoSideEffect, PredOpTrait<"source operand and result have same element type", diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a3d724b..9fcad2f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -124,6 +124,7 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, } namespace { + class VectorBroadcastOpConversion : public LLVMOpLowering { public: explicit VectorBroadcastOpConversion(MLIRContext *context, @@ -272,6 +273,73 @@ private: } }; +class VectorReductionOpConversion : public LLVMOpLowering { +public: + explicit VectorReductionOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ReductionOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto reductionOp = cast<vector::ReductionOp>(op); + auto kind = reductionOp.kind(); + Type eltType = reductionOp.dest().getType(); + Type llvmType = lowering.convertType(eltType); + if (eltType.isInteger(32) || eltType.isInteger(64)) { + // Integer reductions: add/mul/min/max/and/or/xor. + if (kind == "add") + rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>( + op, llvmType, operands[0]); + else if (kind == "mul") + rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>( + op, llvmType, operands[0]); + else if (kind == "min") + rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>( + op, llvmType, operands[0]); + else if (kind == "max") + rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>( + op, llvmType, operands[0]); + else if (kind == "and") + rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>( + op, llvmType, operands[0]); + else if (kind == "or") + rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>( + op, llvmType, operands[0]); + else if (kind == "xor") + rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>( + op, llvmType, operands[0]); + else + return matchFailure(); + return matchSuccess(); + + } else if (eltType.isF32() || eltType.isF64()) { + // Floating-point reductions: add/mul/min/max + if (kind == "add") { + Value zero = rewriter.create<LLVM::ConstantOp>( + op->getLoc(), llvmType, rewriter.getZeroAttr(eltType)); + rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>( + op, llvmType, zero, operands[0]); + } else if (kind == "mul") { + Value one = rewriter.create<LLVM::ConstantOp>( + op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0)); + rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>( + op, llvmType, one, operands[0]); + } else if (kind == "min") + rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>( + op, llvmType, operands[0]); + else if (kind == "max") + rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>( + op, llvmType, operands[0]); + else + return matchFailure(); + return matchSuccess(); + } + return matchFailure(); + } +}; + class VectorShuffleOpConversion : public LLVMOpLowering { public: explicit VectorShuffleOpConversion(MLIRContext *context, @@ -1056,12 +1124,12 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorInsertStridedSliceOpDifferentRankRewritePattern, VectorInsertStridedSliceOpSameRankRewritePattern, VectorStridedSliceOpConversion>(ctx); - patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, - VectorExtractElementOpConversion, VectorExtractOpConversion, - VectorFMAOp1DConversion, VectorInsertElementOpConversion, - VectorInsertOpConversion, VectorOuterProductOpConversion, - VectorTypeCastOpConversion, VectorPrintOpConversion>( - ctx, converter); + patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion, + VectorShuffleOpConversion, VectorExtractElementOpConversion, + VectorExtractOpConversion, VectorFMAOp1DConversion, + VectorInsertElementOpConversion, VectorInsertOpConversion, + VectorOuterProductOpConversion, VectorTypeCastOpConversion, + VectorPrintOpConversion>(ctx, converter); } namespace { diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index a987a54..b4d7aee 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -61,6 +61,33 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, } //===----------------------------------------------------------------------===// +// ReductionOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ReductionOp op) { + // Verify for 1-D vector. + int64_t rank = op.getVectorType().getRank(); + if (rank != 1) + return op.emitOpError("unsupported reduction rank: ") << rank; + + // Verify supported reduction kind. + auto kind = op.kind(); + Type eltType = op.dest().getType(); + if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") { + if (eltType.isF32() || eltType.isF64() || eltType.isInteger(32) || + eltType.isInteger(64)) + return success(); + return op.emitOpError("unsupported reduction type"); + } + if (kind == "and" || kind == "or" || kind == "xor") { + if (eltType.isInteger(32) || eltType.isInteger(64)) + return success(); + return op.emitOpError("unsupported reduction type"); + } + return op.emitOpError("unknown reduction kind: ") << kind; +} + +//===----------------------------------------------------------------------===// // ContractionOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index d1535d5..5159031 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -645,7 +645,7 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect // CHECK: "llvm.intr.fma"(%[[A]], %[[A]], %[[A]]) : // CHECK-SAME: (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>"> %0 = vector.fma %a, %a, %a : vector<8xf32> - + // CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]"> // CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]"> // CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]"> @@ -659,7 +659,45 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect // CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> // CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]"> %1 = vector.fma %b, %b, %b : vector<2x4xf32> - + return %0, %1: vector<8xf32>, vector<2x4xf32> } - + +func @reduce_f32(%arg0: vector<16xf32>) -> f32 { + %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32 + return %0 : f32 +} +// CHECK-LABEL: llvm.func @reduce_f32 +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>"> +// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK: llvm.return %[[V]] : !llvm.float + +func @reduce_f64(%arg0: vector<16xf64>) -> f64 { + %0 = vector.reduction "add", %arg0 : vector<16xf64> into f64 + return %0 : f64 +} +// CHECK-LABEL: llvm.func @reduce_f64 +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>"> +// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : !llvm.double +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK: llvm.return %[[V]] : !llvm.double + +func @reduce_i32(%arg0: vector<16xi32>) -> i32 { + %0 = vector.reduction "add", %arg0 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: llvm.func @reduce_i32 +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i32>"> +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]]) +// CHECK: llvm.return %[[V]] : !llvm.i32 + +func @reduce_i64(%arg0: vector<16xi64>) -> i64 { + %0 = vector.reduction "add", %arg0 : vector<16xi64> into i64 + return %0 : i64 +} +// CHECK-LABEL: llvm.func @reduce_i64 +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i64>"> +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]]) +// CHECK: llvm.return %[[V]] : !llvm.i64 + diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index a1fee2c..2a45820 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -990,3 +990,31 @@ func @shape_cast_different_tuple_sizes( %1 = vector.shape_cast %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to tuple<vector<20x2xf32>> } + +// ----- + +func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 { + // expected-error@+1 {{'vector.reduction' op unknown reduction kind: joho}} + %0 = vector.reduction "joho", %arg0 : vector<16xf32> into f32 +} + +// ----- + +func @reduce_elt_type_mismatch(%arg0: vector<16xf32>) -> i32 { + // expected-error@+1 {{'vector.reduction' op failed to verify that source operand and result have same element type}} + %0 = vector.reduction "add", %arg0 : vector<16xf32> into i32 +} + +// ----- + +func @reduce_unsupported_type(%arg0: vector<16xf32>) -> f32 { + // expected-error@+1 {{'vector.reduction' op unsupported reduction type}} + %0 = vector.reduction "xor", %arg0 : vector<16xf32> into f32 +} + +// ----- + +func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 { + // expected-error@+1 {{'vector.reduction' op unsupported reduction rank: 2}} + %0 = vector.reduction "add", %arg0 : vector<4x16xf32> into f32 +} diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index ff00783..bb5ca6e 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -277,3 +277,37 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) { vector.fma %b, %b, %b : vector<8x4xf32> return } + +// CHECK-LABEL: reduce_fp +func @reduce_fp(%arg0: vector<16xf32>) -> f32 { + // CHECK: vector.reduction "add", %{{.*}} : vector<16xf32> into f32 + vector.reduction "add", %arg0 : vector<16xf32> into f32 + // CHECK: vector.reduction "mul", %{{.*}} : vector<16xf32> into f32 + vector.reduction "mul", %arg0 : vector<16xf32> into f32 + // CHECK: vector.reduction "min", %{{.*}} : vector<16xf32> into f32 + vector.reduction "min", %arg0 : vector<16xf32> into f32 + // CHECK: %[[X:.*]] = vector.reduction "max", %{{.*}} : vector<16xf32> into f32 + %0 = vector.reduction "max", %arg0 : vector<16xf32> into f32 + // CHECK: return %[[X]] : f32 + return %0 : f32 +} + +// CHECK-LABEL: reduce_int +func @reduce_int(%arg0: vector<16xi32>) -> i32 { + // CHECK: vector.reduction "add", %{{.*}} : vector<16xi32> into i32 + vector.reduction "add", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "mul", %{{.*}} : vector<16xi32> into i32 + vector.reduction "mul", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "min", %{{.*}} : vector<16xi32> into i32 + vector.reduction "min", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "max", %{{.*}} : vector<16xi32> into i32 + vector.reduction "max", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "and", %{{.*}} : vector<16xi32> into i32 + vector.reduction "and", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "or", %{{.*}} : vector<16xi32> into i32 + vector.reduction "or", %arg0 : vector<16xi32> into i32 + // CHECK: %[[X:.*]] = vector.reduction "xor", %{{.*}} : vector<16xi32> into i32 + %0 = vector.reduction "xor", %arg0 : vector<16xi32> into i32 + // CHECK: return %[[X]] : i32 + return %0 : i32 +} |