diff options
author | Benjamin Maxwell <benjamin.maxwell@arm.com> | 2024-06-20 10:27:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-20 10:27:07 +0100 |
commit | e2296d8295516e9991cd6ca99ba193fbd232b6da (patch) | |
tree | d7eb98ffcf897ff3d0ddd1f4adeadd05c0183bec /mlir | |
parent | 94fdfc1ca859d5802bee70853913e8d0400ad9d1 (diff) | |
download | llvm-e2296d8295516e9991cd6ca99ba193fbd232b6da.zip llvm-e2296d8295516e9991cd6ca99ba193fbd232b6da.tar.gz llvm-e2296d8295516e9991cd6ca99ba193fbd232b6da.tar.bz2 |
[mlir][ArmSME] Lower extract from 2D scalable create_mask to psel (#96066)
Example:
```mlir
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
%slice = vector.extract %mask[%index]
: vector<[8]xi1> from vector<[4]x[8]xi1>
```
Becomes:
```mlir
%mask_rows = vector.create_mask %a : vector<[4]xi1>
%mask_cols = vector.create_mask %b : vector<[8]xi1>
%slice = arm_sve.psel %mask_cols, %mask_rows[%index]
: vector<[8]xi1>, vector<[4]xi1>
```
Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this
is currently the most logical place for this lowering.
Diffstat (limited to 'mlir')
6 files changed, 168 insertions, 10 deletions
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index db67d6a..9ab5faf 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1276,7 +1276,7 @@ def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> { Pass that converts vector dialect operations into equivalent ArmSME dialect operations. }]; - let dependentDialects = ["arm_sme::ArmSMEDialect"]; + let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt index b062f65..6a81a09 100644 --- a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt @@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME LINK_LIBS PUBLIC MLIRArmSMEDialect + MLIRArmSVEDialect MLIRLLVMCommonConversion ) diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 56ae46a..ee52b9e 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Casting.h" @@ -719,16 +720,86 @@ struct FoldTransferWriteOfExtractTileSlice } }; +/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to +/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or +/// SVE 2.1), so this is currently the most logical place for this lowering. +/// +/// Example: +/// ```mlir +/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> +/// %slice = vector.extract %mask[%index] +/// : vector<[8]xi1> from vector<[4]x[8]xi1> +/// ``` +/// Becomes: +/// ``` +/// %mask_rows = vector.create_mask %a : vector<[4]xi1> +/// %mask_cols = vector.create_mask %b : vector<[8]xi1> +/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index] +/// : vector<[8]xi1>, vector<[4]xi1> +/// ``` +struct ExtractFromCreateMaskToPselLowering + : public OpRewritePattern<vector::ExtractOp> { + using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + if (extractOp.getNumIndices() != 1) + return rewriter.notifyMatchFailure(extractOp, "not single extract index"); + + auto resultType = extractOp.getResult().getType(); + auto resultVectorType = dyn_cast<VectorType>(resultType); + if (!resultVectorType) + return rewriter.notifyMatchFailure(extractOp, "result not VectorType"); + + auto createMaskOp = + extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); + if (!createMaskOp) + return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp"); + + auto maskType = createMaskOp.getVectorType(); + if (maskType.getRank() != 2 || !maskType.allDimsScalable()) + return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask"); + + auto isSVEPredicateSize = [](int64_t size) { + return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size)); + }; + + auto rowsBaseSize = maskType.getDimSize(0); + auto colsBaseSize = maskType.getDimSize(1); + if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize)) + return rewriter.notifyMatchFailure( + createMaskOp, "mask dimensions not SVE predicate-sized"); + + auto loc = extractOp.getLoc(); + VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1); + VectorType colMaskType = VectorType::Builder(maskType).dropDim(0); + + // Create the two 1-D masks at the location of the 2-D create_mask (which is + // usually outside a loop). This prevents the need for later hoisting. + rewriter.setInsertionPoint(createMaskOp); + auto rowMask = rewriter.create<vector::CreateMaskOp>( + loc, rowMaskType, createMaskOp.getOperand(0)); + auto colMask = rewriter.create<vector::CreateMaskOp>( + loc, colMaskType, createMaskOp.getOperand(1)); + + rewriter.setInsertionPoint(extractOp); + auto position = + vector::getAsValues(rewriter, loc, extractOp.getMixedPosition()); + rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask, + position[0]); + return success(); + } +}; + } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { - patterns - .add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, - TransferReadToArmSMELowering, TransferWriteToArmSMELowering, - TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, - VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, - VectorExtractToArmSMELowering, VectorInsertToArmSMELowering, - VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>( - &ctx); + patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, + TransferReadToArmSMELowering, TransferWriteToArmSMELowering, + TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, + VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, + VectorExtractToArmSMELowering, VectorInsertToArmSMELowering, + VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice, + ExtractFromCreateMaskToPselLowering>(&ctx); } diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp index 2601f31..cc00bf4 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir index 8ed52cd..ff7b4bc 100644 --- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir +++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir @@ -192,3 +192,54 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () } + +// ----- + +/// Not SVE predicate-sized. + +// CHECK-LABEL: @negative_vector_extract_to_psel_0 +func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index) -> vector<[32]xi1> +{ + // CHECK-NOT: arm_sve.psel + %mask = vector.create_mask %a, %b : vector<[4]x[32]xi1> + %slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1> + return %slice : vector<[32]xi1> +} + +// ----- + +/// Source not 2-D scalable mask. + +// CHECK-LABEL: @negative_vector_extract_to_psel_1 +func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index) -> vector<[8]xi1> +{ + // CHECK-NOT: arm_sve.psel + %mask = vector.create_mask %a, %b : vector<4x[8]xi1> + %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1> + return %slice : vector<[8]xi1> +} + +// ----- + +/// Source not vector.create_mask. + +// CHECK-LABEL: @negative_vector_extract_to_psel_2 +func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1> +{ + // CHECK-NOT: arm_sve.psel + %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> + return %slice : vector<[8]xi1> +} + +// ----- + +/// Not psel-like extract. + +// CHECK-LABEL: @negative_vector_extract_to_psel_3 +func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index) -> i1 +{ + // CHECK-NOT: arm_sve.psel + %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> + %el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1> + return %el : i1 +} diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir index 8aeffb0..068fd0d 100644 --- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir +++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir @@ -1124,7 +1124,7 @@ func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vect } //===----------------------------------------------------------------------===// -// vector.extract +// vector.extract --> arm_sme.move_tile_slice_to_vector //===----------------------------------------------------------------------===// // ----- @@ -1320,3 +1320,37 @@ func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 { %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64> return %el : f64 } + +//===----------------------------------------------------------------------===// +// vector.extract --> arm_sve.psel +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @dynamic_vector_extract_mask_to_psel( +// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[INDEX:.*]]: index) +func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: index) -> vector<[8]xi1> +{ + // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[4]xi1> + // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1> + // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1> + %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> + %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> + return %slice : vector<[8]xi1> +} + +// ----- + +// CHECK-LABEL: @vector_extract_mask_to_psel( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: index) +func.func @vector_extract_mask_to_psel(%a: index, %b: index) -> vector<[2]xi1> +{ + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[16]xi1> + // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[2]xi1> + // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[C1]]] : vector<[2]xi1>, vector<[16]xi1> + %mask = vector.create_mask %a, %b : vector<[16]x[2]xi1> + %slice = vector.extract %mask[1] : vector<[2]xi1> from vector<[16]x[2]xi1> + return %slice : vector<[2]xi1> +} |