diff options
author | Tom Eccles <tom.eccles@arm.com> | 2025-06-24 17:45:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-24 17:45:10 +0100 |
commit | cc756716cf69a16701f0dfeb583127ea4124533b (patch) | |
tree | 2a76dc203fac3b2e4498ac31356cfd0dc761adc2 | |
parent | f6973baf289abf2eda5bbad41bdce1a80b05f051 (diff) | |
download | llvm-cc756716cf69a16701f0dfeb583127ea4124533b.zip llvm-cc756716cf69a16701f0dfeb583127ea4124533b.tar.gz llvm-cc756716cf69a16701f0dfeb583127ea4124533b.tar.bz2 |
[mlir][NFC] Move LLVM::ModuleTranslation::SaveStack to a shared header (#144897)
This is so that we can re-use the same code in Flang.
-rw-r--r-- | mlir/include/mlir/Support/StateStack.h | 117 | ||||
-rw-r--r-- | mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h | 72 | ||||
-rw-r--r-- | mlir/lib/Support/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Support/StateStack.cpp | 15 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 2 |
6 files changed, 141 insertions, 70 deletions
diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h new file mode 100644 index 0000000..ac70d05 --- /dev/null +++ b/mlir/include/mlir/Support/StateStack.h @@ -0,0 +1,117 @@ +//===- StateStack.h - Utility for storing a stack of state ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines utilities for storing a stack of generic context. +// The context can be arbitrary data, possibly including file-scoped types. Data +// must be derived from StateStackFrameBase and implement MLIR TypeID. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_STACKFRAME_H +#define MLIR_SUPPORT_STACKFRAME_H + +#include "mlir/IR/Visitors.h" +#include "mlir/Support/TypeID.h" +#include <memory> + +namespace mlir { + +/// Common CRTP base class for StateStack frames. +class StateStackFrame { +public: + virtual ~StateStackFrame() = default; + TypeID getTypeID() const { return typeID; } + +protected: + explicit StateStackFrame(TypeID typeID) : typeID(typeID) {} + +private: + const TypeID typeID; + virtual void anchor(); +}; + +/// Concrete CRTP base class for StateStack frames. This is used for keeping a +/// stack of common state useful for recursive IR conversions. For example, when +/// translating operations with regions, users of StateStack can store state on +/// StateStack before entering the region and inspect it when converting +/// operations nested within that region. Users are expected to derive this +/// class and put any relevant information into fields of the derived class. The +/// usual isa/dyn_cast functionality is available for instances of derived +/// classes. +template <typename Derived> +class StateStackFrameBase : public StateStackFrame { +public: + explicit StateStackFrameBase() : StateStackFrame(TypeID::get<Derived>()) {} +}; + +class StateStack { +public: + /// Creates a stack frame of type `T` on StateStack. `T` must + /// be derived from `StackFrameBase<T>` and constructible from the provided + /// arguments. Doing this before entering the region of the op being + /// translated makes the frame available when translating ops within that + /// region. + template <typename T, typename... Args> + void stackPush(Args &&...args) { + static_assert(std::is_base_of<StateStackFrame, T>::value, + "can only push instances of StackFrame on StateStack"); + stack.push_back(std::make_unique<T>(std::forward<Args>(args)...)); + } + + /// Pops the last element from the StateStack. + void stackPop() { stack.pop_back(); } + + /// Calls `callback` for every StateStack frame of type `T` + /// starting from the top of the stack. + template <typename T> + WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) { + static_assert(std::is_base_of<StateStackFrame, T>::value, + "expected T derived from StackFrame"); + if (!callback) + return WalkResult::skip(); + for (std::unique_ptr<StateStackFrame> &frame : llvm::reverse(stack)) { + if (T *ptr = dyn_cast_or_null<T>(frame.get())) { + WalkResult result = callback(*ptr); + if (result.wasInterrupted()) + return result; + } + } + return WalkResult::advance(); + } + +private: + SmallVector<std::unique_ptr<StateStackFrame>> stack; +}; + +/// RAII object calling stackPush/stackPop on construction/destruction. +/// HostClass could be a StateStack or some other class which forwards calls to +/// one. +template <typename T, typename HostClass = StateStack> +struct SaveStateStack { + template <typename... Args> + explicit SaveStateStack(HostClass &host, Args &&...args) : host(host) { + host.template stackPush<T>(std::forward<Args>(args)...); + } + ~SaveStateStack() { host.stackPop(); } + +private: + HostClass &host; +}; + +} // namespace mlir + +namespace llvm { +template <typename T> +struct isa_impl<T, ::mlir::StateStackFrame> { + static inline bool doit(const ::mlir::StateStackFrame &frame) { + return frame.getTypeID() == ::mlir::TypeID::get<T>(); + } +}; +} // namespace llvm + +#endif // MLIR_SUPPORT_STACKFRAME_H diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 0f136c5..79e8bb6 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -18,6 +18,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" +#include "mlir/Support/StateStack.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" @@ -271,33 +272,6 @@ public: /// it if it does not exist. llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name); - /// Common CRTP base class for ModuleTranslation stack frames. - class StackFrame { - public: - virtual ~StackFrame() = default; - TypeID getTypeID() const { return typeID; } - - protected: - explicit StackFrame(TypeID typeID) : typeID(typeID) {} - - private: - const TypeID typeID; - virtual void anchor(); - }; - - /// Concrete CRTP base class for ModuleTranslation stack frames. When - /// translating operations with regions, users of ModuleTranslation can store - /// state on ModuleTranslation stack before entering the region and inspect - /// it when converting operations nested within that region. Users are - /// expected to derive this class and put any relevant information into fields - /// of the derived class. The usual isa/dyn_cast functionality is available - /// for instances of derived classes. - template <typename Derived> - class StackFrameBase : public StackFrame { - public: - explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {} - }; - /// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must /// be derived from `StackFrameBase<T>` and constructible from the provided /// arguments. Doing this before entering the region of the op being @@ -305,46 +279,22 @@ public: /// region. template <typename T, typename... Args> void stackPush(Args &&...args) { - static_assert( - std::is_base_of<StackFrame, T>::value, - "can only push instances of StackFrame on ModuleTranslation stack"); - stack.push_back(std::make_unique<T>(std::forward<Args>(args)...)); + stack.stackPush<T>(std::forward<Args>(args)...); } /// Pops the last element from the ModuleTranslation stack. - void stackPop() { stack.pop_back(); } + void stackPop() { stack.stackPop(); } /// Calls `callback` for every ModuleTranslation stack frame of type `T` /// starting from the top of the stack. template <typename T> WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) { - static_assert(std::is_base_of<StackFrame, T>::value, - "expected T derived from StackFrame"); - if (!callback) - return WalkResult::skip(); - for (std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) { - if (T *ptr = dyn_cast_or_null<T>(frame.get())) { - WalkResult result = callback(*ptr); - if (result.wasInterrupted()) - return result; - } - } - return WalkResult::advance(); + return stack.stackWalk(callback); } /// RAII object calling stackPush/stackPop on construction/destruction. template <typename T> - struct SaveStack { - template <typename... Args> - explicit SaveStack(ModuleTranslation &m, Args &&...args) - : moduleTranslation(m) { - moduleTranslation.stackPush<T>(std::forward<Args>(args)...); - } - ~SaveStack() { moduleTranslation.stackPop(); } - - private: - ModuleTranslation &moduleTranslation; - }; + using SaveStack = SaveStateStack<T, ModuleTranslation>; SymbolTableCollection &symbolTable() { return symbolTableCollection; } @@ -468,7 +418,7 @@ private: /// Stack of user-specified state elements, useful when translating operations /// with regions. - SmallVector<std::unique_ptr<StackFrame>> stack; + StateStack stack; /// A cache for the symbol tables constructed during symbols lookup. SymbolTableCollection symbolTableCollection; @@ -510,14 +460,4 @@ llvm::CallInst *createIntrinsicCall( } // namespace LLVM } // namespace mlir -namespace llvm { -template <typename T> -struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> { - static inline bool - doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) { - return frame.getTypeID() == ::mlir::TypeID::get<T>(); - } -}; -} // namespace llvm - #endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt index 488decd..02b6c69 100644 --- a/mlir/lib/Support/CMakeLists.txt +++ b/mlir/lib/Support/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_library(MLIRSupport FileUtilities.cpp InterfaceSupport.cpp RawOstreamExtras.cpp + StateStack.cpp StorageUniquer.cpp Timing.cpp ToolUtilities.cpp diff --git a/mlir/lib/Support/StateStack.cpp b/mlir/lib/Support/StateStack.cpp new file mode 100644 index 0000000..a9bb3ff --- /dev/null +++ b/mlir/lib/Support/StateStack.cpp @@ -0,0 +1,15 @@ +//===- StateStack.cpp - Utility for storing a stack of state --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/StateStack.h" + +namespace mlir { + +void StateStackFrame::anchor() {} + +} // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 90ce06a..e29e3d8 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -71,7 +71,7 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) { /// ModuleTranslation stack frame for OpenMP operations. This keeps track of the /// insertion points for allocas. class OpenMPAllocaStackFrame - : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> { + : public StateStackFrameBase<OpenMPAllocaStackFrame> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame) @@ -84,7 +84,7 @@ public: /// collapsed canonical loop information corresponding to an \c omp.loop_nest /// operation. class OpenMPLoopInfoStackFrame - : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> { + : public StateStackFrameBase<OpenMPLoopInfoStackFrame> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame) llvm::CanonicalLoopInfo *loopInfo = nullptr; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 3eaa24e..e8ce528 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -2225,8 +2225,6 @@ ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) { return llvmModule->getOrInsertNamedMetadata(name); } -void ModuleTranslation::StackFrame::anchor() {} - static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, StringRef name) { |