aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAman LaChapelle <aman.lachapelle@gmail.com>2023-12-08 11:22:08 -0800
committerGitHub <noreply@github.com>2023-12-08 11:22:08 -0800
commit46708a5bcba28955b2ddeddf5c0e64398223642b (patch)
treeb07c290bc795779e87334c11d661ebf727886978
parent687e63a2bddf7eb9e91d9d718e48942aca47c631 (diff)
downloadllvm-46708a5bcba28955b2ddeddf5c0e64398223642b.zip
llvm-46708a5bcba28955b2ddeddf5c0e64398223642b.tar.gz
llvm-46708a5bcba28955b2ddeddf5c0e64398223642b.tar.bz2
[mlir][Pass] Move PassExecutionAction to Pass.h, NFC. (#74850)
This patch moves PassExecutionAction to Pass.h so that it can be used by the action framework to introspect and intercede in pass managers that might be set up opaquely. This provides for a very particular use case, which essentially involves being able to intercede in a PassManager and skip or apply individual passes. Because of this, this patch also adds a test for this use case to verify that it could in fact work.
-rw-r--r--mlir/include/mlir/Pass/Pass.h47
-rw-r--r--mlir/lib/Pass/Pass.cpp10
-rw-r--r--mlir/lib/Pass/PassDetail.h20
-rw-r--r--mlir/unittests/Pass/CMakeLists.txt1
-rw-r--r--mlir/unittests/Pass/PassManagerTest.cpp100
5 files changed, 158 insertions, 20 deletions
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 5a4df43..121b253 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -9,6 +9,7 @@
#ifndef MLIR_PASS_PASS_H
#define MLIR_PASS_PASS_H
+#include "mlir/IR/Action.h"
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LogicalResult.h"
@@ -457,6 +458,52 @@ protected:
}
};
+/// This class encapsulates the "action" of executing a single pass. This allows
+/// a user of the Action infrastructure to query information about an action in
+/// (for example) a breakpoint context. You could use it like this:
+///
+/// auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+/// if (auto passExec = dyn_cast<PassExecutionAction>(anAction))
+/// record(passExec.getPass());
+/// return ExecutionContext::Apply;
+/// };
+/// ExecutionContext exeCtx(onBreakpoint);
+///
+class PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
+ using Base = tracing::ActionImpl<PassExecutionAction>;
+
+public:
+ /// Define a TypeID for this PassExecutionAction.
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PassExecutionAction)
+ /// Construct a PassExecutionAction. This is called by the OpToOpPassAdaptor
+ /// when it calls `executeAction`.
+ PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass);
+
+ /// The tag required by ActionImpl to identify this action.
+ static constexpr StringLiteral tag = "pass-execution";
+
+ /// Print a textual version of this action to `os`.
+ void print(raw_ostream &os) const override;
+
+ /// Get the pass that will be executed by this action. This is not a class of
+ /// passes, or all instances of a pass kind, this is a single pass.
+ const Pass &getPass() const { return pass; }
+
+ /// Get the operation that is the base of this pass. For example, an
+ /// OperationPass<ModuleOp> would return a ModuleOp.
+ Operation *getOp() const;
+
+public:
+ /// Reference to the pass being run. Notice that this will *not* extend the
+ /// lifetime of the pass, and so this class is therefore unsafe to keep past
+ /// the lifetime of the `executeAction` call.
+ const Pass &pass;
+
+ /// The base op for this pass. For an OperationPass<ModuleOp>, we would have a
+ /// ModuleOp here.
+ Operation *op;
+};
+
} // namespace mlir
#endif // MLIR_PASS_PASS_H
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 658f884..810d6a3 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -36,11 +36,21 @@ using namespace mlir::detail;
// PassExecutionAction
//===----------------------------------------------------------------------===//
+PassExecutionAction::PassExecutionAction(ArrayRef<IRUnit> irUnits,
+ const Pass &pass)
+ : Base(irUnits), pass(pass) {}
+
void PassExecutionAction::print(raw_ostream &os) const {
os << llvm::formatv("`{0}` running `{1}` on Operation `{2}`", tag,
pass.getName(), getOp()->getName());
}
+Operation *PassExecutionAction::getOp() const {
+ ArrayRef<IRUnit> irUnits = getContextIRUnits();
+ return irUnits.empty() ? nullptr
+ : llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
+}
+
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 7276071..0e964b6 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -15,26 +15,6 @@
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
-/// Encapsulate the "action" of executing a single pass, used for the MLIR
-/// tracing infrastructure.
-struct PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
- using Base = tracing::ActionImpl<PassExecutionAction>;
- PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass)
- : Base(irUnits), pass(pass) {}
- static constexpr StringLiteral tag = "pass-execution";
- void print(raw_ostream &os) const override;
- const Pass &getPass() const { return pass; }
- Operation *getOp() const {
- ArrayRef<IRUnit> irUnits = getContextIRUnits();
- return irUnits.empty() ? nullptr
- : llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
- }
-
-public:
- const Pass &pass;
- Operation *op;
-};
-
namespace detail {
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Pass/CMakeLists.txt b/mlir/unittests/Pass/CMakeLists.txt
index 65f0774..802b3bb 100644
--- a/mlir/unittests/Pass/CMakeLists.txt
+++ b/mlir/unittests/Pass/CMakeLists.txt
@@ -5,5 +5,6 @@ add_mlir_unittest(MLIRPassTests
)
target_link_libraries(MLIRPassTests
PRIVATE
+ MLIRDebug
MLIRFuncDialect
MLIRPass)
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 9a30f64..7ceed3b 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Pass/PassManager.h"
+#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
+#include "mlir/Debug/ExecutionContext.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -86,6 +88,104 @@ TEST(PassManagerTest, OpSpecificAnalysis) {
}
}
+/// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`.
+struct AddAttrFunctionPass
+ : public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)
+
+ void runOnOperation() override {
+ func::FuncOp op = getOperation();
+ Builder builder(op->getParentOfType<ModuleOp>());
+ if (op->hasAttr("didProcess"))
+ op->setAttr("didProcessAgain", builder.getUnitAttr());
+
+ // We always want to set this one.
+ op->setAttr("didProcess", builder.getUnitAttr());
+ }
+};
+
+/// Simple pass to annotate a func::FuncOp with a single attribute
+/// `didProcess2`.
+struct AddSecondAttrFunctionPass
+ : public PassWrapper<AddSecondAttrFunctionPass,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass)
+
+ void runOnOperation() override {
+ func::FuncOp op = getOperation();
+ Builder builder(op->getParentOfType<ModuleOp>());
+ op->setAttr("didProcess2", builder.getUnitAttr());
+ }
+};
+
+TEST(PassManagerTest, ExecutionAction) {
+ MLIRContext context;
+ context.loadDialect<func::FuncDialect>();
+ Builder builder(&context);
+
+ // Create a module with 2 functions.
+ OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
+ auto f =
+ func::FuncOp::create(builder.getUnknownLoc(), "process_me_once",
+ builder.getFunctionType(std::nullopt, std::nullopt));
+ f.setPrivate();
+ module->push_back(f);
+
+ // Instantiate our passes.
+ auto pm = PassManager::on<ModuleOp>(&context);
+ auto pass = std::make_unique<AddAttrFunctionPass>();
+ auto *passPtr = pass.get();
+ pm.addNestedPass<func::FuncOp>(std::move(pass));
+ pm.addNestedPass<func::FuncOp>(std::make_unique<AddSecondAttrFunctionPass>());
+ // Duplicate the first pass to ensure that we *only* run the *first* pass, not
+ // all instances of this pass kind. Notice that this pass (and the test as a
+ // whole) are built to ensure that we can run just a single pass out of a
+ // pipeline that may contain duplicates.
+ pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
+
+ // Use the action manager to only hit the first pass, not the second one.
+ auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace)
+ -> tracing::ExecutionContext::Control {
+ // Not a PassExecutionAction, apply the action.
+ auto *passExec = dyn_cast<PassExecutionAction>(&backtrace->getAction());
+ if (!passExec)
+ return tracing::ExecutionContext::Next;
+
+ // If this isn't a function, apply the action.
+ if (!isa<func::FuncOp>(passExec->getOp()))
+ return tracing::ExecutionContext::Next;
+
+ // Only apply the first function pass. Not all instances of the first pass,
+ // only the first pass.
+ if (passExec->getPass().getThreadingSiblingOrThis() == passPtr)
+ return tracing::ExecutionContext::Next;
+
+ // Do not apply any other passes in the pass manager.
+ return tracing::ExecutionContext::Skip;
+ };
+
+ // Set up our breakpoint manager.
+ tracing::TagBreakpointManager simpleManager;
+ tracing::ExecutionContext executionCtx(onBreakpoint);
+ executionCtx.addBreakpointManager(&simpleManager);
+ simpleManager.addBreakpoint(PassExecutionAction::tag);
+
+ // Register the execution context in the MLIRContext.
+ context.registerActionHandler(executionCtx);
+
+ // Run the pass manager, expecting our handler to be called.
+ LogicalResult result = pm.run(module.get());
+ EXPECT_TRUE(succeeded(result));
+
+ // Verify that each function got annotated with `didProcess` and *not*
+ // `didProcess2`.
+ for (func::FuncOp func : module->getOps<func::FuncOp>()) {
+ ASSERT_TRUE(func->getDiscardableAttr("didProcess"));
+ ASSERT_FALSE(func->getDiscardableAttr("didProcess2"));
+ ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain"));
+ }
+}
+
namespace {
struct InvalidPass : Pass {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)