aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrzej WarzyƄski <andrzej.warzynski@arm.com>2023-12-20 11:08:54 +0000
committerGitHub <noreply@github.com>2023-12-20 11:08:54 +0000
commit354adb44c94c3660dcc709726e980d58e44afabc (patch)
tree52b70e9a33bd2d4e5858687a38a147f50ce17bd7
parent83f8caeab476646eea21bdde619b0beb84ebd70b (diff)
downloadllvm-354adb44c94c3660dcc709726e980d58e44afabc.zip
llvm-354adb44c94c3660dcc709726e980d58e44afabc.tar.gz
llvm-354adb44c94c3660dcc709726e980d58e44afabc.tar.bz2
[mlir][vector] Extend `CreateMaskFolder` (#75842)
Extends `CreateMaskFolder` pattern so that the following: ```mlir %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %0 = vector.vscale %1 = arith.muli %0, %c16 : index %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1> ``` is folded as: ```mlir %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> ```
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp95
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir13
2 files changed, 87 insertions, 21 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8a23ada..f257728 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5660,30 +5660,76 @@ LogicalResult CreateMaskOp::verify() {
namespace {
-// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+///
+/// Ex 1:
+/// %c2 = arith.constant 2 : index
+/// %c3 = arith.constant 3 : index
+/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
+/// Becomes:
+/// vector.constant_mask [3, 2] : vector<4x3xi1>
+///
+/// Ex 2:
+/// %c_neg_1 = arith.constant -1 : index
+/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
+/// becomes:
+/// vector.constant_mask [0] : vector<[8]xi1>
+///
+/// Ex 3:
+/// %c8 = arith.constant 8 : index
+/// %c16 = arith.constant 16 : index
+/// %0 = vector.vscale
+/// %1 = arith.muli %0, %c16 : index
+/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
+/// becomes:
+/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
- // Return if any of 'createMaskOp' operands are not defined by a constant.
- auto isNotDefByConstant = [](Value operand) {
- return !getConstantIntValue(operand).has_value();
- };
- if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
- return failure();
+ VectorType retTy = createMaskOp.getResult().getType();
+ bool isScalable = retTy.isScalable();
+
+ // Check every mask operand
+ for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
+ if (auto cst = getConstantIntValue(operand)) {
+ // Most basic case - this operand is a constant value. Note that for
+ // scalable dimensions, CreateMaskOp can be folded only if the
+ // corresponding operand is negative or zero.
+ if (retTy.getScalableDims()[opIdx] && *cst > 0)
+ return failure();
- // CreateMaskOp for scalable vectors can be folded only if all dimensions
- // are negative or zero.
- if (auto vType = llvm::dyn_cast<VectorType>(createMaskOp.getType())) {
- if (vType.isScalable())
- for (auto opDim : createMaskOp.getOperands()) {
- APInt intVal;
- if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
- intVal.isStrictlyPositive())
- return failure();
- }
+ continue;
+ }
+
+ // Non-constant operands are not allowed for non-scalable vectors.
+ if (!isScalable)
+ return failure();
+
+ // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
+ // true" mask, so can also be treated as constant.
+ auto mul = operand.getDefiningOp<arith::MulIOp>();
+ if (!mul)
+ return failure();
+ auto mulLHS = mul.getRhs();
+ auto mulRHS = mul.getLhs();
+ bool isOneOpVscale =
+ (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
+ isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
+
+ auto isConstantValMatchingDim =
+ [=, dim = retTy.getShape()[opIdx]](Value operand) {
+ auto constantVal = getConstantIntValue(operand);
+ return (constantVal.has_value() && constantVal.value() == dim);
+ };
+
+ bool isOneOpConstantMatchingDim =
+ isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
+
+ if (!isOneOpVscale || !isOneOpConstantMatchingDim)
+ return failure();
}
// Gather constant mask dimension sizes.
@@ -5691,18 +5737,25 @@ public:
maskDimSizes.reserve(createMaskOp->getNumOperands());
for (auto [operand, maxDimSize] : llvm::zip_equal(
createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
- int64_t dimSize = getConstantIntValue(operand).value();
- dimSize = std::min(dimSize, maxDimSize);
+ std::optional dimSize = getConstantIntValue(operand);
+ if (!dimSize) {
+ // Although not a constant, it is safe to assume that `operand` is
+ // "vscale * maxDimSize".
+ maskDimSizes.push_back(maxDimSize);
+ continue;
+ }
+ int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
// If one of dim sizes is zero, set all dims to zero.
if (dimSize <= 0) {
maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
break;
}
- maskDimSizes.push_back(dimSize);
+ maskDimSizes.push_back(dimSizeVal);
}
+
// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
- createMaskOp, createMaskOp.getResult().getType(),
+ createMaskOp, retTy,
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
return success();
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b5164b6..feefb0c 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -53,6 +53,19 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x
// -----
+// CHECK-LABEL: create_vector_mask_to_constant_mask_scalable_all_true
+func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x[16]xi1>) {
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %0 = vector.vscale
+ %1 = arith.muli %0, %c16 : index
+ // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+ %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
+ return %10 : vector<8x[16]xi1>
+}
+
+// -----
+
// CHECK-LABEL: create_mask_transpose_to_transposed_create_mask
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
func.func @create_mask_transpose_to_transposed_create_mask(