diff options
author | Benjamin Maxwell <benjamin.maxwell@arm.com> | 2024-06-19 13:33:23 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-19 13:33:23 +0100 |
commit | 781133037387eefa4080aa31c73554cc0452e6e6 (patch) | |
tree | e1455541789bb83e79d704851ae60e4de6fffd86 | |
parent | 6244d87f42775e8d49cf758eeb1909f2ce144e3c (diff) | |
download | llvm-781133037387eefa4080aa31c73554cc0452e6e6.zip llvm-781133037387eefa4080aa31c73554cc0452e6e6.tar.gz llvm-781133037387eefa4080aa31c73554cc0452e6e6.tar.bz2 |
[mlir][ArmSVE] Add `arm_sve.psel` operation (#95764)
This adds a new operation for the SME/SVE2.1 psel instruction. This
allows selecting a predicate based on a bit within another predicate,
essentially allowing for 2-D predication. Informally, the semantics are:
```mlir
%pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
```
=>
```
if p2[index % num_elements(p2)] == 1:
pd = p1 : type(p1)
else:
pd = all-false : type(p1)
```
-rw-r--r-- | mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 55 | ||||
-rw-r--r-- | mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp | 26 | ||||
-rw-r--r-- | mlir/test/Dialect/ArmSVE/invalid.mlir | 8 | ||||
-rw-r--r-- | mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir | 32 | ||||
-rw-r--r-- | mlir/test/Dialect/ArmSVE/roundtrip.mlir | 29 | ||||
-rw-r--r-- | mlir/test/Target/LLVMIR/arm-sve.mlir | 19 |
6 files changed, 166 insertions, 3 deletions
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index aea5583..d7e8b22 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -37,10 +37,16 @@ def ArmSVE_Dialect : Dialect { //===----------------------------------------------------------------------===// def SVBool : ScalableVectorOfRankAndLengthAndType< - [1], [16], [I1]>; + [1], [16], [I1]> +{ + let summary = "vector<[16]xi1>"; +} def SVEPredicate : ScalableVectorOfRankAndLengthAndType< - [1], [16, 8, 4, 2, 1], [I1]>; + [1], [16, 8, 4, 2, 1], [I1]> +{ + let summary = "vector<[1]xi1>, vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, or vector<[16]xi1>"; +} // Generalizations of SVBool and SVEPredicate to ranks >= 1. // These are masks with a single trailing scalable dimension. @@ -442,6 +448,43 @@ def ZipX4Op : ArmSVE_Op<"zip.x4", [ }]; } +def PselOp : ArmSVE_Op<"psel", [ + Pure, + AllTypesMatch<["p1", "result"]>, +]> { + let summary = "Predicate select"; + + let description = [{ + This operation returns the input predicate `p1` or an all-false predicate + based on the bit at `p2[index]`. Informally, the semantics are: + ``` + if p2[index % num_elements(p2)] == 1: + return p1 : type(p1) + return all-false : type(p1) + ``` + + Example: + ```mlir + // Note: p1 and p2 can have different sizes. + %pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1> + ``` + + Note: This requires SME or SVE2.1 (`+sme` or `+sve2p1` in LLVM target features). + }]; + + let arguments = (ins SVEPredicate:$p1, SVEPredicate:$p2, Index:$index); + let results = (outs SVEPredicate:$result); + + let builders = [ + OpBuilder<(ins "Value":$p1, "Value":$p2, "Value":$index), [{ + build($_builder, $_state, p1.getType(), p1, p2, index); + }]>]; + + let assemblyFormat = [{ + $p1 `,` $p2 `[` $index `]` attr-dict `:` type($p1) `,` type($p2) + }]; +} + def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition", [Commutative]>; @@ -552,6 +595,14 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4", Arg<AnyScalableVector, "v3">:$v3, Arg<AnyScalableVector, "v3">:$v4)>; +// Note: This intrinsic requires SME or SVE2.1. +def PselIntrOp : ArmSVE_IntrOp<"psel", + /*traits=*/[Pure, TypeIs<"res", SVBool>], + /*overloadedOperands=*/[1]>, + Arguments<(ins Arg<SVBool, "p1">:$p1, + Arg<SVEPredicate, "p2">:$p2, + Arg<I32, "index">:$index)>; + def WhileLTIntrOp : ArmSVE_IntrOp<"whilelt", [TypeIs<"res", SVEPredicate>, Pure], diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index ed4f4cc..10f39a0 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -140,6 +140,28 @@ using ConvertFromSvboolOpLowering = using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>; using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>; +/// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion +/// but first input (P1) and result predicates need conversion to/from svbool. +struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto svboolType = VectorType::get(16, rewriter.getI1Type(), true); + auto loc = pselOp.getLoc(); + auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType, + adaptor.getP1()); + auto indexI32 = rewriter.create<arith::IndexCastOp>( + loc, rewriter.getI32Type(), pselOp.getIndex()); + auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1, + pselOp.getP2(), indexI32); + rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>( + pselOp, adaptor.getP1().getType(), pselIntr); + return success(); + } +}; + /// Converts `vector.create_mask` ops that match the size of an SVE predicate /// to the `whilelt` intrinsic. This produces more canonical codegen than the /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840 @@ -202,7 +224,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( ConvertToSvboolOpLowering, ConvertFromSvboolOpLowering, ZipX2OpLowering, - ZipX4OpLowering>(converter); + ZipX4OpLowering, + PselOpLowering>(converter); // Add vector.create_mask conversion with a high benefit as it produces much // nicer code than the generic lowering. patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096); @@ -229,6 +252,7 @@ void mlir::configureArmSVELegalizeForExportTarget( ConvertFromSvboolIntrOp, ZipX2IntrOp, ZipX4IntrOp, + PselIntrOp, WhileLTIntrOp>(); target.addIllegalOp<SdotOp, SmmlaOp, diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir index 1258d35..a021d43 100644 --- a/mlir/test/Dialect/ArmSVE/invalid.mlir +++ b/mlir/test/Dialect/ArmSVE/invalid.mlir @@ -64,3 +64,11 @@ func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) { arm_sve.zip.x4 %a, %a, %a, %a : vector<[5]xf64> return } + +// ----- + +func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) { + // expected-error@+1 {{op operand #0 must be vector<[1]xi1>, vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, or vector<[16]xi1>, but got 'vector<[7]xi1>'}} + arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1> + return +} diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir index 3fc5e6e..31d5376 100644 --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -239,3 +239,35 @@ func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, v %2 = vector.create_mask %index : vector<[32]xi1> return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1> } + +// ----- + +// CHECK-LABEL: @arm_sve_psel_matching_predicate_types( +// CHECK-SAME: %[[P0:[a-z0-9]+]]: vector<[4]xi1>, +// CHECK-SAME: %[[P1:[a-z0-9]+]]: vector<[4]xi1>, +// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: i64 +func.func @arm_sve_psel_matching_predicate_types(%p0: vector<[4]xi1>, %p1: vector<[4]xi1>, %index: index) -> vector<[4]xi1> +{ + // CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32 + // CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[4]xi1>) -> vector<[16]xi1> + // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1> + // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[4]xi1> + %0 = arm_sve.psel %p0, %p1[%index] : vector<[4]xi1>, vector<[4]xi1> + return %0 : vector<[4]xi1> +} + +// ----- + +// CHECK-LABEL: @arm_sve_psel_mixed_predicate_types( +// CHECK-SAME: %[[P0:[a-z0-9]+]]: vector<[8]xi1>, +// CHECK-SAME: %[[P1:[a-z0-9]+]]: vector<[16]xi1>, +// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: i64 +func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[16]xi1>, %index: index) -> vector<[8]xi1> +{ + // CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32 + // CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[8]xi1>) -> vector<[16]xi1> + // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1> + // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[8]xi1> + %0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1> + return %0 : vector<[8]xi1> +} diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir index f7b79aa..0f0c5a8 100644 --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -225,3 +225,32 @@ func.func @arm_sve_zip_x4( %a8, %b8, %c8, %d8 = arm_sve.zip.x4 %v8, %v8, %v8, %v8 : vector<[16]xi8> return } + +// ----- + +func.func @arm_sve_psel( + %p0: vector<[2]xi1>, + %p1: vector<[4]xi1>, + %p2: vector<[8]xi1>, + %p3: vector<[16]xi1>, + %index: index +) { + // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[2]xi1> + %0 = arm_sve.psel %p0, %p0[%index] : vector<[2]xi1>, vector<[2]xi1> + // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[4]xi1> + %1 = arm_sve.psel %p1, %p1[%index] : vector<[4]xi1>, vector<[4]xi1> + // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[8]xi1> + %2 = arm_sve.psel %p2, %p2[%index] : vector<[8]xi1>, vector<[8]xi1> + // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[16]xi1> + %3 = arm_sve.psel %p3, %p3[%index] : vector<[16]xi1>, vector<[16]xi1> + /// Some mixed predicate type examples: + // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[4]xi1> + %4 = arm_sve.psel %p0, %p1[%index] : vector<[2]xi1>, vector<[4]xi1> + // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[8]xi1> + %5 = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1> + // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[16]xi1> + %6 = arm_sve.psel %p2, %p3[%index] : vector<[8]xi1>, vector<[16]xi1> + // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[2]xi1> + %7 = arm_sve.psel %p3, %p0[%index] : vector<[16]xi1>, vector<[2]xi1> + return +} diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir index 34413d4..ed5a1fc 100644 --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -371,3 +371,22 @@ llvm.func @arm_sve_whilelt(%base: i64, %n: i64) { %4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1> llvm.return } + +// CHECK-LABEL: arm_sve_psel( +// CHECK-SAME: <vscale x 16 x i1> %[[PN:[0-9]+]], +// CHECK-SAME: <vscale x 2 x i1> %[[P1:[0-9]+]], +// CHECK-SAME: <vscale x 4 x i1> %[[P2:[0-9]+]], +// CHECK-SAME: <vscale x 8 x i1> %[[P3:[0-9]+]], +// CHECK-SAME: <vscale x 16 x i1> %[[P4:[0-9]+]], +// CHECK-SAME: i32 %[[INDEX:[0-9]+]]) +llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[4]xi1>, %p3: vector<[8]xi1>, %p4: vector<[16]xi1>, %index: i32) { + // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv2i1(<vscale x 16 x i1> %[[PN]], <vscale x 2 x i1> %[[P1]], i32 %[[INDEX]]) + "arm_sve.intr.psel"(%pn, %p1, %index) : (vector<[16]xi1>, vector<[2]xi1>, i32) -> vector<[16]xi1> + // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv4i1(<vscale x 16 x i1> %[[PN]], <vscale x 4 x i1> %[[P2]], i32 %[[INDEX]]) + "arm_sve.intr.psel"(%pn, %p2, %index) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1> + // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv8i1(<vscale x 16 x i1> %[[PN]], <vscale x 8 x i1> %[[P3]], i32 %[[INDEX]]) + "arm_sve.intr.psel"(%pn, %p3, %index) : (vector<[16]xi1>, vector<[8]xi1>, i32) -> vector<[16]xi1> + // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv16i1(<vscale x 16 x i1> %[[PN]], <vscale x 16 x i1> %[[P4]], i32 %[[INDEX]]) + "arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1> + llvm.return +} |