diff options
author | Thomas Raoux <thomasraoux@google.com> | 2023-01-06 16:09:21 +0000 |
---|---|---|
committer | Thomas Raoux <thomasraoux@google.com> | 2023-01-06 16:20:17 +0000 |
commit | f41abcda5ee0cf9d6a99bae5db08c60cbbafa760 (patch) | |
tree | 6396d5f1c8ca0750a69232cdba074cfb32df409b | |
parent | 8e20cb6bb8d0b6cf91cc25204eb29620a5040ba4 (diff) | |
download | llvm-f41abcda5ee0cf9d6a99bae5db08c60cbbafa760.zip llvm-f41abcda5ee0cf9d6a99bae5db08c60cbbafa760.tar.gz llvm-f41abcda5ee0cf9d6a99bae5db08c60cbbafa760.tar.bz2 |
[mlir][vector] Relax restriction on reduction distribution
Relax unnecessary restriction when distribution a vector.reduce op.
All the float and integer types can be supported by user's lambda.
Differential Revision: https://reviews.llvm.org/D141094
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 9 |
1 files changed, 3 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 16b6000..08841e3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1179,13 +1179,10 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> { if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) return rewriter.notifyMatchFailure( warpOp, "Reduction vector dimension must match was size."); - // Only f32, i32, f16, i8 element types are supported. - if (!reductionOp.getType().isF32() && - !reductionOp.getType().isSignlessInteger(32) && - !reductionOp.getType().isF16() && !reductionOp.getType().isInteger(8)) + if (!reductionOp.getType().isIntOrFloat()) return rewriter.notifyMatchFailure( - warpOp, "Reduction distribution currently only supports 32bits, f16, " - "and i8 types."); + warpOp, "Reduction distribution currently only supports floats and " + "integer types."); int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); // Return vector that will be reduced from the WarpExecuteOnLane0Op. |