aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Raoux <thomasraoux@google.com>2023-01-06 16:09:21 +0000
committerThomas Raoux <thomasraoux@google.com>2023-01-06 16:20:17 +0000
commitf41abcda5ee0cf9d6a99bae5db08c60cbbafa760 (patch)
tree6396d5f1c8ca0750a69232cdba074cfb32df409b
parent8e20cb6bb8d0b6cf91cc25204eb29620a5040ba4 (diff)
downloadllvm-f41abcda5ee0cf9d6a99bae5db08c60cbbafa760.zip
llvm-f41abcda5ee0cf9d6a99bae5db08c60cbbafa760.tar.gz
llvm-f41abcda5ee0cf9d6a99bae5db08c60cbbafa760.tar.bz2
[mlir][vector] Relax restriction on reduction distribution
Relax unnecessary restriction when distribution a vector.reduce op. All the float and integer types can be supported by user's lambda. Differential Revision: https://reviews.llvm.org/D141094
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp9
1 files changed, 3 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 16b6000..08841e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1179,13 +1179,10 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
return rewriter.notifyMatchFailure(
warpOp, "Reduction vector dimension must match was size.");
- // Only f32, i32, f16, i8 element types are supported.
- if (!reductionOp.getType().isF32() &&
- !reductionOp.getType().isSignlessInteger(32) &&
- !reductionOp.getType().isF16() && !reductionOp.getType().isInteger(8))
+ if (!reductionOp.getType().isIntOrFloat())
return rewriter.notifyMatchFailure(
- warpOp, "Reduction distribution currently only supports 32bits, f16, "
- "and i8 types.");
+ warpOp, "Reduction distribution currently only supports floats and "
+ "integer types.");
int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
// Return vector that will be reduced from the WarpExecuteOnLane0Op.