From d319fc41d0e35bfea8368ad91dc15ab319cddcb7 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Mon, 10 Jun 2024 12:02:16 +0100 Subject: [mlir][ArmSME] Add option to only enable streaming mode for scalable code (#94759) This adds a new option `-enable-arm-streaming=if-contains-scalable-vectors`, which only applies the selected streaming/ZA modes if the function contains scalable vector types. As a NFC this patch also removes the `only-` prefix from the `if-required-by-ops` mode. --- .../mlir/Dialect/ArmSME/Transforms/Passes.h | 3 +- .../mlir/Dialect/ArmSME/Transforms/Passes.td | 10 +++-- .../ArmSME/Transforms/EnableArmStreaming.cpp | 49 +++++++++++++++++----- .../ArmSME/enable-arm-streaming-invalid.mlir | 4 ++ mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir | 17 +++++++- .../CPU/ArmSME/multi-tile-matmul-mixed-types.mlir | 2 +- mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp | 2 +- 7 files changed, 69 insertions(+), 18 deletions(-) create mode 100644 mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir (limited to 'mlir') diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h index 156744b..167e5b7 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h @@ -27,7 +27,8 @@ namespace arm_sme { /// Pass to enable Armv9 Streaming SVE mode. std::unique_ptr createEnableArmStreamingPass( const ArmStreamingMode = ArmStreamingMode::Streaming, - const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false); + const ArmZaMode = ArmZaMode::Disabled, bool ifRequiredByOps = false, + bool ifContainsScalableVectors = false); /// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening /// variants. diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td index 869a031..c1f016d 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td @@ -116,10 +116,14 @@ def EnableArmStreaming "not be used for input and/or output and the " "function must return with ZA unchanged") )}]>, - Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool", + Option<"ifRequiredByOps", "if-required-by-ops", "bool", /*default=*/"false", - "Only apply the selected streaming/ZA modes if the function " - " contains ops that require them."> + "Only apply the selected streaming/ZA modes if the function contains" + " ops that implement the ArmSMETileOpInterface.">, + Option<"ifContainsScalableVectors", "if-contains-scalable-vectors", + "bool", /*default=*/"false", + "Only apply the selected streaming/ZA modes if the function contains" + " operations that use scalable vector types."> ]; let dependentDialects = ["func::FuncDialect"]; } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp index 79a6caf..fb4bb41 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp @@ -58,17 +58,25 @@ constexpr StringLiteral struct EnableArmStreamingPass : public arm_sme::impl::EnableArmStreamingBase { EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode, - bool onlyIfRequiredByOps) { + bool ifRequiredByOps, bool ifContainsScalableVectors) { this->streamingMode = streamingMode; this->zaMode = zaMode; - this->onlyIfRequiredByOps = onlyIfRequiredByOps; + this->ifRequiredByOps = ifRequiredByOps; + this->ifContainsScalableVectors = ifContainsScalableVectors; } void runOnOperation() override { - auto op = getOperation(); + auto function = getOperation(); - if (onlyIfRequiredByOps) { + if (ifRequiredByOps && ifContainsScalableVectors) { + function->emitOpError( + "enable-arm-streaming: `if-required-by-ops` and " + "`if-contains-scalable-vectors` are mutually exclusive"); + return signalPassFailure(); + } + + if (ifRequiredByOps) { bool foundTileOp = false; - op.walk([&](Operation *op) { + function.walk([&](Operation *op) { if (llvm::isa(op)) { foundTileOp = true; return WalkResult::interrupt(); @@ -79,27 +87,46 @@ struct EnableArmStreamingPass return; } - if (op->getAttr(kEnableArmStreamingIgnoreAttr) || + if (ifContainsScalableVectors) { + bool foundScalableVector = false; + auto isScalableVector = [&](Type type) { + if (auto vectorType = dyn_cast(type)) + return vectorType.isScalable(); + return false; + }; + function.walk([&](Operation *op) { + if (llvm::any_of(op->getOperandTypes(), isScalableVector) || + llvm::any_of(op->getResultTypes(), isScalableVector)) { + foundScalableVector = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (!foundScalableVector) + return; + } + + if (function->getAttr(kEnableArmStreamingIgnoreAttr) || streamingMode == ArmStreamingMode::Disabled) return; auto unitAttr = UnitAttr::get(&getContext()); - op->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr); + function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr); // The pass currently only supports enabling ZA when in streaming-mode, but // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth // supporting this later. if (zaMode != ArmZaMode::Disabled) - op->setAttr(stringifyArmZaMode(zaMode), unitAttr); + function->setAttr(stringifyArmZaMode(zaMode), unitAttr); } }; } // namespace std::unique_ptr mlir::arm_sme::createEnableArmStreamingPass( const ArmStreamingMode streamingMode, const ArmZaMode zaMode, - bool onlyIfRequiredByOps) { - return std::make_unique(streamingMode, zaMode, - onlyIfRequiredByOps); + bool ifRequiredByOps, bool ifContainsScalableVectors) { + return std::make_unique( + streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors); } diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir new file mode 100644 index 0000000..da70b63 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir @@ -0,0 +1,4 @@ +// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics + +// expected-error@below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}} +func.func @test() { return } diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir index 6b58d8f..2011802 100644 --- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir +++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir @@ -2,7 +2,8 @@ // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE // RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA -// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED +// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED +// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE // CHECK-LABEL: @arm_streaming // CHECK-SAME: attributes {arm_streaming} @@ -38,3 +39,17 @@ func.func @requires_arm_streaming() { // IF-REQUIRED: @does_not_require_arm_streaming // IF-REQUIRED-NOT: arm_streaming func.func @does_not_require_arm_streaming() { return } + +// IF-SCALABLE-LABEL: @contains_scalable_vectors +// IF-SCALABLE-SAME: attributes {arm_streaming} +func.func @contains_scalable_vectors(%vec: vector<[4]xf32>) -> vector<[4]xf32> { + %0 = arith.addf %vec, %vec : vector<[4]xf32> + return %0 : vector<[4]xf32> +} + +// IF-SCALABLE-LABEL: @no_scalable_vectors +// IF-SCALABLE-NOT: arm_streaming +func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> { + %0 = arith.addf %vec, %vec : vector<4xf32> + return %0 : vector<4xf32> +} diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir index 10ffed2..aabd9d2 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir @@ -4,7 +4,7 @@ // RUN: -arm-sme-vector-legalization -canonicalize -cse \ // RUN: -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \ // RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \ -// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ +// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \ // RUN: -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \ // RUN: -test-lower-to-llvm | \ // RUN: %mcr_aarch64_cmd \ diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp index d3dabaf..a220791 100644 --- a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp +++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp @@ -74,7 +74,7 @@ void buildTestLowerToArmSME(OpPassManager &pm, // Enable streaming-mode and ZA. pm.addPass(arm_sme::createEnableArmStreamingPass( arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA, - /*onlyIfRequiredByOps=*/true)); + /*ifRequiredByOps=*/true)); // Convert SCF to CF (required for ArmSME tile allocation). pm.addPass(createConvertSCFToCFPass()); -- cgit v1.1