aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorBenjamin Maxwell <benjamin.maxwell@arm.com>2024-06-20 10:27:07 +0100
committerGitHub <noreply@github.com>2024-06-20 10:27:07 +0100
commite2296d8295516e9991cd6ca99ba193fbd232b6da (patch)
treed7eb98ffcf897ff3d0ddd1f4adeadd05c0183bec /mlir
parent94fdfc1ca859d5802bee70853913e8d0400ad9d1 (diff)
downloadllvm-e2296d8295516e9991cd6ca99ba193fbd232b6da.zip
llvm-e2296d8295516e9991cd6ca99ba193fbd232b6da.tar.gz
llvm-e2296d8295516e9991cd6ca99ba193fbd232b6da.tar.bz2
[mlir][ArmSME] Lower extract from 2D scalable create_mask to psel (#96066)
Example: ```mlir %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> ``` Becomes: ```mlir %mask_rows = vector.create_mask %a : vector<[4]xi1> %mask_cols = vector.create_mask %b : vector<[8]xi1> %slice = arm_sve.psel %mask_cols, %mask_rows[%index] : vector<[8]xi1>, vector<[4]xi1> ``` Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this is currently the most logical place for this lowering.
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Conversion/Passes.td2
-rw-r--r--mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp87
-rw-r--r--mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp1
-rw-r--r--mlir/test/Conversion/VectorToArmSME/unsupported.mlir51
-rw-r--r--mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir36
6 files changed, 168 insertions, 10 deletions
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index db67d6a..9ab5faf 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1276,7 +1276,7 @@ def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> {
Pass that converts vector dialect operations into equivalent ArmSME dialect
operations.
}];
- let dependentDialects = ["arm_sme::ArmSMEDialect"];
+ let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
index b062f65..6a81a09 100644
--- a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME
LINK_LIBS PUBLIC
MLIRArmSMEDialect
+ MLIRArmSVEDialect
MLIRLLVMCommonConversion
)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 56ae46a..ee52b9e 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Casting.h"
@@ -719,16 +720,86 @@ struct FoldTransferWriteOfExtractTileSlice
}
};
+/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
+/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
+/// SVE 2.1), so this is currently the most logical place for this lowering.
+///
+/// Example:
+/// ```mlir
+/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+/// %slice = vector.extract %mask[%index]
+/// : vector<[8]xi1> from vector<[4]x[8]xi1>
+/// ```
+/// Becomes:
+/// ```
+/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
+/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
+/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
+/// : vector<[8]xi1>, vector<[4]xi1>
+/// ```
+struct ExtractFromCreateMaskToPselLowering
+ : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ if (extractOp.getNumIndices() != 1)
+ return rewriter.notifyMatchFailure(extractOp, "not single extract index");
+
+ auto resultType = extractOp.getResult().getType();
+ auto resultVectorType = dyn_cast<VectorType>(resultType);
+ if (!resultVectorType)
+ return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
+
+ auto createMaskOp =
+ extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
+
+ auto maskType = createMaskOp.getVectorType();
+ if (maskType.getRank() != 2 || !maskType.allDimsScalable())
+ return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
+
+ auto isSVEPredicateSize = [](int64_t size) {
+ return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
+ };
+
+ auto rowsBaseSize = maskType.getDimSize(0);
+ auto colsBaseSize = maskType.getDimSize(1);
+ if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "mask dimensions not SVE predicate-sized");
+
+ auto loc = extractOp.getLoc();
+ VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
+ VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
+
+ // Create the two 1-D masks at the location of the 2-D create_mask (which is
+ // usually outside a loop). This prevents the need for later hoisting.
+ rewriter.setInsertionPoint(createMaskOp);
+ auto rowMask = rewriter.create<vector::CreateMaskOp>(
+ loc, rowMaskType, createMaskOp.getOperand(0));
+ auto colMask = rewriter.create<vector::CreateMaskOp>(
+ loc, colMaskType, createMaskOp.getOperand(1));
+
+ rewriter.setInsertionPoint(extractOp);
+ auto position =
+ vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
+ rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
+ position[0]);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns
- .add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
- TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
- TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
- VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
- VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
- VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
- &ctx);
+ patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
+ TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
+ VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
+ VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
+ ExtractFromCreateMaskToPselLowering>(&ctx);
}
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
index 2601f31..cc00bf4 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
@@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 8ed52cd..ff7b4bc 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -192,3 +192,54 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}
+
+// -----
+
+/// Not SVE predicate-sized.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_0
+func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index) -> vector<[32]xi1>
+{
+ // CHECK-NOT: arm_sve.psel
+ %mask = vector.create_mask %a, %b : vector<[4]x[32]xi1>
+ %slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1>
+ return %slice : vector<[32]xi1>
+}
+
+// -----
+
+/// Source not 2-D scalable mask.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_1
+func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index) -> vector<[8]xi1>
+{
+ // CHECK-NOT: arm_sve.psel
+ %mask = vector.create_mask %a, %b : vector<4x[8]xi1>
+ %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1>
+ return %slice : vector<[8]xi1>
+}
+
+// -----
+
+/// Source not vector.create_mask.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_2
+func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1>
+{
+ // CHECK-NOT: arm_sve.psel
+ %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
+ return %slice : vector<[8]xi1>
+}
+
+// -----
+
+/// Not psel-like extract.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_3
+func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index) -> i1
+{
+ // CHECK-NOT: arm_sve.psel
+ %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+ %el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1>
+ return %el : i1
+}
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 8aeffb0..068fd0d 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -1124,7 +1124,7 @@ func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vect
}
//===----------------------------------------------------------------------===//
-// vector.extract
+// vector.extract --> arm_sme.move_tile_slice_to_vector
//===----------------------------------------------------------------------===//
// -----
@@ -1320,3 +1320,37 @@ func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
%el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
return %el : f64
}
+
+//===----------------------------------------------------------------------===//
+// vector.extract --> arm_sve.psel
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @dynamic_vector_extract_mask_to_psel(
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[INDEX:.*]]: index)
+func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: index) -> vector<[8]xi1>
+{
+ // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[4]xi1>
+ // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1>
+ // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1>
+ %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+ %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
+ return %slice : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_mask_to_psel(
+// CHECK-SAME: %[[A:.*]]: index,
+// CHECK-SAME: %[[B:.*]]: index)
+func.func @vector_extract_mask_to_psel(%a: index, %b: index) -> vector<[2]xi1>
+{
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[16]xi1>
+ // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[2]xi1>
+ // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[C1]]] : vector<[2]xi1>, vector<[16]xi1>
+ %mask = vector.create_mask %a, %b : vector<[16]x[2]xi1>
+ %slice = vector.extract %mask[1] : vector<[2]xi1> from vector<[16]x[2]xi1>
+ return %slice : vector<[2]xi1>
+}