diff options
author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2023-12-20 11:08:54 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-20 11:08:54 +0000 |
commit | 354adb44c94c3660dcc709726e980d58e44afabc (patch) | |
tree | 52b70e9a33bd2d4e5858687a38a147f50ce17bd7 | |
parent | 83f8caeab476646eea21bdde619b0beb84ebd70b (diff) | |
download | llvm-354adb44c94c3660dcc709726e980d58e44afabc.zip llvm-354adb44c94c3660dcc709726e980d58e44afabc.tar.gz llvm-354adb44c94c3660dcc709726e980d58e44afabc.tar.bz2 |
[mlir][vector] Extend `CreateMaskFolder` (#75842)
Extends `CreateMaskFolder` pattern so that the following:
```mlir
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%0 = vector.vscale
%1 = arith.muli %0, %c16 : index
%10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
```
is folded as:
```mlir
%0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
```
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 95 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/canonicalize.mlir | 13 |
2 files changed, 87 insertions, 21 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8a23ada..f257728 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5660,30 +5660,76 @@ LogicalResult CreateMaskOp::verify() { namespace { -// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. +/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. +/// +/// Ex 1: +/// %c2 = arith.constant 2 : index +/// %c3 = arith.constant 3 : index +/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1> +/// Becomes: +/// vector.constant_mask [3, 2] : vector<4x3xi1> +/// +/// Ex 2: +/// %c_neg_1 = arith.constant -1 : index +/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1> +/// becomes: +/// vector.constant_mask [0] : vector<[8]xi1> +/// +/// Ex 3: +/// %c8 = arith.constant 8 : index +/// %c16 = arith.constant 16 : index +/// %0 = vector.vscale +/// %1 = arith.muli %0, %c16 : index +/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1> +/// becomes: +/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { - // Return if any of 'createMaskOp' operands are not defined by a constant. - auto isNotDefByConstant = [](Value operand) { - return !getConstantIntValue(operand).has_value(); - }; - if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant)) - return failure(); + VectorType retTy = createMaskOp.getResult().getType(); + bool isScalable = retTy.isScalable(); + + // Check every mask operand + for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) { + if (auto cst = getConstantIntValue(operand)) { + // Most basic case - this operand is a constant value. Note that for + // scalable dimensions, CreateMaskOp can be folded only if the + // corresponding operand is negative or zero. + if (retTy.getScalableDims()[opIdx] && *cst > 0) + return failure(); - // CreateMaskOp for scalable vectors can be folded only if all dimensions - // are negative or zero. - if (auto vType = llvm::dyn_cast<VectorType>(createMaskOp.getType())) { - if (vType.isScalable()) - for (auto opDim : createMaskOp.getOperands()) { - APInt intVal; - if (matchPattern(opDim, m_ConstantInt(&intVal)) && - intVal.isStrictlyPositive()) - return failure(); - } + continue; + } + + // Non-constant operands are not allowed for non-scalable vectors. + if (!isScalable) + return failure(); + + // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all + // true" mask, so can also be treated as constant. + auto mul = operand.getDefiningOp<arith::MulIOp>(); + if (!mul) + return failure(); + auto mulLHS = mul.getRhs(); + auto mulRHS = mul.getLhs(); + bool isOneOpVscale = + (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) || + isa<vector::VectorScaleOp>(mulRHS.getDefiningOp())); + + auto isConstantValMatchingDim = + [=, dim = retTy.getShape()[opIdx]](Value operand) { + auto constantVal = getConstantIntValue(operand); + return (constantVal.has_value() && constantVal.value() == dim); + }; + + bool isOneOpConstantMatchingDim = + isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS); + + if (!isOneOpVscale || !isOneOpConstantMatchingDim) + return failure(); } // Gather constant mask dimension sizes. @@ -5691,18 +5737,25 @@ public: maskDimSizes.reserve(createMaskOp->getNumOperands()); for (auto [operand, maxDimSize] : llvm::zip_equal( createMaskOp.getOperands(), createMaskOp.getType().getShape())) { - int64_t dimSize = getConstantIntValue(operand).value(); - dimSize = std::min(dimSize, maxDimSize); + std::optional dimSize = getConstantIntValue(operand); + if (!dimSize) { + // Although not a constant, it is safe to assume that `operand` is + // "vscale * maxDimSize". + maskDimSizes.push_back(maxDimSize); + continue; + } + int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize); // If one of dim sizes is zero, set all dims to zero. if (dimSize <= 0) { maskDimSizes.assign(createMaskOp.getType().getRank(), 0); break; } - maskDimSizes.push_back(dimSize); + maskDimSizes.push_back(dimSizeVal); } + // Replace 'createMaskOp' with ConstantMaskOp. rewriter.replaceOpWithNewOp<ConstantMaskOp>( - createMaskOp, createMaskOp.getResult().getType(), + createMaskOp, retTy, vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); return success(); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index b5164b6..feefb0c 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -53,6 +53,19 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x // ----- +// CHECK-LABEL: create_vector_mask_to_constant_mask_scalable_all_true +func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x[16]xi1>) { + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %0 = vector.vscale + %1 = arith.muli %0, %c16 : index + // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1> + %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1> + return %10 : vector<8x[16]xi1> +} + +// ----- + // CHECK-LABEL: create_mask_transpose_to_transposed_create_mask // CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index func.func @create_mask_transpose_to_transposed_create_mask( |