aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Maxwell <benjamin.maxwell@arm.com>2024-06-19 13:33:23 +0100
committerGitHub <noreply@github.com>2024-06-19 13:33:23 +0100
commit781133037387eefa4080aa31c73554cc0452e6e6 (patch)
treee1455541789bb83e79d704851ae60e4de6fffd86
parent6244d87f42775e8d49cf758eeb1909f2ce144e3c (diff)
downloadllvm-781133037387eefa4080aa31c73554cc0452e6e6.zip
llvm-781133037387eefa4080aa31c73554cc0452e6e6.tar.gz
llvm-781133037387eefa4080aa31c73554cc0452e6e6.tar.bz2
[mlir][ArmSVE] Add `arm_sve.psel` operation (#95764)
This adds a new operation for the SME/SVE2.1 psel instruction. This allows selecting a predicate based on a bit within another predicate, essentially allowing for 2-D predication. Informally, the semantics are: ```mlir %pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1> ``` => ``` if p2[index % num_elements(p2)] == 1: pd = p1 : type(p1) else: pd = all-false : type(p1) ```
-rw-r--r--mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td55
-rw-r--r--mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp26
-rw-r--r--mlir/test/Dialect/ArmSVE/invalid.mlir8
-rw-r--r--mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir32
-rw-r--r--mlir/test/Dialect/ArmSVE/roundtrip.mlir29
-rw-r--r--mlir/test/Target/LLVMIR/arm-sve.mlir19
6 files changed, 166 insertions, 3 deletions
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index aea5583..d7e8b22 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -37,10 +37,16 @@ def ArmSVE_Dialect : Dialect {
//===----------------------------------------------------------------------===//
def SVBool : ScalableVectorOfRankAndLengthAndType<
- [1], [16], [I1]>;
+ [1], [16], [I1]>
+{
+ let summary = "vector<[16]xi1>";
+}
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
- [1], [16, 8, 4, 2, 1], [I1]>;
+ [1], [16, 8, 4, 2, 1], [I1]>
+{
+ let summary = "vector<[1]xi1>, vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, or vector<[16]xi1>";
+}
// Generalizations of SVBool and SVEPredicate to ranks >= 1.
// These are masks with a single trailing scalable dimension.
@@ -442,6 +448,43 @@ def ZipX4Op : ArmSVE_Op<"zip.x4", [
}];
}
+def PselOp : ArmSVE_Op<"psel", [
+ Pure,
+ AllTypesMatch<["p1", "result"]>,
+]> {
+ let summary = "Predicate select";
+
+ let description = [{
+ This operation returns the input predicate `p1` or an all-false predicate
+ based on the bit at `p2[index]`. Informally, the semantics are:
+ ```
+ if p2[index % num_elements(p2)] == 1:
+ return p1 : type(p1)
+ return all-false : type(p1)
+ ```
+
+ Example:
+ ```mlir
+ // Note: p1 and p2 can have different sizes.
+ %pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
+ ```
+
+ Note: This requires SME or SVE2.1 (`+sme` or `+sve2p1` in LLVM target features).
+ }];
+
+ let arguments = (ins SVEPredicate:$p1, SVEPredicate:$p2, Index:$index);
+ let results = (outs SVEPredicate:$result);
+
+ let builders = [
+ OpBuilder<(ins "Value":$p1, "Value":$p2, "Value":$index), [{
+ build($_builder, $_state, p1.getType(), p1, p2, index);
+ }]>];
+
+ let assemblyFormat = [{
+ $p1 `,` $p2 `[` $index `]` attr-dict `:` type($p1) `,` type($p2)
+ }];
+}
+
def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
[Commutative]>;
@@ -552,6 +595,14 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
Arg<AnyScalableVector, "v3">:$v3,
Arg<AnyScalableVector, "v3">:$v4)>;
+// Note: This intrinsic requires SME or SVE2.1.
+def PselIntrOp : ArmSVE_IntrOp<"psel",
+ /*traits=*/[Pure, TypeIs<"res", SVBool>],
+ /*overloadedOperands=*/[1]>,
+ Arguments<(ins Arg<SVBool, "p1">:$p1,
+ Arg<SVEPredicate, "p2">:$p2,
+ Arg<I32, "index">:$index)>;
+
def WhileLTIntrOp :
ArmSVE_IntrOp<"whilelt",
[TypeIs<"res", SVEPredicate>, Pure],
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index ed4f4cc..10f39a0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -140,6 +140,28 @@ using ConvertFromSvboolOpLowering =
using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
+/// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
+/// but first input (P1) and result predicates need conversion to/from svbool.
+struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
+ auto loc = pselOp.getLoc();
+ auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
+ adaptor.getP1());
+ auto indexI32 = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), pselOp.getIndex());
+ auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
+ pselOp.getP2(), indexI32);
+ rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
+ pselOp, adaptor.getP1().getType(), pselIntr);
+ return success();
+ }
+};
+
/// Converts `vector.create_mask` ops that match the size of an SVE predicate
/// to the `whilelt` intrinsic. This produces more canonical codegen than the
/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
@@ -202,7 +224,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ConvertToSvboolOpLowering,
ConvertFromSvboolOpLowering,
ZipX2OpLowering,
- ZipX4OpLowering>(converter);
+ ZipX4OpLowering,
+ PselOpLowering>(converter);
// Add vector.create_mask conversion with a high benefit as it produces much
// nicer code than the generic lowering.
patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
@@ -229,6 +252,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
ConvertFromSvboolIntrOp,
ZipX2IntrOp,
ZipX4IntrOp,
+ PselIntrOp,
WhileLTIntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
index 1258d35..a021d43 100644
--- a/mlir/test/Dialect/ArmSVE/invalid.mlir
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -64,3 +64,11 @@ func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) {
arm_sve.zip.x4 %a, %a, %a, %a : vector<[5]xf64>
return
}
+
+// -----
+
+func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) {
+ // expected-error@+1 {{op operand #0 must be vector<[1]xi1>, vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, or vector<[16]xi1>, but got 'vector<[7]xi1>'}}
+ arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 3fc5e6e..31d5376 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -239,3 +239,35 @@ func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, v
%2 = vector.create_mask %index : vector<[32]xi1>
return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>
}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_psel_matching_predicate_types(
+// CHECK-SAME: %[[P0:[a-z0-9]+]]: vector<[4]xi1>,
+// CHECK-SAME: %[[P1:[a-z0-9]+]]: vector<[4]xi1>,
+// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: i64
+func.func @arm_sve_psel_matching_predicate_types(%p0: vector<[4]xi1>, %p1: vector<[4]xi1>, %index: index) -> vector<[4]xi1>
+{
+ // CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
+ // CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[4]xi1>) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[4]xi1>
+ %0 = arm_sve.psel %p0, %p1[%index] : vector<[4]xi1>, vector<[4]xi1>
+ return %0 : vector<[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_psel_mixed_predicate_types(
+// CHECK-SAME: %[[P0:[a-z0-9]+]]: vector<[8]xi1>,
+// CHECK-SAME: %[[P1:[a-z0-9]+]]: vector<[16]xi1>,
+// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: i64
+func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[16]xi1>, %index: index) -> vector<[8]xi1>
+{
+ // CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
+ // CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[8]xi1>) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[8]xi1>
+ %0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1>
+ return %0 : vector<[8]xi1>
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index f7b79aa..0f0c5a8 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -225,3 +225,32 @@ func.func @arm_sve_zip_x4(
%a8, %b8, %c8, %d8 = arm_sve.zip.x4 %v8, %v8, %v8, %v8 : vector<[16]xi8>
return
}
+
+// -----
+
+func.func @arm_sve_psel(
+ %p0: vector<[2]xi1>,
+ %p1: vector<[4]xi1>,
+ %p2: vector<[8]xi1>,
+ %p3: vector<[16]xi1>,
+ %index: index
+) {
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[2]xi1>
+ %0 = arm_sve.psel %p0, %p0[%index] : vector<[2]xi1>, vector<[2]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[4]xi1>
+ %1 = arm_sve.psel %p1, %p1[%index] : vector<[4]xi1>, vector<[4]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[8]xi1>
+ %2 = arm_sve.psel %p2, %p2[%index] : vector<[8]xi1>, vector<[8]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[16]xi1>
+ %3 = arm_sve.psel %p3, %p3[%index] : vector<[16]xi1>, vector<[16]xi1>
+ /// Some mixed predicate type examples:
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[4]xi1>
+ %4 = arm_sve.psel %p0, %p1[%index] : vector<[2]xi1>, vector<[4]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[8]xi1>
+ %5 = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[16]xi1>
+ %6 = arm_sve.psel %p2, %p3[%index] : vector<[8]xi1>, vector<[16]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[2]xi1>
+ %7 = arm_sve.psel %p3, %p0[%index] : vector<[16]xi1>, vector<[2]xi1>
+ return
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 34413d4..ed5a1fc 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -371,3 +371,22 @@ llvm.func @arm_sve_whilelt(%base: i64, %n: i64) {
%4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1>
llvm.return
}
+
+// CHECK-LABEL: arm_sve_psel(
+// CHECK-SAME: <vscale x 16 x i1> %[[PN:[0-9]+]],
+// CHECK-SAME: <vscale x 2 x i1> %[[P1:[0-9]+]],
+// CHECK-SAME: <vscale x 4 x i1> %[[P2:[0-9]+]],
+// CHECK-SAME: <vscale x 8 x i1> %[[P3:[0-9]+]],
+// CHECK-SAME: <vscale x 16 x i1> %[[P4:[0-9]+]],
+// CHECK-SAME: i32 %[[INDEX:[0-9]+]])
+llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[4]xi1>, %p3: vector<[8]xi1>, %p4: vector<[16]xi1>, %index: i32) {
+ // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv2i1(<vscale x 16 x i1> %[[PN]], <vscale x 2 x i1> %[[P1]], i32 %[[INDEX]])
+ "arm_sve.intr.psel"(%pn, %p1, %index) : (vector<[16]xi1>, vector<[2]xi1>, i32) -> vector<[16]xi1>
+ // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv4i1(<vscale x 16 x i1> %[[PN]], <vscale x 4 x i1> %[[P2]], i32 %[[INDEX]])
+ "arm_sve.intr.psel"(%pn, %p2, %index) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1>
+ // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv8i1(<vscale x 16 x i1> %[[PN]], <vscale x 8 x i1> %[[P3]], i32 %[[INDEX]])
+ "arm_sve.intr.psel"(%pn, %p3, %index) : (vector<[16]xi1>, vector<[8]xi1>, i32) -> vector<[16]xi1>
+ // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv16i1(<vscale x 16 x i1> %[[PN]], <vscale x 16 x i1> %[[P4]], i32 %[[INDEX]])
+ "arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
+ llvm.return
+}