diff options
Diffstat (limited to 'mlir/lib')
111 files changed, 6087 insertions, 892 deletions
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp index 509f520..65df355 100644 --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -294,7 +294,34 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) { solver.load<LivenessAnalysis>(symbolTable); LDBG() << "Initializing and running solver"; (void)solver.initializeAndRun(op); - LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName(); + LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName() + << " check on unreachable code now:"; + // The framework doesn't visit operations in dead blocks, so we need to + // explicitly mark them as dead. + op->walk([&](Operation *op) { + if (op->getNumResults() == 0) + return; + for (auto result : llvm::enumerate(op->getResults())) { + if (getLiveness(result.value())) + continue; + LDBG() << "Result: " << result.index() << " of " + << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " has no liveness info (unreachable), mark dead"; + solver.getOrCreateState<Liveness>(result.value()); + } + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + for (auto blockArg : llvm::enumerate(block.getArguments())) { + if (getLiveness(blockArg.value())) + continue; + LDBG() << "Block argument: " << blockArg.index() << " of " + << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " has no liveness info, mark dead"; + solver.getOrCreateState<Liveness>(blockArg.value()); + } + } + } + }); } const Liveness *RunLivenessAnalysis::getLiveness(Value val) { diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index e625f62..13a3e14 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -19,12 +19,15 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #include <optional> using namespace mlir; using namespace mlir::dataflow; +#define DEBUG_TYPE "dataflow" + //===----------------------------------------------------------------------===// // AbstractSparseLattice //===----------------------------------------------------------------------===// @@ -64,22 +67,36 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) { LogicalResult AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { + LDBG() << "Initializing recursively for operation: " << op->getName(); + // Initialize the analysis by visiting every owner of an SSA value (all // operations and blocks). - if (failed(visitOperation(op))) + if (failed(visitOperation(op))) { + LDBG() << "Failed to visit operation: " << op->getName(); return failure(); + } for (Region ®ion : op->getRegions()) { + LDBG() << "Processing region with " << region.getBlocks().size() + << " blocks"; for (Block &block : region) { + LDBG() << "Processing block with " << block.getNumArguments() + << " arguments"; getOrCreate<Executable>(getProgramPointBefore(&block)) ->blockContentSubscribe(this); visitBlock(&block); - for (Operation &op : block) - if (failed(initializeRecursively(&op))) + for (Operation &op : block) { + LDBG() << "Recursively initializing nested operation: " << op.getName(); + if (failed(initializeRecursively(&op))) { + LDBG() << "Failed to initialize nested operation: " << op.getName(); return failure(); + } + } } } + LDBG() << "Successfully completed recursive initialization for operation: " + << op->getName(); return success(); } @@ -409,11 +426,20 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) { LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { + LDBG() << "Visiting operation: " << op->getName() << " with " + << op->getNumOperands() << " operands and " << op->getNumResults() + << " results"; + // If we're in a dead block, bail out. if (op->getBlock() != nullptr && - !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) + !getOrCreate<Executable>(getProgramPointBefore(op->getBlock())) + ->isLive()) { + LDBG() << "Operation is in dead block, bailing out"; return success(); + } + LDBG() << "Creating lattice elements for " << op->getNumOperands() + << " operands and " << op->getNumResults() << " results"; SmallVector<AbstractSparseLattice *> operandLattices = getLatticeElements(op->getOperands()); SmallVector<const AbstractSparseLattice *> resultLattices = @@ -422,11 +448,15 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // Block arguments of region branch operations flow back into the operands // of the parent op if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { + LDBG() << "Processing RegionBranchOpInterface operation"; visitRegionSuccessors(branch, operandLattices); return success(); } if (auto branch = dyn_cast<BranchOpInterface>(op)) { + LDBG() << "Processing BranchOpInterface operation with " + << op->getNumSuccessors() << " successors"; + // Block arguments of successor blocks flow back into our operands. // We remember all operands not forwarded to any block in a BitVector. @@ -463,6 +493,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // For function calls, connect the arguments of the entry blocks to the // operands of the call op that are forwarded to these arguments. if (auto call = dyn_cast<CallOpInterface>(op)) { + LDBG() << "Processing CallOpInterface operation"; Operation *callableOp = call.resolveCallableInTable(&symbolTable); if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) { // Not all operands of a call op forward to arguments. Such operands are @@ -513,6 +544,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // of this op itself and the operands of the terminators of the regions of // this op. if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) { + LDBG() << "Processing RegionBranchTerminatorOpInterface operation"; if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) { visitRegionSuccessorsFromTerminator(terminator, branch); return success(); @@ -520,12 +552,16 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { } if (op->hasTrait<OpTrait::ReturnLike>()) { + LDBG() << "Processing ReturnLike operation"; // Going backwards, the operands of the return are derived from the // results of all CallOps calling this CallableOp. - if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) + if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) { + LDBG() << "Callable parent found, visiting callable operation"; return visitCallableOperation(op, callable, operandLattices); + } } + LDBG() << "Using default visitOperationImpl for operation: " << op->getName(); return visitOperationImpl(op, operandLattices, resultLattices); } diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index f4b02b4..30ce1fb 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -60,7 +60,7 @@ private: AffineExpr localExpr) override { SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); // Update localVarCst. - localVarCst.addLocalFloorDiv(dividend, divisor); + (void)localVarCst.addLocalFloorDiv(dividend, divisor); } LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs, diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 5c4d4d1..0dcdd5b 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1500,12 +1500,13 @@ void IntegerRelation::addBound(BoundType type, ArrayRef<DynamicAPInt> expr, /// respect to a positive constant 'divisor'. Two constraints are added to the /// system to capture equivalence with the floordiv. /// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. -void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend, - const DynamicAPInt &divisor) { +/// Returns the column position of the new local variable. +unsigned IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend, + const DynamicAPInt &divisor) { assert(dividend.size() == getNumCols() && "incorrect dividend size"); assert(divisor > 0 && "positive divisor expected"); - appendVar(VarKind::Local); + unsigned newVar = appendVar(VarKind::Local); SmallVector<DynamicAPInt, 8> dividendCopy(dividend); dividendCopy.insert(dividendCopy.end() - 1, DynamicAPInt(0)); @@ -1513,6 +1514,28 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend, getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2)); addInequality( getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2)); + return newVar; +} + +unsigned IntegerRelation::addLocalModulo(ArrayRef<DynamicAPInt> exprs, + const DynamicAPInt &modulus) { + assert(exprs.size() == getNumCols() && "incorrect exprs size"); + assert(modulus > 0 && "positive modulus expected"); + + /// Add a local variable for q = expr floordiv modulus + addLocalFloorDiv(exprs, modulus); + + /// Add a local var to represent the result + auto resultIndex = appendVar(VarKind::Local); + + SmallVector<DynamicAPInt, 8> exprsCopy(exprs); + /// Insert the two new locals before the constant + /// Add locals that correspond to `q` and `result` to compute + /// 0 = (expr - modulus * q) - result + exprsCopy.insert(exprsCopy.end() - 1, + {DynamicAPInt(-modulus), DynamicAPInt(-1)}); + addEquality(exprsCopy); + return resultIndex; } int IntegerRelation::findEqualityToConstant(unsigned pos, bool symbolic) const { diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index 08290db..51e2007 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -433,7 +433,7 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) { normalizeDiv(divCoeffs, divDenom); domainSimplex.addDivisionVariable(divCoeffs, divDenom); - domainPoly.addLocalFloorDiv(divCoeffs, divDenom); + (void)domainPoly.addLocalFloorDiv(divCoeffs, divDenom); // Update `this` to account for the additional symbol we just added. appendSymbol(); diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 81dada3..4885d62c 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir-c/ExecutionEngine.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace mlir; @@ -125,6 +125,17 @@ NB_MODULE(_mlirExecutionEngine, m) { nb::arg("name"), nb::arg("callback"), "Register `callback` as the runtime symbol `name`.") .def( + "initialize", + [](PyExecutionEngine &executionEngine) { + mlirExecutionEngineInitialize(executionEngine.get()); + }, + "Initialize the ExecutionEngine. Global constructors specified by " + "`llvm.mlir.global_ctors` will be run. One common scenario is that " + "kernel binary compiled from `gpu.module` gets loaded during " + "initialization. Make sure all symbols are resolvable before " + "initialization by calling `register_runtime` or including " + "shared libraries.") + .def( "dump_to_object_file", [](PyExecutionEngine &executionEngine, const std::string &fileName) { mlirExecutionEngineDumpToObjectFile( diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 826a34a..71a051c 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -10,15 +10,19 @@ #define MLIR_BINDINGS_PYTHON_GLOBALS_H #include <optional> +#include <regex> #include <string> +#include <unordered_set> #include <vector> #include "NanobindUtils.h" #include "mlir-c/IR.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/Regex.h" namespace mlir { namespace python { @@ -114,6 +118,39 @@ public: std::optional<nanobind::object> lookupOperationClass(llvm::StringRef operationName); + class TracebackLoc { + public: + bool locTracebacksEnabled(); + + void setLocTracebacksEnabled(bool value); + + size_t locTracebackFramesLimit(); + + void setLocTracebackFramesLimit(size_t value); + + void registerTracebackFileInclusion(const std::string &file); + + void registerTracebackFileExclusion(const std::string &file); + + bool isUserTracebackFilename(llvm::StringRef file); + + static constexpr size_t kMaxFrames = 512; + + private: + nanobind::ft_mutex mutex; + bool locTracebackEnabled_ = false; + size_t locTracebackFramesLimit_ = 10; + std::unordered_set<std::string> userTracebackIncludeFiles; + std::unordered_set<std::string> userTracebackExcludeFiles; + std::regex userTracebackIncludeRegex; + bool rebuildUserTracebackIncludeRegex = false; + std::regex userTracebackExcludeRegex; + bool rebuildUserTracebackExcludeRegex = false; + llvm::StringMap<bool> isUserTracebackFilenameCache; + }; + + TracebackLoc &getTracebackLoc() { return tracebackLoc; } + private: static PyGlobals *instance; @@ -134,6 +171,8 @@ private: /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModules; + + TracebackLoc tracebackLoc; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 5feed95..4b3a06c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -20,11 +20,8 @@ #include "nanobind/nanobind.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" #include <optional> -#include <system_error> -#include <utility> namespace nb = nanobind; using namespace nb::literals; @@ -1523,7 +1520,7 @@ nb::object PyOperation::create(std::string_view name, llvm::ArrayRef<MlirValue> operands, std::optional<nb::dict> attributes, std::optional<std::vector<PyBlock *>> successors, - int regions, DefaultingPyLocation location, + int regions, PyLocation &location, const nb::object &maybeIp, bool inferType) { llvm::SmallVector<MlirType, 4> mlirResults; llvm::SmallVector<MlirBlock, 4> mlirSuccessors; @@ -1627,7 +1624,7 @@ nb::object PyOperation::create(std::string_view name, if (!operation.ptr) throw nb::value_error("Operation creation failed"); PyOperationRef created = - PyOperation::createDetached(location->getContext(), operation); + PyOperation::createDetached(location.getContext(), operation); maybeInsertOperation(created, maybeIp); return created.getObject(); @@ -1937,9 +1934,9 @@ nb::object PyOpView::buildGeneric( std::optional<nb::list> resultTypeList, nb::list operandList, std::optional<nb::dict> attributes, std::optional<std::vector<PyBlock *>> successors, - std::optional<int> regions, DefaultingPyLocation location, + std::optional<int> regions, PyLocation &location, const nb::object &maybeIp) { - PyMlirContextRef context = location->getContext(); + PyMlirContextRef context = location.getContext(); // Class level operation construction metadata. // Operand and result segment specs are either none, which does no @@ -2789,6 +2786,156 @@ private: PyOperationRef operation; }; +// see +// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h + +#ifndef _Py_CAST +#define _Py_CAST(type, expr) ((type)(expr)) +#endif + +// Static inline functions should use _Py_NULL rather than using directly NULL +// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer, +// _Py_NULL is defined as nullptr. +#ifndef _Py_NULL +#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \ + (defined(__cplusplus) && __cplusplus >= 201103) +#define _Py_NULL nullptr +#else +#define _Py_NULL NULL +#endif +#endif + +// Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 + +// bpo-42262 added Py_XNewRef() +#if !defined(Py_XNewRef) +[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) { + Py_XINCREF(obj); + return obj; +} +#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj)) +#endif + +// bpo-42262 added Py_NewRef() +#if !defined(Py_NewRef) +[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) { + Py_INCREF(obj); + return obj; +} +#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj)) +#endif + +#endif // Python 3.10.0a3 + +// Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) + +// bpo-40429 added PyThreadState_GetFrame() +PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) { + assert(tstate != _Py_NULL && "expected tstate != _Py_NULL"); + return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame)); +} + +// bpo-40421 added PyFrame_GetBack() +PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) { + assert(frame != _Py_NULL && "expected frame != _Py_NULL"); + return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back)); +} + +// bpo-40421 added PyFrame_GetCode() +PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) { + assert(frame != _Py_NULL && "expected frame != _Py_NULL"); + assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL"); + return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code)); +} + +#endif // Python 3.9.0b1 + +MlirLocation tracebackToLocation(MlirContext ctx) { + size_t framesLimit = + PyGlobals::get().getTracebackLoc().locTracebackFramesLimit(); + // Use a thread_local here to avoid requiring a large amount of space. + thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames> + frames; + size_t count = 0; + + nb::gil_scoped_acquire acquire; + PyThreadState *tstate = PyThreadState_GET(); + PyFrameObject *next; + PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate); + // In the increment expression: + // 1. get the next prev frame; + // 2. decrement the ref count on the current frame (in order that it can get + // gc'd, along with any objects in its closure and etc); + // 3. set current = next. + for (; pyFrame != nullptr && count < framesLimit; + next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) { + PyCodeObject *code = PyFrame_GetCode(pyFrame); + auto fileNameStr = + nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename)); + llvm::StringRef fileName(fileNameStr); + if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName)) + continue; + + // co_qualname and PyCode_Addr2Location added in py3.11 +#if PY_VERSION_HEX < 0x030B00F0 + std::string name = + nb::cast<std::string>(nb::borrow<nb::str>(code->co_name)); + llvm::StringRef funcName(name); + int startLine = PyFrame_GetLineNumber(pyFrame); + MlirLocation loc = + mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0); +#else + std::string name = + nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname)); + llvm::StringRef funcName(name); + int startLine, startCol, endLine, endCol; + int lasti = PyFrame_GetLasti(pyFrame); + if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine, + &endCol)) { + throw nb::python_error(); + } + MlirLocation loc = mlirLocationFileLineColRangeGet( + ctx, wrap(fileName), startLine, startCol, endLine, endCol); +#endif + + frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc); + ++count; + } + // When the loop breaks (after the last iter), current frame (if non-null) + // is leaked without this. + Py_XDECREF(pyFrame); + + if (count == 0) + return mlirLocationUnknownGet(ctx); + + MlirLocation callee = frames[0]; + assert(!mlirLocationIsNull(callee) && "expected non-null callee location"); + if (count == 1) + return callee; + + MlirLocation caller = frames[count - 1]; + assert(!mlirLocationIsNull(caller) && "expected non-null caller location"); + for (int i = count - 2; i >= 1; i--) + caller = mlirLocationCallSiteGet(frames[i], caller); + + return mlirLocationCallSiteGet(callee, caller); +} + +PyLocation +maybeGetTracebackLocation(const std::optional<PyLocation> &location) { + if (location.has_value()) + return location.value(); + if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled()) + return DefaultingPyLocation::resolve(); + + PyMlirContext &ctx = DefaultingPyMlirContext::resolve(); + MlirLocation mlirLoc = tracebackToLocation(ctx.get()); + PyMlirContextRef ref = PyMlirContext::forContext(ctx.get()); + return {ref, mlirLoc}; +} + } // namespace //------------------------------------------------------------------------------ @@ -3052,10 +3199,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) .def_prop_ro_static( "current", - [](nb::object & /*class*/) { + [](nb::object & /*class*/) -> std::optional<PyLocation *> { auto *loc = PyThreadContextEntry::getDefaultLocation(); if (!loc) - throw nb::value_error("No current Location"); + return std::nullopt; return loc; }, "Gets the Location bound to the current thread or raises ValueError") @@ -3240,8 +3387,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { kModuleParseDocstring) .def_static( "create", - [](DefaultingPyLocation loc) { - MlirModule module = mlirModuleCreateEmpty(loc); + [](const std::optional<PyLocation> &loc) { + PyLocation pyLoc = maybeGetTracebackLocation(loc); + MlirModule module = mlirModuleCreateEmpty(pyLoc.get()); return PyModule::forModule(module).releaseObject(); }, nb::arg("loc").none() = nb::none(), "Creates an empty module") @@ -3442,6 +3590,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { return operation.createOpView(); }, "Detaches the operation from its parent block.") + .def_prop_ro( + "attached", + [](PyOperationBase &self) { + PyOperation &operation = self.getOperation(); + operation.checkValid(); + return operation.isAttached(); + }, + "Reports if the operation is attached to its parent block.") .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) .def("walk", &PyOperationBase::walk, nb::arg("callback"), nb::arg("walk_order") = MlirWalkPostOrder); @@ -3454,8 +3610,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { std::optional<std::vector<PyValue *>> operands, std::optional<nb::dict> attributes, std::optional<std::vector<PyBlock *>> successors, int regions, - DefaultingPyLocation location, const nb::object &maybeIp, - bool inferType) { + const std::optional<PyLocation> &location, + const nb::object &maybeIp, bool inferType) { // Unpack/validate operands. llvm::SmallVector<MlirValue, 4> mlirOperands; if (operands) { @@ -3467,8 +3623,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { } } + PyLocation pyLoc = maybeGetTracebackLocation(location); return PyOperation::create(name, results, mlirOperands, attributes, - successors, regions, location, maybeIp, + successors, regions, pyLoc, maybeIp, inferType); }, nb::arg("name"), nb::arg("results").none() = nb::none(), @@ -3512,12 +3669,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { std::optional<nb::list> resultTypeList, nb::list operandList, std::optional<nb::dict> attributes, std::optional<std::vector<PyBlock *>> successors, - std::optional<int> regions, DefaultingPyLocation location, + std::optional<int> regions, + const std::optional<PyLocation> &location, const nb::object &maybeIp) { + PyLocation pyLoc = maybeGetTracebackLocation(location); new (self) PyOpView(PyOpView::buildGeneric( name, opRegionSpec, operandSegmentSpecObj, resultSegmentSpecObj, resultTypeList, operandList, - attributes, successors, regions, location, maybeIp)); + attributes, successors, regions, pyLoc, maybeIp)); }, nb::arg("name"), nb::arg("opRegionSpec"), nb::arg("operandSegmentSpecObj").none() = nb::none(), @@ -3551,17 +3710,18 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](nb::handle cls, std::optional<nb::list> resultTypeList, nb::list operandList, std::optional<nb::dict> attributes, std::optional<std::vector<PyBlock *>> successors, - std::optional<int> regions, DefaultingPyLocation location, + std::optional<int> regions, std::optional<PyLocation> location, const nb::object &maybeIp) { std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME")); std::tuple<int, bool> opRegionSpec = nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS"); nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS"); + PyLocation pyLoc = maybeGetTracebackLocation(location); return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec, resultSegmentSpec, resultTypeList, operandList, attributes, successors, - regions, location, maybeIp); + regions, pyLoc, maybeIp); }, nb::arg("cls"), nb::arg("results").none() = nb::none(), nb::arg("operands").none() = nb::none(), diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index e600f1b..0de2f17 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -13,9 +13,9 @@ #include "Globals.h" #include "NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Nanobind.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. namespace nb = nanobind; using namespace mlir; @@ -197,3 +197,71 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) { // Not found and loading did not yield a registration. return std::nullopt; } + +bool PyGlobals::TracebackLoc::locTracebacksEnabled() { + nanobind::ft_lock_guard lock(mutex); + return locTracebackEnabled_; +} + +void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) { + nanobind::ft_lock_guard lock(mutex); + locTracebackEnabled_ = value; +} + +size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() { + nanobind::ft_lock_guard lock(mutex); + return locTracebackFramesLimit_; +} + +void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) { + nanobind::ft_lock_guard lock(mutex); + locTracebackFramesLimit_ = std::min(value, kMaxFrames); +} + +void PyGlobals::TracebackLoc::registerTracebackFileInclusion( + const std::string &file) { + nanobind::ft_lock_guard lock(mutex); + auto reg = "^" + llvm::Regex::escape(file); + if (userTracebackIncludeFiles.insert(reg).second) + rebuildUserTracebackIncludeRegex = true; + if (userTracebackExcludeFiles.count(reg)) { + if (userTracebackExcludeFiles.erase(reg)) + rebuildUserTracebackExcludeRegex = true; + } +} + +void PyGlobals::TracebackLoc::registerTracebackFileExclusion( + const std::string &file) { + nanobind::ft_lock_guard lock(mutex); + auto reg = "^" + llvm::Regex::escape(file); + if (userTracebackExcludeFiles.insert(reg).second) + rebuildUserTracebackExcludeRegex = true; + if (userTracebackIncludeFiles.count(reg)) { + if (userTracebackIncludeFiles.erase(reg)) + rebuildUserTracebackIncludeRegex = true; + } +} + +bool PyGlobals::TracebackLoc::isUserTracebackFilename( + const llvm::StringRef file) { + nanobind::ft_lock_guard lock(mutex); + if (rebuildUserTracebackIncludeRegex) { + userTracebackIncludeRegex.assign( + llvm::join(userTracebackIncludeFiles, "|")); + rebuildUserTracebackIncludeRegex = false; + isUserTracebackFilenameCache.clear(); + } + if (rebuildUserTracebackExcludeRegex) { + userTracebackExcludeRegex.assign( + llvm::join(userTracebackExcludeFiles, "|")); + rebuildUserTracebackExcludeRegex = false; + isUserTracebackFilenameCache.clear(); + } + if (!isUserTracebackFilenameCache.contains(file)) { + std::string fileStr = file.str(); + bool include = std::regex_search(fileStr, userTracebackIncludeRegex); + bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex); + isUserTracebackFilenameCache[file] = include || !exclude; + } + return isUserTracebackFilenameCache[file]; +} diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 9c22dea..fa16ae3 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -192,16 +192,6 @@ public: PyMlirContext(const PyMlirContext &) = delete; PyMlirContext(PyMlirContext &&) = delete; - /// For the case of a python __init__ (nanobind::init) method, pybind11 is - /// quite strict about needing to return a pointer that is not yet associated - /// to an nanobind::object. Since the forContext() method acts like a pool, - /// possibly returning a recycled context, it does not satisfy this need. The - /// usual way in python to accomplish such a thing is to override __new__, but - /// that is also not supported by pybind11. Instead, we use this entry - /// point which always constructs a fresh context (which cannot alias an - /// existing one because it is fresh). - static PyMlirContext *createNewContextForInit(); - /// Returns a context reference for the singleton PyMlirContext wrapper for /// the given context. static PyMlirContextRef forContext(MlirContext context); @@ -722,8 +712,7 @@ public: llvm::ArrayRef<MlirValue> operands, std::optional<nanobind::dict> attributes, std::optional<std::vector<PyBlock *>> successors, int regions, - DefaultingPyLocation location, const nanobind::object &ip, - bool inferType); + PyLocation &location, const nanobind::object &ip, bool inferType); /// Creates an OpView suitable for this operation. nanobind::object createOpView(); @@ -781,7 +770,7 @@ public: nanobind::list operandList, std::optional<nanobind::dict> attributes, std::optional<std::vector<PyBlock *>> successors, - std::optional<int> regions, DefaultingPyLocation location, + std::optional<int> regions, PyLocation &location, const nanobind::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 6f49431..278847e 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// - #include "Globals.h" #include "IRModule.h" #include "NanobindUtils.h" @@ -44,7 +43,27 @@ NB_MODULE(_mlir, m) { .def("_register_operation_impl", &PyGlobals::registerOperationImpl, "operation_name"_a, "operation_class"_a, nb::kw_only(), "replace"_a = false, - "Testing hook for directly registering an operation"); + "Testing hook for directly registering an operation") + .def("loc_tracebacks_enabled", + [](PyGlobals &self) { + return self.getTracebackLoc().locTracebacksEnabled(); + }) + .def("set_loc_tracebacks_enabled", + [](PyGlobals &self, bool enabled) { + self.getTracebackLoc().setLocTracebacksEnabled(enabled); + }) + .def("set_loc_tracebacks_frame_limit", + [](PyGlobals &self, int n) { + self.getTracebackLoc().setLocTracebackFramesLimit(n); + }) + .def("register_traceback_file_inclusion", + [](PyGlobals &self, const std::string &filename) { + self.getTracebackLoc().registerTracebackFileInclusion(filename); + }) + .def("register_traceback_file_exclusion", + [](PyGlobals &self, const std::string &filename) { + self.getTracebackLoc().registerTracebackFileExclusion(filename); + }); // Aside from making the globals accessible to python, having python manage // it is necessary to make sure it is destroyed (and releases its python diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 306cebd..2dbb993 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, return wrap(jitOrError->release()); } +extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) { + unwrap(jit)->initialize(); +} + extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) { delete (unwrap(jit)); } @@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, void *sym) { unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { llvm::orc::SymbolMap symbolMap; - symbolMap[interner(unwrap(name))] = - { llvm::orc::ExecutorAddr::fromPtr(sym), - llvm::JITSymbolFlags::Exported }; + symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym), + llvm::JITSymbolFlags::Exported}; return symbolMap; }); } diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 18e857c..cb0c829 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { ConversionPatternRewriter &rewriter) const override; }; +struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using Adaptor = + typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(arith::SelectOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -480,6 +490,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, } //===----------------------------------------------------------------------===// +// SelectOpOneToNLowering +//===----------------------------------------------------------------------===// + +/// Pattern for arith.select where the true/false values lower to multiple +/// SSA values (1:N conversion). This pattern generates multiple arith.select +/// than can be lowered by the 1:1 arith.select pattern. +LogicalResult SelectOpOneToNLowering::matchAndRewrite( + arith::SelectOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // In case of a 1:1 conversion, the 1:1 pattern will match. + if (llvm::hasSingleElement(adaptor.getTrueValue())) + return rewriter.notifyMatchFailure( + op, "not a 1:N conversion, 1:1 pattern will match"); + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure(op, + "non-i1 conditions are not supported"); + SmallVector<Value> results; + for (auto [trueValue, falseValue] : + llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue())) + results.push_back(arith::SelectOp::create( + rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue)); + rewriter.replaceOpWithMultiple(op, {results}); + return success(); +} + +//===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns( RemSIOpLowering, RemUIOpLowering, SelectOpLowering, + SelectOpOneToNLowering, ShLIOpLowering, ShRSIOpLowering, ShRUIOpLowering, diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 785cb82..171f716 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -68,6 +68,7 @@ add_subdirectory(TosaToSCF) add_subdirectory(TosaToTensor) add_subdirectory(UBToLLVM) add_subdirectory(UBToSPIRV) +add_subdirectory(VectorToAMX) add_subdirectory(VectorToArmSME) add_subdirectory(VectorToGPU) add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 35ad99c..b3d6d59 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -64,14 +64,6 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_cabs_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>( patterns.getContext(), "__ocml_cabs_f64"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>( - patterns.getContext(), "__ocml_carg_f32"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>( - patterns.getContext(), "__ocml_carg_f64"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>( - patterns.getContext(), "__ocml_conj_f32"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>( - patterns.getContext(), "__ocml_conj_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>( patterns.getContext(), "__ocml_ccos_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>( @@ -84,10 +76,6 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_clog_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>( patterns.getContext(), "__ocml_clog_f64"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>( - patterns.getContext(), "__ocml_cpow_f32"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>( - patterns.getContext(), "__ocml_cpow_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>( patterns.getContext(), "__ocml_csin_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>( @@ -122,9 +110,8 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<func::FuncDialect>(); - target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp, - complex::CosOp, complex::ExpOp, complex::LogOp, - complex::PowOp, complex::SinOp, complex::SqrtOp, + target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp, + complex::LogOp, complex::SinOp, complex::SqrtOp, complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index ff6d369..798d8b0 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -125,22 +125,33 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter, return rewriter.applySignatureConversion(block, *conversion, converter); } +/// Flatten the given value ranges into a single vector of values. +static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { + SmallVector<Value> result; + for (const ValueRange &vals : values) + llvm::append_range(result, vals); + return result; +} + /// Convert the destination block signature (if necessary) and lower the branch /// op to llvm.br. struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; + using Adaptor = + typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor; LogicalResult - matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, + matchAndRewrite(cf::BranchOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands()); FailureOr<Block *> convertedBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), - TypeRange(adaptor.getOperands())); + TypeRange(ValueRange(flattenedAdaptor))); if (failed(convertedBlock)) return failure(); DictionaryAttr attrs = op->getAttrDictionary(); Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( - op, adaptor.getOperands(), *convertedBlock); + op, flattenedAdaptor, *convertedBlock); // TODO: We should not just forward all attributes like that. But there are // existing Flang tests that depend on this behavior. newOp->setAttrs(attrs); @@ -152,29 +163,37 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { /// branch op to llvm.cond_br. struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; + using Adaptor = + typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor; LogicalResult - matchAndRewrite(cf::CondBranchOp op, - typename cf::CondBranchOp::Adaptor adaptor, + matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector<Value> flattenedAdaptorTrue = + flattenValues(adaptor.getTrueDestOperands()); + SmallVector<Value> flattenedAdaptorFalse = + flattenValues(adaptor.getFalseDestOperands()); + if (!llvm::hasSingleElement(adaptor.getCondition())) + return rewriter.notifyMatchFailure(op, + "expected single element condition"); FailureOr<Block *> convertedTrueBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), - TypeRange(adaptor.getTrueDestOperands())); + TypeRange(ValueRange(flattenedAdaptorTrue))); if (failed(convertedTrueBlock)) return failure(); FailureOr<Block *> convertedFalseBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), - TypeRange(adaptor.getFalseDestOperands())); + TypeRange(ValueRange(flattenedAdaptorFalse))); if (failed(convertedFalseBlock)) return failure(); - DictionaryAttr attrs = op->getAttrDictionary(); + DictionaryAttr attrs = op->getDiscardableAttrDictionary(); auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( - op, adaptor.getCondition(), adaptor.getTrueDestOperands(), - adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(), + op, llvm::getSingleElement(adaptor.getCondition()), + flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(), *convertedTrueBlock, *convertedFalseBlock); // TODO: We should not just forward all attributes like that. But there are // existing Flang tests that depend on this behavior. - newOp->setAttrs(attrs); + newOp->setDiscardableAttrs(attrs); return success(); } }; diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp index ed5d6d4..cdb7150 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -31,7 +31,8 @@ namespace { class ConvertToLLVMPassInterface { public: ConvertToLLVMPassInterface(MLIRContext *context, - ArrayRef<std::string> filterDialects); + ArrayRef<std::string> filterDialects, + bool allowPatternRollback = true); virtual ~ConvertToLLVMPassInterface() = default; /// Get the dependent dialects used by `convert-to-llvm`. @@ -60,6 +61,9 @@ protected: MLIRContext *context; /// List of dialects names to use as filters. ArrayRef<std::string> filterDialects; + /// An experimental flag to disallow pattern rollback. This is more efficient + /// but not supported by all lowering patterns. + bool allowPatternRollback; }; /// This DialectExtension can be attached to the context, which will invoke the @@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface { /// Apply the conversion driver. LogicalResult transform(Operation *op, AnalysisManager manager) const final { - if (failed(applyPartialConversion(op, *target, *patterns))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, *target, *patterns, config))) return failure(); return success(); } @@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface { patterns); // Apply the conversion. - if (failed(applyPartialConversion(op, target, std::move(patterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, target, std::move(patterns), config))) return failure(); return success(); } @@ -206,9 +214,11 @@ public: std::shared_ptr<ConvertToLLVMPassInterface> impl; // Choose the pass implementation. if (useDynamic) - impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects); + impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects, + allowPatternRollback); else - impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects); + impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects, + allowPatternRollback); if (failed(impl->initialize())) return failure(); this->impl = impl; @@ -228,8 +238,10 @@ public: //===----------------------------------------------------------------------===// ConvertToLLVMPassInterface::ConvertToLLVMPassInterface( - MLIRContext *context, ArrayRef<std::string> filterDialects) - : context(context), filterDialects(filterDialects) {} + MLIRContext *context, ArrayRef<std::string> filterDialects, + bool allowPatternRollback) + : context(context), filterDialects(filterDialects), + allowPatternRollback(allowPatternRollback) {} void ConvertToLLVMPassInterface::getDependentDialects( DialectRegistry ®istry) { diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 67bb1c1..42c76ed 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -527,19 +527,21 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern; using Super = CallOpInterfaceLowering<CallOpType>; using Base = ConvertOpToLLVMPattern<CallOpType>; + using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor; - LogicalResult matchAndRewriteImpl(CallOpType callOp, - typename CallOpType::Adaptor adaptor, + LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor, ConversionPatternRewriter &rewriter, bool useBarePtrCallConv = false) const { // Pack the result types into a struct. Type packedResult = nullptr; + SmallVector<SmallVector<Type>> groupedResultTypes; unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); - + int64_t numConvertedTypes = 0; if (numResults != 0) { if (!(packedResult = this->getTypeConverter()->packFunctionResults( - resultTypes, useBarePtrCallConv))) + resultTypes, useBarePtrCallConv, &groupedResultTypes, + &numConvertedTypes))) return failure(); } @@ -565,34 +567,64 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { static_cast<int32_t>(promoted.size()), 0}; newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); - SmallVector<Value, 4> results; - if (numResults < 2) { - // If < 2 results, packing did not do anything and we can just return. - results.append(newOp.result_begin(), newOp.result_end()); - } else { - // Otherwise, it had been converted to an operation producing a structure. - // Extract individual results from the structure and return them as list. - results.reserve(numResults); - for (unsigned i = 0; i < numResults; ++i) { - results.push_back(LLVM::ExtractValueOp::create( - rewriter, callOp.getLoc(), newOp->getResult(0), i)); + // Helper function that extracts an individual result from the return value + // of the new call op. llvm.call ops support only 0 or 1 result. In case of + // 2 or more results, the results are packed into a structure. + // + // The new call op may have more than 2 results because: + // a. The original call op has more than 2 results. + // b. An original op result type-converted to more than 1 result. + auto getUnpackedResult = [&](unsigned i) -> Value { + assert(numConvertedTypes > 0 && "convert op has no results"); + if (numConvertedTypes == 1) { + assert(i == 0 && "out of bounds: converted op has only one result"); + return newOp->getResult(0); } + // Results have been converted to a structure. Extract individual results + // from the structure. + return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(), + newOp->getResult(0), i); + }; + + // Group the results into a vector of vectors, such that it is clear which + // original op result is replaced with which range of values. (In case of a + // 1:N conversion, there can be multiple replacements for a single result.) + SmallVector<SmallVector<Value>> results; + results.reserve(numResults); + unsigned counter = 0; + for (unsigned i = 0; i < numResults; ++i) { + SmallVector<Value> &group = results.emplace_back(); + for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j) + group.push_back(getUnpackedResult(counter++)); } - if (useBarePtrCallConv) { - // For the bare-ptr calling convention, promote memref results to - // descriptors. - assert(results.size() == resultTypes.size() && - "The number of arguments and types doesn't match"); - this->getTypeConverter()->promoteBarePtrsToDescriptors( - rewriter, callOp.getLoc(), resultTypes, results); - } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(), - resultTypes, results, - /*toDynamic=*/false))) { - return failure(); + // Special handling for MemRef types. + for (unsigned i = 0; i < numResults; ++i) { + Type origType = resultTypes[i]; + auto memrefType = dyn_cast<MemRefType>(origType); + auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType); + if (useBarePtrCallConv && memrefType) { + // For the bare-ptr calling convention, promote memref results to + // descriptors. + assert(results[i].size() == 1 && "expected one converted result"); + results[i].front() = MemRefDescriptor::fromStaticShape( + rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType, + results[i].front()); + } + if (unrankedMemrefType) { + assert(!useBarePtrCallConv && "unranked memref is not supported in the " + "bare-ptr calling convention"); + assert(results[i].size() == 1 && "expected one converted result"); + Value desc = this->copyUnrankedDescriptor( + rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(), + /*toDynamic=*/false); + if (!desc) + return failure(); + results[i].front() = desc; + } } - rewriter.replaceOp(callOp, results); + rewriter.replaceOpWithMultiple(callOp, results); return success(); } }; @@ -606,7 +638,7 @@ public: symbolTables(symbolTables) {} LogicalResult - matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, + matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { bool useBarePtrCallConv = false; if (getTypeConverter()->getOptions().useBarePtrCallConv) { @@ -636,7 +668,7 @@ struct CallIndirectOpLowering using Super::Super; LogicalResult - matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor, + matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter); } @@ -679,41 +711,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> { using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - unsigned numArguments = op.getNumOperands(); SmallVector<Value, 4> updatedOperands; auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); bool useBarePtrCallConv = shouldUseBarePtrCallConv(funcOp, this->getTypeConverter()); - if (useBarePtrCallConv) { - // For the bare-ptr calling convention, extract the aligned pointer to - // be returned from the memref descriptor. - for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { - Type oldTy = std::get<0>(it).getType(); - Value newOperand = std::get<1>(it); - if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr( - cast<BaseMemRefType>(oldTy))) { - MemRefDescriptor memrefDesc(newOperand); - newOperand = memrefDesc.allocatedPtr(rewriter, loc); - } else if (isa<UnrankedMemRefType>(oldTy)) { + + for (auto [oldOperand, newOperands] : + llvm::zip_equal(op->getOperands(), adaptor.getOperands())) { + Type oldTy = oldOperand.getType(); + if (auto memRefType = dyn_cast<MemRefType>(oldTy)) { + assert(newOperands.size() == 1 && "expected one converted result"); + if (useBarePtrCallConv && + getTypeConverter()->canConvertToBarePtr(memRefType)) { + // For the bare-ptr calling convention, extract the aligned pointer to + // be returned from the memref descriptor. + MemRefDescriptor memrefDesc(newOperands.front()); + updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc)); + continue; + } + } else if (auto unrankedMemRefType = + dyn_cast<UnrankedMemRefType>(oldTy)) { + assert(newOperands.size() == 1 && "expected one converted result"); + if (useBarePtrCallConv) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); } - updatedOperands.push_back(newOperand); + Value updatedDesc = + copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType, + newOperands.front(), /*toDynamic=*/true); + if (!updatedDesc) + return failure(); + updatedOperands.push_back(updatedDesc); + continue; } - } else { - updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); - (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), - updatedOperands, - /*toDynamic=*/true); + + llvm::append_range(updatedOperands, newOperands); } // If ReturnOp has 0 or 1 operand, create it and return immediately. - if (numArguments <= 1) { + if (updatedOperands.size() <= 1) { rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( op, TypeRange(), updatedOperands, op->getAttrs()); return success(); diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index d22364e..e6fbcf9 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -79,17 +79,30 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { return canBeBare; } -static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, - const unsigned indexBitwidth) { +static Value getLaneId(RewriterBase &rewriter, Location loc) { auto int32Type = IntegerType::get(rewriter.getContext(), 32); Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32); - Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, - ValueRange{minus1, zero}); - Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, - ValueRange{minus1, mbcntLo}); + NamedAttribute noundef = rewriter.getNamedAttr( + LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr()); + NamedAttribute lowRange = rewriter.getNamedAttr( + LLVM::LLVMDialect::getRangeAttrName(), + LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32), + APInt(32, 32))); + NamedAttribute highRange = rewriter.getNamedAttr( + LLVM::LLVMDialect::getRangeAttrName(), + LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32), + APInt(32, 64))); + Value mbcntLo = ROCDL::MbcntLoOp::create( + rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{}, + /*res_attrs=*/ + rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange}))); + Value laneId = ROCDL::MbcntHiOp::create( + rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{}, + rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange}))); return laneId; } + static constexpr StringLiteral amdgcnDataLayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32" "-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:" @@ -104,18 +117,16 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> { LogicalResult matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); + Location loc = op.getLoc(); MLIRContext *context = rewriter.getContext(); - // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0) - // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo) - - Type intTy = IntegerType::get(context, 32); - Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); - Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32); - Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, intTy, - ValueRange{minus1, zero}); - Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy, - ValueRange{minus1, mbcntLo}); + // convert to: + // %mlo = call noundef range(i32 0, 32) + // @llvm.amdgcn.mbcnt.lo(-1, 0) + // followed by: + // %lid = call noundef range(i32 0, 64) + // @llvm.amdgcn.mbcnt.hi(-1, %mlo) + + Value laneId = getLaneId(rewriter, loc); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); @@ -185,8 +196,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { Location loc = op->getLoc(); Value initShflValue = adaptor.getValue(); - const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); - Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth); + Value srcLaneId = getLaneId(rewriter, loc); auto int32Type = IntegerType::get(rewriter.getContext(), 32); Value width = adaptor.getWidth(); diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp index fce7a3f..522e914 100644 --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, results.push_back(d.memRefDescPtr(builder, loc)); } -void UnrankedMemRefDescriptor::computeSizes( +Value UnrankedMemRefDescriptor::computeSize( OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, - ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces, - SmallVectorImpl<Value> &sizes) { - if (values.empty()) - return; - assert(values.size() == addressSpaces.size() && - "must provide address space for each descriptor"); + UnrankedMemRefDescriptor desc, unsigned addressSpace) { // Cache the index type. Type indexType = typeConverter.getIndexType(); @@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes( builder, loc, indexType, llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8)); - sizes.reserve(sizes.size() + values.size()); - for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) { - // Emit IR computing the memory necessary to store the descriptor. This - // assumes the descriptor to be - // { type*, type*, index, index[rank], index[rank] } - // and densely packed, so the total size is - // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). - // TODO: consider including the actual size (including eventual padding due - // to data layout) into the unranked descriptor. - Value pointerSize = createIndexAttrConstant( - builder, loc, indexType, - llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8)); - Value doublePointerSize = - LLVM::MulOp::create(builder, loc, indexType, two, pointerSize); - - // (1 + 2 * rank) * sizeof(index) - Value rank = desc.rank(builder, loc); - Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank); - Value doubleRankIncremented = - LLVM::AddOp::create(builder, loc, indexType, doubleRank, one); - Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType, - doubleRankIncremented, indexSize); - - // Total allocation size. - Value allocationSize = LLVM::AddOp::create( - builder, loc, indexType, doublePointerSize, rankIndexSize); - sizes.push_back(allocationSize); - } + // Emit IR computing the memory necessary to store the descriptor. This + // assumes the descriptor to be + // { type*, type*, index, index[rank], index[rank] } + // and densely packed, so the total size is + // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). + // TODO: consider including the actual size (including eventual padding due + // to data layout) into the unranked descriptor. + Value pointerSize = createIndexAttrConstant( + builder, loc, indexType, + llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8)); + Value doublePointerSize = + LLVM::MulOp::create(builder, loc, indexType, two, pointerSize); + + // (1 + 2 * rank) * sizeof(index) + Value rank = desc.rank(builder, loc); + Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank); + Value doubleRankIncremented = + LLVM::AddOp::create(builder, loc, indexType, doubleRank, one); + Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType, + doubleRankIncremented, indexSize); + + // Total allocation size. + Value allocationSize = LLVM::AddOp::create(builder, loc, indexType, + doublePointerSize, rankIndexSize); + return allocationSize; } Value UnrankedMemRefDescriptor::allocatedPtr( diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 2568044..48a0319 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -216,34 +216,14 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( return memRefDescriptor; } -LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( - OpBuilder &builder, Location loc, TypeRange origTypes, - SmallVectorImpl<Value> &operands, bool toDynamic) const { - assert(origTypes.size() == operands.size() && - "expected as may original types as operands"); - - // Find operands of unranked memref type and store them. - SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs; - SmallVector<unsigned> unrankedAddressSpaces; - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) { - unrankedMemrefs.emplace_back(operands[i]); - FailureOr<unsigned> addressSpace = - getTypeConverter()->getMemRefAddressSpace(memRefType); - if (failed(addressSpace)) - return failure(); - unrankedAddressSpaces.emplace_back(*addressSpace); - } - } - - if (unrankedMemrefs.empty()) - return success(); - - // Compute allocation sizes. - SmallVector<Value> sizes; - UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), - unrankedMemrefs, unrankedAddressSpaces, - sizes); +Value ConvertToLLVMPattern::copyUnrankedDescriptor( + OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, + Value operand, bool toDynamic) const { + // Convert memory space. + FailureOr<unsigned> addressSpace = + getTypeConverter()->getMemRefAddressSpace(memRefType); + if (failed(addressSpace)) + return {}; // Get frequently used types. Type indexType = getTypeConverter()->getIndexType(); @@ -254,52 +234,61 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( if (toDynamic) { mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType); if (failed(mallocFunc)) - return failure(); + return {}; } if (!toDynamic) { freeFunc = LLVM::lookupOrCreateFreeFn(builder, module); if (failed(freeFunc)) - return failure(); + return {}; } - unsigned unrankedMemrefPos = 0; - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - Type type = origTypes[i]; - if (!isa<UnrankedMemRefType>(type)) - continue; - Value allocationSize = sizes[unrankedMemrefPos++]; - UnrankedMemRefDescriptor desc(operands[i]); - - // Allocate memory, copy, and free the source if necessary. - Value memory = - toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(), - allocationSize) - .getResult() - : LLVM::AllocaOp::create(builder, loc, getPtrType(), - IntegerType::get(getContext(), 8), - allocationSize, - /*alignment=*/0); - Value source = desc.memRefDescPtr(builder, loc); - LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); - if (!toDynamic) - LLVM::CallOp::create(builder, loc, freeFunc.value(), source); - - // Create a new descriptor. The same descriptor can be returned multiple - // times, attempting to modify its pointer can lead to memory leaks - // (allocated twice and overwritten) or double frees (the caller does not - // know if the descriptor points to the same memory). - Type descriptorType = getTypeConverter()->convertType(type); - if (!descriptorType) - return failure(); - auto updatedDesc = - UnrankedMemRefDescriptor::poison(builder, loc, descriptorType); - Value rank = desc.rank(builder, loc); - updatedDesc.setRank(builder, loc, rank); - updatedDesc.setMemRefDescPtr(builder, loc, memory); + UnrankedMemRefDescriptor desc(operand); + Value allocationSize = UnrankedMemRefDescriptor::computeSize( + builder, loc, *getTypeConverter(), desc, *addressSpace); + + // Allocate memory, copy, and free the source if necessary. + Value memory = toDynamic + ? LLVM::CallOp::create(builder, loc, mallocFunc.value(), + allocationSize) + .getResult() + : LLVM::AllocaOp::create(builder, loc, getPtrType(), + IntegerType::get(getContext(), 8), + allocationSize, + /*alignment=*/0); + Value source = desc.memRefDescPtr(builder, loc); + LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); + if (!toDynamic) + LLVM::CallOp::create(builder, loc, freeFunc.value(), source); + + // Create a new descriptor. The same descriptor can be returned multiple + // times, attempting to modify its pointer can lead to memory leaks + // (allocated twice and overwritten) or double frees (the caller does not + // know if the descriptor points to the same memory). + Type descriptorType = getTypeConverter()->convertType(memRefType); + if (!descriptorType) + return {}; + auto updatedDesc = + UnrankedMemRefDescriptor::poison(builder, loc, descriptorType); + Value rank = desc.rank(builder, loc); + updatedDesc.setRank(builder, loc, rank); + updatedDesc.setMemRefDescPtr(builder, loc, memory); + return updatedDesc; +} - operands[i] = updatedDesc; +LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( + OpBuilder &builder, Location loc, TypeRange origTypes, + SmallVectorImpl<Value> &operands, bool toDynamic) const { + assert(origTypes.size() == operands.size() && + "expected as may original types as operands"); + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) { + Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType, + operands[i], toDynamic); + if (!updatedDesc) + return failure(); + operands[i] = updatedDesc; + } } - return success(); } diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 1a9bf56..cb9dea1 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -365,6 +365,7 @@ Type LLVMTypeConverter::convertFunctionSignatureImpl( useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv; auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter : structFuncArgTypeConverter; + // Convert argument types one by one and check for errors. for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) { SmallVector<Type, 8> converted; @@ -658,27 +659,19 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const { /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. -Type LLVMTypeConverter::convertCallingConventionType( - Type type, bool useBarePtrCallConv) const { - if (useBarePtrCallConv) - if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) - return convertMemRefToBarePtr(memrefTy); - - return convertType(type); -} +LogicalResult LLVMTypeConverter::convertCallingConventionType( + Type type, SmallVectorImpl<Type> &result, bool useBarePtrCallConv) const { + if (useBarePtrCallConv) { + if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) { + Type converted = convertMemRefToBarePtr(memrefTy); + if (!converted) + return failure(); + result.push_back(converted); + return success(); + } + } -/// Promote the bare pointers in 'values' that resulted from memrefs to -/// descriptors. 'stdTypes' holds they types of 'values' before the conversion -/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). -void LLVMTypeConverter::promoteBarePtrsToDescriptors( - ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes, - SmallVectorImpl<Value> &values) const { - assert(stdTypes.size() == values.size() && - "The number of types and values doesn't match"); - for (unsigned i = 0, end = values.size(); i < end; ++i) - if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i])) - values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, - memrefTy, values[i]); + return convertType(type, result); } /// Convert a non-empty list of types of values produced by an operation into an @@ -706,23 +699,35 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) const { /// LLVM-compatible type. In particular, if more than one value is returned, /// create an LLVM dialect structure type with elements that correspond to each /// of the types converted with `convertCallingConventionType`. -Type LLVMTypeConverter::packFunctionResults(TypeRange types, - bool useBarePtrCallConv) const { +Type LLVMTypeConverter::packFunctionResults( + TypeRange types, bool useBarePtrCallConv, + SmallVector<SmallVector<Type>> *groupedTypes, + int64_t *numConvertedTypes) const { assert(!types.empty() && "expected non-empty list of type"); + assert((!groupedTypes || groupedTypes->empty()) && + "expected groupedTypes to be empty"); useBarePtrCallConv |= options.useBarePtrCallConv; - if (types.size() == 1) - return convertCallingConventionType(types.front(), useBarePtrCallConv); - SmallVector<Type> resultTypes; resultTypes.reserve(types.size()); + size_t sizeBefore = 0; for (auto t : types) { - auto converted = convertCallingConventionType(t, useBarePtrCallConv); - if (!converted || !LLVM::isCompatibleType(converted)) + if (failed( + convertCallingConventionType(t, resultTypes, useBarePtrCallConv))) return {}; - resultTypes.push_back(converted); + if (groupedTypes) { + SmallVector<Type> &group = groupedTypes->emplace_back(); + llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore)); + } + sizeBefore = resultTypes.size(); } + if (numConvertedTypes) + *numConvertedTypes = resultTypes.size(); + if (resultTypes.size() == 1) + return resultTypes.front(); + if (resultTypes.empty()) + return {}; return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); } @@ -740,40 +745,50 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, return allocated; } -SmallVector<Value, 4> -LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, - ValueRange operands, OpBuilder &builder, - bool useBarePtrCallConv) const { +SmallVector<Value, 4> LLVMTypeConverter::promoteOperands( + Location loc, ValueRange opOperands, ValueRange adaptorOperands, + OpBuilder &builder, bool useBarePtrCallConv) const { + SmallVector<ValueRange> ranges; + for (size_t i = 0, e = adaptorOperands.size(); i < e; i++) + ranges.push_back(adaptorOperands.slice(i, 1)); + return promoteOperands(loc, opOperands, ranges, builder, useBarePtrCallConv); +} + +SmallVector<Value, 4> LLVMTypeConverter::promoteOperands( + Location loc, ValueRange opOperands, ArrayRef<ValueRange> adaptorOperands, + OpBuilder &builder, bool useBarePtrCallConv) const { SmallVector<Value, 4> promotedOperands; - promotedOperands.reserve(operands.size()); + promotedOperands.reserve(adaptorOperands.size()); useBarePtrCallConv |= options.useBarePtrCallConv; - for (auto it : llvm::zip(opOperands, operands)) { - auto operand = std::get<0>(it); - auto llvmOperand = std::get<1>(it); - + for (auto [operand, llvmOperand] : + llvm::zip_equal(opOperands, adaptorOperands)) { if (useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. if (isa<MemRefType>(operand.getType())) { - MemRefDescriptor desc(llvmOperand); - llvmOperand = desc.alignedPtr(builder, loc); + assert(llvmOperand.size() == 1 && "Expected a single operand"); + MemRefDescriptor desc(llvmOperand.front()); + promotedOperands.push_back(desc.alignedPtr(builder, loc)); + continue; } else if (isa<UnrankedMemRefType>(operand.getType())) { llvm_unreachable("Unranked memrefs are not supported"); } } else { if (isa<UnrankedMemRefType>(operand.getType())) { - UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, + assert(llvmOperand.size() == 1 && "Expected a single operand"); + UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand.front(), promotedOperands); continue; } if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) { - MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType, + assert(llvmOperand.size() == 1 && "Expected a single operand"); + MemRefDescriptor::unpack(builder, loc, llvmOperand.front(), memrefType, promotedOperands); continue; } } - promotedOperands.push_back(llvmOperand); + llvm::append_range(promotedOperands, llvmOperand); } return promotedOperands; } @@ -802,11 +817,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, result.append(converted.begin(), converted.end()); return success(); } - auto converted = converter.convertType(type); - if (!converted) - return failure(); - result.push_back(converted); - return success(); + return converter.convertType(type, result); } /// Callback to convert function argument types. It converts MemRef function @@ -814,11 +825,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, LogicalResult mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl<Type> &result) { - auto llvmTy = converter.convertCallingConventionType( - type, /*useBarePointerCallConv=*/true); - if (!llvmTy) - return failure(); - - result.push_back(llvmTy); - return success(); + return converter.convertCallingConventionType( + type, result, + /*useBarePointerCallConv=*/true); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 6bd0e2d..2b7bdc9 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -17,11 +17,13 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" #include <cstdint> +#include <numeric> using namespace mlir; @@ -97,6 +99,48 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, + OpBuilder &builder) { + assert(isMemRefTypeLegalForEmitC(memrefType) && + "incompatible memref type for EmitC conversion"); + emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create( + builder, loc, emitc::SizeTType::get(builder.getContext()), + builder.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(builder.getContext(), + {TypeAttr::get(memrefType.getElementType())})); + + IndexType indexType = builder.getIndexType(); + int64_t numElements = std::accumulate(memrefType.getShape().begin(), + memrefType.getShape().end(), int64_t{1}, + std::multiplies<int64_t>()); + emitc::ConstantOp numElementsValue = emitc::ConstantOp::create( + builder, loc, indexType, builder.getIndexAttr(numElements)); + + Type sizeTType = emitc::SizeTType::get(builder.getContext()); + emitc::MulOp totalSizeBytes = emitc::MulOp::create( + builder, loc, sizeTType, elementSize.getResult(0), numElementsValue); + + return totalSizeBytes.getResult(); +} + +static emitc::ApplyOp +createPointerFromEmitcArray(Location loc, OpBuilder &builder, + TypedValue<emitc::ArrayType> arrayValue) { + + emitc::ConstantOp zeroIndex = emitc::ConstantOp::create( + builder, loc, builder.getIndexType(), builder.getIndexAttr(0)); + + emitc::ArrayType arrayType = arrayValue.getType(); + llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex); + emitc::SubscriptOp subPtr = + emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices)); + emitc::ApplyOp ptr = emitc::ApplyOp::create( + builder, loc, emitc::PointerType::get(arrayType.getElementType()), + builder.getStringAttr("&"), subPtr); + + return ptr; +} + struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -112,19 +156,21 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); Type elementType = memrefType.getElementType(); IndexType indexType = rewriter.getIndexType(); - emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>( - loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, + emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create( + rewriter, loc, sizeTType, rewriter.getStringAttr("sizeof"), + ValueRange{}, ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); int64_t numElements = 1; for (int64_t dimSize : memrefType.getShape()) { numElements *= dimSize; } - Value numElementsValue = rewriter.create<emitc::ConstantOp>( - loc, indexType, rewriter.getIndexAttr(numElements)); + Value numElementsValue = emitc::ConstantOp::create( + rewriter, loc, indexType, rewriter.getIndexAttr(numElements)); - Value totalSizeBytes = rewriter.create<emitc::MulOp>( - loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); + Value totalSizeBytes = + emitc::MulOp::create(rewriter, loc, sizeTType, + sizeofElementOp.getResult(0), numElementsValue); emitc::CallOpaqueOp allocCall; StringAttr allocFunctionName; @@ -132,8 +178,8 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { SmallVector<Value, 2> argsVec; if (allocOp.getAlignment()) { allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); - alignmentValue = rewriter.create<emitc::ConstantOp>( - loc, sizeTType, + alignmentValue = emitc::ConstantOp::create( + rewriter, loc, sizeTType, rewriter.getIntegerAttr(indexType, allocOp.getAlignment().value_or(0))); argsVec.push_back(alignmentValue); @@ -144,21 +190,62 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { argsVec.push_back(totalSizeBytes); ValueRange args(argsVec); - allocCall = rewriter.create<emitc::CallOpaqueOp>( - loc, + allocCall = emitc::CallOpaqueOp::create( + rewriter, loc, emitc::PointerType::get( emitc::OpaqueType::get(rewriter.getContext(), "void")), allocFunctionName, args); emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); - emitc::CastOp castOp = rewriter.create<emitc::CastOp>( - loc, targetPointerType, allocCall.getResult(0)); + emitc::CastOp castOp = emitc::CastOp::create( + rewriter, loc, targetPointerType, allocCall.getResult(0)); rewriter.replaceOp(allocOp, castOp); return success(); } }; +struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = copyOp.getLoc(); + MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType()); + MemRefType targetMemrefType = + cast<MemRefType>(copyOp.getTarget().getType()); + + if (!isMemRefTypeLegalForEmitC(srcMemrefType)) + return rewriter.notifyMatchFailure( + loc, "incompatible source memref type for EmitC conversion"); + + if (!isMemRefTypeLegalForEmitC(targetMemrefType)) + return rewriter.notifyMatchFailure( + loc, "incompatible target memref type for EmitC conversion"); + + auto srcArrayValue = + cast<TypedValue<emitc::ArrayType>>(operands.getSource()); + emitc::ApplyOp srcPtr = + createPointerFromEmitcArray(loc, rewriter, srcArrayValue); + + auto targetArrayValue = + cast<TypedValue<emitc::ArrayType>>(operands.getTarget()); + emitc::ApplyOp targetPtr = + createPointerFromEmitcArray(loc, rewriter, targetArrayValue); + + emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create( + rewriter, loc, TypeRange{}, "memcpy", + ValueRange{ + targetPtr.getResult(), srcPtr.getResult(), + calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)}); + + rewriter.replaceOp(copyOp, memCpyCall.getResults()); + + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; @@ -320,6 +407,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, - ConvertLoad, ConvertStore>(converter, patterns.getContext()); + patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal, + ConvertGetGlobal, ConvertLoad, ConvertStore>( + converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index e78dd76..a073a9a 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -18,6 +18,8 @@ #include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMEMREFTOEMITC @@ -27,6 +29,15 @@ namespace mlir { using namespace mlir; namespace { + +emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module, + StringRef headerName) { + StringAttr includeAttr = builder.getStringAttr(headerName); + return emitc::IncludeOp::create( + builder, module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); +} + struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { using Base::Base; @@ -55,34 +66,29 @@ struct ConvertMemRefToEmitCPass return signalPassFailure(); mlir::ModuleOp module = getOperation(); + llvm::SmallSet<StringRef, 4> existingHeaders; + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + module.walk([&](mlir::emitc::IncludeOp includeOp) { + if (includeOp.getIsStandardInclude()) + existingHeaders.insert(includeOp.getInclude()); + }); + module.walk([&](mlir::emitc::CallOpaqueOp callOp) { - if (callOp.getCallee() != alignedAllocFunctionName && - callOp.getCallee() != mallocFunctionName) { + StringRef expectedHeader; + if (callOp.getCallee() == alignedAllocFunctionName || + callOp.getCallee() == mallocFunctionName) + expectedHeader = options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader; + else if (callOp.getCallee() == memcpyFunctionName) + expectedHeader = + options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader; + else return mlir::WalkResult::advance(); + if (!existingHeaders.contains(expectedHeader)) { + addStandardHeader(builder, module, expectedHeader); + existingHeaders.insert(expectedHeader); } - - for (auto &op : *module.getBody()) { - emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op); - if (!includeOp) { - continue; - } - if (includeOp.getIsStandardInclude() && - ((options.lowerToCpp && - includeOp.getInclude() == cppStandardLibraryHeader) || - (!options.lowerToCpp && - includeOp.getInclude() == cStandardLibraryHeader))) { - return mlir::WalkResult::interrupt(); - } - } - - mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); - StringAttr includeAttr = - builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader - : cStandardLibraryHeader); - builder.create<mlir::emitc::IncludeOp>( - module.getLoc(), includeAttr, - /*is_standard_include=*/builder.getUnitAttr()); - return mlir::WalkResult::interrupt(); + return mlir::WalkResult::advance(); }); } }; diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index d6bdd34..262e0e7 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering auto result = UnrankedMemRefDescriptor::poison( rewriter, loc, typeConverter->convertType(resultTypeU)); result.setRank(rewriter, loc, rank); - SmallVector<Value, 1> sizes; - UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), - result, resultAddrSpace, sizes); - Value resultUnderlyingSize = sizes.front(); + Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize( + rewriter, loc, *getTypeConverter(), result, resultAddrSpace); Value resultUnderlyingDesc = LLVM::AllocaOp::create(rewriter, loc, getPtrType(), rewriter.getI8Type(), resultUnderlyingSize); @@ -1530,12 +1528,11 @@ private: auto targetDesc = UnrankedMemRefDescriptor::poison( rewriter, loc, typeConverter->convertType(targetType)); targetDesc.setRank(rewriter, loc, resultRank); - SmallVector<Value, 4> sizes; - UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), - targetDesc, addressSpace, sizes); + Value allocationSize = UnrankedMemRefDescriptor::computeSize( + rewriter, loc, *getTypeConverter(), targetDesc, addressSpace); Value underlyingDescPtr = LLVM::AllocaOp::create( rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8), - sizes.front()); + allocationSize); targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); // Extract pointers and offset from the source memref. @@ -1872,6 +1869,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::umin; case arith::AtomicRMWKind::ori: return LLVM::AtomicBinOp::_or; + case arith::AtomicRMWKind::xori: + return LLVM::AtomicBinOp::_xor; case arith::AtomicRMWKind::andi: return LLVM::AtomicBinOp::_and; default: diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 2549a9c..c6c5ab3 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -283,11 +283,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { Value srcPtr = getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices()); + auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8); Value ldMatrixResult = NVVM::LdMatrixOp::create( b, ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col - : NVVM::MMALayout::row); + : NVVM::MMALayout::row, + /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16); // The ldmatrix operation returns either a single i32 value or a struct of // i32 values. Here we unpack those values and cast them back to their @@ -1104,12 +1106,10 @@ struct NVGPUGenerateWarpgroupDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LDBG() << "Generating warpgroup.descriptor: " - << "leading_off:" << leadDimVal << "\t" - << "stride_off :" << strideDimVal << "\t" - << "base_offset:" << offsetVal << "\t" - << "layout_type:" << swizzle << " (" - << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) + LDBG() << "Generating warpgroup.descriptor: " << "leading_off:" + << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t" + << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle + << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) << ")\n start_addr : " << baseAddr; rewriter.replaceOp(op, dsc); @@ -1399,14 +1399,12 @@ struct NVGPUWarpgroupMmaOpLowering /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix /// descriptors and arranges them based on induction variables: i, j, and k. Value generateWgmma(int i, int j, int k, Value matrixC) { - LDBG() << "\t wgmma." - << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A[" - << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM - << "][" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "] * " - << " B[" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN - << "])"; + LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK + << "(A[" << (iterationM * wgmmaM) << ":" + << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK) + << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B[" + << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK) + << "][" << 0 << ":" << wgmmaN << "])"; Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp index 91788f9..e0144bf 100644 --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -61,7 +61,7 @@ struct PtxLowering op.getAsmValues(rewriter, asmValues); for (auto &[asmValue, modifier] : asmValues) { - LDBG() << asmValue << "\t Modifier : " << &modifier; + LDBG() << asmValue << "\t Modifier : " << modifier; generator.insertValue(asmValue, modifier); } diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index ba448e4..37cfc9f 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -382,8 +382,11 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); - auto comparison = arith::CmpIOp::create( - rewriter, loc, arith::CmpIPredicate::slt, iv, upperBound); + arith::CmpIPredicate predicate = forOp.getUnsignedCmp() + ? arith::CmpIPredicate::ult + : arith::CmpIPredicate::slt; + auto comparison = + arith::CmpIOp::create(rewriter, loc, predicate, iv, upperBound); cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock, ArrayRef<Value>(), endBlock, ArrayRef<Value>()); diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 84cbd86..1f239aa 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -154,6 +154,10 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = forOp.getLoc(); + if (forOp.getUnsignedCmp()) + return rewriter.notifyMatchFailure(forOp, + "unsigned loops are not supported"); + // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the loop body. SmallVector<Value> resultVariables; diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index dc92367f..55ed31e 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -178,8 +178,14 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> { // Generate the rest of the loop header. rewriter.setInsertionPointToEnd(header); auto *mergeBlock = loopOp.getMergeBlock(); - auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(), - newIndVar, adaptor.getUpperBound()); + Value cmpOp; + if (forOp.getUnsignedCmp()) { + cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.getI1Type(), + newIndVar, adaptor.getUpperBound()); + } else { + cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(), + newIndVar, adaptor.getUpperBound()); + } spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, diff --git a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt new file mode 100644 index 0000000..2d4b2b6 --- /dev/null +++ b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRVectorToAMX + VectorToAMX.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMX + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRAMXDialect + MLIRAffineUtils + MLIRArithDialect + MLIRLinalgUtils + MLIRMemRefDialect + MLIRSCFDialect + MLIRTransforms + MLIRVectorDialect + ) diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp new file mode 100644 index 0000000..a11e9b2 --- /dev/null +++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp @@ -0,0 +1,283 @@ +//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToAMX/VectorToAMX.h" + +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include <numeric> + +namespace mlir { +#define GEN_PASS_DEF_CONVERTVECTORTOAMX +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +/// Return true if vector shape is compatible with AMX tiles. +/// The validation accounts for VNNI packing. +static bool verifyAmxShape(VectorType vec) { + // Check overall shape: + // - 2D for plain layout input or output + // - 3D for VNNI packed input + if (vec.getRank() != 2 && vec.getRank() != 3) + return false; + + ArrayRef<int64_t> shape = vec.getShape(); + int64_t rows = shape[0]; + int64_t cols = shape[1]; + unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth(); + + // 3D shape indicates VNNI packed layout. + if (vec.getRank() == 3) { + int64_t vnniFactor = 32 / elemBitWidth; + if (shape.back() != vnniFactor) + return false; + cols *= vnniFactor; + } + + // AMX tile supports up to 16 rows of 64 bytes each. + constexpr unsigned maxRows = 16; + constexpr unsigned maxBitsPerRow = 64 * 8; + return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow; +} + +/// Checks if contraction operands are in AMX-compatible packed VNNI layout. +static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter, + vector::ContractionOp contractOp) { + VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType()); + if (!accType || accType.getRank() != 2) + return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector"); + + // Expect 3D inputs for VNNI packed data. + VectorType lhsType = contractOp.getLhs().getType(); + VectorType rhsType = contractOp.getRhs().getType(); + if (lhsType.getRank() != 3 || rhsType.getRank() != 3) + return rewriter.notifyMatchFailure(contractOp, + "Expects lhs and rhs 3D vectors"); + + // Check if shapes are compatible with AMX tile. + if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) || + !verifyAmxShape(accType)) + return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape"); + + // Validate affine maps. + // + // Iterators can be ordered arbitrarily. Indexing map positions are based on + // operands' target shapes. + // The matrix layouts must match the following: + // - matrix A - [M]x[K/vnniFactor]x[vnniFactor] + // - matrix B - [K/vnniFactor]x[N]x[vnniFactor] + // - matrix C - [M]x[N] + SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray(); + AffineMap mapA = indexingMaps[0]; + AffineMap mapB = indexingMaps[1]; + if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 || + mapB.getNumResults() != 3) + return rewriter.notifyMatchFailure(contractOp, + "Invalid input indexing maps"); + FailureOr<linalg::ContractionDimensions> dims = + linalg::inferContractionDims(indexingMaps); + if (failed(dims)) + return rewriter.notifyMatchFailure(contractOp, + "Failed to infer contraction dims"); + // Two reduction dimensions are expected: + // - one for the K dimension + // - one for the VNNI factor + if (dims->k.size() != 2) + return rewriter.notifyMatchFailure(contractOp, + "Expected two reduction dims"); + assert(dims->m.size() == 1 && dims->n.size() == 1 && + "Invalid parallel contraction dims"); + + SmallVector<vector::IteratorType> iteratorTypes = + contractOp.getIteratorTypesArray(); + // Check VNNI dim maps - the innermost dim for A and B inputs. + auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2)); + auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2)); + if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB || + iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction) + return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map"); + // Check K dim maps - non-transposed row-major layout. + auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1)); + auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0)); + if (!redDimA || !redDimB || redDimA != redDimB || + iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction) + return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map"); + // Check M and N dim maps - map to non-transposed output. + AffineMap mapC = indexingMaps[2]; + auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0)); + auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1)); + if (!mDimC || !nDimC) + return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps"); + auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0)); + if (!parallelDimA || + iteratorTypes[parallelDimA.getPosition()] != + vector::IteratorType::parallel || + parallelDimA != mDimC) + return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map"); + auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1)); + if (!parallelDimB || + iteratorTypes[parallelDimB.getPosition()] != + vector::IteratorType::parallel || + parallelDimB != nDimC) + return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map"); + + return success(); +} + +/// Validate contraction operands for AMX lowering. +static LogicalResult validateOperands(PatternRewriter &rewriter, + vector::ContractionOp contractOp) { + VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType()); + if (!accType) + return rewriter.notifyMatchFailure(contractOp, "Expects vector acc"); + + // Check if operand types are compatible with AMX compute ops. + bool validElemTypes = false; + Type lhsElemType = contractOp.getLhs().getType().getElementType(); + Type rhsElemType = contractOp.getRhs().getType().getElementType(); + Type accElemType = accType.getElementType(); + if (accElemType.isInteger(32)) { + validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8); + } else if (accElemType.isF32()) { + validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) || + (lhsElemType.isBF16() && rhsElemType.isBF16()); + } + if (!validElemTypes) + return rewriter.notifyMatchFailure(contractOp, + "Invalid combination of operand types"); + + if (failed(isAmxVnniLayout(rewriter, contractOp))) + return failure(); + + return success(); +} + +/// Collapses the two innermost dimensions together. +static Value collapseLastDim(PatternRewriter &rewriter, + TypedValue<MemRefType> memref) { + int64_t rank = memref.getType().getRank(); + SmallVector<ReassociationIndices> reassocIndices; + for (auto i : llvm::seq<int64_t>(0, rank - 2)) + reassocIndices.push_back({i}); + reassocIndices.push_back({rank - 2, rank - 1}); + return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref, + reassocIndices); +} + +/// Loads vector values to an AMX tile. +static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter, + TypedValue<VectorType> vec) { + Location loc = vec.getLoc(); + Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); + + // Transfer the vector to a tile through an intermediate buffer. + VectorType vecTy = vec.getType(); + Value buf = memref::AllocaOp::create( + rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType())); + SmallVector<Value> indices(vecTy.getRank(), zeroIndex); + vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices); + + // Collapse the VNNI dimension in case of packing. + bool isPacked = vecTy.getRank() == 3; + if (isPacked) + buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf)); + + ArrayRef<int64_t> shape = vecTy.getShape(); + int64_t rows = shape[0]; + int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1, + std::multiplies<int64_t>()); + auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); + + return amx::TileLoadOp::create(rewriter, loc, tileType, buf, + {zeroIndex, zeroIndex}); +} + +/// Stores an AMX tile in a vector. +static TypedValue<VectorType> storeTile(PatternRewriter &rewriter, + TypedValue<amx::TileType> tile) { + Location loc = tile.getLoc(); + Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); + + // Transfer the tile to a vector through an intermediate buffer. + amx::TileType tileTy = tile.getType(); + Value buf = memref::AllocaOp::create( + rewriter, loc, + MemRefType::get(tileTy.getShape(), tileTy.getElementType())); + SmallVector<Value> indices(2, zeroIndex); + amx::TileStoreOp::create(rewriter, loc, buf, indices, tile); + + auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType()); + return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {}); +} + +struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + Location loc = contractOp.getLoc(); + + if (contractOp.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind"); + if (failed(validateOperands(rewriter, contractOp))) + return failure(); + + TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs()); + TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs()); + auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc()); + assert(acc && "Invalid accumulator type"); + TypedValue<amx::TileType> accTile = loadTile(rewriter, acc); + + TypedValue<amx::TileType> tileMul; + if (acc.getType().getElementType().isFloat()) { + tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(), + lhsTile, rhsTile, accTile); + } else { + tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(), + lhsTile, rhsTile, accTile); + } + + Value res = storeTile(rewriter, tileMul); + rewriter.replaceOp(contractOp, res); + + return success(); + } +}; + +struct ConvertVectorToAMXPass + : public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> { + void runOnOperation() override { + MLIRContext &ctx = getContext(); + RewritePatternSet patterns(&ctx); + populateVectorToAMXConversionPatterns(patterns); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) { + patterns.add<ContractionToAMX>(patterns.getContext()); +} diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index f9e2a01..afc3d1b 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1891,15 +1891,21 @@ struct VectorFromElementsLowering ConversionPatternRewriter &rewriter) const override { Location loc = fromElementsOp.getLoc(); VectorType vectorType = fromElementsOp.getType(); - // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>. - // Such ops should be handled in the same way as vector.insert. + // Only support 1-D vectors. Multi-dimensional vectors should have been + // transformed to 1-D vectors by the vector-to-vector transformations before + // this. if (vectorType.getRank() > 1) return rewriter.notifyMatchFailure(fromElementsOp, "rank > 1 vectors are not supported"); Type llvmType = typeConverter->convertType(vectorType); + Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType); - for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) - result = vector::InsertOp::create(rewriter, loc, val, result, idx); + for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) { + auto constIdx = + LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx); + result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result, + val, constIdx); + } rewriter.replaceOp(fromElementsOp, result); return success(); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index cf10869..9852df6 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); + populateVectorFromElementsLoweringPatterns(patterns); if (armI8MM) { if (armNeon) arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); diff --git a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt index 567083d..e9ad67c5 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt @@ -13,4 +13,5 @@ add_mlir_conversion_library(MLIRVectorToXeGPU MLIRTransforms MLIRVectorDialect MLIRXeGPUDialect + MLIRXeGPUUtils ) diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 8010755..819c2e5 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -14,9 +14,11 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" @@ -68,11 +70,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter, if (!srcTy) return rewriter.notifyMatchFailure(xferOp, "Expects memref source"); - // Perform common data transfer checks. - VectorType vecTy = xferOp.getVectorType(); - if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy))) - return failure(); - // Validate further transfer op semantics. SmallVector<int64_t> strides; int64_t offset; @@ -80,6 +77,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter, return rewriter.notifyMatchFailure( xferOp, "Buffer must be contiguous in the innermost dimension"); + VectorType vecTy = xferOp.getVectorType(); unsigned vecRank = vecTy.getRank(); if (xferOp.hasOutOfBoundsDim() && vecRank < 2) return rewriter.notifyMatchFailure( @@ -155,6 +153,277 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc, return ndDesc; } +// Adjusts the strides of a memref according to a given permutation map for +// vector operations. +// +// This function updates the innermost strides in the `strides` array to +// reflect the permutation specified by `permMap`. The permutation is computed +// using the inverse and broadcasting-aware version of the permutation map, +// and is applied to the relevant strides. This ensures that memory accesses +// are consistent with the logical permutation of vector elements. +// +// Example: +// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`. +// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1, +// 0]), then after calling this function, the last two strides will be +// swapped: +// Original strides: [s0, s1, s2, s3] +// After permutation: [s0, s1, s3, s2] +// +static void adjustStridesForPermutation(AffineMap permMap, + SmallVectorImpl<Value> &strides) { + + AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap); + SmallVector<unsigned> perms; + invMap.isPermutationOfMinorIdentityWithBroadcasting(perms); + SmallVector<int64_t> perms64(perms.begin(), perms.end()); + strides = applyPermutation(strides, perms64); +} + +// Computes memory strides for vector transfer operations, handling both +// static and dynamic memrefs while applying permutation transformations +// for XeGPU lowering. +static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp, + PatternRewriter &rewriter) { + SmallVector<Value> strides; + Value baseMemref = xferOp.getBase(); + AffineMap permMap = xferOp.getPermutationMap(); + MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType()); + + Location loc = xferOp.getLoc(); + if (memrefType.hasStaticShape()) { + int64_t offset; + SmallVector<int64_t> intStrides; + if (failed(memrefType.getStridesAndOffset(intStrides, offset))) + return {}; + // Wrap static strides as MLIR values + for (int64_t s : intStrides) + strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s)); + } else { + // For dynamic shape memref, use memref.extract_strided_metadata to get + // stride values + unsigned rank = memrefType.getRank(); + Type indexType = rewriter.getIndexType(); + + // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1, + // size0, size1, ..., sizeN-1] + SmallVector<Type> resultTypes; + resultTypes.push_back(MemRefType::get( + {}, memrefType.getElementType())); // base memref (unranked) + resultTypes.push_back(indexType); // offset + + for (unsigned i = 0; i < rank; ++i) + resultTypes.push_back(indexType); // strides + + for (unsigned i = 0; i < rank; ++i) + resultTypes.push_back(indexType); // sizes + + auto meta = memref::ExtractStridedMetadataOp::create( + rewriter, loc, resultTypes, baseMemref); + strides.append(meta.getStrides().begin(), meta.getStrides().end()); + } + // Adjust strides according to the permutation map (e.g., for transpose) + adjustStridesForPermutation(permMap, strides); + return strides; +} + +// This function compute the vectors of localOffsets for scattered load/stores. +// It is used in the lowering of vector.transfer_read/write to +// load_gather/store_scatter Example: +// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0], +// %cst {in_bounds = [true, true, true, true]}>} : +// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16> +// +// %6 = vector.step: vector<4xindex> +// %7 = vector.step: vector<2xindex> +// %8 = vector.step: vector<6xindex> +// %9 = vector.step: vector<32xindex> +// %10 = arith.mul %6, 384 +// %11 = arith.mul %7, 192 +// %12 = arith.mul %8, 32 +// %13 = arith.mul %9, 1 +// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16> +// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16> +// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16> +// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16> +// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex> +// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex> +// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex> +// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex> +// %22 = arith.add %18, %19 +// %23 = arith.add %20, %21 +// %local_offsets = arith.add %22, %23 +// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map +// %offsets = orig_offset + local_offsets +static Value computeOffsets(VectorTransferOpInterface xferOp, + PatternRewriter &rewriter, + ArrayRef<Value> strides) { + Location loc = xferOp.getLoc(); + VectorType vectorType = xferOp.getVectorType(); + SmallVector<Value> indices(xferOp.getIndices().begin(), + xferOp.getIndices().end()); + ArrayRef<int64_t> vectorShape = vectorType.getShape(); + + // Create vector.step operations for each dimension + SmallVector<Value> stepVectors; + llvm::map_to_vector(vectorShape, [&](int64_t dim) { + auto stepType = VectorType::get({dim}, rewriter.getIndexType()); + auto stepOp = vector::StepOp::create(rewriter, loc, stepType); + stepVectors.push_back(stepOp); + return stepOp; + }); + + // Multiply step vectors by corresponding strides + size_t memrefRank = strides.size(); + size_t vectorRank = vectorShape.size(); + SmallVector<Value> strideMultiplied; + for (size_t i = 0; i < vectorRank; ++i) { + size_t memrefDim = memrefRank - vectorRank + i; + Value strideValue = strides[memrefDim]; + auto mulType = dyn_cast<VectorType>(stepVectors[i].getType()); + auto bcastOp = + vector::BroadcastOp::create(rewriter, loc, mulType, strideValue); + auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp); + strideMultiplied.push_back(mulOp); + } + + // Shape cast each multiplied vector to add singleton dimensions + SmallVector<Value> shapeCasted; + for (size_t i = 0; i < vectorRank; ++i) { + SmallVector<int64_t> newShape(vectorRank, 1); + newShape[i] = vectorShape[i]; + auto newType = VectorType::get(newShape, rewriter.getIndexType()); + auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType, + strideMultiplied[i]); + shapeCasted.push_back(castOp); + } + + // Broadcast each shape-casted vector to full vector shape + SmallVector<Value> broadcasted; + auto fullIndexVectorType = + VectorType::get(vectorShape, rewriter.getIndexType()); + for (Value shapeCastVal : shapeCasted) { + auto broadcastOp = vector::BroadcastOp::create( + rewriter, loc, fullIndexVectorType, shapeCastVal); + broadcasted.push_back(broadcastOp); + } + + // Add all broadcasted vectors together to compute local offsets + Value localOffsets = broadcasted[0]; + for (size_t i = 1; i < broadcasted.size(); ++i) + localOffsets = + arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]); + + // Compute base offset from transfer read indices + Value baseOffset = nullptr; + if (!indices.empty()) { + baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); + for (size_t i = 0; i < indices.size(); ++i) { + Value strideVal = strides[i]; + Value offsetContrib = + arith::MulIOp::create(rewriter, loc, indices[i], strideVal); + baseOffset = + arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); + } + // Broadcast base offset to match vector shape + Value bcastBase = vector::BroadcastOp::create( + rewriter, loc, fullIndexVectorType, baseOffset); + localOffsets = + arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets); + } + return localOffsets; +} + +// Collapse memref shape to 1D +static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp, + PatternRewriter &rewriter) { + Location loc = xferOp.getLoc(); + + Value baseMemref = xferOp.getBase(); + MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType()); + Type elementType = memrefType.getElementType(); + + // Compute the total number of elements in the memref + MemRefType flatMemrefType; + if (memrefType.hasStaticShape()) { + auto totalElements = memrefType.getNumElements(); + flatMemrefType = MemRefType::get({totalElements}, elementType); + } else { + flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType); + } + + SmallVector<ReassociationIndices> reassociation; + ReassociationIndices allDims = + llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank())); + reassociation.push_back(allDims); + + auto collapseOp = memref::CollapseShapeOp::create( + rewriter, loc, flatMemrefType, baseMemref, reassociation); + return collapseOp; +} + +static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, + PatternRewriter &rewriter) { + + Location loc = readOp.getLoc(); + VectorType vectorType = readOp.getVectorType(); + ArrayRef<int64_t> vectorShape = vectorType.getShape(); + auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType()); + if (!memrefType) + return rewriter.notifyMatchFailure(readOp, "Expected memref source"); + + SmallVector<Value> strides = computeStrides(readOp, rewriter); + if (strides.empty()) + return rewriter.notifyMatchFailure(readOp, "Failed to compute strides"); + + Value localOffsets = computeOffsets(readOp, rewriter, strides); + + Value flatMemref = collapseMemrefTo1D(readOp, rewriter); + + Value mask = vector::ConstantMaskOp::create( + rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), + vectorShape); + auto gatherOp = xegpu::LoadGatherOp::create( + rewriter, loc, vectorType, flatMemref, localOffsets, mask, + /*chunk_size=*/IntegerAttr{}, + /*l1_hint=*/xegpu::CachePolicyAttr{}, + /*l2_hint=*/xegpu::CachePolicyAttr{}, + /*l3_hint=*/xegpu::CachePolicyAttr{}); + + rewriter.replaceOp(readOp, gatherOp.getResult()); + return success(); +} + +static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, + PatternRewriter &rewriter) { + + Location loc = writeOp.getLoc(); + VectorType vectorType = writeOp.getVectorType(); + ArrayRef<int64_t> vectorShape = vectorType.getShape(); + + auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType()); + if (!memrefType) + return rewriter.notifyMatchFailure(writeOp, "Expected memref source"); + + SmallVector<Value> strides = computeStrides(writeOp, rewriter); + + Value localOffsets = computeOffsets(writeOp, rewriter, strides); + + Value flatMemref = collapseMemrefTo1D(writeOp, rewriter); + + Value mask = vector::ConstantMaskOp::create( + rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), + vectorShape); + xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref, + localOffsets, mask, + /*chunk_size=*/IntegerAttr{}, + /*l1_hint=*/xegpu::CachePolicyAttr{}, + /*l2_hint=*/xegpu::CachePolicyAttr{}, + /*l3_hint=*/xegpu::CachePolicyAttr{}); + rewriter.eraseOp(writeOp); + return success(); +} + struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; @@ -165,6 +434,22 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { if (failed(transferPreconditions(rewriter, readOp))) return failure(); + // TODO:This check needs to be replaced with proper uArch capability check + auto chip = xegpu::getChipStr(readOp); + if (chip != "pvc" && chip != "bmg") { + // lower to scattered load Op if the target HW doesn't have 2d block load + // support + // TODO: add support for OutOfBound access + if (readOp.hasOutOfBoundsDim()) + return failure(); + return lowerToScatteredLoadOp(readOp, rewriter); + } + + // Perform common data transfer checks. + VectorType vecTy = readOp.getVectorType(); + if (failed(storeLoadPreconditions(rewriter, readOp, vecTy))) + return failure(); + bool isOutOfBounds = readOp.hasOutOfBoundsDim(); if (isOutOfBounds && !isZeroConstant(readOp.getPadding())) return rewriter.notifyMatchFailure( @@ -173,7 +458,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { AffineMap readMap = readOp.getPermutationMap(); bool isTransposeLoad = !readMap.isMinorIdentity(); - VectorType vecTy = readOp.getVectorType(); Type elementType = vecTy.getElementType(); unsigned minTransposeBitWidth = 32; if (isTransposeLoad && @@ -221,11 +505,26 @@ struct TransferWriteLowering if (failed(transferPreconditions(rewriter, writeOp))) return failure(); + // TODO:This check needs to be replaced with proper uArch capability check + auto chip = xegpu::getChipStr(writeOp); + if (chip != "pvc" && chip != "bmg") { + // lower to scattered store Op if the target HW doesn't have 2d block + // store support + // TODO: add support for OutOfBound access + if (writeOp.hasOutOfBoundsDim()) + return failure(); + return lowerToScatteredStoreOp(writeOp, rewriter); + } + + // Perform common data transfer checks. + VectorType vecTy = writeOp.getVectorType(); + if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy))) + return failure(); + AffineMap map = writeOp.getPermutationMap(); if (!map.isMinorIdentity()) return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); - VectorType vecTy = writeOp.getVectorType(); auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp index 86edc2b..b405ec2 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -93,13 +93,13 @@ FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) { int64_t lb = forOp.getConstantLowerBound(); dividend[pos] = 1; dividend.back() -= lb; - addLocalFloorDiv(dividend, step); + unsigned qPos = addLocalFloorDiv(dividend, step); // Second constraint: (iv - lb) - step * q = 0. SmallVector<int64_t, 8> eq(getNumCols(), 0); eq[pos] = 1; eq.back() -= lb; // For the local var just added above. - eq[getNumCols() - 2] = -step; + eq[qPos] = -step; addEquality(eq); } } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 488c3c3..7d4d818 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, case AtomicRMWKind::addi: case AtomicRMWKind::maxu: case AtomicRMWKind::ori: + case AtomicRMWKind::xori: return builder.getZeroAttr(resultType); case AtomicRMWKind::andi: return builder.getIntegerAttr( @@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) { // Integer operations. .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; }) .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; }) - .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; }) + .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; }) .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; }) .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; }) .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; }) @@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, return arith::OrIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::andi: return arith::AndIOp::create(builder, loc, lhs, rhs); + case AtomicRMWKind::xori: + return arith::XOrIOp::create(builder, loc, lhs, rhs); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported"); diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index e79da92..5359826 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -1395,6 +1395,7 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) { //===----------------------------------------------------------------------===// // FieldOp //===----------------------------------------------------------------------===// + static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op, TypeAttr type, Attribute initialValue) { @@ -1452,6 +1453,15 @@ LogicalResult FieldOp::verify() { //===----------------------------------------------------------------------===// // GetFieldOp //===----------------------------------------------------------------------===// + +LogicalResult GetFieldOp::verify() { + auto parentClassOp = getOperation()->getParentOfType<emitc::ClassOp>(); + if (!parentClassOp.getOperation()) + return emitOpError(" must be nested within an emitc.class operation"); + + return success(); +} + LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) { mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr(); FieldOp fieldOp = diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp index c55e26e..06d7e07 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp @@ -64,8 +64,8 @@ public: TypeAttr typeAttr = TypeAttr::get(val.getType()); fields.push_back({fieldName, typeAttr}); - FieldOp fieldop = rewriter.create<emitc::FieldOp>( - funcOp->getLoc(), fieldName, typeAttr, nullptr); + FieldOp fieldop = emitc::FieldOp::create(rewriter, funcOp->getLoc(), + fieldName, typeAttr, nullptr); if (argAttrs && idx < argAttrs->size()) { fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx)); diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index d4978ca..97adad6 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -431,8 +431,7 @@ private: if (std::optional<SymbolTable::UseRange> symbolUses = SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { for (SymbolTable::SymbolUse symbolUse : *symbolUses) { - StringRef symbolName = - cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue(); + StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference(); if (symbolTable.lookup(symbolName)) continue; diff --git a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp index e9cf493..6da76e9 100644 --- a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" +#include "mlir/Target/LLVM/XeVM/Target.h" #include "llvm/Support/Regex.h" namespace mlir { diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp index 384d1a0..be71bd0 100644 --- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp +++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" #include <numeric> @@ -57,26 +58,29 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns( warpOp.getResultTypes().end()); auto yield = cast<gpu::YieldOp>( warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(), - yield.getOperands().end()); + SmallVector<Value> yieldValues(yield.getOperands().begin(), + yield.getOperands().end()); + llvm::SmallDenseMap<Value, unsigned> indexLookup; + // Record the value -> first index mapping for faster lookup. + for (auto [i, v] : llvm::enumerate(yieldValues)) { + if (!indexLookup.count(v)) + indexLookup[v] = i; + } + for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) { - if (yieldValues.insert(value)) { + // If the value already exists in the yield, don't create a new output. + if (indexLookup.count(value)) { + indices.push_back(indexLookup[value]); + } else { + // If the value is new, add it to the yield and to the types. + yieldValues.push_back(value); types.push_back(type); indices.push_back(yieldValues.size() - 1); - } else { - // If the value already exit the region don't create a new output. - for (auto [idx, yieldOperand] : - llvm::enumerate(yieldValues.getArrayRef())) { - if (yieldOperand == value) { - indices.push_back(idx); - break; - } - } } } - yieldValues.insert_range(newYieldedValues); + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, yieldValues.getArrayRef(), types); + rewriter, warpOp, yieldValues, types); rewriter.replaceOp(warpOp, newWarpOp.getResults().take_front(warpOp.getNumResults())); return newWarpOp; diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp index 894de44..e004d5f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp @@ -107,11 +107,32 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { ss << getModifier() << getRegisterType(v) << ","; } +/// Check if the operation needs to pack and unpack results. +static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) { + return interfaceOp->getNumResults() > 1; +} + +/// Pack the result types of the interface operation. +/// If the operation has multiple results, it packs them into a struct +/// type. Otherwise, it returns the original result types. +static SmallVector<Type> packResultTypes(MLIRContext *ctx, + BasicPtxBuilderInterface interfaceOp) { + TypeRange results = interfaceOp->getResultTypes(); + + if (!needsPackUnpack(interfaceOp)) + return llvm::to_vector<1>(results); + + SmallVector<mlir::Type> elems(results.begin(), results.end()); + auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false); + return {sTy}; +} + LLVM::InlineAsmOp PtxBuilder::build() { + MLIRContext *ctx = interfaceOp->getContext(); auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(), LLVM::AsmDialect::AD_ATT); - auto resultTypes = interfaceOp->getResultTypes(); + SmallVector<Type> resultTypes = packResultTypes(ctx, interfaceOp); // Remove the last comma from the constraints string. if (!registerConstraints.empty() && @@ -136,7 +157,7 @@ LLVM::InlineAsmOp PtxBuilder::build() { rewriter, interfaceOp->getLoc(), /*result types=*/resultTypes, /*operands=*/ptxOperands, - /*asm_string=*/llvm::StringRef(ptxInstruction), + /*asm_string=*/ptxInstruction, /*constraints=*/registerConstraints.data(), /*has_side_effects=*/interfaceOp.hasSideEffect(), /*is_align_stack=*/false, LLVM::TailCallKind::None, @@ -147,9 +168,34 @@ LLVM::InlineAsmOp PtxBuilder::build() { void PtxBuilder::buildAndReplaceOp() { LLVM::InlineAsmOp inlineAsmOp = build(); LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n"); - if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) { - rewriter.replaceOp(interfaceOp, inlineAsmOp); - } else { + + // Case 1: no result + if (inlineAsmOp->getNumResults() == 0) { rewriter.eraseOp(interfaceOp); + return; + } + + // Case 2: single result, forward it directly + if (!needsPackUnpack(interfaceOp)) { + rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults()); + return; } + + // Case 3: multiple results were packed; unpack the struct. + assert(mlir::LLVM::LLVMStructType::classof( + inlineAsmOp.getResultTypes().front()) && + "Expected result type to be LLVMStructType when unpacking multiple " + "results"); + auto structTy = llvm::cast<mlir::LLVM::LLVMStructType>( + inlineAsmOp.getResultTypes().front()); + + SmallVector<mlir::Value> unpacked; + Value structVal = inlineAsmOp.getResult(0); + for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) { + Value unpackedValue = LLVM::ExtractValueOp::create( + rewriter, interfaceOp->getLoc(), structVal, idx); + unpacked.push_back(unpackedValue); + } + + rewriter.replaceOp(interfaceOp, unpacked); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 422039f..a6e89f6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -141,6 +141,38 @@ static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) { return success(); } +static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder, + bool isExpandLoad, + uint64_t alignment = 1) { + // From + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // + // The pointer alignment defaults to 1. + if (alignment == 1) { + return nullptr; + } + + auto emptyDictAttr = builder.getDictionaryAttr({}); + auto alignmentAttr = builder.getI64IntegerAttr(alignment); + auto namedAttr = + builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr); + SmallVector<mlir::NamedAttribute> attrs = {namedAttr}; + auto alignDictAttr = builder.getDictionaryAttr(attrs); + // From + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // + // The align parameter attribute can be provided for [expandload]'s first + // argument. The align parameter attribute can be provided for + // [compressstore]'s second argument. + int pos = isExpandLoad ? 0 : 1; + return pos == 0 ? builder.getArrayAttr( + {alignDictAttr, emptyDictAttr, emptyDictAttr}) + : builder.getArrayAttr( + {emptyDictAttr, alignDictAttr, emptyDictAttr}); +} + //===----------------------------------------------------------------------===// // Operand bundle helpers. //===----------------------------------------------------------------------===// @@ -4117,6 +4149,32 @@ LogicalResult LLVM::masked_scatter::verify() { } //===----------------------------------------------------------------------===// +// masked_expandload (intrinsic) +//===----------------------------------------------------------------------===// + +void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state, + mlir::TypeRange resTys, Value ptr, + Value mask, Value passthru, + uint64_t align) { + ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align); + build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs, + /*res_attrs=*/nullptr); +} + +//===----------------------------------------------------------------------===// +// masked_compressstore (intrinsic) +//===----------------------------------------------------------------------===// + +void LLVM::masked_compressstore::build(OpBuilder &builder, + OperationState &state, Value value, + Value ptr, Value mask, uint64_t align) { + ArrayAttr argAttrs = + getLLVMAlignParamForCompressExpand(builder, false, align); + build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs, + /*res_attrs=*/nullptr); +} + +//===----------------------------------------------------------------------===// // InlineAsmOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index e0977f5..dbcc738 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -189,6 +189,26 @@ LogicalResult BulkStoreOp::verify() { return success(); } +LogicalResult PMEventOp::verify() { + auto eventId = getEventId(); + auto maskedEventId = getMaskedEventId(); + if (!maskedEventId && !eventId) { + return emitOpError() << "either `id` or `mask` must be set"; + } + + if (maskedEventId && eventId) { + return emitOpError() << "`id` and `mask` cannot be set at the same time"; + } + + if (eventId) { + if (eventId < 0 || eventId > 15) { + return emitOpError() << "`id` must be between 0 and 15"; + } + } + + return llvm::success(); +} + // Given the element type of an operand and whether or not it is an accumulator, // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the // operand's element type. @@ -791,24 +811,58 @@ LogicalResult NVVM::WMMAMmaOp::verify() { } LogicalResult NVVM::LdMatrixOp::verify() { - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); - if (addressSpace != NVVM::kSharedMemorySpace) - return emitOpError("expected source pointer in memory space 3"); - - if (getNum() != 1 && getNum() != 2 && getNum() != 4) - return emitOpError("expected num attribute to be 1, 2 or 4"); + uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN(); + if (m == 8 && n == 8) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 " + "matrix"); + } + if (getEltType() != LdStMatrixEltType::B16) { + return emitOpError("expected element type to be b16 for 8x8 matrix"); + } + } else if (m == 8 && n == 16) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 " + "matrix"); + } + if (getLayout() != MMALayout::row) { + return emitOpError("expected layout to be row for 8x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 8x16 matrix"); + } + } else if (m == 16 && n == 16) { + if (num != 1 && num != 2) { + return emitOpError("expected num attribute to be 1 or 2 for 16x16 " + "matrix"); + } + if (getLayout() != MMALayout::col) { + return emitOpError("expected layout to be col for 16x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8 && + getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 16x16 matrix"); + } + } else { + return emitOpError("expected shape to be 8x8, 8x16 or 16x16"); + } Type i32 = IntegerType::get(getContext(), 32); - if (getNum() == 1 && getType() != i32) + uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num); + if (numElements == 1 && getType() != i32) return emitOpError("expected destination type is i32"); - if (getNum() == 2 || getNum() == 4) { + if (numElements == 2 || numElements == 4) { Type dstType = LLVM::LLVMStructType::getLiteral( - getContext(), SmallVector<Type>(getNum(), i32)); + getContext(), SmallVector<Type>(numElements, i32)); if (getType() != dstType) return emitOpError("expected destination type is a structure of ") - << getNum() << " elements of type i32"; + << numElements << " elements of type i32"; } + return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 34c63d3..e0e3716 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -194,9 +194,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state, ArrayRef<AffineMap> indexingMaps) { // Initialize indexingMaps attribute, for MatmulOp. SmallVector<Attribute, 3> indexingMapsAttrVal; - indexingMapsAttrVal = llvm::map_to_vector( - MatmulOp::getDefaultIndexingMaps(b.getContext()), - [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + indexingMapsAttrVal = + llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + }); state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, attributes, regionBuilder); @@ -1569,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -// Retrieve the operation from the body, if it is the only one (except -// yield) and if it gets the same amount of arguments as the body does. -// If initFirst flag is enabled, we check that init takes the first position in -// operands of payload. -static Operation *findPayloadOp(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false) { + // Check if the body can be printed in short form. The following 4 conditions + // must be satisfied: + + // 1) The body must contain exactly 2 operations: the payload op and a yield. if (body->getOperations().size() != 2) - return nullptr; + return false; Operation &payload = body->getOperations().front(); - assert(isa<YieldOp>(body->getOperations().back())); + // 2) The payload op must have the same number of operands as the number of + // block arguments. if (payload.getNumOperands() == 0 || payload.getNumOperands() != body->getNumArguments()) - return nullptr; + return false; + + // 3) If `initFirst` is true (e.g., for reduction ops), the init block + // must be the first operand of the payload op, otherwise, the operands + // must match the block arguments in order. if (initFirst) { // check init if (payload.getOperands().back() != body->getArgument(0)) - return nullptr; + return false; // check rest for (const auto &[operand, bbArg] : llvm::zip(payload.getOperands(), body->getArguments().drop_front())) { if (bbArg != operand) - return nullptr; + return false; } } else { for (const auto &[operand, bbArg] : llvm::zip(payload.getOperands(), body->getArguments())) { if (bbArg != operand) - return nullptr; + return false; } } - return &payload; + + // 4) The `yield` operand must be the result of the payload op. + auto yieldOp = cast<YieldOp>(body->getTerminator()); + return yieldOp.getNumOperands() == 1 && + yieldOp.getOperand(0).getDefiningOp() && + yieldOp.getOperand(0).getDefiningOp() == &payload; } -void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { +static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { SmallVector<StringRef> elidedAttrs; std::string attrToElide; p << " { " << payloadOp->getName().getStringRef(); @@ -1621,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { void MapOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - Operation *payloadOp = findPayloadOp(mapper); - if (payloadOp) { - printShortForm(p, payloadOp); + bool useShortForm = canUseShortForm(mapper); + if (useShortForm) { + printShortForm(p, &mapper->getOperations().front()); } printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); p.printOptionalAttrDict((*this)->getAttrs()); - if (!payloadOp) { + if (!useShortForm) { // Print region if the payload op was not detected. p.increaseIndent(); p.printNewline(); @@ -1828,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, void ReduceOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true); - if (payloadOp) { - printShortForm(p, payloadOp); + bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true); + if (useShortForm) { + printShortForm(p, &mapper->getOperations().front()); } printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); - if (!payloadOp) { + if (!useShortForm) { // Print region if the payload op was not detected. p.increaseIndent(); p.printNewline(); @@ -3749,6 +3760,25 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) { // MatMulOp //===----------------------------------------------------------------------===// +static FailureOr<SmallVector<SmallVector<int64_t>>> +getAffineResultPositions(ArrayAttr maps) { + SmallVector<SmallVector<int64_t>> positions; + for (auto map : maps) { + AffineMapAttr attr = dyn_cast<AffineMapAttr>(map); + if (!attr) + return failure(); + SmallVector<int64_t> pos; + for (auto result : attr.getAffineMap().getResults()) { + auto dim = dyn_cast<AffineDimExpr>(result); + if (!dim) + return failure(); + pos.push_back(dim.getPosition()); + } + positions.push_back(pos); + } + return positions; +} + /// Returns a list of AffineMap with the typical matmul indexing charactristic. SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) { AffineExpr d0, d1, d2; @@ -3760,6 +3790,20 @@ SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) { return indexingMaps; } +bool MatmulOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 2} && + (*positions)[1] == SmallVector<int64_t>{2, 1} && + (*positions)[2] == SmallVector<int64_t>{0, 1}; +} + SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() { return SmallVector<utils::IteratorType>{utils::IteratorType::parallel, utils::IteratorType::parallel, @@ -3912,6 +3956,380 @@ Speculation::Speculatability MatmulOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); } +SmallVector<AffineMap> +MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2); + AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context); + AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context); + AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool MatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{2, 0} && + (*positions)[1] == SmallVector<int64_t>{2, 1} && + (*positions)[2] == SmallVector<int64_t>{0, 1}; +} + +void linalg::MatmulTransposeAOp::build(OpBuilder &builder, + OperationState &result, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeAOp +MatmulTransposeAOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeAOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeAOp +MatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeAOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeAOp +MatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool MatmulTransposeAOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::MatmulOp>(op) && + MatmulTransposeAOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + +SmallVector<AffineMap> +MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2); + AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context); + AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context); + AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool MatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 2} && + (*positions)[1] == SmallVector<int64_t>{1, 2} && + (*positions)[2] == SmallVector<int64_t>{0, 1}; +} + +void linalg::MatmulTransposeBOp::build(OpBuilder &builder, + OperationState &result, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeBOp +MatmulTransposeBOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeBOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeBOp +MatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeBOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeBOp +MatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool MatmulTransposeBOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::MatmulOp>(op) && + MatmulTransposeBOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + +SmallVector<AffineMap> +BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2, d3; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2, d3); + AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context); + AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context); + AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool BatchMatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} && + (*positions)[1] == SmallVector<int64_t>{0, 3, 2} && + (*positions)[2] == SmallVector<int64_t>{0, 1, 2}; +} + +void linalg::BatchMatmulTransposeAOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange outputs, ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeAOp +BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeAOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeAOp +BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeAOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeAOp +BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool BatchMatmulTransposeAOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::BatchMatmulOp>(op) && + BatchMatmulTransposeAOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + +SmallVector<AffineMap> +BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2, d3; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2, d3); + AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context); + AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context); + AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool BatchMatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} && + (*positions)[1] == SmallVector<int64_t>{0, 2, 3} && + (*positions)[2] == SmallVector<int64_t>{0, 1, 2}; +} + +void linalg::BatchMatmulTransposeBOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange outputs, ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeBOp +BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeBOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeBOp +BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeBOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeBOp +BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool BatchMatmulTransposeBOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::BatchMatmulOp>(op) && + BatchMatmulTransposeBOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + //===----------------------------------------------------------------------===// // ContractOp //===----------------------------------------------------------------------===// @@ -4120,6 +4538,20 @@ BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) { return indexingMaps; } +bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} && + (*positions)[1] == SmallVector<int64_t>{0, 3, 2} && + (*positions)[2] == SmallVector<int64_t>{0, 1, 2}; +} + SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() { return SmallVector<utils::IteratorType>{ utils::IteratorType::parallel, utils::IteratorType::parallel, @@ -5345,11 +5777,18 @@ ArrayRef<int64_t> UnPackOp::getAllOuterDims() { SmallVector<int64_t> UnPackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); - auto packedShape = getSourceType().getShape(); + SmallVector<int64_t> outerDims(getAllOuterDims()); SmallVector<int64_t> res; + // Recover the original order of the outer dims. + SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm()); + invertPermutationVector(outerDimPermInv); + if (!outerDimPermInv.empty()) + applyPermutationToVector(outerDims, outerDimPermInv); + + // Collect the outer dims corresponding to the tilled inner dims. for (auto index : innerDimsPos) - res.push_back(packedShape[index]); + res.push_back(outerDims[index]); return res; } @@ -5646,6 +6085,19 @@ BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) { return indexingMaps; } +bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} && + (*positions)[1] == SmallVector<int64_t>{0, 3, 2} && + (*positions)[2] == SmallVector<int64_t>{1, 2}; +} unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; } std::string BatchReduceMatmulOp::getLibraryCallName() { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 8754743..639e0fe 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h" @@ -27,6 +28,7 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Utils/Utils.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" @@ -1985,14 +1987,19 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, // Convert the padding values to attributes. SmallVector<Attribute> paddingValues; - for (auto const &it : + for (auto const &[untypedAttr, elementOrTensorType] : llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) { - auto attr = dyn_cast<TypedAttr>(std::get<0>(it)); + + if (isa<ub::PoisonAttr>(untypedAttr)) { + paddingValues.push_back(untypedAttr); + continue; + } + auto attr = dyn_cast<TypedAttr>(untypedAttr); if (!attr) { - emitOpError("expects padding values to be typed attributes"); + emitOpError("expects padding values to be typed attributes or poison"); return DiagnosedSilenceableFailure::definiteFailure(); } - Type elementType = getElementTypeOrSelf(std::get<1>(it)); + Type elementType = getElementTypeOrSelf(elementOrTensorType); // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = dyn_cast<StringAttr>(attr)) { auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute( @@ -2000,7 +2007,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); if (!parsedAttr || parsedAttr.getType() != elementType) { auto diag = this->emitOpError("expects a padding that parses to ") - << elementType << ", got " << std::get<0>(it); + << elementType << ", got " << untypedAttr; diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } @@ -2235,8 +2242,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter, llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) { auto attr = dyn_cast<TypedAttr>(untypedAttr); Type elementType = getElementTypeOrSelf(elementOrTensorType); + + if (isa<ub::PoisonAttr>(untypedAttr)) { + paddingValues.push_back(untypedAttr); + continue; + } if (!attr) { - emitOpError("expects padding values to be typed attributes"); + emitOpError("expects padding values to be typed attributes or poison"); return DiagnosedSilenceableFailure::definiteFailure(); } // Try to parse string attributes to obtain an attribute of element type. diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp index 3908d73..b4507a9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp @@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns( RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) { patterns.add<BlockPackMatmul<linalg::GenericOp>, BlockPackMatmul<linalg::MatmulOp>, - BlockPackMatmul<linalg::BatchMatmulOp>, - BlockPackMatmul<linalg::MatmulTransposeAOp>, - BlockPackMatmul<linalg::BatchMatmulTransposeAOp>, - BlockPackMatmul<linalg::MatmulTransposeBOp>, - BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>( - patterns.getContext(), controlFn); + BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(), + controlFn); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index bf66ed0..22690da 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -691,9 +691,9 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { auto newResultType = RankedTensorType::get( newResultShape, padOp.getResultType().getElementType()); - auto newPadOp = rewriter.create<tensor::PadOp>( - padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad, - newHighPad, paddingVal, padOp.getNofold()); + auto newPadOp = tensor::PadOp::create( + rewriter, padOp.getLoc(), /*result=*/newResultType, collapsedSource, + newLowPad, newHighPad, paddingVal, padOp.getNofold()); Value dest = padOp.getResult(); if (options.rankReductionStrategy == @@ -1052,12 +1052,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> { static bool constexpr reduceLeft = (std::is_same_v<FromOpTy, BatchMatmulOp> && std::is_same_v<ToOpTy, BatchVecmatOp>) || - (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> && - std::is_same_v<ToOpTy, BatchVecmatOp>) || (std::is_same_v<FromOpTy, MatmulOp> && std::is_same_v<ToOpTy, VecmatOp>) || - (std::is_same_v<FromOpTy, MatmulTransposeAOp> && - std::is_same_v<ToOpTy, VecmatOp>) || (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>); /// Look for non-batch spatial dims to collapse. @@ -1113,27 +1109,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns( MLIRContext *context = patterns.getContext(); // Unbatching patterns for unit batch size patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context); - patterns - .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>( - context); - patterns - .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>( - context); patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context); patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context); // Non-batch rank 1 reducing patterns patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context); patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context); - patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context); - patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context); // Batch rank 1 reducing patterns patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context); patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context); - patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>( - context); - patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>( - context); // Non-batch rank 0 reducing patterns patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index fd530f2..9436f1c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -594,7 +594,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl( auto clonedForOp = scf::ForOp::create( rewriter, loc, bvm.lookupOrDefault(forOp.getLowerBound()), bvm.lookupOrDefault(forOp.getUpperBound()), - bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor); + bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); // Map the induction var, region args and results to the `clonedForOp`. bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 58986a6..922b7d6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -55,7 +55,8 @@ static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp newLoop = scf::ForOp::create( rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), - loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); + loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}, + loop.getUnsignedCmp()); // Generate the new yield with the replaced operand. auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 2e62523..3d12bc3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinAttributes.h" @@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, Value paddingValue; if (auto complexTy = dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) { - auto complexAttr = cast<ArrayAttr>(paddingValueAttr); - paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), - complexTy, complexAttr); - } else { - paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(), - cast<TypedAttr>(paddingValueAttr)); + if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) { + paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), + complexTy, complexAttr); + } + } else if (isa<ub::PoisonAttr>(paddingValueAttr)) { + paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(), + getElementTypeOrSelf(v.getType())); + } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) { + paddingValue = + arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr); } + assert(paddingValue && "failed to create value from padding attribute"); // Pad the operand to the bounding box defined by `paddedShape`. SmallVector<int64_t> tensorShape; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 455e1a6..35ba4f15 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter, /// Codegen the different matmul variants. if (numOfBatchDims) { - if (a == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter, - genericOp); - if (b == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter, - genericOp); return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp); } - - if (a == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp); - if (b == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp); return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp index a2a4335..2650488 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp @@ -59,12 +59,12 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter, ArrayRef<int64_t>{1, 0}); Operation *newMatmulOp; if (transposeLHS) { - newMatmulOp = linalg::MatmulTransposeAOp::create( + newMatmulOp = MatmulTransposeAOp::create( rewriter, loc, matmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]}, matmulOp.getOutputs()); } else { - newMatmulOp = linalg::MatmulTransposeBOp::create( + newMatmulOp = MatmulTransposeBOp::create( rewriter, loc, matmulOp.getResultTypes(), ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)}, matmulOp.getOutputs()); @@ -116,12 +116,12 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, ArrayRef<int64_t>{0, 2, 1}); Operation *newMatmulOp; if (transposeLHS) { - newMatmulOp = linalg::BatchMatmulTransposeAOp::create( + newMatmulOp = BatchMatmulTransposeAOp::create( rewriter, loc, batchMatmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]}, batchMatmulOp.getOutputs()); } else { - newMatmulOp = linalg::BatchMatmulTransposeBOp::create( + newMatmulOp = BatchMatmulTransposeBOp::create( rewriter, loc, batchMatmulOp.getResultTypes(), ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)}, batchMatmulOp.getOutputs()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index cf65e67..406f05c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2563,7 +2563,7 @@ vectorizeScalableVectorPrecondition(Operation *op, "vectorization"; return failure(); } - if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) { + if (isa<linalg::MatmulOp>(op)) { LDBG() << "Scalable vectorization of the reduction dim in Matmul-like ops " "is not supported"; @@ -2604,17 +2604,12 @@ vectorizeScalableVectorPrecondition(Operation *op, return failure(); } - // Check to not let go the matmul with extended semantic, through this - // transform. - if (linalgOp.hasUserDefinedMaps()) - return failure(); - // Cond 4: Only the following ops are supported in the // presence of scalable vectors return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) || - isa<linalg::MatmulTransposeAOp>(op) || isa<linalg::DepthwiseConv1DNwcWcOp>(op) || isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) || + isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp)); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 74b968c..b59d73d 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() { case arith::AtomicRMWKind::minu: case arith::AtomicRMWKind::muli: case arith::AtomicRMWKind::ori: + case arith::AtomicRMWKind::xori: case arith::AtomicRMWKind::andi: if (!llvm::isa<IntegerType>(getValue().getType())) return emitOpError() << "with kind '" diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 34c95e3..8474244 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -422,6 +422,12 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref( << descMemref << " != " << dstMemref; } + int lastDimBytes = + descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8; + if (lastDimBytes % 16 != 0) { + return op->emitError() << "the bytes in the last dimension of the tensor " + "map must be a multiple of 16"; + } return std::nullopt; } diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 485bb73..d7c8916 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -1390,6 +1390,20 @@ void acc::ParallelOp::addPrivatization(MLIRContext *context, setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } +void acc::ParallelOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + getFirstprivateOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getFirstprivatizationRecipesAttr()) + llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} static ParseResult parseNumGangs( mlir::OpAsmParser &parser, llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, @@ -2041,6 +2055,21 @@ void acc::SerialOp::addPrivatization(MLIRContext *context, setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } +void acc::SerialOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + getFirstprivateOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getFirstprivatizationRecipesAttr()) + llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + //===----------------------------------------------------------------------===// // KernelsOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index c1c1767..fa94219 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3874,6 +3874,107 @@ LogicalResult AllocateDirOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// TargetAllocMemOp +//===----------------------------------------------------------------------===// + +mlir::Type omp::TargetAllocMemOp::getAllocatedType() { + return getInTypeAttr().getValue(); +} + +/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype, +/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )? +/// attr-dict-without-keyword +static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto &builder = parser.getBuilder(); + bool hasOperands = false; + std::int32_t typeparamsSize = 0; + + // Parse device number as a new operand + mlir::OpAsmParser::UnresolvedOperand deviceOperand; + mlir::Type deviceType; + if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType)) + return mlir::failure(); + if (parser.resolveOperand(deviceOperand, deviceType, result.operands)) + return mlir::failure(); + if (parser.parseComma()) + return mlir::failure(); + + mlir::Type intype; + if (parser.parseType(intype)) + return mlir::failure(); + result.addAttribute("in_type", mlir::TypeAttr::get(intype)); + llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; + llvm::SmallVector<mlir::Type> typeVec; + if (!parser.parseOptionalLParen()) { + // parse the LEN params of the derived type. (<params> : <types>) + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || + parser.parseColonTypeList(typeVec) || parser.parseRParen()) + return mlir::failure(); + typeparamsSize = operands.size(); + hasOperands = true; + } + std::int32_t shapeSize = 0; + if (!parser.parseOptionalComma()) { + // parse size to scale by, vector of n dimensions of type index + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) + return mlir::failure(); + shapeSize = operands.size() - typeparamsSize; + auto idxTy = builder.getIndexType(); + for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) + typeVec.push_back(idxTy); + hasOperands = true; + } + if (hasOperands && + parser.resolveOperands(operands, typeVec, parser.getNameLoc(), + result.operands)) + return mlir::failure(); + + mlir::Type restype = builder.getIntegerType(64); + if (!restype) { + parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; + return mlir::failure(); + } + llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize}; + result.addAttribute("operandSegmentSizes", + builder.getDenseI32ArrayAttr(segmentSizes)); + if (parser.parseOptionalAttrDict(result.attributes) || + parser.addTypeToList(restype, result.types)) + return mlir::failure(); + return mlir::success(); +} + +mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseTargetAllocMemOp(parser, result); +} + +void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) { + p << " "; + p.printOperand(getDevice()); + p << " : "; + p << getDevice().getType(); + p << ", "; + p << getInType(); + if (!getTypeparams().empty()) { + p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')'; + } + for (auto sh : getShape()) { + p << ", "; + p.printOperand(sh); + } + p.printOptionalAttrDict((*this)->getAttrs(), + {"in_type", "operandSegmentSizes"}); +} + +llvm::LogicalResult omp::TargetAllocMemOp::verify() { + mlir::Type outType = getType(); + if (!mlir::dyn_cast<IntegerType>(outType)) + return emitOpError("must be a integer type"); + return mlir::success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 0262a1b..0dbc041 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -318,9 +318,12 @@ void ConditionOp::getSuccessorRegions( void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, Value ub, Value step, ValueRange initArgs, - BodyBuilderFn bodyBuilder) { + BodyBuilderFn bodyBuilder, bool unsignedCmp) { OpBuilder::InsertionGuard guard(builder); + if (unsignedCmp) + result.addAttribute(getUnsignedCmpAttrName(result.name), + builder.getUnitAttr()); result.addOperands({lb, ub, step}); result.addOperands(initArgs); for (Value v : initArgs) @@ -450,6 +453,9 @@ static void printInitializationList(OpAsmPrinter &p, } void ForOp::print(OpAsmPrinter &p) { + if (getUnsignedCmp()) + p << " unsigned"; + p << " " << getInductionVar() << " = " << getLowerBound() << " to " << getUpperBound() << " step " << getStep(); @@ -462,7 +468,8 @@ void ForOp::print(OpAsmPrinter &p) { p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/!getInitArgs().empty()); - p.printOptionalAttrDict((*this)->getAttrs()); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/getUnsignedCmpAttrName().strref()); } ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { @@ -472,6 +479,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::Argument inductionVariable; OpAsmParser::UnresolvedOperand lb, ub, step; + if (succeeded(parser.parseOptionalKeyword("unsigned"))) + result.addAttribute(getUnsignedCmpAttrName(result.name), + builder.getUnitAttr()); + // Parse the induction variable followed by '='. if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() || // Parse loop bounds. @@ -562,7 +573,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter, inits.append(newInitOperands.begin(), newInitOperands.end()); scf::ForOp newLoop = scf::ForOp::create( rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits, - [](OpBuilder &, Location, Value, ValueRange) {}); + [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp()); newLoop->setAttrs(getPrunedAttributeList(getOperation(), {})); // Generate the new yield values and append them to the scf.yield operation. @@ -806,7 +817,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, // 2. Create the new forOp shell. scf::ForOp newForOp = scf::ForOp::create( rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newIterOperands); + forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr, + forOp.getUnsignedCmp()); newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.getRegion().front(); SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(), @@ -931,7 +943,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { scf::ForOp newForOp = scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), - forOp.getUpperBound(), forOp.getStep(), newIterArgs); + forOp.getUpperBound(), forOp.getStep(), newIterArgs, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.getRegion().front(); @@ -989,12 +1002,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { /// Util function that tries to compute a constant diff between u and l. /// Returns std::nullopt when the difference between two AffineValueMap is /// dynamic. -static std::optional<int64_t> computeConstDiff(Value l, Value u) { +static std::optional<APInt> computeConstDiff(Value l, Value u) { IntegerAttr clb, cub; if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) { llvm::APInt lbValue = clb.getValue(); llvm::APInt ubValue = cub.getValue(); - return (ubValue - lbValue).getSExtValue(); + return ubValue - lbValue; } // Else a simple pattern match for x + c or c + x @@ -1003,7 +1016,7 @@ static std::optional<int64_t> computeConstDiff(Value l, Value u) { u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) || matchPattern( u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l)))) - return diff.getSExtValue(); + return diff; return std::nullopt; } @@ -1022,13 +1035,15 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { return success(); } - std::optional<int64_t> diff = + std::optional<APInt> diff = computeConstDiff(op.getLowerBound(), op.getUpperBound()); if (!diff) return failure(); // If the loop is known to have 0 iterations, remove it. - if (*diff <= 0) { + bool zeroOrLessIterations = + diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative()); + if (zeroOrLessIterations) { rewriter.replaceOp(op, op.getInitArgs()); return success(); } @@ -3384,9 +3399,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) { if (functionType.getNumInputs() != operands.size()) { return parser.emitError(typeLoc) - << "expected as many input types as operands " - << "(expected " << operands.size() << " got " - << functionType.getNumInputs() << ")"; + << "expected as many input types as operands " << "(expected " + << operands.size() << " got " << functionType.getNumInputs() << ")"; } // Resolve input operands. @@ -4222,14 +4236,15 @@ LogicalResult scf::IndexSwitchOp::verify() { << "see yield operation here"; } for (auto [idx, result, operand] : - llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(), - yield.getOperandTypes())) { - if (result == operand) + llvm::enumerate(getResultTypes(), yield.getOperands())) { + if (!operand) + return yield.emitOpError() << "operand " << idx << " is null\n"; + if (result == operand.getType()) continue; return (emitOpError("expected result #") << idx << " of each region to be " << result) .attachNote(yield.getLoc()) - << name << " returns " << operand << " here"; + << name << " returns " << operand.getType() << " here"; } return success(); }; diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index f8799c5..fb179e6 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -769,7 +769,8 @@ struct ForOpInterface // Construct a new scf.for op with memref instead of tensor values. auto newForOp = scf::ForOp::create( rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), castedInitArgs); + forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr, + forOp.getUnsignedCmp()); newForOp->setAttrs(forOp->getAttrs()); Block *loopBody = newForOp.getBody(); diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index bee7780..ae52af5 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -58,9 +58,12 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> { auto *beforeBlock = rewriter.createBlock( &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs); rewriter.setInsertionPointToStart(whileOp.getBeforeBody()); - auto cmpOp = arith::CmpIOp::create( - rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt, - beforeBlock->getArgument(0), forOp.getUpperBound()); + arith::CmpIPredicate predicate = forOp.getUnsignedCmp() + ? arith::CmpIPredicate::ult + : arith::CmpIPredicate::slt; + auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), predicate, + beforeBlock->getArgument(0), + forOp.getUpperBound()); scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(), beforeBlock->getArguments()); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index 1130538..7e7fba4 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -791,6 +791,11 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, bool *modifiedIR) { if (modifiedIR) *modifiedIR = false; + + // TODO: Add support for unsigned loops. + if (forOp.getUnsignedCmp()) + return failure(); + LoopPipelinerInternal pipeliner; if (!pipeliner.initializeLoopInfo(forOp, options)) return failure(); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index 4752c08..f1203b2 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -256,6 +256,10 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> { LogicalResult matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const override { + if (forOp.getUnsignedCmp()) + return rewriter.notifyMatchFailure(forOp, + "unsigned loops are not supported"); + // Do not peel already peeled loops. if (forOp->hasAttr(kPeeledLoopLabel)) return failure(); diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 1b07b77..3b75970 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -116,7 +116,8 @@ public: llvm::getSingleElement(adaptor.getLowerBound()), llvm::getSingleElement(adaptor.getUpperBound()), llvm::getSingleElement(adaptor.getStep()), - flattenValues(adaptor.getInitArgs())); + flattenValues(adaptor.getInitArgs()), + /*bodyBuilder=*/nullptr, op.getUnsignedCmp()); // Reserve whatever attributes in the original op. newOp->setAttrs(op->getAttrs()); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index c0e47ee..250c413 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -797,7 +797,8 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( inits.append(newInitOperands.begin(), newInitOperands.end()); auto newLoop = scf::ForOp::create( rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(), - loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); + loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}, + loopOp.getUnsignedCmp()); // Move the loop body to the new op. Block *loopBody = loopOp.getBody(); @@ -935,7 +936,8 @@ static LogicalResult addInitOperandsToLoopNest( auto newLoop = scf::ForOp::create( rewriter, forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), forLoop.getStep(), newInits, - [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); + [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}, + forLoop.getUnsignedCmp()); // Merge the body of the new loop with the body of the old loops. SmallVector<Value> sourceBlockArgs; diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 5731795..4910258 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1233,6 +1233,7 @@ static void getPerfectlyNestedLoopsImpl( static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef<scf::ForOp> targets) { + assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported"); auto originalStep = forOp.getStep(); auto iv = forOp.getInductionVar(); @@ -1241,6 +1242,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor, Loops innerLoops; for (auto t : targets) { + assert(!t.getUnsignedCmp() && "unsigned loops are not supported"); + // Save information for splicing ops out of t when done auto begin = t.getBody()->begin(); auto nOps = t.getBody()->getOperations().size(); @@ -1415,6 +1418,8 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter) { + assert(source.getUnsignedCmp() == target.getUnsignedCmp() && + "incompatible signedness"); unsigned numTargetOuts = target.getNumResults(); unsigned numSourceOuts = source.getNumResults(); @@ -1428,7 +1433,8 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, rewriter.setInsertionPointAfter(source); scf::ForOp fusedLoop = scf::ForOp::create( rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(), - source.getStep(), fusedInitArgs); + source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr, + source.getUnsignedCmp()); // Map original induction variables and operands to those of the fused loop. IRMapping mapping; diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp index 3b97786..dabbea1 100644 --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -71,7 +71,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm, pm.addPass(createLowerAffinePass()); pm.addPass( createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions())); - pm.addPass(createFinalizeMemRefToLLVMConversionPass()); pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass()); pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass()); pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass()); @@ -79,12 +78,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm, pm.addPass(createConvertComplexToLibm()); pm.addPass( createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions())); - pm.addPass(createConvertComplexToLLVMPass()); - pm.addPass( - createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions())); - pm.addPass(createConvertFuncToLLVMPass()); - pm.addPass(createArithToLLVMConversionPass()); - pm.addPass(createConvertControlFlowToLLVMPass()); // Finalize GPU code generation. if (gpuCodegen) { @@ -99,8 +92,8 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm, pm.addPass(createGpuModuleToBinaryPass(gpuModuleToBinaryPassOptions)); } - // Convert poison values. - pm.addPass(createUBToLLVMConversionPass()); + // Convert to LLVM. + pm.addPass(createConvertToLLVMPass()); // Ensure all casts are realized. pm.addPass(createReconcileUnrealizedCastsPass()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 134aef3..0e88d31d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -730,9 +730,9 @@ public: {tensor, lvlCoords, values, filled, added, count}, EmitCInterface::On); Operation *parent = getTop(op); + rewriter.setInsertionPointAfter(parent); rewriter.replaceOp(op, adaptor.getTensor()); // Deallocate the buffers on exit of the loop nest. - rewriter.setInsertionPointAfter(parent); memref::DeallocOp::create(rewriter, loc, values); memref::DeallocOp::create(rewriter, loc, filled); memref::DeallocOp::create(rewriter, loc, added); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index 4464450..febec6d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -533,8 +533,10 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, VectorType vtp = vectorType(vl, init.getType()); Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), forOp.getRegionIterArg(0), init, vtp); - forOpNew = scf::ForOp::create(rewriter, loc, forOp.getLowerBound(), - forOp.getUpperBound(), step, vinit); + forOpNew = + scf::ForOp::create(rewriter, loc, forOp.getLowerBound(), + forOp.getUpperBound(), step, vinit, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); forOpNew->setAttr( LoopEmitter::getLoopEmitterLoopAttrName(), forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); @@ -605,8 +607,8 @@ public: ForOpRewriter(MLIRContext *context, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32) - : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization, - enableSIMDIndex32} {} + : OpRewritePattern(context), + vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {} LogicalResult matchAndRewrite(scf::ForOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index e3cba388..fce61f2 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1120,13 +1120,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { } if (rhsTy == resultTy) { - if (isSplatZero(resultETy, lhsAttr)) + if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape()) + // constant values can only be resized if resulting type is static return lhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, lhsAttr, shift)) return rhs; } if (lhsTy == resultTy) { - if (isSplatZero(resultETy, rhsAttr)) + if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape()) return rhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, rhsAttr, shift)) return lhs; diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index e6ef028..34385d7 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -276,7 +276,7 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub, if (!ubConstant) return std::nullopt; std::optional<int64_t> stepConstant = getConstantIntValue(step); - if (!stepConstant) + if (!stepConstant || *stepConstant == 0) return std::nullopt; return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a450056..74e48b5 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2402,6 +2402,16 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor, return foldToElementsFromElements(*this, results); } +LogicalResult +ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, + ToElementsOp::Adaptor adaptor, + SmallVectorImpl<Type> &inferredReturnTypes) { + auto vecType = cast<VectorType>(adaptor.getSource().getType()); + Type elType = vecType.getElementType(); + inferredReturnTypes.append(vecType.getNumElements(), elType); + return success(); +} + //===----------------------------------------------------------------------===// // FromElementsOp //===----------------------------------------------------------------------===// @@ -2841,9 +2851,47 @@ LogicalResult BroadcastOp::verify() { llvm_unreachable("unexpected vector.broadcast op error"); } +// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible +// with broadcast's result type and shape_cast only adds or removes ones in the +// leading dimensions. +static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) { + auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>(); + if (!srcShapeCast) + return failure(); + + VectorType srcType = srcShapeCast.getSourceVectorType(); + VectorType destType = broadcastOp.getResultVectorType(); + // Check type compatibility. + if (vector::isBroadcastableTo(srcType, destType) != + BroadcastableToResult::Success) + return failure(); + + ArrayRef<int64_t> srcShape = srcType.getShape(); + ArrayRef<int64_t> shapecastShape = + srcShapeCast.getResultVectorType().getShape(); + // Trailing dimensions should be the same if shape_cast only alters the + // leading dimensions. + unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size()); + if (!llvm::equal(srcShape.take_back(numTrailingDims), + shapecastShape.take_back(numTrailingDims))) + return failure(); + + assert(all_of(srcShape.drop_back(numTrailingDims), + [](int64_t E) { return E == 1; }) && + all_of(shapecastShape.drop_back(numTrailingDims), + [](int64_t E) { return E == 1; }) && + "ill-formed shape_cast"); + + broadcastOp.getSourceMutable().assign(srcShapeCast.getSource()); + return success(); +} + OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { if (getSourceType() == getResultVectorType()) return getSource(); + if (succeeded(foldBroadcastOfShapeCast(*this))) + return getResult(); + if (!adaptor.getSource()) return {}; auto vectorType = getResultVectorType(); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 2d5cc07..fe066dc 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns( vector::populateVectorGatherLoweringPatterns(patterns); } +void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorFromElementsLoweringPatterns(patterns); +} + void transform::ApplyLowerScanPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorScanLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 9e287fc..acbf2b7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorBitCast.cpp LowerVectorBroadcast.cpp LowerVectorContract.cpp + LowerVectorFromElements.cpp LowerVectorGather.cpp LowerVectorInterleave.cpp LowerVectorMask.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp new file mode 100644 index 0000000..c22fd54 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp @@ -0,0 +1,65 @@ +//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.from_elements' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" + +#define DEBUG_TYPE "lower-vector-from-elements" + +using namespace mlir; + +namespace { + +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> +/// +/// ==> +/// +/// %0 = ub.poison : vector<2x3xf32> +/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> +/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> +/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> +/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d from_elements +/// ops. +struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::FromElementsOp op, + PatternRewriter &rewriter) const override { + ValueRange allElements = op.getElements(); + + auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + size_t subTyNumElements = subTy.getNumElements(); + assert((index + 1) * subTyNumElements <= allElements.size() && + "out of bounds"); + ValueRange subElements = + allElements.slice(index * subTyNumElements, subTyNumElements); + return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); + }; + + return unrollVectorOp(op, rewriter, unrollFromElementsFn); + } +}; + +} // namespace + +void mlir::vector::populateVectorFromElementsLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add<UnrollFromElements>(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index e062f55..90f21c5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -54,27 +54,13 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> { LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { - VectorType resultTy = op.getType(); - if (resultTy.getRank() < 2) - return rewriter.notifyMatchFailure(op, "already 1-D"); - - // Unrolling doesn't take vscale into account. Pattern is disabled for - // vectors with leading scalable dim(s). - if (resultTy.getScalableDims().front()) - return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); - - Location loc = op.getLoc(); Value indexVec = op.getIndexVec(); Value maskVec = op.getMask(); Value passThruVec = op.getPassThru(); - Value result = arith::ConstantOp::create(rewriter, loc, resultTy, - rewriter.getZeroAttr(resultTy)); - - VectorType subTy = VectorType::Builder(resultTy).dropDim(0); - - for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { - int64_t thisIdx[1] = {i}; + auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + int64_t thisIdx[1] = {index}; Value indexSubVec = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); @@ -82,15 +68,12 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> { vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx); Value passThruSubVec = vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx); - Value subGather = vector::GatherOp::create( - rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec, - maskSubVec, passThruSubVec); - result = - vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx); - } + return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(), + op.getIndices(), indexSubVec, maskSubVec, + passThruSubVec); + }; - rewriter.replaceOp(op, result); - return success(); + return unrollVectorOp(op, rewriter, unrollGatherFn); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index bb0f339..be0d28a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1826,7 +1826,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); auto newForOp = scf::ForOp::create( rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newForOpOperands); + forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr, + forOp.getUnsignedCmp()); // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the // newly created `ForOp`. This `WarpOp` will contain all ops that were // contained within the original `ForOp` body. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 2269a40..023c4da 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2274,7 +2274,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> { LogicalResult matchAndRewrite(MulOpType mulOp, PatternRewriter &rewriter) const override { - auto resType = llvm::cast<VectorType>(mulOp.getResult().getType()); + auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType()); if (!resType) return failure(); if (resType.getRank() != 2) diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 6e2fa35..841e138 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -392,3 +392,29 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape, } return success(); } + +LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, + vector::UnrollVectorOpFn unrollFn) { + assert(op->getNumResults() == 1 && "expected single result"); + assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type"); + VectorType resultTy = cast<VectorType>(op->getResult(0).getType()); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(op, "already 1-D"); + + // Unrolling doesn't take vscale into account. Pattern is disabled for + // vectors with leading scalable dim(s). + if (resultTy.getScalableDims().front()) + return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); + + Location loc = op->getLoc(); + Value result = ub::PoisonOp::create(rewriter, loc, resultTy); + VectorType subTy = VectorType::Builder(resultTy).dropDim(0); + + for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { + Value subVector = unrollFn(rewriter, loc, subTy, i); + result = vector::InsertOp::create(rewriter, loc, subVector, result, i); + } + + rewriter.replaceOp(op, result); + return success(); +} diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt index 242a97c..7869a28 100644 --- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt @@ -7,13 +7,18 @@ add_mlir_dialect_library(MLIRXeGPUDialect DEPENDS MLIRXeGPUIncGen + MLIRXeGPUAttrInterfaceIncGen MLIRXeGPUAttrsIncGen MLIRXeGPUEnumsIncGen LINK_LIBS PUBLIC MLIRArithDialect + MLIRIndexDialect + MLIRAffineUtils MLIRArithUtils MLIRDialectUtils + MLIRGPUDialect + MLIRXeVMDialect MLIRIR MLIRViewLikeInterface MLIRVectorDialect diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 3c0ca114..8ea8cb1 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -6,12 +6,16 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" using std::optional; @@ -33,6 +37,57 @@ void XeGPUDialect::initialize() { >(); } +/// Generates instructions to compute offsets for a subgroup identified by +/// its multidimensional indices (sgId), using the specified subgroup layout +/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data +/// dimensions (sizePerWg). +static SmallVector<SmallVector<Value>> +genOffsetsComputingInsts(OpBuilder &builder, Location loc, + SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout, + ArrayRef<int64_t> sizePerSg, + ArrayRef<int64_t> sizePerWg) { + + SmallVector<SmallVector<Value>> offsets; + + // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i] + SmallVector<Value> localOffsets = llvm::map_to_vector( + llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value { + return builder.createOrFold<index::MulOp>( + loc, std::get<0>(t), + builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t))); + }); + + // distUnit[i] is the minimum value between sizePerWg[i] and + // sgLayout[i] * sizePerSg[i] + SmallVector<int64_t> distUnit = llvm::map_to_vector( + llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)), + [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); + + for (SmallVector<int64_t> unitOffs : + StaticTileOffsetRange(sizePerWg, distUnit)) { + SmallVector<Value> base = + llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value { + return arith::ConstantIndexOp::create(builder, loc, d); + }); + + SmallVector<Value> adds = llvm::map_to_vector( + llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value { + return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t), + std::get<1>(t)); + }); + + SmallVector<Value> mods = llvm::map_to_vector( + llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value { + return builder.createOrFold<index::RemUOp>( + loc, std::get<0>(t), + arith::ConstantIndexOp::create(builder, loc, std::get<1>(t))); + }); + + offsets.push_back(mods); + } + return offsets; +} + // Checks if the given shape can be evenly distributed based on the layout // and data factors provided by the LayoutAttr. bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, @@ -211,6 +266,148 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, return success(); } +FailureOr<SmallVector<Value>> +LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, + Value linearId) { + // delinearizeSubgroupId is only available for + // workgroup-level layout attribute + if (!isWgLayout()) + return failure(); + + // TODO: handle order attribute + auto hasDefaultOrder = [&]() { + DenseI32ArrayAttr order = getOrder(); + return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>( + llvm::reverse(order.asArrayRef()))); + }; + if (!hasDefaultOrder()) + return mlir::emitError(loc, "order attribute is currently not supported."); + + auto dims = llvm::map_to_vector(*getSgLayoutAsInt(), [&](int64_t d) -> Value { + return builder.createOrFold<arith::ConstantIndexOp>(loc, d); + }); + + return affine::delinearizeIndex(builder, loc, linearId, dims); +} + +/// Implements LayoutTrait::getOffsets to generate instructions for +/// computing multi-dimensional offsets when distributed by LayoutAttr. +FailureOr<SmallVector<SmallVector<Value>>> +LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, + ArrayRef<int64_t> shape) { + if (!isWgLayout()) + return failure(); + + SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value(); + SmallVector<int64_t> sgShape; + if (auto maybeSgShape = getSgDataAsInt()) + sgShape = maybeSgShape.value(); + else if (auto derivedShape = computeShapeRatio(shape, sgLayout)) + sgShape = derivedShape.value(); + else + return failure(); + + // delinearize Ids + auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + if (failed(maybeIds)) + return failure(); + SmallVector<Value> sgIds = *maybeIds; + + return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, + shape); +} + +//===----------------------------------------------------------------------===// +// XeGPU_SliceAttr +//===----------------------------------------------------------------------===// +LogicalResult +SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError, + xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) { + if (!parent || !dims) + return emitError() << "expected parent layout and dims attribute"; + + int64_t rank = parent.getRank(); + + // check every element in dims is unique and smaller than rank + llvm::SmallDenseSet<int64_t> seen; + for (int64_t dim : dims.asArrayRef()) { + if (dim < 0 || dim >= rank) + return emitError() << "invalid dim (" << dim << ") in slice attribute."; + if (!seen.insert(dim).second) + return emitError() << "repeated dim (" << dim << ") in slice attribute."; + } + return success(); +} + +SliceAttr SliceAttr::flatten() const { + xegpu::LayoutTrait parent = getParent(); + SmallVector<DenseI64ArrayAttr> slicedDims({getDims()}); + + while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) { + parent = sliceAttr.getParent(); + slicedDims.push_back(sliceAttr.getDims()); + } + + auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent); + SmallVector<int64_t> indices = + llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank())); + + // get remaining dims (flattend) by applying slice ops with all slicedDims + SmallVector<int64_t> remainingDims(indices); + for (auto dim : llvm::reverse(slicedDims)) + remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims), + dim.asArrayRef()); + + // get flattend sliced dims by applying slice ops with the remaining dims + SmallVector<int64_t> flattendDims = XeGPUDialect::slice( + llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims)); + + return xegpu::SliceAttr::get( + getContext(), layoutAttr, + DenseI64ArrayAttr::get(getContext(), flattendDims)); +} + +FailureOr<SmallVector<Value>> +SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, + Value linearId) { + SliceAttr attr = flatten(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + return parent.delinearizeSubgroupId(builder, loc, linearId); +} + +/// Implements LayoutTrait::getOffsets to generate instructions for +/// computing multi-dimensional offsets when distributed by SliceAttr. +FailureOr<SmallVector<SmallVector<Value>>> +SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, + ArrayRef<int64_t> shape) { + assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape."); + if (!isWgLayout()) + return failure(); + + SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value(); + SmallVector<int64_t> sgShape; + if (auto maybeSgShape = getSgDataAsInt()) + sgShape = maybeSgShape.value(); + else if (auto derivedShape = computeShapeRatio(shape, sgLayout)) + sgShape = derivedShape.value(); + else + return failure(); + + // delinearize Ids + auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + if (failed(maybeIds)) + return failure(); + + // The effective sgIds for offsets computing correspond + // to the dims that are not sliced. + ArrayRef<int64_t> dims = flatten().getDims().asArrayRef(); + SmallVector<Value> sgIds = + XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims); + + return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, + shape); +} + //===----------------------------------------------------------------------===// // XeGPU_RangeAttr //===----------------------------------------------------------------------===// @@ -230,7 +427,7 @@ RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, // XeGPU_TensorDescType //===----------------------------------------------------------------------===// -mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { +mlir::Type TensorDescType::parse(AsmParser &parser) { llvm::SmallVector<int64_t> shape; mlir::Type elementType; mlir::FailureOr<mlir::Attribute> encoding; @@ -280,7 +477,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { layout.value_or(mlir::Attribute())); } -void TensorDescType::print(::mlir::AsmPrinter &printer) const { +void TensorDescType::print(AsmPrinter &printer) const { printer << "<"; auto shape = getShape(); @@ -325,10 +522,10 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape, return Base::get(context, shape, elementType, attr, layout); } -LogicalResult TensorDescType::verify( - llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - llvm::ArrayRef<int64_t> shape, mlir::Type elementType, - mlir::Attribute encoding, mlir::Attribute layout) { +LogicalResult +TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError, + llvm::ArrayRef<int64_t> shape, mlir::Type elementType, + mlir::Attribute encoding, mlir::Attribute layout) { size_t rank = shape.size(); if (rank == 0) @@ -394,6 +591,119 @@ LogicalResult TensorDescType::verify( return success(); } +//===----------------------------------------------------------------------===// +// XeGPU_MemDescType +//===----------------------------------------------------------------------===// +mlir::Type MemDescType::parse(AsmParser &parser) { + llvm::SmallVector<int64_t> shape; + mlir::Type elementType; + mlir::FailureOr<MemLayoutAttr> layout; + + // Parse literal '<' + if (parser.parseLess()) + return {}; + + auto shapeLoc = parser.getCurrentLocation(); + if (mlir::failed(parser.parseDimensionList(shape, false, true))) { + parser.emitError(shapeLoc, "failed to parse parameter 'shape'"); + return {}; + } + + auto elemTypeLoc = parser.getCurrentLocation(); + if (mlir::failed(parser.parseType(elementType))) { + parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'"); + return {}; + } + + // parse optional attributes + if (mlir::succeeded(parser.parseOptionalComma())) { + MemLayoutAttr attr; + ParseResult res = parser.parseAttribute(attr); + if (mlir::failed(res)) + return {}; + layout = attr; + } + + // Parse literal '>' + if (parser.parseGreater()) + return {}; + + MLIRContext *ctxt = parser.getContext(); + return MemDescType::getChecked( + [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape, + elementType, layout.value_or(MemLayoutAttr())); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + + printer.printDimensionList(getShape()); + printer << 'x'; + printer << getElementType(); + + if (auto layout = getMemLayout()) + printer << ", " << layout; + + printer << ">"; +} + +//===----------------------------------------------------------------------===// +// XeGPU_MemDescType +//===----------------------------------------------------------------------===// + +Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) { + + auto context = parser.getContext(); + llvm::SMLoc loc = parser.getCurrentLocation(); + + llvm::SmallDenseSet<StringRef> seenKeys; + SmallVector<NamedAttribute> attributes; + + auto parseElt = [&]() -> ParseResult { + StringRef nameId; + if (failed(parser.parseKeyword(&nameId))) + return parser.emitError(loc, "expected valid attribute name"); + + if (!seenKeys.insert(nameId).second) + return parser.emitError(loc, "duplicate key '") + << nameId << " in mem layout attribute"; + + if (failed(parser.parseEqual())) + return failure(); + + Attribute attr; + if (failed(parser.parseAttribute(attr))) + return failure(); + attributes.emplace_back(nameId, attr); + return success(); + }; + + // Parse literal '<' + if (parser.parseLess()) + return {}; + + if (failed(parser.parseCommaSeparatedList(parseElt))) + return {}; + + // Parse literal '>' + if (parser.parseGreater()) + return {}; + + return parser.getChecked<MemLayoutAttr>( + loc, context, DictionaryAttr::get(context, attributes)); +} + +void MemLayoutAttr::print(AsmPrinter &printer) const { + printer << "<"; + ArrayRef<NamedAttribute> attrs = getAttrs().getValue(); + for (size_t i = 0; i < attrs.size(); i++) { + printer << attrs[i].getName().str() << " = " << attrs[i].getValue(); + if (i < attrs.size() - 1) + printer << ", "; + } + printer << ">"; +} + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 33450f3..906c71d 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" @@ -21,6 +23,17 @@ namespace mlir { namespace xegpu { +bool isSharedMemory(const MemRefType &memrefTy) { + Attribute attr = memrefTy.getMemorySpace(); + if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) + return intAttr.getInt() == 3; + if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr)) + return memrefSpace.getValue() == MemorySpace::SLM; + if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr)) + return xevmSpace.getValue() == xevm::AddrSpace::SHARED; + return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr); +} + template <typename T> static std::string makeString(T array, bool breakline = false) { std::string buf; @@ -121,12 +134,20 @@ isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, auto maskShape = getShapeOf(maskTy); auto valueShape = getShapeOf(valueTy); - // a valid shape for SIMT case - if (valueTy.getRank() == 1) { - if (valueTy.getNumElements() != chunkSize) - return emitError() << "value elements must match chunk size " << chunkSize - << " for SIMT code."; - return success(); + auto maskVecTy = dyn_cast<VectorType>(maskTy); + if (!maskVecTy) + return emitError() << "Expecting a vector type mask."; + int64_t maskSize = maskVecTy.getNumElements(); + + auto valueSize = valueTy.getNumElements(); + if (chunkSize > 1) { + if ((valueTy.getRank() == 1) && (valueSize != chunkSize)) + return emitError() << "value elements must match chunk size " + << chunkSize; + } else { + if (valueSize != maskSize) + return emitError() + << "Mask should match value except the chunk size dim."; } llvm::SmallVector<int64_t> expectedMaskShape(valueShape); @@ -156,41 +177,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, - Type tdesc, TypedValue<MemRefType> source, + Type tdesc, Value source, llvm::ArrayRef<OpFoldResult> shape, llvm::ArrayRef<OpFoldResult> strides) { - assert(shape.size() && strides.size() && shape.size() == strides.size() && - "Shape and strides must be present and of equal size for ui64 " - "initialization."); + Type srcTy = source.getType(); + assert((isa<IntegerType, MemRefType>(srcTy)) && + "Source has to be either int or memref."); - llvm::SmallVector<int64_t> staticShape; - llvm::SmallVector<int64_t> staticStrides; llvm::SmallVector<Value> dynamicShape; llvm::SmallVector<Value> dynamicStrides; - dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - - auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); - auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); - - build(builder, state, tdesc, source, ValueRange({}), dynamicShape, - dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, - staticStridesAttr); -} - -void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, - Type tdesc, TypedValue<IntegerType> source, - llvm::ArrayRef<OpFoldResult> shape, - llvm::ArrayRef<OpFoldResult> strides) { - assert(shape.size() && strides.size() && shape.size() == strides.size() && - "Shape and strides must be present and of equal size for ui64 " - "initialization."); - llvm::SmallVector<int64_t> staticShape; llvm::SmallVector<int64_t> staticStrides; - llvm::SmallVector<Value> dynamicShape; - llvm::SmallVector<Value> dynamicStrides; dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); @@ -198,6 +196,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); + if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) { + auto memrefShape = memrefTy.getShape(); + auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); + + // if shape and strides are from Memref, we don't need attributes for them + // to keep the IR print clean. + if (staticShape == memrefShape && staticStrides == memrefStrides) { + staticShapeAttr = DenseI64ArrayAttr(); + staticStridesAttr = DenseI64ArrayAttr(); + } + } + build(builder, state, tdesc, source, ValueRange({}), dynamicShape, dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, staticStridesAttr); @@ -265,8 +275,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, } LogicalResult CreateNdDescOp::verify() { - auto rank = (int64_t)getMixedOffsets().size(); - bool invalidRank = false; + size_t rank = getMixedSizes().size(); + bool invalidRank = rank != getMixedStrides().size(); bool invalidElemTy = false; // Memory space of created TensorDesc should match with the source. @@ -280,31 +290,28 @@ LogicalResult CreateNdDescOp::verify() { << " Source: " << srcMemorySpace << ", TensorDesc: " << tdescMemorySpace; + if (size_t offsetRank = getMixedOffsets().size()) + invalidRank |= (offsetRank != rank); + // check source type matches the rank if it is a memref. // It also should have the same ElementType as TensorDesc. - auto memrefTy = dyn_cast<MemRefType>(getSourceType()); - if (memrefTy) { - invalidRank |= (memrefTy.getRank() != rank); + if (auto memrefTy = dyn_cast<MemRefType>(getSourceType())) invalidElemTy |= memrefTy.getElementType() != getElementType(); - } if (llvm::isa<IntegerType>(getSourceType())) { // strides and shape must present for integer source. if (getMixedStrides().empty() || getMixedSizes().empty()) - return emitOpError("Expecting strides and shape to be present for " + return emitOpError("expecting strides and shape to be present for " "integer source."); } - // mismatches among shape, strides, and offsets are - // already handeled by OffsetSizeAndStrideOpInterface. - // So they are not check here. if (invalidRank) return emitOpError( "Expecting the rank of shape, strides, offsets, and source (if source " "is a memref) should match with each other."); // check result TensorDesc rank - if (getType().getRank() > rank) + if (getType().getRank() > (int64_t)rank) return emitOpError( "Expecting the TensorDesc rank is not greater than the " "ranks of shape, strides, offsets or the memref source."); @@ -360,13 +367,10 @@ ParseResult parseOptionalDynamicIndexList( void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers) { - - if (!integers) + if (!integers || integers.empty()) return; - - return printDynamicIndexList(printer, op, values, integers, - /*scalableFlags=*/{}, {}, - AsmParser::Delimiter::Square); + printDynamicIndexList(printer, op, values, integers, + /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square); } //===----------------------------------------------------------------------===// // XeGPU_PrefetchNdOp @@ -381,6 +385,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, l1_hint, l2_hint, l3_hint); } +void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, + Value tensorDesc, ArrayRef<OpFoldResult> offsets, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector<Value> dynamicOffsets; + SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, + l2_hint, l3_hint); +} + LogicalResult PrefetchNdOp::verify() { auto tdescTy = getTensorDescType(); if (tdescTy.isScattered()) @@ -423,6 +442,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, l3_hint); } +void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, + Value tensorDesc, ArrayRef<OpFoldResult> offsets, + UnitAttr packed, DenseI64ArrayAttr transpose, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector<Value> dynamicOffsets; + SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr, + packed, transpose, l1_hint, l2_hint, l3_hint); +} + LogicalResult LoadNdOp::verify() { auto tdescTy = getTensorDescType(); auto valueTy = getType(); @@ -529,6 +564,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); } +void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, + Value tensorDesc, ArrayRef<OpFoldResult> offsets, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector<Value> dynamicOffsets; + SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr, + l1_hint, l2_hint, l3_hint); +} + LogicalResult StoreNdOp::verify() { auto dstTy = getTensorDescType(); // Tile auto valTy = getValueType(); // Vector @@ -674,7 +724,7 @@ LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); if (tdescTy && !tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); + return emitOpError("Expects a scattered TensorDesc."); if (!tdescTy && getRankOf(getSource()) > 1) return emitOpError( @@ -755,7 +805,7 @@ LogicalResult StoreScatterOp::verify() { auto valueTy = getValueType(); if (tdescTy && !tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); + return emitOpError("Expects a scattered TensorDesc."); if (!tdescTy && getRankOf(getDest()) > 1) return emitOpError( @@ -928,9 +978,107 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add<FoldConvertLayoutOp>(context); } +//===----------------------------------------------------------------------===// +// XeGPU_LoadMatrixOp +//===----------------------------------------------------------------------===// +void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, + TypedValue<MemDescType> memDesc, + llvm::ArrayRef<OpFoldResult> offsets, + LayoutTrait layout) { + llvm::SmallVector<Value> dynamicOffsets; + llvm::SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, + layout); +} + +LogicalResult LoadMatrixOp::verify() { + VectorType resTy = getRes().getType(); + MemDescType mdescTy = getMemDesc().getType(); + + if (mdescTy.getRank() != 2) + return emitOpError("mem_desc must be 2D."); + + ArrayRef<int64_t> valueShape = resTy.getShape(); + ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed mem_desc shape."); + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_StoreMatrixOp +//===----------------------------------------------------------------------===// +void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, + TypedValue<MemDescType> memDesc, + llvm::ArrayRef<OpFoldResult> offsets, + LayoutTrait layout) { + llvm::SmallVector<Value> dynamicOffsets; + llvm::SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, + layout); +} + +LogicalResult StoreMatrixOp::verify() { + VectorType dataTy = getData().getType(); + MemDescType mdescTy = getMemDesc().getType(); + + if (mdescTy.getRank() != 2) + return emitOpError("mem_desc must be 2D."); + + ArrayRef<int64_t> dataShape = dataTy.getShape(); + ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("data shape must not exceed mem_desc shape."); + + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_MemDescSubviewOp +//===----------------------------------------------------------------------===// + +void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, + Type resTy, Value src, + llvm::ArrayRef<OpFoldResult> offsets) { + llvm::SmallVector<Value> dynamicOffsets; + llvm::SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); +} + +LogicalResult MemDescSubviewOp::verify() { + MemDescType srcTy = getSrc().getType(); + MemDescType resTy = getRes().getType(); + ArrayRef<int64_t> srcShape = srcTy.getShape(); + ArrayRef<int64_t> resShape = resTy.getShape(); + + if (srcTy.getRank() < resTy.getRank()) + return emitOpError("result rank must not exceed source rank."); + + if (llvm::any_of( + llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed source shape."); + + if (srcTy.getStrides() != resTy.getStrides()) + return emitOpError("result must inherit the source strides."); + + return success(); +} + } // namespace xegpu } // namespace mlir +namespace mlir { +#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc> +} // namespace mlir #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc> #define GET_OP_CLASSES #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc> diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 850f70c..8f1208e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -125,42 +125,15 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) { struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern; - // Calculate offset for each subgroup - static SmallVector<OpFoldResult> - calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc, - const SmallVector<OpFoldResult> &originalOffsets, - const SmallVector<Value> &localOffset, - const SmallVector<int64_t> &distUnitBaseAddr, - const SmallVector<int64_t> &distUnitShape) { - assert(localOffset.size() == distUnitBaseAddr.size() && - "localOffset and distUnitBaseAddr must have the same rank"); - - SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(), - originalOffsets.end()); - size_t rank = localOffset.size(); - for (size_t i = 0; i < rank; ++i) { - size_t dimIdx = originalOffsets.size() - rank + i; - Value constOffset = - arith::ConstantIndexOp::create(rewriter, loc, distUnitBaseAddr[i]); - Value offset = - rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset); - Value modValue = - arith::ConstantIndexOp::create(rewriter, loc, distUnitShape[i]); - Value offsetMod = - rewriter.createOrFold<index::RemUOp>(loc, offset, modValue); - Value origOffset = getValueOrCreateConstantIndexOp( - rewriter, loc, originalOffsets[dimIdx]); - Value globalOffset = - rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod); - globalOffsets[dimIdx] = globalOffset; - } - - return globalOffsets; - } - LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + + // Ensure that the op has explicit offsets specified (either dynamic or + // constant). + if (op.getMixedOffsets().empty()) + return failure(); + Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); @@ -177,73 +150,98 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { return rewriter.notifyMatchFailure( op, "sgLayout attribute is required in layout"); - SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; - - // TODO : Handle order attribute // Get the subgroup ID - auto linearSgId = + Value linearSgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - // Create constants for layout dimensions - SmallVector<Value> sgLayoutDim(sgLayout.size()); - SmallVector<Value> sgDataDim(sgShape.size()); - - for (size_t i = 0; i < sgLayout.size(); i++) { - sgLayoutDim[i] = - arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]); - sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); - } - int64_t startOfRange = -1, endOfRange = -1; bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange); - Value adjustedSgId = linearSgId; if (sgIdRangeSpecified) { int64_t sgCount = endOfRange - startOfRange; if (computeProduct(sgLayout) != sgCount) return rewriter.notifyMatchFailure( op, "sg_layout size must match the sg_id_range"); - // Subtract startOfRange from the original subgroup id to get the adjusted - // sg id + // Subtract startOfRange from the original subgroup id to get + // the adjusted sg id Value startOfRangeVal = arith::ConstantIndexOp::create(rewriter, loc, startOfRange); - adjustedSgId = + linearSgId = rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal); } - auto deLinearizeSgId = - affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim); - if (failed(deLinearizeSgId)) + auto maybeTdescOffsets = + layout.getOffsets(rewriter, loc, linearSgId, wgShape); + if (failed(maybeTdescOffsets)) return failure(); - SmallVector<Value> sgIds = *deLinearizeSgId; - - // Calculate distribution unit shape and local offsets for subgroup - SmallVector<int64_t> distUnitShape(sgLayout.size()); - SmallVector<Value> localOffset(sgLayout.size()); - for (size_t i = 0; i < sgLayout.size(); i++) { - distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]); - localOffset[i] = - rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]); - } - - SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets(); + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; xegpu::TensorDescType newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), layout.dropSgLayoutAndData()); + SmallVector<Value> newCreateNdOps; - for (SmallVector<int64_t> distUnitBaseAddr : - StaticTileOffsetRange(wgShape, distUnitShape)) { - SmallVector<OpFoldResult> globalOffsets = - calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset, - distUnitBaseAddr, distUnitShape); - - auto newCreateNdOp = xegpu::CreateNdDescOp::create( - rewriter, loc, newTdescTy, op.getSource(), globalOffsets, + SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets(); + + for (auto tdescOffsets : *maybeTdescOffsets) { + SmallVector<OpFoldResult> sgOffsets; + size_t rank = tdescOffsets.size(); + for (size_t i = 0; i < rank; i++) { + size_t idx = origOffsets.size() - rank + i; + Value add = rewriter.createOrFold<index::AddOp>( + loc, tdescOffsets[i], + getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx])); + sgOffsets.push_back(add); + } + + auto newOp = xegpu::CreateNdDescOp::create( + rewriter, loc, newTdescTy, op.getSource(), sgOffsets, op.getMixedSizes(), op.getMixedStrides()); - newCreateNdOps.push_back(newCreateNdOp); + newCreateNdOps.push_back(newOp); } + rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); + return success(); + } +}; + +// This pattern transforms the CreateNdDescOp without offsets to create a +// subgroup descriptor from a workgroup descriptor +struct WgToSgCreateNdOpNoOffset + : public OpConversionPattern<xegpu::CreateNdDescOp> { + using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Check no offsets are specified. + if (!op.getMixedOffsets().empty()) + return failure(); + + Location loc = op.getLoc(); + MLIRContext *ctx = op.getContext(); + xegpu::TensorDescType tdescTy = op.getType(); + auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout()); + if (!layout || !layout.isWgLayout()) + return failure(); + + Type elemTy = tdescTy.getElementType(); + ArrayRef<int64_t> wgShape = tdescTy.getShape(); + + SmallVector<int64_t> sgShape; + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); + xegpu::TensorDescType newTdescTy = + xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), + layout.dropSgLayoutAndData()); + + SmallVector<Value> newCreateNdOps(count); + std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() { + return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy, + op.getSource(), op.getMixedSizes(), + op.getMixedStrides()); + }); rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); return success(); @@ -298,6 +296,205 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> { } }; +// Utility function to compute global offsets for subgroup operations. +// Returns a vector of new offsets for each subgroup, given the original op's +// offsets and subgroup relative offsets. +static SmallVector<SmallVector<OpFoldResult>> +computeOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList, + ArrayRef<OpFoldResult> origOffsets, + ConversionPatternRewriter &rewriter) { + SmallVector<SmallVector<OpFoldResult>> finalOffsets; + Location loc = op->getLoc(); + for (const auto &sgOffsets : sgOffsetsList) { + SmallVector<OpFoldResult> newOffsets; + size_t rank = sgOffsets.size(); + for (size_t i = 0; i < rank; i++) { + size_t idx = origOffsets.size() - rank + i; + Value add = rewriter.createOrFold<index::AddOp>( + loc, sgOffsets[i], + getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx])); + newOffsets.push_back(add); + } + finalOffsets.push_back(std::move(newOffsets)); + } + return finalOffsets; +} + +// Utility function to get sgShape, sgOffsetList for a given +// op. +template <typename OpTy, typename AdaptorTy> +LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor, + ConversionPatternRewriter &rewriter, + SmallVector<int64_t> &sgShape, + SmallVector<SmallVector<Value>> &sgOffsetList) { + int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); + if (offsetSize == 0 && (!op.getConstOffsetsAttr())) + return failure(); + + Location loc = op.getLoc(); + Value tdesc = op.getTensorDesc(); + auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType()); + if (!tdescTy) + return failure(); + auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout()); + if (!layout) + return failure(); + + SmallVector<int64_t> sgLayout; + auto sgLayoutAttr = layout.getSgLayout(); + if (!sgLayoutAttr) + return rewriter.notifyMatchFailure( + op, "sgLayout attribute is required in layout"); + sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); + + ArrayRef<int64_t> wgShape = tdescTy.getShape(); + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); + + // Get the subgroup ID + Value linearSgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + + int64_t startOfRange = -1, endOfRange = -1; + bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange); + + if (sgIdRangeSpecified) { + int64_t sgCount = endOfRange - startOfRange; + if (computeProduct(sgLayout) != sgCount) + return rewriter.notifyMatchFailure( + op, "sg_layout size must match the sg_id_range"); + Value startOfRangeVal = + rewriter.create<arith::ConstantIndexOp>(loc, startOfRange); + linearSgId = + rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal); + } + + auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + sgOffsetList = *sgOffsets; + return success(); +} + +template <typename OpTy> +SmallVector<OpFoldResult> getOffsets(OpTy op, + ConversionPatternRewriter &rewriter) { + SmallVector<OpFoldResult> origOffsets; + if (auto constOffsets = op.getConstOffsetsAttr()) { + for (auto attr : constOffsets.asArrayRef()) + origOffsets.push_back(rewriter.getIndexAttr(attr)); + } + for (auto v : op.getOffsets()) + origOffsets.push_back(v); + return origOffsets; +} + +// This pattern transforms the LoadNdOp with explicit offsets to load +// subgroup data. +struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { + using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<int64_t> sgShape; + SmallVector<SmallVector<Value>> sgOffsetList; + + // Do the distribution from workgroup to subgroup and get subgroup offsets + if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + return failure(); + + // Get the original workgroup offsets + SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter); + + // Calculate the final offsets for each subgroup + auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); + + SmallVector<Value> newLoadOps; + for (auto [offsets, tdesc] : + llvm::zip(finalOffsets, adaptor.getTensorDesc())) { + VectorType newResTy = VectorType::get( + sgShape, + dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType()); + auto newLoadOp = rewriter.create<xegpu::LoadNdOp>( + op.getLoc(), newResTy, tdesc, offsets, + /*packed=*/nullptr, + /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + newLoadOps.push_back(newLoadOp); + } + rewriter.replaceOpWithMultiple(op, {newLoadOps}); + return success(); + } +}; + +// This pattern transforms the StoreNdOp with explicit offsets to store +// subgroup data. +struct WgToSgStoreNdOpWithOffset + : public OpConversionPattern<xegpu::StoreNdOp> { + using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<int64_t> sgShape; + SmallVector<SmallVector<Value>> sgOffsetList; + + // Do the distribution from workgroup to subgroup and get subgroup offsets + if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + return failure(); + + // Get the original workgroup offsets + SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter); + + // Calculate the final offsets for each subgroup + auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); + + for (auto [offsets, tdesc, value] : + llvm::zip(finalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) { + rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, offsets, + op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + } + rewriter.eraseOp(op); + return success(); + } +}; + +// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch +// subgroup data. +struct WgToSgPrefetchNdOpWithOffset + : public OpConversionPattern<xegpu::PrefetchNdOp> { + using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<int64_t> sgShape; + SmallVector<SmallVector<Value>> sgOffsetList; + + // Do the distribution from workgroup to subgroup and get subgroup offsets + if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + return failure(); + + // Get the original workgroup offsets + SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter); + + // Calculate the final offsets for each subgroup + auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); + + for (auto [offsets, tdesc] : + llvm::zip(finalOffsets, adaptor.getTensorDesc())) { + rewriter.create<xegpu::PrefetchNdOp>( + op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + } + rewriter.eraseOp(op); + return success(); + } +}; + /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the /// offsets of the new subgroup src tensor descriptors. @@ -526,8 +723,8 @@ struct WgToSgElementwiseOp : public ConversionPattern { // is lowered to: // #a = #xegpu.layout<inst_data = [16, 16]> // #b = #xegpu.layout<inst_data = [8, 16]> -// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32> -// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32> +// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32> +// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32> // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32> // clang-format on struct WgToSgConvertLayoutOp @@ -649,16 +846,56 @@ struct UnrealizedConversionCastOpPattern } }; +// This pattern distributes arith.constant op into subgroup-level constants +struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { + using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue()); + auto vecType = dyn_cast<VectorType>(op.getType()); + if (!vecAttr || !vecAttr.isSplat() || !vecType) + return failure(); + + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); + if (!layout || !layout.getSgLayout()) + return failure(); + + ArrayRef<int64_t> wgShape = vecType.getShape(); + SmallVector<int64_t> sgShape; + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); + + // Current limitation: constant of vector with single value. + // TODO: support more complex cases, e.g., vector with multiple values. + Attribute singleVal = vecAttr.getSplatValue<Attribute>(); + + auto newType = VectorType::get(sgShape, vecType.getElementType()); + auto sgAttr = DenseElementsAttr::get(newType, singleVal); + auto cstOp = + arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); + if (auto newLayout = layout.dropSgLayoutAndData()) + xegpu::setLayoutAttr(cstOp->getResult(0), newLayout); + SmallVector<Value> newConsts(count, cstOp); + + rewriter.replaceOpWithMultiple(op, {newConsts}); + return success(); + } +}; + } // namespace namespace mlir { namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { - patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, - WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, - UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, - WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>( - patterns.getContext()); + patterns + .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp, + WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset, + WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, + WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, + WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, + WgToSgArithConstantOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -770,6 +1007,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(xegpu::getLayoutAttr(op.getResult())); }); + target.addDynamicallyLegalOp<arith::ConstantOp>( + [=](arith::ConstantOp op) -> bool { + auto vecType = dyn_cast<VectorType>(op.getType()); + if (!vecType) + return true; + return isLegal(xegpu::getLayoutAttr(op.getResult())); + }); + target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt index 98e84a4..d9bf4a1 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt @@ -7,5 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUUtils LINK_LIBS PUBLIC MLIRIR MLIRSCFTransforms + MLIRGPUDialect + MLIRXeVMDialect MLIRXeGPUDialect ) diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 2cf21fb..19eedba 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -11,6 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" @@ -404,3 +406,21 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( (void)mlir::applyPartialConversion(op, target, std::move(patterns)); } } + +std::optional<std::string> xegpu::getChipStr(Operation *op) { + auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>(); + + if (!gpuModuleOp) + return std::nullopt; + + auto targetAttrs = gpuModuleOp.getTargets(); + if (targetAttrs) { + for (auto &attr : *targetAttrs) { + auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr); + if (xevmAttr) + return xevmAttr.getChip().str(); + } + } + + return std::nullopt; +} diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index f704fbf..52162a4 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) { } // Compilation is lazy and it doesn't populate object cache unless requested. // In case object dump is requested before cache is populated, we need to - // force compilation manually. + // force compilation manually. if (cache->isEmpty()) { for (std::string &functionName : functionNames) { auto result = lookupPacked(functionName); @@ -400,13 +400,6 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options, return symbolMap; }; engine->registerSymbols(runtimeSymbolMap); - - // Execute the global constructors from the module being processed. - // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a - // crash for AArch64 see related issue #71963. - if (!engine->jit->getTargetTriple().isAArch64()) - cantFail(engine->jit->initialize(engine->jit->getMainJITDylib())); - return std::move(engine); } @@ -442,6 +435,7 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const { Error ExecutionEngine::invokePacked(StringRef name, MutableArrayRef<void *> args) { + initialize(); auto expectedFPtr = lookupPacked(name); if (!expectedFPtr) return expectedFPtr.takeError(); @@ -451,3 +445,13 @@ Error ExecutionEngine::invokePacked(StringRef name, return Error::success(); } + +void ExecutionEngine::initialize() { + if (isInitialized) + return; + // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a + // crash for AArch64 see related issue #71963. + if (!jit->getTargetTriple().isAArch64()) + cantFail(jit->initialize(jit->getMainJITDylib())); + isInitialized = true; +} diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp index 2107df3..0ada4cc 100644 --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -202,6 +202,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint, auto engine = std::move(*expectedEngine); + engine->initialize(); + auto expectedFPtr = engine->lookupPacked(entryPoint); if (!expectedFPtr) return expectedFPtr.takeError(); diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp index 2f47939..af4ea5a 100644 --- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -290,8 +290,7 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, DivisionFixupFn fixup) { const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); - - if (!rhsMin.isZero()) { + if (!rhsMin.isZero() && !rhsMax.isZero()) { auto udiv = [&fixup](const APInt &a, const APInt &b) -> std::optional<APInt> { return fixup(a, b, a.udiv(b)); diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp index 950b85e2..258fed1 100644 --- a/mlir/lib/RegisterAllDialects.cpp +++ b/mlir/lib/RegisterAllDialects.cpp @@ -102,6 +102,7 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Target/LLVM/NVVM/Target.h" #include "mlir/Target/LLVM/ROCDL/Target.h" +#include "mlir/Target/LLVM/XeVM/Target.h" #include "mlir/Target/SPIRV/Target.h" /// Add all the MLIR dialects to the provided registry. @@ -199,6 +200,7 @@ void mlir::registerAllDialects(DialectRegistry ®istry) { NVVM::registerNVVMTargetInterfaceExternalModels(registry); ROCDL::registerROCDLTargetInterfaceExternalModels(registry); spirv::registerSPIRVTargetInterfaceExternalModels(registry); + xevm::registerXeVMTargetInterfaceExternalModels(registry); } /// Append all the MLIR dialects to the registry contained in the given context. diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp index 8f7c67c..232ddaf 100644 --- a/mlir/lib/RegisterAllExtensions.cpp +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -58,6 +58,7 @@ #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" /// This function may be called to register all MLIR dialect extensions with the /// provided registry. diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt index 6eb0abc..f0c3ac4 100644 --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(SPIRV) add_subdirectory(LLVMIR) add_subdirectory(LLVM) add_subdirectory(SMTLIB) +add_subdirectory(Wasm) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 8e83e45..a5ee64c 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1447,7 +1447,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { } if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) { if (auto iType = dyn_cast<IntegerType>( - cast<TensorType>(dense.getType()).getElementType())) { + cast<ShapedType>(dense.getType()).getElementType())) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, shouldMapToUnsigned(iType.getSignedness())); @@ -1456,7 +1456,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { return success(); } if (auto iType = dyn_cast<IndexType>( - cast<TensorType>(dense.getType()).getElementType())) { + cast<ShapedType>(dense.getType()).getElementType())) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, false); }); diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt index f6e44c6..9a0e4d4 100644 --- a/mlir/lib/Target/LLVM/CMakeLists.txt +++ b/mlir/lib/Target/LLVM/CMakeLists.txt @@ -210,3 +210,27 @@ if(MLIR_ENABLE_ROCM_CONVERSIONS) ) endif() +if ("SPIRV" IN_LIST LLVM_TARGETS_TO_BUILD) + set(SPIRV_LIBS + SPIRVCodeGen + SPIRVDesc + SPIRVInfo + ) +endif() + +add_mlir_dialect_library(MLIRXeVMTarget + XeVM/Target.cpp + + OBJECT + + LINK_COMPONENTS + ${SPIRV_LIBS} + + LINK_LIBS PUBLIC + MLIRIR + MLIRExecutionEngineUtils + MLIRSupport + MLIRGPUDialect + MLIRTargetLLVM + MLIRXeVMToLLVMIRTranslation +) diff --git a/mlir/lib/Target/LLVM/XeVM/Target.cpp b/mlir/lib/Target/LLVM/XeVM/Target.cpp new file mode 100644 index 0000000..1e6784a2 --- /dev/null +++ b/mlir/lib/Target/LLVM/XeVM/Target.cpp @@ -0,0 +1,418 @@ +//===- Target.cpp - MLIR LLVM XeVM target compilation -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This files defines XeVM target related functions including registration +// calls for the `#xevm.target` compilation attribute. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVM/XeVM/Target.h" + +#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectResourceBlobManager.h" +#include "mlir/Target/LLVM/XeVM/Utils.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/Target/TargetMachine.h" + +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/Config/Targets.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Process.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" + +#include <cstdint> +#include <cstdlib> + +using namespace mlir; +using namespace mlir::xevm; + +namespace { +// XeVM implementation of the gpu:TargetAttrInterface. +class XeVMTargetAttrImpl + : public gpu::TargetAttrInterface::FallbackModel<XeVMTargetAttrImpl> { +public: + std::optional<SmallVector<char, 0>> + serializeToObject(Attribute attribute, Operation *module, + const gpu::TargetOptions &options) const; + + Attribute createObject(Attribute attribute, Operation *module, + const SmallVector<char, 0> &object, + const gpu::TargetOptions &options) const; +}; +} // namespace + +void mlir::xevm::registerXeVMTargetInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) { + XeVMTargetAttr::attachInterface<XeVMTargetAttrImpl>(*ctx); + }); +} + +void mlir::xevm::registerXeVMTargetInterfaceExternalModels( + MLIRContext &context) { + DialectRegistry registry; + registerXeVMTargetInterfaceExternalModels(registry); + context.appendDialectRegistry(registry); +} + +SerializeGPUModuleBase::SerializeGPUModuleBase( + Operation &module, XeVMTargetAttr xeTarget, + const gpu::TargetOptions &targetOptions) + : ModuleToObject(module, xeTarget.getTriple(), "", {}, xeTarget.getO()), + xeTarget(xeTarget), librariesToLink(targetOptions.getLibrariesToLink()), + targetOptions(targetOptions) { + if (xeTarget.getLinkFiles()) + librariesToLink.append(xeTarget.getLinkFiles().begin(), + xeTarget.getLinkFiles().end()); +} + +XeVMTargetAttr SerializeGPUModuleBase::getTarget() const { return xeTarget; } + +std::optional<SmallVector<std::unique_ptr<llvm::Module>>> +SerializeGPUModuleBase::loadBitcodeFiles(llvm::Module &module) { + if (librariesToLink.empty()) + return SmallVector<std::unique_ptr<llvm::Module>>(); + SmallVector<std::unique_ptr<llvm::Module>> bcFiles; + if (failed(loadBitcodeFilesFromList(module.getContext(), librariesToLink, + bcFiles))) + return std::nullopt; + return std::move(bcFiles); +} + +gpu::GPUModuleOp SerializeGPUModuleBase::getGPUModuleOp() { + return dyn_cast<gpu::GPUModuleOp>(&SerializeGPUModuleBase::getOperation()); +} + +// There is 1 way to finalize IL to native code: IGC +// There are 2 ways to access IGC: AOT (ocloc) and JIT (L0 runtime). +// - L0 runtime consumes IL and is external to MLIR codebase (rt wrappers). +// - `ocloc` tool can be "queried" from within MLIR. +std::optional<SmallVector<char, 0>> +SerializeGPUModuleBase::compileToBinary(const std::string &asmStr, + StringRef inputFormat) { + using TmpFile = std::pair<llvm::SmallString<128>, llvm::FileRemover>; + // Find the `ocloc` tool. + std::optional<std::string> oclocCompiler = findTool("ocloc"); + if (!oclocCompiler) + return std::nullopt; + Location loc = getGPUModuleOp().getLoc(); + std::string basename = llvm::formatv( + "mlir-{0}-{1}-{2}", getGPUModuleOp().getNameAttr().getValue(), + getTarget().getTriple(), getTarget().getChip()); + + auto createTemp = [&](StringRef name, + StringRef suffix) -> std::optional<TmpFile> { + llvm::SmallString<128> filePath; + if (auto ec = llvm::sys::fs::createTemporaryFile(name, suffix, filePath)) { + getGPUModuleOp().emitError() + << "Couldn't create the temp file: `" << filePath + << "`, error message: " << ec.message(); + return std::nullopt; + } + return TmpFile(filePath, llvm::FileRemover(filePath.c_str())); + }; + // Create temp file + std::optional<TmpFile> asmFile = createTemp(basename, "asm"); + std::optional<TmpFile> binFile = createTemp(basename, ""); + std::optional<TmpFile> logFile = createTemp(basename, "log"); + if (!logFile || !asmFile || !binFile) + return std::nullopt; + // Dump the assembly to a temp file + std::error_code ec; + { + llvm::raw_fd_ostream asmStream(asmFile->first, ec); + if (ec) { + emitError(loc) << "Couldn't open the file: `" << asmFile->first + << "`, error message: " << ec.message(); + return std::nullopt; + } + asmStream << asmStr; + if (asmStream.has_error()) { + emitError(loc) << "An error occurred while writing the assembly to: `" + << asmFile->first << "`."; + return std::nullopt; + } + asmStream.flush(); + } + // Set cmd options + std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts = + targetOptions.tokenizeCmdOptions(); + // Example: --gpu-module-to-binary="opts='opt1 opt2'" + const std::string cmdOptsStr = "\"" + llvm::join(cmdOpts.second, " ") + "\""; + SmallVector<StringRef, 12> oclocArgs( + {"ocloc", "compile", "-file", asmFile->first, inputFormat, "-device", + getTarget().getChip(), "-output", binFile->first, "-output_no_suffix", + "-options", cmdOptsStr}); + +// Dump tool invocation commands. +#define DEBUG_TYPE "serialize-to-binary" + LLVM_DEBUG({ + llvm::dbgs() << "Tool invocation for module: " + << getGPUModuleOp().getNameAttr() << "\n"; + llvm::interleave(oclocArgs, llvm::dbgs(), " "); + llvm::dbgs() << "\n"; + }); +#undef DEBUG_TYPE + // Helper function for printing tool error logs. + std::string message; + auto emitLogError = + [&](StringRef toolName) -> std::optional<SmallVector<char, 0>> { + if (message.empty()) { + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> toolStderr = + llvm::MemoryBuffer::getFile(logFile->first); + if (toolStderr) + emitError(loc) << toolName << " invocation failed. Log:\n" + << toolStderr->get()->getBuffer(); + else + emitError(loc) << toolName << " invocation failed."; + return std::nullopt; + } + emitError(loc) << toolName + << " invocation failed, error message: " << message; + return std::nullopt; + }; + std::optional<StringRef> redirects[] = { + std::nullopt, + logFile->first, + logFile->first, + }; + // Invoke ocloc. + if (llvm::sys::ExecuteAndWait(oclocCompiler.value(), oclocArgs, std::nullopt, + redirects, 0, 0, &message)) + return emitLogError("`ocloc`"); + binFile->first.append(".bin"); + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> binaryBuffer = + llvm::MemoryBuffer::getFile(binFile->first); + if (!binaryBuffer) { + emitError(loc) << "Couldn't open the file: `" << binFile->first + << "`, error message: " << binaryBuffer.getError().message(); + return std::nullopt; + } + StringRef bin = (*binaryBuffer)->getBuffer(); + return SmallVector<char, 0>(bin.begin(), bin.end()); +} + +std::optional<std::string> SerializeGPUModuleBase::findTool(StringRef tool) { + // 1. Check the toolkit path given in the command line. + StringRef pathRef = targetOptions.getToolkitPath(); + SmallVector<char, 256> path; + if (!pathRef.empty()) { + path.insert(path.begin(), pathRef.begin(), pathRef.end()); + llvm::sys::path::append(path, "bin", tool); + if (llvm::sys::fs::can_execute(path)) + return StringRef(path.data(), path.size()).str(); + } + // 2. Check PATH. + if (std::optional<std::string> toolPath = + llvm::sys::Process::FindInEnvPath("PATH", tool)) + return *toolPath; + + getGPUModuleOp().emitError() + << "Couldn't find the `" << tool + << "` binary. Please specify the toolkit " + "path via GpuModuleToBinaryPass or add the compiler to $PATH`."; + return std::nullopt; +} + +namespace { +class SPIRVSerializer : public SerializeGPUModuleBase { +public: + SPIRVSerializer(Operation &module, XeVMTargetAttr xeTarget, + const gpu::TargetOptions &targetOptions) + : SerializeGPUModuleBase(module, xeTarget, targetOptions) {} + + static void init(); + + /// Serializes the LLVM module to an object format, depending on the + /// compilation target selected in target options. + std::optional<SmallVector<char, 0>> + moduleToObject(llvm::Module &llvmModule) override; + +private: + /// Translates the LLVM module to SPIR-V binary using LLVM's + /// SPIR-V target. + std::optional<std::string> + translateToSPIRVBinary(llvm::Module &llvmModule, + llvm::TargetMachine &targetMachine); +}; +} // namespace + +void SPIRVSerializer::init() { + static llvm::once_flag initializeBackendOnce; + llvm::call_once(initializeBackendOnce, []() { +#if LLVM_HAS_SPIRV_TARGET + LLVMInitializeSPIRVTarget(); + LLVMInitializeSPIRVTargetInfo(); + LLVMInitializeSPIRVTargetMC(); + LLVMInitializeSPIRVAsmPrinter(); +#endif + }); +} + +std::optional<SmallVector<char, 0>> +SPIRVSerializer::moduleToObject(llvm::Module &llvmModule) { +#define DEBUG_TYPE "serialize-to-llvm" + LLVM_DEBUG({ + llvm::dbgs() << "LLVM IR for module: " << getGPUModuleOp().getNameAttr() + << "\n"; + llvm::dbgs() << llvmModule << "\n"; + llvm::dbgs().flush(); + }); +#undef DEBUG_TYPE + + // Return LLVM IR if the compilation target is `offload`. + if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Offload) + return SerializeGPUModuleBase::moduleToObject(llvmModule); + +#if !LLVM_HAS_SPIRV_TARGET + getGPUModuleOp()->emitError("The `SPIRV` target was not built. Please enable " + "it when building LLVM."); + return std::nullopt; +#endif // LLVM_HAS_SPIRV_TARGET + + std::optional<llvm::TargetMachine *> targetMachine = + getOrCreateTargetMachine(); + if (!targetMachine) { + getGPUModuleOp().emitError() << "Target Machine unavailable for triple " + << triple << ", can't optimize with LLVM\n"; + return std::nullopt; + } + + // Return SPIRV if the compilation target is `assembly`. + if (targetOptions.getCompilationTarget() == + gpu::CompilationTarget::Assembly) { + std::optional<std::string> serializedISA = + translateToISA(llvmModule, **targetMachine); + if (!serializedISA) { + getGPUModuleOp().emitError() << "Failed translating the module to ISA." + << triple << ", can't compile with LLVM\n"; + return std::nullopt; + } + +#define DEBUG_TYPE "serialize-to-isa" + LLVM_DEBUG({ + llvm::dbgs() << "SPIR-V for module: " << getGPUModuleOp().getNameAttr() + << "\n"; + llvm::dbgs() << *serializedISA << "\n"; + llvm::dbgs().flush(); + }); +#undef DEBUG_TYPE + + // Make sure to include the null terminator. + StringRef bin(serializedISA->c_str(), serializedISA->size() + 1); + return SmallVector<char, 0>(bin.begin(), bin.end()); + } + + // Level zero runtime is set up to accept SPIR-V binary + // translateToSPIRVBinary translates the LLVM module to SPIR-V binary + // using LLVM's SPIRV target. + // compileToBinary can be used in the future if level zero runtime + // implementation switches to native XeVM binary format. + std::optional<std::string> serializedSPIRVBinary = + translateToSPIRVBinary(llvmModule, **targetMachine); + if (!serializedSPIRVBinary) { + getGPUModuleOp().emitError() << "Failed translating the module to Binary."; + return std::nullopt; + } + if (serializedSPIRVBinary->size() % 4) { + getGPUModuleOp().emitError() << "SPIRV code size must be a multiple of 4."; + return std::nullopt; + } + StringRef bin(serializedSPIRVBinary->c_str(), serializedSPIRVBinary->size()); + return SmallVector<char, 0>(bin.begin(), bin.end()); +} + +std::optional<std::string> +SPIRVSerializer::translateToSPIRVBinary(llvm::Module &llvmModule, + llvm::TargetMachine &targetMachine) { + std::string targetISA; + llvm::raw_string_ostream stream(targetISA); + + { // Drop pstream after this to prevent the ISA from being stuck buffering + llvm::buffer_ostream pstream(stream); + llvm::legacy::PassManager codegenPasses; + if (targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr, + llvm::CodeGenFileType::ObjectFile)) + return std::nullopt; + + codegenPasses.run(llvmModule); + } + return targetISA; +} + +std::optional<SmallVector<char, 0>> +XeVMTargetAttrImpl::serializeToObject(Attribute attribute, Operation *module, + const gpu::TargetOptions &options) const { + if (!module) + return std::nullopt; + auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module); + if (!gpuMod) { + module->emitError("expected to be a gpu.module op"); + return std::nullopt; + } + auto xeTarget = cast<XeVMTargetAttr>(attribute); + if (xeTarget.getTriple().starts_with("spirv")) { + gpuMod.walk([&](LLVM::LLVMFuncOp funcOp) { + if (funcOp->hasAttr(gpu::GPUDialect::getKernelFuncAttrName())) { + funcOp.setIntelReqdSubGroupSize(16); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + SPIRVSerializer serializer(*module, cast<XeVMTargetAttr>(attribute), + options); + serializer.init(); + +#if !LLVM_HAS_SPIRV_TARGET + module->emitError("Cannot run `TargetRegistry::lookupTarget()` for SPIRV " + "without having the target built."); +#endif + + return serializer.run(); + } + module->emitError("Unsupported XeVM target triple: ") << xeTarget.getTriple(); + return std::nullopt; +} + +Attribute +XeVMTargetAttrImpl::createObject(Attribute attribute, Operation *module, + const SmallVector<char, 0> &object, + const gpu::TargetOptions &options) const { + Builder builder(attribute.getContext()); + gpu::CompilationTarget format = options.getCompilationTarget(); + auto xeTarget = cast<XeVMTargetAttr>(attribute); + SmallVector<NamedAttribute, 2> properties; + if (format == gpu::CompilationTarget::Assembly) + properties.push_back( + builder.getNamedAttr("O", builder.getI32IntegerAttr(xeTarget.getO()))); + + DictionaryAttr objectProps; + if (!properties.empty()) + objectProps = builder.getDictionaryAttr(properties); + + return builder.getAttr<gpu::ObjectAttr>( + attribute, format, + builder.getStringAttr(StringRef(object.data(), object.size())), + objectProps, /*kernels=*/nullptr); +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 90462d1..e67cfed 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -135,33 +135,83 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) { llvm_unreachable("unsupported vote kind"); } -/// Return the intrinsic ID associated with ldmatrix for the given paramters. -static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, - int32_t num) { - if (layout == NVVM::MMALayout::row) { +static llvm::Intrinsic::ID +getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, + NVVM::LdStMatrixShapeAttr shape, + NVVM::LdStMatrixEltType eltType) { + if (shape.getM() == 8 && shape.getN() == 8) { switch (num) { case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; - default: - llvm_unreachable("unsupported number of matrix"); + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; } - - } else { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; - case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; - case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; - default: - llvm_unreachable("unsupported number of matrix"); + } else if (shape.getM() == 8 && shape.getN() == 16) { + if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32; + case 4: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64; + case 4: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64; + } + } + } else if (shape.getM() == 16 && shape.getN() == 16) { + if (eltType == NVVM::LdStMatrixEltType::B8) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64; + } } } + llvm_unreachable("unknown ldmatrix kind"); } /// Return the intrinsic ID associated with stmatrix for the given paramters. diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 2cdd502..6694de8 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -4356,9 +4356,11 @@ createAlteredByCaptureMap(MapInfoData &mapData, if (!isPtrTy) { auto curInsert = builder.saveIP(); + llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation(); builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation)); auto *memTempAlloc = builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted"); + builder.SetCurrentDebugLocation(DbgLoc); builder.restoreIP(curInsert); builder.CreateStore(newV, memTempAlloc); @@ -5865,6 +5867,10 @@ static bool isTargetDeviceOp(Operation *op) { if (mlir::isa<omp::ThreadprivateOp>(op)) return true; + if (mlir::isa<omp::TargetAllocMemOp>(op) || + mlir::isa<omp::TargetFreeMemOp>(op)) + return true; + if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) if (auto declareTargetIface = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( @@ -5877,6 +5883,85 @@ static bool isTargetDeviceOp(Operation *op) { return false; } +static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *i64Ty = builder.getInt64Ty(); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *returnType = builder.getPtrTy(0); + llvm::FunctionType *fnType = + llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false); + llvm::Function *func = cast<llvm::Function>( + llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst); + if (!allocMemOp) + return failure(); + + // Get "omp_target_alloc" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = allocMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the allocation size. + llvm::DataLayout dataLayout = llvmModule->getDataLayout(); + mlir::Type heapTy = allocMemOp.getAllocatedType(); + llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy); + llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy); + llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); + for (auto typeParam : allocMemOp.getTypeparams()) + allocSize = + builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam)); + // Create call to "omp_target_alloc" with the args as translated llvm values. + llvm::CallInst *call = + builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum}); + llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty()); + + // Map the result + moduleTranslation.mapValue(allocMemOp.getResult(), resultI64); + return success(); +} + +static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *ptrTy = builder.getPtrTy(0); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *voidTy = builder.getVoidTy(); + llvm::FunctionType *fnType = + llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false); + llvm::Function *func = dyn_cast<llvm::Function>( + llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst); + if (!freeMemOp) + return failure(); + + // Get "omp_target_free" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = freeMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the corresponding heapref value in llvm + mlir::Value heapref = freeMemOp.getHeapref(); + llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref); + // Convert heapref int to ptr and call "omp_target_free" + llvm::Value *intToPtr = + builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0)); + builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum}); + return success(); +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including /// OpenMP runtime calls). static LogicalResult @@ -6051,6 +6136,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, // the omp.canonical_loop. return applyUnrollHeuristic(op, builder, moduleTranslation); }) + .Case([&](omp::TargetAllocMemOp) { + return convertTargetAllocMemOp(*op, builder, moduleTranslation); + }) + .Case([&](omp::TargetFreeMemOp) { + return convertTargetFreeMemOp(*op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError() << "not yet implemented: " << inst->getName(); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index c967e86..d8c54ec 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1560,7 +1560,19 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) { } auto resultID = operands[1]; - if (auto shapedType = dyn_cast<ShapedType>(resultType)) { + if (auto tensorType = dyn_cast<TensorArmType>(resultType)) { + SmallVector<Attribute> flattenedElems; + for (Attribute element : elements) { + if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) { + for (auto value : denseElemAttr.getValues<Attribute>()) + flattenedElems.push_back(value); + } else { + flattenedElems.push_back(element); + } + } + auto attr = DenseElementsAttr::get(tensorType, flattenedElems); + constantMap.try_emplace(resultID, attr, tensorType); + } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) { auto attr = DenseElementsAttr::get(shapedType, elements); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index c049574..7fc7795 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -112,7 +113,9 @@ LogicalResult Serializer::serialize() { // TODO: handle the other sections processCapability(); - processExtension(); + if (failed(processExtension())) { + return failure(); + } processMemoryModel(); processDebugInfo(); @@ -204,13 +207,24 @@ void Serializer::processDebugInfo() { // TODO: Encode more debug instructions. } -void Serializer::processExtension() { +LogicalResult Serializer::processExtension() { llvm::SmallVector<uint32_t, 16> extName; - for (spirv::Extension ext : module.getVceTriple()->getExtensions()) { + llvm::SmallSet<Extension, 4> deducedExts( + llvm::from_range, module.getVceTriple()->getExtensions()); + auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info; + if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) { + TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(module); + if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt)) + return module.emitError( + "SPV_KHR_non_semantic_info extension not available"); + deducedExts.insert(nonSemanticInfoExt); + } + for (spirv::Extension ext : deducedExts) { extName.clear(); spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); } + return success(); } void Serializer::processMemoryModel() { @@ -956,6 +970,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, uint32_t resultID = getNextID(); SmallVector<uint32_t, 4> operands = {typeID, resultID}; auto elementType = cast<spirv::CompositeType>(constType).getElementType(0); + if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) { + ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front(); + if (!innerShape.empty()) + elementType = spirv::TensorArmType::get(innerShape, elementType); + } // "If the Result Type is a cooperative matrix type, then there must be only // one Constituent, with scalar type matching the cooperative matrix Component @@ -979,30 +998,10 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, } else { return 0; } - } else if (isa<spirv::TensorArmType>(constType)) { - if (isZeroValue(valueAttr)) { - encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull, - {typeID, resultID}); - return resultID; - } - numberOfConstituents = shapedType.getNumElements(); - operands.reserve(numberOfConstituents + 2); - for (int i = 0; i < numberOfConstituents; ++i) { - uint32_t elementID = 0; - if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) { - elementID = - elementType.isInteger(1) - ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i]) - : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]); - } - if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) { - elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]); - } - if (!elementID) { - return 0; - } - operands.push_back(elementID); - } + } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) { + encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull, + {typeID, resultID}); + return resultID; } else { operands.reserve(numberOfConstituents + 2); for (int i = 0; i < numberOfConstituents; ++i) { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h index 7047869..fb2cecd 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -102,7 +102,7 @@ private: void processDebugInfo(); - void processExtension(); + LogicalResult processExtension(); void processMemoryModel(); diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp index ac338d55..796354e 100644 --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -21,8 +21,11 @@ #include "mlir/Target/SPIRV/Serialization.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" using namespace mlir; @@ -76,24 +79,66 @@ void registerFromSPIRVTranslation() { // Serialization registration //===----------------------------------------------------------------------===// -static LogicalResult serializeModule(spirv::ModuleOp module, - raw_ostream &output) { +static LogicalResult +serializeModule(spirv::ModuleOp moduleOp, raw_ostream &output, + const spirv::SerializationOptions &options) { SmallVector<uint32_t, 0> binary; - if (failed(spirv::serialize(module, binary))) + if (failed(spirv::serialize(moduleOp, binary))) return failure(); - output.write(reinterpret_cast<char *>(binary.data()), - binary.size() * sizeof(uint32_t)); + size_t sizeInBytes = binary.size() * sizeof(uint32_t); + + output.write(reinterpret_cast<char *>(binary.data()), sizeInBytes); + + if (options.saveModuleForValidation) { + size_t dirSeparator = + options.validationFilePrefix.find(llvm::sys::path::get_separator()); + // If file prefix includes directory check if that directory exists. + if (dirSeparator != std::string::npos) { + llvm::StringRef parentDir = + llvm::sys::path::parent_path(options.validationFilePrefix); + if (!llvm::sys::fs::is_directory(parentDir)) + return moduleOp.emitError( + "validation prefix directory does not exist\n"); + } + + SmallString<128> filename; + int fd = 0; + + std::error_code errorCode = llvm::sys::fs::createUniqueFile( + options.validationFilePrefix + "%%%%%%.spv", fd, filename); + if (errorCode) + return moduleOp.emitError("error creating validation output file: ") + << errorCode.message() << "\n"; + + llvm::raw_fd_ostream validationOutput(fd, /*shouldClose=*/true); + validationOutput.write(reinterpret_cast<char *>(binary.data()), + sizeInBytes); + validationOutput.flush(); + } return mlir::success(); } namespace mlir { void registerToSPIRVTranslation() { + static llvm::cl::opt<std::string> validationFilesPrefix( + "spirv-save-validation-files-with-prefix", + llvm::cl::desc( + "When non-empty string is passed each serialized SPIR-V module is " + "saved to an additional file that starts with the given prefix. This " + "is used to generate separate binaries for validation, where " + "`--split-input-file` normally combines all outputs into one. The " + "one combined output (`-o`) is still written. Created files need to " + "be removed manually once processed."), + llvm::cl::init("")); + TranslateFromMLIRRegistration toBinary( "serialize-spirv", "serialize SPIR-V dialect", - [](spirv::ModuleOp module, raw_ostream &output) { - return serializeModule(module, output); + [](spirv::ModuleOp moduleOp, raw_ostream &output) { + return serializeModule(moduleOp, output, + {true, false, !validationFilesPrefix.empty(), + validationFilesPrefix}); }, [](DialectRegistry ®istry) { registry.insert<spirv::SPIRVDialect>(); diff --git a/mlir/lib/Target/Wasm/CMakeLists.txt b/mlir/lib/Target/Wasm/CMakeLists.txt new file mode 100644 index 0000000..890fc0ec --- /dev/null +++ b/mlir/lib/Target/Wasm/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_translation_library(MLIRTargetWasmImport + TranslateRegistration.cpp + TranslateFromWasm.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/Target/Wasm + + LINK_LIBS PUBLIC + MLIRWasmSSADialect + MLIRIR + MLIRSupport + MLIRTranslateLib +) diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp new file mode 100644 index 0000000..8d45052 --- /dev/null +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -0,0 +1,1245 @@ +//===- TranslateFromWasm.cpp - Translating to WasmSSA dialect -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the WebAssembly importer. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/Target/Wasm/WasmBinaryEncoding.h" +#include "mlir/Target/Wasm/WasmImporter.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LEB128.h" + +#include <climits> +#include <cstdint> +#include <variant> + +#define DEBUG_TYPE "wasm-translate" + +static_assert(CHAR_BIT == 8, + "This code expects std::byte to be exactly 8 bits"); + +using namespace mlir; +using namespace mlir::wasm; +using namespace mlir::wasmssa; + +namespace { +using section_id_t = uint8_t; +enum struct WasmSectionType : section_id_t { + CUSTOM = 0, + TYPE = 1, + IMPORT = 2, + FUNCTION = 3, + TABLE = 4, + MEMORY = 5, + GLOBAL = 6, + EXPORT = 7, + START = 8, + ELEMENT = 9, + CODE = 10, + DATA = 11, + DATACOUNT = 12 +}; + +constexpr section_id_t highestWasmSectionID{ + static_cast<section_id_t>(WasmSectionType::DATACOUNT)}; + +#define APPLY_WASM_SEC_TRANSFORM \ + WASM_SEC_TRANSFORM(CUSTOM) \ + WASM_SEC_TRANSFORM(TYPE) \ + WASM_SEC_TRANSFORM(IMPORT) \ + WASM_SEC_TRANSFORM(FUNCTION) \ + WASM_SEC_TRANSFORM(TABLE) \ + WASM_SEC_TRANSFORM(MEMORY) \ + WASM_SEC_TRANSFORM(GLOBAL) \ + WASM_SEC_TRANSFORM(EXPORT) \ + WASM_SEC_TRANSFORM(START) \ + WASM_SEC_TRANSFORM(ELEMENT) \ + WASM_SEC_TRANSFORM(CODE) \ + WASM_SEC_TRANSFORM(DATA) \ + WASM_SEC_TRANSFORM(DATACOUNT) + +template <WasmSectionType> +constexpr const char *wasmSectionName = ""; + +#define WASM_SEC_TRANSFORM(section) \ + template <> \ + [[maybe_unused]] constexpr const char \ + *wasmSectionName<WasmSectionType::section> = #section; +APPLY_WASM_SEC_TRANSFORM +#undef WASM_SEC_TRANSFORM + +constexpr bool sectionShouldBeUnique(WasmSectionType secType) { + return secType != WasmSectionType::CUSTOM; +} + +template <std::byte... Bytes> +struct ByteSequence {}; + +/// Template class for representing a byte sequence of only one byte +template <std::byte Byte> +struct UniqueByte : ByteSequence<Byte> {}; + +[[maybe_unused]] constexpr ByteSequence< + WasmBinaryEncoding::Type::i32, WasmBinaryEncoding::Type::i64, + WasmBinaryEncoding::Type::f32, WasmBinaryEncoding::Type::f64, + WasmBinaryEncoding::Type::v128> valueTypesEncodings{}; + +template <std::byte... allowedFlags> +constexpr bool isValueOneOf(std::byte value, + ByteSequence<allowedFlags...> = {}) { + return ((value == allowedFlags) | ... | false); +} + +template <std::byte... flags> +constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) { + return !isValueOneOf<flags...>(value); +} + +struct GlobalTypeRecord { + Type type; + bool isMutable; +}; + +struct TypeIdxRecord { + size_t id; +}; + +struct SymbolRefContainer { + FlatSymbolRefAttr symbol; +}; + +struct GlobalSymbolRefContainer : SymbolRefContainer { + Type globalType; +}; + +struct FunctionSymbolRefContainer : SymbolRefContainer { + FunctionType functionType; +}; + +using ImportDesc = + std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>; + +using parsed_inst_t = FailureOr<SmallVector<Value>>; + +struct WasmModuleSymbolTables { + SmallVector<FunctionSymbolRefContainer> funcSymbols; + SmallVector<GlobalSymbolRefContainer> globalSymbols; + SmallVector<SymbolRefContainer> memSymbols; + SmallVector<SymbolRefContainer> tableSymbols; + SmallVector<FunctionType> moduleFuncTypes; + + std::string getNewSymbolName(StringRef prefix, size_t id) const { + return (prefix + Twine{id}).str(); + } + + std::string getNewFuncSymbolName() const { + auto id = funcSymbols.size(); + return getNewSymbolName("func_", id); + } + + std::string getNewGlobalSymbolName() const { + auto id = globalSymbols.size(); + return getNewSymbolName("global_", id); + } + + std::string getNewMemorySymbolName() const { + auto id = memSymbols.size(); + return getNewSymbolName("mem_", id); + } + + std::string getNewTableSymbolName() const { + auto id = tableSymbols.size(); + return getNewSymbolName("table_", id); + } +}; + +class ParserHead; + +/// Wrapper around SmallVector to only allow access as push and pop on the +/// stack. Makes sure that there are no "free accesses" on the stack to preserve +/// its state. +class ValueStack { +private: + struct LabelLevel { + size_t stackIdx; + LabelLevelOpInterface levelOp; + }; + +public: + bool empty() const { return values.empty(); } + + size_t size() const { return values.size(); } + + /// Pops values from the stack because they are being used in an operation. + /// @param operandTypes The list of expected types of the operation, used + /// to know how many values to pop and check if the types match the + /// expectation. + /// @param opLoc Location of the caller, used to report accurately the + /// location + /// if an error occurs. + /// @return Failure or the vector of popped values. + FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes, + Location *opLoc); + + /// Push the results of an operation to the stack so they can be used in a + /// following operation. + /// @param results The list of results of the operation + /// @param opLoc Location of the caller, used to report accurately the + /// location + /// if an error occurs. + LogicalResult pushResults(ValueRange results, Location *opLoc); + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// A simple dump function for debugging. + /// Writes output to llvm::dbgs(). + LLVM_DUMP_METHOD void dump() const; +#endif + +private: + SmallVector<Value> values; +}; + +using local_val_t = TypedValue<wasmssa::LocalRefType>; + +class ExpressionParser { +public: + using locals_t = SmallVector<local_val_t>; + ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols, + ArrayRef<local_val_t> initLocal) + : parser{parser}, symbols{symbols}, locals{initLocal} {} + +private: + template <std::byte opCode> + inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder); + + template <typename valueT> + parsed_inst_t + parseConstInst(OpBuilder &builder, + std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr); + + /// This function generates a dispatch tree to associate an opcode with a + /// parser. Parsers are registered by specialising the + /// `parseSpecificInstruction` function for the op code to handle. + /// + /// The dispatcher is generated by recursively creating all possible patterns + /// for an opcode and calling the relevant parser on the leaf. + /// + /// @tparam patternBitSize is the first bit for which the pattern is not fixed + /// + /// @tparam highBitPattern is the fixed pattern that this instance handles for + /// the 8-patternBitSize bits + template <size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}> + inline parsed_inst_t dispatchToInstParser(std::byte opCode, + OpBuilder &builder) { + static_assert(patternBitSize <= 8, + "PatternBitSize is outside of range of opcode space! " + "(expected at most 8 bits)"); + if constexpr (patternBitSize < 8) { + constexpr std::byte bitSelect{1 << (7 - patternBitSize)}; + constexpr std::byte nextHighBitPatternStem = highBitPattern << 1; + constexpr size_t nextPatternBitSize = patternBitSize + 1; + if ((opCode & bitSelect) != std::byte{0}) + return dispatchToInstParser<nextPatternBitSize, + nextHighBitPatternStem | std::byte{1}>( + opCode, builder); + return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>( + opCode, builder); + } else { + return parseSpecificInstruction<highBitPattern>(builder); + } + } + + struct ParseResultWithInfo { + SmallVector<Value> opResults; + std::byte endingByte; + }; + +public: + template <std::byte ParseEndByte = WasmBinaryEncoding::endByte> + parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {}); + + template <std::byte... ExpressionParseEnd> + FailureOr<ParseResultWithInfo> + parse(OpBuilder &builder, + ByteSequence<ExpressionParseEnd...> parsingEndFilters); + + FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) { + return valueStack.popOperands(operandTypes, ¤tOpLoc.value()); + } + + LogicalResult pushResults(ValueRange results) { + return valueStack.pushResults(results, ¤tOpLoc.value()); + } + +private: + std::optional<Location> currentOpLoc; + ParserHead &parser; + [[maybe_unused]] WasmModuleSymbolTables const &symbols; + locals_t locals; + ValueStack valueStack; +}; + +class ParserHead { +public: + ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {} + ParserHead(ParserHead &&) = default; + +private: + ParserHead(ParserHead const &other) = default; + +public: + auto getLocation() const { + return FileLineColLoc::get(locName, 0, anchorOffset + offset); + } + + FailureOr<StringRef> consumeNBytes(size_t nBytes) { + LDBG() << "Consume " << nBytes << " bytes"; + LDBG() << " Bytes remaining: " << size(); + LDBG() << " Current offset: " << offset; + if (nBytes > size()) + return emitError(getLocation(), "trying to extract ") + << nBytes << "bytes when only " << size() << "are available"; + + StringRef res = head.slice(offset, offset + nBytes); + offset += nBytes; + LDBG() << " Updated offset (+" << nBytes << "): " << offset; + return res; + } + + FailureOr<std::byte> consumeByte() { + auto res = consumeNBytes(1); + if (failed(res)) + return failure(); + return std::byte{*res->bytes_begin()}; + } + + template <typename T> + FailureOr<T> parseLiteral(); + + FailureOr<uint32_t> parseVectorSize(); + +private: + // TODO: This is equivalent to parseLiteral<uint32_t> and could be removed + // if parseLiteral specialization were moved here, but default GCC on Ubuntu + // 22.04 has bug with template specialization in class declaration + inline FailureOr<uint32_t> parseUI32(); + inline FailureOr<int64_t> parseI64(); + +public: + FailureOr<StringRef> parseName() { + FailureOr<uint32_t> size = parseVectorSize(); + if (failed(size)) + return failure(); + + return consumeNBytes(*size); + } + + FailureOr<WasmSectionType> parseWasmSectionType() { + FailureOr<std::byte> id = consumeByte(); + if (failed(id)) + return failure(); + if (std::to_integer<unsigned>(*id) > highestWasmSectionID) + return emitError(getLocation(), "invalid section ID: ") + << static_cast<int>(*id); + return static_cast<WasmSectionType>(*id); + } + + FailureOr<LimitType> parseLimit(MLIRContext *ctx) { + using WasmLimits = WasmBinaryEncoding::LimitHeader; + FileLineColLoc limitLocation = getLocation(); + FailureOr<std::byte> limitHeader = consumeByte(); + if (failed(limitHeader)) + return failure(); + + if (isNotIn<WasmLimits::bothLimits, WasmLimits::lowLimitOnly>(*limitHeader)) + return emitError(limitLocation, "invalid limit header: ") + << static_cast<int>(*limitHeader); + FailureOr<uint32_t> minParse = parseUI32(); + if (failed(minParse)) + return failure(); + std::optional<uint32_t> max{std::nullopt}; + if (*limitHeader == WasmLimits::bothLimits) { + FailureOr<uint32_t> maxParse = parseUI32(); + if (failed(maxParse)) + return failure(); + max = *maxParse; + } + return LimitType::get(ctx, *minParse, max); + } + + FailureOr<Type> parseValueType(MLIRContext *ctx) { + FileLineColLoc typeLoc = getLocation(); + FailureOr<std::byte> typeEncoding = consumeByte(); + if (failed(typeEncoding)) + return failure(); + switch (*typeEncoding) { + case WasmBinaryEncoding::Type::i32: + return IntegerType::get(ctx, 32); + case WasmBinaryEncoding::Type::i64: + return IntegerType::get(ctx, 64); + case WasmBinaryEncoding::Type::f32: + return Float32Type::get(ctx); + case WasmBinaryEncoding::Type::f64: + return Float64Type::get(ctx); + case WasmBinaryEncoding::Type::v128: + return IntegerType::get(ctx, 128); + case WasmBinaryEncoding::Type::funcRef: + return wasmssa::FuncRefType::get(ctx); + case WasmBinaryEncoding::Type::externRef: + return wasmssa::ExternRefType::get(ctx); + default: + return emitError(typeLoc, "invalid value type encoding: ") + << static_cast<int>(*typeEncoding); + } + } + + FailureOr<GlobalTypeRecord> parseGlobalType(MLIRContext *ctx) { + using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability; + FailureOr<Type> typeParsed = parseValueType(ctx); + if (failed(typeParsed)) + return failure(); + FileLineColLoc mutLoc = getLocation(); + FailureOr<std::byte> mutSpec = consumeByte(); + if (failed(mutSpec)) + return failure(); + if (isNotIn<WasmGlobalMut::isConst, WasmGlobalMut::isMutable>(*mutSpec)) + return emitError(mutLoc, "invalid global mutability specifier: ") + << static_cast<int>(*mutSpec); + return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable}; + } + + FailureOr<TupleType> parseResultType(MLIRContext *ctx) { + FailureOr<uint32_t> nParamsParsed = parseVectorSize(); + if (failed(nParamsParsed)) + return failure(); + uint32_t nParams = *nParamsParsed; + SmallVector<Type> res{}; + res.reserve(nParams); + for (size_t i = 0; i < nParams; ++i) { + FailureOr<Type> parsedType = parseValueType(ctx); + if (failed(parsedType)) + return failure(); + res.push_back(*parsedType); + } + return TupleType::get(ctx, res); + } + + FailureOr<FunctionType> parseFunctionType(MLIRContext *ctx) { + FileLineColLoc typeLoc = getLocation(); + FailureOr<std::byte> funcTypeHeader = consumeByte(); + if (failed(funcTypeHeader)) + return failure(); + if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType) + return emitError(typeLoc, "invalid function type header byte. Expecting ") + << std::to_integer<unsigned>(WasmBinaryEncoding::Type::funcType) + << " got " << std::to_integer<unsigned>(*funcTypeHeader); + FailureOr<TupleType> inputTypes = parseResultType(ctx); + if (failed(inputTypes)) + return failure(); + + FailureOr<TupleType> resTypes = parseResultType(ctx); + if (failed(resTypes)) + return failure(); + + return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes()); + } + + FailureOr<TypeIdxRecord> parseTypeIndex() { + FailureOr<uint32_t> res = parseUI32(); + if (failed(res)) + return failure(); + return TypeIdxRecord{*res}; + } + + FailureOr<TableType> parseTableType(MLIRContext *ctx) { + FailureOr<Type> elmTypeParse = parseValueType(ctx); + if (failed(elmTypeParse)) + return failure(); + if (!isWasmRefType(*elmTypeParse)) + return emitError(getLocation(), "invalid element type for table"); + FailureOr<LimitType> limitParse = parseLimit(ctx); + if (failed(limitParse)) + return failure(); + return TableType::get(ctx, *elmTypeParse, *limitParse); + } + + FailureOr<ImportDesc> parseImportDesc(MLIRContext *ctx) { + FileLineColLoc importLoc = getLocation(); + FailureOr<std::byte> importType = consumeByte(); + auto packager = [](auto parseResult) -> FailureOr<ImportDesc> { + if (llvm::failed(parseResult)) + return failure(); + return {*parseResult}; + }; + if (failed(importType)) + return failure(); + switch (*importType) { + case WasmBinaryEncoding::Import::typeID: + return packager(parseTypeIndex()); + case WasmBinaryEncoding::Import::tableType: + return packager(parseTableType(ctx)); + case WasmBinaryEncoding::Import::memType: + return packager(parseLimit(ctx)); + case WasmBinaryEncoding::Import::globalType: + return packager(parseGlobalType(ctx)); + default: + return emitError(importLoc, "invalid import type descriptor: ") + << static_cast<int>(*importType); + } + } + + parsed_inst_t parseExpression(OpBuilder &builder, + WasmModuleSymbolTables const &symbols, + ArrayRef<local_val_t> locals = {}) { + auto eParser = ExpressionParser{*this, symbols, locals}; + return eParser.parse(builder); + } + + bool end() const { return curHead().empty(); } + + ParserHead copy() const { return *this; } + +private: + StringRef curHead() const { return head.drop_front(offset); } + + FailureOr<std::byte> peek() const { + if (end()) + return emitError( + getLocation(), + "trying to peek at next byte, but input stream is empty"); + return static_cast<std::byte>(curHead().front()); + } + + size_t size() const { return head.size() - offset; } + + StringRef head; + StringAttr locName; + unsigned anchorOffset{0}; + unsigned offset{0}; +}; + +template <> +FailureOr<float> ParserHead::parseLiteral<float>() { + auto bytes = consumeNBytes(4); + if (failed(bytes)) + return failure(); + float result; + std::memcpy(&result, bytes->bytes_begin(), 4); + return result; +} + +template <> +FailureOr<double> ParserHead::parseLiteral<double>() { + auto bytes = consumeNBytes(8); + if (failed(bytes)) + return failure(); + double result; + std::memcpy(&result, bytes->bytes_begin(), 8); + return result; +} + +template <> +FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() { + char const *error = nullptr; + uint32_t res{0}; + unsigned encodingSize{0}; + StringRef src = curHead(); + uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + + if (std::isgreater(decoded, std::numeric_limits<uint32_t>::max())) + return emitError(getLocation()) << "literal does not fit on 32 bits"; + + res = static_cast<uint32_t>(decoded); + offset += encodingSize; + return res; +} + +template <> +FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() { + char const *error = nullptr; + int32_t res{0}; + unsigned encodingSize{0}; + StringRef src = curHead(); + int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) || + std::isgreater(std::numeric_limits<int32_t>::min(), decoded)) + return emitError(getLocation()) << "literal does not fit on 32 bits"; + + res = static_cast<int32_t>(decoded); + offset += encodingSize; + return res; +} + +template <> +FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() { + char const *error = nullptr; + unsigned encodingSize{0}; + StringRef src = curHead(); + int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + + offset += encodingSize; + return res; +} + +FailureOr<uint32_t> ParserHead::parseVectorSize() { + return parseLiteral<uint32_t>(); +} + +inline FailureOr<uint32_t> ParserHead::parseUI32() { + return parseLiteral<uint32_t>(); +} + +inline FailureOr<int64_t> ParserHead::parseI64() { + return parseLiteral<int64_t>(); +} + +template <std::byte opCode> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) { + return emitError(*currentOpLoc, "unknown instruction opcode: ") + << static_cast<int>(opCode); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void ValueStack::dump() const { + llvm::dbgs() << "================= Wasm ValueStack =======================\n"; + llvm::dbgs() << "size: " << size() << "\n"; + llvm::dbgs() << "<Top>" + << "\n"; + // Stack is pushed to via push_back. Therefore the top of the stack is the + // end of the vector. Iterate in reverse so that the first thing we print + // is the top of the stack. + size_t stackSize = size(); + for (size_t idx = 0; idx < stackSize; idx++) { + size_t actualIdx = stackSize - 1 - idx; + llvm::dbgs() << " "; + values[actualIdx].dump(); + } + llvm::dbgs() << "<Bottom>" + << "\n"; + llvm::dbgs() << "=========================================================\n"; +} +#endif + +parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) { + LDBG() << "Popping from ValueStack\n" + << " Elements(s) to pop: " << operandTypes.size() << "\n" + << " Current stack size: " << values.size(); + if (operandTypes.size() > values.size()) + return emitError(*opLoc, + "stack doesn't contain enough values. Trying to get ") + << operandTypes.size() << " operands on a stack containing only " + << values.size() << " values."; + size_t stackIdxOffset = values.size() - operandTypes.size(); + SmallVector<Value> res{}; + res.reserve(operandTypes.size()); + for (size_t i{0}; i < operandTypes.size(); ++i) { + Value operand = values[i + stackIdxOffset]; + Type stackType = operand.getType(); + if (stackType != operandTypes[i]) + return emitError(*opLoc, "invalid operand type on stack. Expecting ") + << operandTypes[i] << ", value on stack is of type " << stackType + << "."; + LDBG() << " POP: " << operand; + res.push_back(operand); + } + values.resize(values.size() - operandTypes.size()); + LDBG() << " Updated stack size: " << values.size(); + return res; +} + +LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) { + LDBG() << "Pushing to ValueStack\n" + << " Elements(s) to push: " << results.size() << "\n" + << " Current stack size: " << values.size(); + for (Value val : results) { + if (!isWasmValueType(val.getType())) + return emitError(*opLoc, "invalid value type on stack: ") + << val.getType(); + LDBG() << " PUSH: " << val; + values.push_back(val); + } + + LDBG() << " Updated stack size: " << values.size(); + return success(); +} + +template <std::byte EndParseByte> +parsed_inst_t ExpressionParser::parse(OpBuilder &builder, + UniqueByte<EndParseByte> endByte) { + auto res = parse(builder, ByteSequence<EndParseByte>{}); + if (failed(res)) + return failure(); + return res->opResults; +} + +template <std::byte... ExpressionParseEnd> +FailureOr<ExpressionParser::ParseResultWithInfo> +ExpressionParser::parse(OpBuilder &builder, + ByteSequence<ExpressionParseEnd...> parsingEndFilters) { + SmallVector<Value> res; + for (;;) { + currentOpLoc = parser.getLocation(); + FailureOr<std::byte> opCode = parser.consumeByte(); + if (failed(opCode)) + return failure(); + if (isValueOneOf(*opCode, parsingEndFilters)) + return {{res, *opCode}}; + parsed_inst_t resParsed; + resParsed = dispatchToInstParser(*opCode, builder); + if (failed(resParsed)) + return failure(); + std::swap(res, *resParsed); + if (failed(pushResults(res))) + return failure(); + } +} + +template <typename T> +inline Type buildLiteralType(OpBuilder &); + +template <> +inline Type buildLiteralType<int32_t>(OpBuilder &builder) { + return builder.getI32Type(); +} + +template <> +inline Type buildLiteralType<int64_t>(OpBuilder &builder) { + return builder.getI64Type(); +} + +template <> +[[maybe_unused]] inline Type buildLiteralType<uint32_t>(OpBuilder &builder) { + return builder.getI32Type(); +} + +template <> +[[maybe_unused]] inline Type buildLiteralType<uint64_t>(OpBuilder &builder) { + return builder.getI64Type(); +} + +template <> +inline Type buildLiteralType<float>(OpBuilder &builder) { + return builder.getF32Type(); +} + +template <> +inline Type buildLiteralType<double>(OpBuilder &builder) { + return builder.getF64Type(); +} + +template <typename ValT, + typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>> +struct AttrHolder; + +template <typename ValT> +struct AttrHolder<ValT, std::enable_if_t<std::is_integral_v<ValT>>> { + using type = IntegerAttr; +}; + +template <typename ValT> +struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> { + using type = FloatAttr; +}; + +template <typename ValT> +using attr_holder_t = typename AttrHolder<ValT>::type; + +template <typename ValT, + typename EnableT = std::enable_if_t<std::is_arithmetic_v<ValT>>> +attr_holder_t<ValT> buildLiteralAttr(OpBuilder &builder, ValT val) { + return attr_holder_t<ValT>::get(buildLiteralType<ValT>(builder), val); +} + +template <typename valueT> +parsed_inst_t ExpressionParser::parseConstInst( + OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueT>> *) { + auto parsedConstant = parser.parseLiteral<valueT>(); + if (failed(parsedConstant)) + return failure(); + auto constOp = + ConstOp::create(builder, *currentOpLoc, + buildLiteralAttr<valueT>(builder, *parsedConstant)); + return {{constOp.getResult()}}; +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) { + return parseConstInst<int32_t>(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) { + return parseConstInst<int64_t>(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) { + return parseConstInst<float>(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) { + return parseConstInst<double>(builder); +} + +class WasmBinaryParser { +private: + struct SectionRegistry { + using section_location_t = StringRef; + + std::array<SmallVector<section_location_t>, highestWasmSectionID + 1> + registry; + + template <WasmSectionType SecType> + std::conditional_t<sectionShouldBeUnique(SecType), + std::optional<section_location_t>, + ArrayRef<section_location_t>> + getContentForSection() const { + constexpr auto idx = static_cast<size_t>(SecType); + if constexpr (sectionShouldBeUnique(SecType)) { + return registry[idx].empty() ? std::nullopt + : std::make_optional(registry[idx][0]); + } else { + return registry[idx]; + } + } + + bool hasSection(WasmSectionType secType) const { + return !registry[static_cast<size_t>(secType)].empty(); + } + + /// + /// @returns success if registration valid, failure in case registration + /// can't be done (if another section of same type already exist and this + /// section type should only be present once) + /// + LogicalResult registerSection(WasmSectionType secType, + section_location_t location, Location loc) { + if (sectionShouldBeUnique(secType) && hasSection(secType)) + return emitError(loc, + "trying to add a second instance of unique section"); + + registry[static_cast<size_t>(secType)].push_back(location); + emitRemark(loc, "Adding section with section ID ") + << static_cast<uint8_t>(secType); + return success(); + } + + LogicalResult populateFromBody(ParserHead ph) { + while (!ph.end()) { + FileLineColLoc sectionLoc = ph.getLocation(); + FailureOr<WasmSectionType> secType = ph.parseWasmSectionType(); + if (failed(secType)) + return failure(); + + FailureOr<uint32_t> secSizeParsed = ph.parseLiteral<uint32_t>(); + if (failed(secSizeParsed)) + return failure(); + + uint32_t secSize = *secSizeParsed; + FailureOr<StringRef> sectionContent = ph.consumeNBytes(secSize); + if (failed(sectionContent)) + return failure(); + + LogicalResult registration = + registerSection(*secType, *sectionContent, sectionLoc); + + if (failed(registration)) + return failure(); + } + return success(); + } + }; + + auto getLocation(int offset = 0) const { + return FileLineColLoc::get(srcName, 0, offset); + } + + template <WasmSectionType> + LogicalResult parseSectionItem(ParserHead &, size_t); + + template <WasmSectionType section> + LogicalResult parseSection() { + auto secName = std::string{wasmSectionName<section>}; + auto sectionNameAttr = + StringAttr::get(ctx, srcName.strref() + ":" + secName + "-SECTION"); + unsigned offset = 0; + auto getLocation = [sectionNameAttr, &offset]() { + return FileLineColLoc::get(sectionNameAttr, 0, offset); + }; + auto secContent = registry.getContentForSection<section>(); + if (!secContent) { + LDBG() << secName << " section is not present in file."; + return success(); + } + + auto secSrc = secContent.value(); + ParserHead ph{secSrc, sectionNameAttr}; + FailureOr<uint32_t> nElemsParsed = ph.parseVectorSize(); + if (failed(nElemsParsed)) + return failure(); + uint32_t nElems = *nElemsParsed; + LDBG() << "Starting to parse " << nElems << " items for section " + << secName; + for (size_t i = 0; i < nElems; ++i) { + if (failed(parseSectionItem<section>(ph, i))) + return failure(); + } + + if (!ph.end()) + return emitError(getLocation(), "unparsed garbage at end of section ") + << secName; + return success(); + } + + /// Handles the registration of a function import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, TypeIdxRecord tid) { + using llvm::Twine; + if (tid.id >= symbols.moduleFuncTypes.size()) + return emitError(loc, "invalid type id: ") + << tid.id << ". Only " << symbols.moduleFuncTypes.size() + << " type registration."; + FunctionType type = symbols.moduleFuncTypes[tid.id]; + std::string symbol = symbols.getNewFuncSymbolName(); + auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName, + importName, type); + symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type}); + return funcOp.verify(); + } + + /// Handles the registration of a memory import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, LimitType limitType) { + std::string symbol = symbols.getNewMemorySymbolName(); + auto memOp = MemImportOp::create(builder, loc, symbol, moduleName, + importName, limitType); + symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)}); + return memOp.verify(); + } + + /// Handles the registration of a table import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, TableType tableType) { + std::string symbol = symbols.getNewTableSymbolName(); + auto tableOp = TableImportOp::create(builder, loc, symbol, moduleName, + importName, tableType); + symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)}); + return tableOp.verify(); + } + + /// Handles the registration of a global variable import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, GlobalTypeRecord globalType) { + std::string symbol = symbols.getNewGlobalSymbolName(); + auto giOp = + GlobalImportOp::create(builder, loc, symbol, moduleName, importName, + globalType.type, globalType.isMutable); + symbols.globalSymbols.push_back( + {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()}); + return giOp.verify(); + } + + // Detect occurence of errors + LogicalResult peekDiag(Diagnostic &diag) { + if (diag.getSeverity() == DiagnosticSeverity::Error) + isValid = false; + return failure(); + } + +public: + WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx) + : builder{ctx}, ctx{ctx} { + ctx->getDiagEngine().registerHandler( + [this](Diagnostic &diag) { return peekDiag(diag); }); + ctx->loadAllAvailableDialects(); + if (sourceMgr.getNumBuffers() != 1) { + emitError(UnknownLoc::get(ctx), "one source file should be provided"); + return; + } + uint32_t sourceBufId = sourceMgr.getMainFileID(); + StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer(); + srcName = StringAttr::get( + ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier()); + + auto parser = ParserHead{source, srcName}; + auto const wasmHeader = StringRef{"\0asm", 4}; + FileLineColLoc magicLoc = parser.getLocation(); + FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size()); + if (failed(magic) || magic->compare(wasmHeader)) { + emitError(magicLoc, "source file does not contain valid Wasm header."); + return; + } + auto const expectedVersionString = StringRef{"\1\0\0\0", 4}; + FileLineColLoc versionLoc = parser.getLocation(); + FailureOr<StringRef> version = + parser.consumeNBytes(expectedVersionString.size()); + if (failed(version)) + return; + if (version->compare(expectedVersionString)) { + emitError(versionLoc, + "unsupported Wasm version. Only version 1 is supported."); + return; + } + LogicalResult fillRegistry = registry.populateFromBody(parser.copy()); + if (failed(fillRegistry)) + return; + + mOp = ModuleOp::create(builder, getLocation()); + builder.setInsertionPointToStart(&mOp.getBodyRegion().front()); + LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>(); + if (failed(parsingTypes)) + return; + + LogicalResult parsingImports = parseSection<WasmSectionType::IMPORT>(); + if (failed(parsingImports)) + return; + + firstInternalFuncID = symbols.funcSymbols.size(); + + LogicalResult parsingFunctions = parseSection<WasmSectionType::FUNCTION>(); + if (failed(parsingFunctions)) + return; + + LogicalResult parsingTables = parseSection<WasmSectionType::TABLE>(); + if (failed(parsingTables)) + return; + + LogicalResult parsingMems = parseSection<WasmSectionType::MEMORY>(); + if (failed(parsingMems)) + return; + + LogicalResult parsingExports = parseSection<WasmSectionType::EXPORT>(); + if (failed(parsingExports)) + return; + + // Copy over sizes of containers into statistics. + LDBG() << "WASM Imports:" + << "\n" + << " - Num functions: " << symbols.funcSymbols.size() << "\n" + << " - Num globals: " << symbols.globalSymbols.size() << "\n" + << " - Num memories: " << symbols.memSymbols.size() << "\n" + << " - Num tables: " << symbols.tableSymbols.size(); + } + + ModuleOp getModule() { + if (isValid) + return mOp; + if (mOp) + mOp.erase(); + return ModuleOp{}; + } + +private: + mlir::StringAttr srcName; + OpBuilder builder; + WasmModuleSymbolTables symbols; + MLIRContext *ctx; + ModuleOp mOp; + SectionRegistry registry; + size_t firstInternalFuncID{0}; + bool isValid{true}; +}; + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph, + size_t) { + FileLineColLoc importLoc = ph.getLocation(); + auto moduleName = ph.parseName(); + if (failed(moduleName)) + return failure(); + + auto importName = ph.parseName(); + if (failed(importName)) + return failure(); + + FailureOr<ImportDesc> import = ph.parseImportDesc(ctx); + if (failed(import)) + return failure(); + + return std::visit( + [this, importLoc, &moduleName, &importName](auto import) { + return visitImport(importLoc, *moduleName, *importName, import); + }, + *import); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph, + size_t) { + FileLineColLoc exportLoc = ph.getLocation(); + + auto exportName = ph.parseName(); + if (failed(exportName)) + return failure(); + + FailureOr<std::byte> opcode = ph.consumeByte(); + if (failed(opcode)) + return failure(); + + FailureOr<uint32_t> idx = ph.parseLiteral<uint32_t>(); + if (failed(idx)) + return failure(); + + using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>, + SmallVector<GlobalSymbolRefContainer>, + SmallVector<FunctionSymbolRefContainer>>; + + SymbolRefDesc currentSymbolList; + std::string symbolType = ""; + switch (*opcode) { + case WasmBinaryEncoding::Export::function: + symbolType = "function"; + currentSymbolList = symbols.funcSymbols; + break; + case WasmBinaryEncoding::Export::table: + symbolType = "table"; + currentSymbolList = symbols.tableSymbols; + break; + case WasmBinaryEncoding::Export::memory: + symbolType = "memory"; + currentSymbolList = symbols.memSymbols; + break; + case WasmBinaryEncoding::Export::global: + symbolType = "global"; + currentSymbolList = symbols.globalSymbols; + break; + default: + return emitError(exportLoc, "invalid value for export type: ") + << std::to_integer<unsigned>(*opcode); + } + + auto currentSymbol = std::visit( + [&](const auto &list) -> FailureOr<FlatSymbolRefAttr> { + if (*idx > list.size()) { + emitError( + exportLoc, + llvm::formatv( + "trying to export {0} {1} which is undefined in this scope", + symbolType, *idx)); + return failure(); + } + return list[*idx].symbol; + }, + currentSymbolList); + + if (failed(currentSymbol)) + return failure(); + + Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol); + SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public); + StringAttr symName = SymbolTable::getSymbolName(op); + return SymbolTable{mOp}.rename(symName, *exportName); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph, + size_t) { + FileLineColLoc opLocation = ph.getLocation(); + FailureOr<TableType> tableType = ph.parseTableType(ctx); + if (failed(tableType)) + return failure(); + LDBG() << " Parsed table description: " << *tableType; + StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName()); + auto tableOp = + TableOp::create(builder, opLocation, symbol.strref(), *tableType); + symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)}); + return success(); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph, + size_t) { + FileLineColLoc opLoc = ph.getLocation(); + auto typeIdxParsed = ph.parseLiteral<uint32_t>(); + if (failed(typeIdxParsed)) + return failure(); + uint32_t typeIdx = *typeIdxParsed; + if (typeIdx >= symbols.moduleFuncTypes.size()) + return emitError(getLocation(), "invalid type index: ") << typeIdx; + std::string symbol = symbols.getNewFuncSymbolName(); + auto funcOp = + FuncOp::create(builder, opLoc, symbol, symbols.moduleFuncTypes[typeIdx]); + Block *block = funcOp.addEntryBlock(); + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointToEnd(block); + ReturnOp::create(builder, opLoc); + builder.restoreInsertionPoint(ip); + symbols.funcSymbols.push_back( + {{FlatSymbolRefAttr::get(funcOp.getSymNameAttr())}, + symbols.moduleFuncTypes[typeIdx]}); + return funcOp.verify(); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph, + size_t) { + FailureOr<FunctionType> funcType = ph.parseFunctionType(ctx); + if (failed(funcType)) + return failure(); + LDBG() << "Parsed function type " << *funcType; + symbols.moduleFuncTypes.push_back(*funcType); + return success(); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph, + size_t) { + FileLineColLoc opLocation = ph.getLocation(); + FailureOr<LimitType> memory = ph.parseLimit(ctx); + if (failed(memory)) + return failure(); + + LDBG() << " Registering memory " << *memory; + std::string symbol = symbols.getNewMemorySymbolName(); + auto memOp = MemOp::create(builder, opLocation, symbol, *memory); + symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)}); + return success(); +} +} // namespace + +namespace mlir::wasm { +OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source, + MLIRContext *context) { + WasmBinaryParser wBN{source, context}; + ModuleOp mOp = wBN.getModule(); + if (mOp) + return {mOp}; + + return {nullptr}; +} +} // namespace mlir::wasm diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp new file mode 100644 index 0000000..03b9784 --- /dev/null +++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp @@ -0,0 +1,28 @@ +//===- TranslateRegistration.cpp - Register translation -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Target/Wasm/WasmImporter.h" +#include "mlir/Tools/mlir-translate/Translation.h" + +using namespace mlir; + +namespace mlir { +void registerFromWasmTranslation() { + TranslateToMLIRRegistration registration{ + "import-wasm", "Translate WASM to MLIR", + [](llvm::SourceMgr &sourceMgr, + MLIRContext *context) -> OwningOpRef<Operation *> { + return wasm::importWebAssemblyToModule(sourceMgr, context); + }, + [](DialectRegistry ®istry) { + registry.insert<wasmssa::WasmSSADialect>(); + }}; +} +} // namespace mlir diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 4ccb83f..02dad69 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -258,18 +258,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG() << "Processing simple op: " << *op; if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { - LDBG() - << "Simple op is not memory effect free or has live results, skipping: " - << *op; + LDBG() << "Simple op is not memory effect free or has live results, " + "preserving it: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); return; } LDBG() << "Simple op has all dead results and is memory effect free, scheduling " "for removal: " - << *op; + << OpWithFlags(op, OpPrintingFlags().skipRegions()); cl.operations.push_back(op); collectNonLiveValues(nonLiveSet, op->getResults(), BitVector(op->getNumResults(), true)); @@ -728,19 +727,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, /// Removes dead values collected in RDVFinalCleanupList. /// To be run once when all dead values have been collected. static void cleanUpDeadVals(RDVFinalCleanupList &list) { + LDBG() << "Starting cleanup of dead values..."; + // 1. Operations + LDBG() << "Cleaning up " << list.operations.size() << " operations"; for (auto &op : list.operations) { + LDBG() << "Erasing operation: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); op->dropAllUses(); op->erase(); } // 2. Values + LDBG() << "Cleaning up " << list.values.size() << " values"; for (auto &v : list.values) { + LDBG() << "Dropping all uses of value: " << v; v.dropAllUses(); } // 3. Functions + LDBG() << "Cleaning up " << list.functions.size() << " functions"; for (auto &f : list.functions) { + LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName(); + LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments"; + LDBG() << " Erasing " << f.nonLiveRets.count() + << " non-live return values"; // Some functions may not allow erasing arguments or results. These calls // return failure in such cases without modifying the function, so it's okay // to proceed. @@ -749,44 +760,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { } // 4. Operands + LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { - if (o.op->getNumOperands() > 0) + if (o.op->getNumOperands() > 0) { + LDBG() << "Erasing " << o.nonLive.count() + << " non-live operands from operation: " + << OpWithFlags(o.op, OpPrintingFlags().skipRegions()); o.op->eraseOperands(o.nonLive); + } } // 5. Results + LDBG() << "Cleaning up " << list.results.size() << " result lists"; for (auto &r : list.results) { + LDBG() << "Erasing " << r.nonLive.count() + << " non-live results from operation: " + << OpWithFlags(r.op, OpPrintingFlags().skipRegions()); dropUsesAndEraseResults(r.op, r.nonLive); } // 6. Blocks + LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists"; for (auto &b : list.blocks) { // blocks that are accessed via multiple codepaths processed once if (b.b->getNumArguments() != b.nonLiveArgs.size()) continue; + LDBG() << "Erasing " << b.nonLiveArgs.count() + << " non-live arguments from block: " << b.b; // it iterates backwards because erase invalidates all successor indexes for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { if (!b.nonLiveArgs[i]) continue; + LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i); b.b->getArgument(i).dropAllUses(); b.b->eraseArgument(i); } } // 7. Successor Operands + LDBG() << "Cleaning up " << list.successorOperands.size() + << " successor operand lists"; for (auto &op : list.successorOperands) { SuccessorOperands successorOperands = op.branch.getSuccessorOperands(op.successorIndex); // blocks that are accessed via multiple codepaths processed once if (successorOperands.size() != op.nonLiveOperands.size()) continue; + LDBG() << "Erasing " << op.nonLiveOperands.count() + << " non-live successor operands from successor " + << op.successorIndex << " of branch: " + << OpWithFlags(op.branch, OpPrintingFlags().skipRegions()); // it iterates backwards because erase invalidates all successor indexes for (int i = successorOperands.size() - 1; i >= 0; --i) { if (!op.nonLiveOperands[i]) continue; + LDBG() << " Erasing successor operand " << i << ": " + << successorOperands[i]; successorOperands.erase(i); } } + + LDBG() << "Finished cleanup of dead values"; } struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 0c26b4e..7494ca9 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -182,15 +182,24 @@ private: /// conversions.) static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__"; +/// Return the operation that defines all values in the vector. Return nullptr +/// if the values are not defined by the same operation. +static Operation *getCommonDefiningOp(const ValueVector &values) { + assert(!values.empty() && "expected non-empty value vector"); + Operation *op = values.front().getDefiningOp(); + for (Value v : llvm::drop_begin(values)) { + if (v.getDefiningOp() != op) + return nullptr; + } + return op; +} + /// A vector of values is a pure type conversion if all values are defined by /// the same operation and the operation has the `kPureTypeConversionMarker` /// attribute. static bool isPureTypeConversion(const ValueVector &values) { assert(!values.empty() && "expected non-empty value vector"); - Operation *op = values.front().getDefiningOp(); - for (Value v : llvm::drop_begin(values)) - if (v.getDefiningOp() != op) - return false; + Operation *op = getCommonDefiningOp(values); return op && op->hasAttr(kPureTypeConversionMarker); } @@ -841,7 +850,7 @@ namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { explicit ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config) - : context(ctx), config(config) {} + : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {} //===--------------------------------------------------------------------===// // State Management @@ -863,6 +872,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// failure. template <typename RewriteTy, typename... Args> void appendRewrite(Args &&...args) { + assert(config.allowPatternRollback && "appending rewrites is not allowed"); rewrites.push_back( std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...)); } @@ -889,15 +899,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { bool wasOpReplaced(Operation *op) const; /// Lookup the most recently mapped values with the desired types in the - /// mapping. - /// - /// Special cases: - /// - If the desired type range is empty, simply return the most recently - /// mapped values. - /// - If there is no mapping to the desired types, also return the most - /// recently mapped values. - /// - If there is no mapping for the given values at all, return the given - /// value. + /// mapping, taking into account only replacements. Perform a best-effort + /// search for existing materializations with the desired types. /// /// If `skipPureTypeConversions` is "true", materializations that are pure /// type conversions are not considered. @@ -1066,6 +1069,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { ConversionValueMapping mapping; /// Ordered list of block operations (creations, splits, motions). + /// This vector is maintained only if `allowPatternRollback` is set to + /// "true". Otherwise, all IR rewrites are materialized immediately and no + /// bookkeeping is needed. SmallVector<std::unique_ptr<IRRewrite>> rewrites; /// A set of operations that should no longer be considered for legalization. @@ -1089,6 +1095,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// by the current pattern. SetVector<Block *> patternInsertedBlocks; + /// A list of unresolved materializations that were created by the current + /// pattern. + DenseSet<UnrealizedConversionCastOp> patternMaterializations; + /// A mapping for looking up metadata of unresolved materializations. DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo> unresolvedMaterializations; @@ -1104,6 +1114,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Dialect conversion configuration. const ConversionConfig &config; + /// A set of erased operations. This set is utilized only if + /// `allowPatternRollback` is set to "false". Conceptually, this set is + /// similar to `replacedOps` (which is maintained when the flag is set to + /// "true"). However, erasing from a DenseSet is more efficient than erasing + /// from a SetVector. + DenseSet<Operation *> erasedOps; + + /// A set of erased blocks. This set is utilized only if + /// `allowPatternRollback` is set to "false". + DenseSet<Block *> erasedBlocks; + + /// A rewriter that notifies the listener (if any) about all IR + /// modifications. This rewriter is utilized only if `allowPatternRollback` + /// is set to "false". If the flag is set to "true", the listener is notified + /// with a separate mechanism (e.g., in `IRRewrite::commit`). + IRRewriter notifyingRewriter; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -1111,8 +1138,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { SmallPtrSet<Operation *, 1> pendingRootUpdates; /// A raw output stream used to prefix the debug log. - llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + "] ").str(), - llvm::dbgs(), /*HasPendingNewline=*/false}; + llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(), + llvm::dbgs()}; /// A logger used to emit diagnostics during the conversion process. llvm::ScopedPrinter logger{os}; @@ -1140,11 +1167,8 @@ void BlockTypeConversionRewrite::rollback() { getNewBlock()->replaceAllUsesWith(getOrigBlock()); } -void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); - if (!repl) - return; - +static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg, + Value repl) { if (isa<BlockArgument>(repl)) { rewriter.replaceAllUsesWith(arg, repl); return; @@ -1161,6 +1185,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { }); } +void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { + Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); + if (!repl) + return; + performReplaceBlockArg(rewriter, arg, repl); +} + void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { @@ -1246,6 +1277,30 @@ void ConversionPatternRewriterImpl::applyRewrites() { ValueVector ConversionPatternRewriterImpl::lookupOrDefault( Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const { + // Helper function that looks up a single value. + auto lookup = [&](const ValueVector &values) -> ValueVector { + assert(!values.empty() && "expected non-empty value vector"); + + // If the pattern rollback is enabled, use the mapping to look up the + // values. + if (config.allowPatternRollback) + return mapping.lookup(values); + + // Otherwise, look up values by examining the IR. All replacements have + // already been materialized in IR. + Operation *op = getCommonDefiningOp(values); + if (!op) + return {}; + auto castOp = dyn_cast<UnrealizedConversionCastOp>(op); + if (!castOp) + return {}; + if (!this->unresolvedMaterializations.contains(castOp)) + return {}; + if (castOp.getOutputs() != values) + return {}; + return castOp.getInputs(); + }; + // Helper function that looks up each value in `values` individually and then // composes the results. If that fails, it tries to look up the entire vector // at once. @@ -1253,7 +1308,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault( // If possible, replace each value with (one or multiple) mapped values. ValueVector next; for (Value v : values) { - ValueVector r = mapping.lookup({v}); + ValueVector r = lookup({v}); if (!r.empty()) { llvm::append_range(next, r); } else { @@ -1273,7 +1328,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault( // be stored (and looked up) in the mapping. But for performance reasons, // we choose to reuse existing IR (when possible) instead of creating it // multiple times. - ValueVector r = mapping.lookup(values); + ValueVector r = lookup(values); if (r.empty()) { // No mapping found: The lookup stops here. return {}; @@ -1347,15 +1402,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state, void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep, StringRef patternName) { for (auto &rewrite : - llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) { - if (!config.allowPatternRollback && - !isa<UnresolvedMaterializationRewrite>(rewrite)) { - // Unresolved materializations can always be rolled back (erased). - llvm::report_fatal_error("pattern '" + patternName + - "' rollback of IR modifications requested"); - } + llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) rewrite->rollback(); - } rewrites.resize(numRewritesToKeep); } @@ -1419,12 +1467,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { // Check to see if this operation is ignored or was replaced. - return replacedOps.count(op) || ignoredOps.count(op); + return wasOpReplaced(op) || ignoredOps.count(op); } bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { // Check to see if this operation was replaced. - return replacedOps.count(op); + return replacedOps.count(op) || erasedOps.count(op); } //===----------------------------------------------------------------------===// @@ -1508,7 +1556,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // a bit more efficient, so we try to do that when possible. bool fastPath = !config.listener; if (fastPath) { - appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end()); + if (config.allowPatternRollback) + appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end()); newBlock->getOperations().splice(newBlock->end(), block->getOperations()); } else { while (!block->empty()) @@ -1556,7 +1605,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( replaceUsesOfBlockArgument(origArg, replArgs, converter); } - appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock); + if (config.allowPatternRollback) + appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.) @@ -1585,23 +1635,37 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( // tracking the materialization like we do for other operations. OpBuilder builder(outputTypes.front().getContext()); builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); - auto convertOp = + UnrealizedConversionCastOp convertOp = UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs); + if (config.attachDebugMaterializationKind) { + StringRef kindStr = + kind == MaterializationKind::Source ? "source" : "target"; + convertOp->setAttr("__kind__", builder.getStringAttr(kindStr)); + } if (isPureTypeConversion) convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr()); - if (!valuesToMap.empty()) - mapping.map(valuesToMap, convertOp.getResults()); + + // Register the materialization. if (castOp) *castOp = convertOp; unresolvedMaterializations[convertOp] = UnresolvedMaterializationInfo(converter, kind, originalType); - appendRewrite<UnresolvedMaterializationRewrite>(convertOp, - std::move(valuesToMap)); + if (config.allowPatternRollback) { + if (!valuesToMap.empty()) + mapping.map(valuesToMap, convertOp.getResults()); + appendRewrite<UnresolvedMaterializationRewrite>(convertOp, + std::move(valuesToMap)); + } else { + patternMaterializations.insert(convertOp); + } return convertOp.getResults(); } Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( Value value, const TypeConverter *converter) { + assert(config.allowPatternRollback && + "this code path is valid only in rollback mode"); + // Try to find a replacement value with the same type in the conversion value // mapping. This includes cached materializations. We try to reuse those // instead of generating duplicate IR. @@ -1663,26 +1727,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(op->getParentOp()) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) && "attempting to insert into a block within a replaced/erased op"); + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyOperationInserted(op, previous); + if (wasDetached) { - // If the op was detached, it is most likely a newly created op. - // TODO: If the same op is inserted multiple times from a detached state, - // the rollback mechanism may erase the same op multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite<CreateOperationRewrite>(op); + // If the op was detached, it is most likely a newly created op. Add it the + // set of newly created ops, so that it will be legalized. If this op is + // not a newly created op, it will be legalized a second time, which is + // inefficient but harmless. patternNewOps.insert(op); + + if (config.allowPatternRollback) { + // TODO: If the same op is inserted multiple times from a detached + // state, the rollback mechanism may erase the same op multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite<CreateOperationRewrite>(op); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased operations that must be kept up to date. + erasedOps.erase(op); + } return; } // The op was moved from one place to another. - appendRewrite<MoveOperationRewrite>(op, previous); + if (config.allowPatternRollback) + appendRewrite<MoveOperationRewrite>(op, previous); +} + +/// Given that `fromRange` is about to be replaced with `toRange`, compute +/// replacement values with the types of `fromRange`. +static SmallVector<Value> +getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, + const SmallVector<SmallVector<Value>> &toRange, + const TypeConverter *converter) { + assert(!impl.config.allowPatternRollback && + "this code path is valid only in 'no rollback' mode"); + SmallVector<Value> repls; + for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) { + if (from.use_empty()) { + // The replaced value is dead. No replacement value is needed. + repls.push_back(Value()); + continue; + } + + if (to.empty()) { + // The replaced value is dropped. Materialize a replacement value "out of + // thin air". + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(from), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/ValueRange(), + /*outputTypes=*/from.getType(), /*originalType=*/Type(), + converter)[0]; + repls.push_back(srcMat); + continue; + } + + if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) { + // The replacement value already has the correct type. Use it directly. + repls.push_back(to[0]); + continue; + } + + // The replacement value has the wrong type. Build a source materialization + // to the original type. + // TODO: This is a bit inefficient. We should try to reuse existing + // materializations if possible. This would require an extension of the + // `lookupOrDefault` API. + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(to), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(), + /*originalType=*/Type(), converter)[0]; + repls.push_back(srcMat); + } + + return repls; } void ConversionPatternRewriterImpl::replaceOp( Operation *op, SmallVector<SmallVector<Value>> &&newValues) { - assert(newValues.size() == op->getNumResults()); + assert(newValues.size() == op->getNumResults() && + "incorrect number of replacement values"); + + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + SmallVector<Value> repls = getReplacementValues( + *this, op->getResults(), newValues, currentTypeConverter); + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + op->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + op->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Replace the op with the replacement values and notify the listener. + notifyingRewriter.replaceOp(op, repls); + return; + } + assert(!ignoredOps.contains(op) && "operation was already replaced"); // Check if replaced op is an unresolved materialization, i.e., an @@ -1722,11 +1879,46 @@ void ConversionPatternRewriterImpl::replaceOp( void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument( BlockArgument from, ValueRange to, const TypeConverter *converter) { + if (!config.allowPatternRollback) { + SmallVector<Value> toConv = llvm::to_vector(to); + SmallVector<Value> repls = + getReplacementValues(*this, from, {toConv}, converter); + IRRewriter r(from.getContext()); + Value repl = repls.front(); + if (!repl) + return; + + performReplaceBlockArg(r, from, repl); + return; + } + appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter); mapping.map(from, to); } void ConversionPatternRewriterImpl::eraseBlock(Block *block) { + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + block->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + block->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Erase the block and notify the listener. + notifyingRewriter.eraseBlock(block); + return; + } + assert(!wasOpReplaced(block->getParentOp()) && "attempting to erase a block within a replaced/erased op"); appendRewrite<EraseBlockRewrite>(block); @@ -1760,23 +1952,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(newParentOp) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) && "attempting to insert into a region within a replaced/erased op"); (void)newParentOp; + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyBlockInserted(block, previous, previousIt); + patternInsertedBlocks.insert(block); if (wasDetached) { // If the block was detached, it is most likely a newly created block. - // TODO: If the same block is inserted multiple times from a detached state, - // the rollback mechanism may erase the same block multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite<CreateBlockRewrite>(block); + if (config.allowPatternRollback) { + // TODO: If the same block is inserted multiple times from a detached + // state, the rollback mechanism may erase the same block multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite<CreateBlockRewrite>(block); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased blocks that must be kept up to date. + erasedBlocks.erase(block); + } return; } // The block was moved from one place to another. - appendRewrite<MoveBlockRewrite>(block, previous, previousIt); + if (config.allowPatternRollback) + appendRewrite<MoveBlockRewrite>(block, previous, previousIt); } void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, @@ -1956,7 +2162,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, // a bit more efficient, so we try to do that when possible. bool fastPath = !getConfig().listener; - if (fastPath) + if (fastPath && impl->config.allowPatternRollback) impl->inlineBlockBefore(source, dest, before); // Replace all uses of block arguments. @@ -1982,6 +2188,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, } void ConversionPatternRewriter::startOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + // Pattern rollback is not allowed: no extra bookkeeping is needed. + PatternRewriter::startOpModification(op); + return; + } assert(!impl->wasOpReplaced(op) && "attempting to modify a replaced/erased op"); #ifndef NDEBUG @@ -1991,20 +2202,29 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { } void ConversionPatternRewriter::finalizeOpModification(Operation *op) { - assert(!impl->wasOpReplaced(op) && - "attempting to modify a replaced/erased op"); - PatternRewriter::finalizeOpModification(op); impl->patternModifiedOps.insert(op); + if (!impl->config.allowPatternRollback) { + PatternRewriter::finalizeOpModification(op); + if (getConfig().listener) + getConfig().listener->notifyOperationModified(op); + return; + } // There is nothing to do here, we only need to track the operation at the // start of the update. #ifndef NDEBUG + assert(!impl->wasOpReplaced(op) && + "attempting to modify a replaced/erased op"); assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); #endif } void ConversionPatternRewriter::cancelOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + PatternRewriter::cancelOpModification(op); + return; + } #ifndef NDEBUG assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); @@ -2029,17 +2249,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { // ConversionPattern //===----------------------------------------------------------------------===// -SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands( +FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands( ArrayRef<ValueRange> operands) const { SmallVector<Value> oneToOneOperands; oneToOneOperands.reserve(operands.size()); for (ValueRange operand : operands) { if (operand.size() != 1) - llvm::report_fatal_error("pattern '" + getDebugName() + - "' does not support 1:N conversion"); + return failure(); + oneToOneOperands.push_back(operand.front()); } - return oneToOneOperands; + return std::move(oneToOneOperands); } LogicalResult @@ -2257,15 +2477,17 @@ OperationLegalizer::legalize(Operation *op, return success(); } - // If the operation isn't legal, try to fold it in-place. - // TODO: Should we always try to do this, even if the op is - // already legal? - if (succeeded(legalizeWithFold(op, rewriter))) { - LLVM_DEBUG({ - logSuccess(logger, "operation was folded"); - logger.startLine() << logLineComment; - }); - return success(); + // If the operation is not legal, try to fold it in-place if the folding mode + // is 'BeforePatterns'. 'Never' will skip this. + const ConversionConfig &config = rewriter.getConfig(); + if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) { + if (succeeded(legalizeWithFold(op, rewriter))) { + LLVM_DEBUG({ + logSuccess(logger, "operation was folded"); + logger.startLine() << logLineComment; + }); + return success(); + } } // Otherwise, we need to apply a legalization pattern to this operation. @@ -2277,6 +2499,18 @@ OperationLegalizer::legalize(Operation *op, return success(); } + // If the operation can't be legalized via patterns, try to fold it in-place + // if the folding mode is 'AfterPatterns'. + if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) { + if (succeeded(legalizeWithFold(op, rewriter))) { + LLVM_DEBUG({ + logSuccess(logger, "operation was folded"); + logger.startLine() << logLineComment; + }); + return success(); + } + } + LLVM_DEBUG({ logFailure(logger, "no matched legalization pattern"); logger.startLine() << logLineComment; @@ -2425,17 +2659,23 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); -#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (!rewriterImpl.config.allowPatternRollback) { - // Returning "failure" after modifying IR is not allowed. + // Erase all unresolved materializations. + for (auto op : rewriterImpl.patternMaterializations) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + rewriterImpl.patternMaterializations.clear(); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Expensive pattern check that can detect API violations. if (checkOp) { OperationFingerPrint fingerPrintAfterPattern(checkOp); if (fingerPrintAfterPattern != *topLevelFingerPrint) llvm::report_fatal_error("pattern '" + pattern.getDebugName() + "' returned failure but IR did change"); } - } #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + } rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); rewriterImpl.patternInsertedBlocks.clear(); @@ -2459,6 +2699,16 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // successfully applied. auto onSuccess = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); + if (!rewriterImpl.config.allowPatternRollback) { + // Eagerly erase unused materializations. + for (auto op : rewriterImpl.patternMaterializations) { + if (op->use_empty()) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + } + rewriterImpl.patternMaterializations.clear(); + } SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps); SetVector<Operation *> modifiedOps = moveAndReset(rewriterImpl.patternModifiedOps); @@ -2549,6 +2799,9 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // If the pattern moved or created any blocks, make sure the types of block // arguments get legalized. for (Block *block : insertedBlocks) { + if (impl.erasedBlocks.contains(block)) + continue; + // Only check blocks outside of the current operation. Operation *parentOp = block->getParentOp(); if (!parentOp || parentOp == op || block->getNumArguments() == 0) diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 607b86c..0324588 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -15,6 +15,8 @@ #include "mlir/Config/mlir-config.h" #include "mlir/IR/Action.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" @@ -23,7 +25,7 @@ #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" @@ -178,9 +180,8 @@ static Operation *getDumpRootOp(Operation *op) { return op; } static void logSuccessfulFolding(Operation *op) { - llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n"; - op->dump(); - llvm::dbgs() << "\n\n"; + LDBG() << "// *** IR Dump After Successful Folding ***\n" + << OpWithFlags(op, OpPrintingFlags().elideLargeElementsAttrs()); } #endif // NDEBUG @@ -394,8 +395,12 @@ private: function_ref<void(Diagnostic &)> reasonCallback) override; #ifndef NDEBUG + /// A raw output stream used to prefix the debug log. + + llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(), + llvm::dbgs()}; /// A logger used to emit information during the application process. - llvm::ScopedPrinter logger{llvm::dbgs()}; + llvm::ScopedPrinter logger{os}; #endif /// The low-level pattern applicator. @@ -871,7 +876,18 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { ctx->executeAction<GreedyPatternRewriteIteration>( [&] { - continueRewrites = processWorklist(); + continueRewrites = false; + + // Erase unreachable blocks + // Operations like: + // %add = arith.addi %add, %add : i64 + // are legal in unreachable code. Unfortunately many patterns would be + // unsafe to apply on such IR and can lead to crashes or infinite + // loops. + continueRewrites |= + succeeded(eraseUnreachableBlocks(rewriter, region)); + + continueRewrites |= processWorklist(); // After applying patterns, make sure that the CFG of each of the // regions is kept up to date. @@ -917,10 +933,9 @@ mlir::applyPatternsGreedily(Region ®ion, RegionPatternRewriteDriver driver(region.getContext(), patterns, config, region); LogicalResult converged = std::move(driver).simplify(changed); - LLVM_DEBUG(if (failed(converged)) { - llvm::dbgs() << "The pattern rewrite did not converge after scanning " - << config.getMaxIterations() << " times\n"; - }); + if (failed(converged)) + LDBG() << "The pattern rewrite did not converge after scanning " + << config.getMaxIterations() << " times"; return converged; } @@ -1052,9 +1067,8 @@ LogicalResult mlir::applyOpPatternsGreedily( LogicalResult converged = std::move(driver).simplify(ops, changed); if (allErased) *allErased = surviving.empty(); - LLVM_DEBUG(if (failed(converged)) { - llvm::dbgs() << "The pattern rewrite did not converge after " - << config.getMaxNumRewrites() << " rewrites"; - }); + if (failed(converged)) + LDBG() << "The pattern rewrite did not converge after " + << config.getMaxNumRewrites() << " rewrites"; return converged; } diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index eeb4052..5ea3105 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -13,6 +13,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/CallInterfaces.h" @@ -182,6 +183,11 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src, IRMapping &valueMapping) { for (auto &block : *src) { for (auto &op : block) { + // UnrealizedConversionCastOp is inlineable but cannot implement the + // inliner interface due to layering constraints. + if (isa<UnrealizedConversionCastOp>(op)) + continue; + // Check this operation. if (!interface.isLegalToInline(&op, insertRegion, shouldCloneInlinedRegion, valueMapping)) { diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index a1d975d..31ae1d1 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -23,12 +23,15 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include <deque> #include <iterator> using namespace mlir; +#define DEBUG_TYPE "region-utils" + void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion) { for (auto &use : llvm::make_early_inc_range(orig.getUses())) { @@ -182,19 +185,34 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove( // TODO: We could likely merge this with the DCE algorithm below. LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef<Region> regions) { + LDBG() << "Starting eraseUnreachableBlocks with " << regions.size() + << " regions"; + // Set of blocks found to be reachable within a given region. llvm::df_iterator_default_set<Block *, 16> reachable; // If any blocks were found to be dead. - bool erasedDeadBlocks = false; + int erasedDeadBlocks = 0; SmallVector<Region *, 1> worklist; worklist.reserve(regions.size()); for (Region ®ion : regions) worklist.push_back(®ion); + + LDBG(2) << "Initial worklist size: " << worklist.size(); + while (!worklist.empty()) { Region *region = worklist.pop_back_val(); - if (region->empty()) + if (region->empty()) { + LDBG(2) << "Skipping empty region"; continue; + } + + LDBG(2) << "Processing region with " << region->getBlocks().size() + << " blocks"; + if (region->getParentOp()) + LDBG(2) << " -> for operation: " + << OpWithFlags(region->getParentOp(), + OpPrintingFlags().skipRegions()); // If this is a single block region, just collect the nested regions. if (region->hasOneBlock()) { @@ -209,13 +227,17 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, for (Block *block : depth_first_ext(®ion->front(), reachable)) (void)block /* Mark all reachable blocks */; + LDBG(2) << "Found " << reachable.size() << " reachable blocks out of " + << region->getBlocks().size() << " total blocks"; + // Collect all of the dead blocks and push the live regions onto the // worklist. for (Block &block : llvm::make_early_inc_range(*region)) { if (!reachable.count(&block)) { + LDBG() << "Erasing unreachable block: " << █ block.dropAllDefinedValueUses(); rewriter.eraseBlock(&block); - erasedDeadBlocks = true; + ++erasedDeadBlocks; continue; } @@ -226,7 +248,10 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, } } - return success(erasedDeadBlocks); + LDBG() << "Finished eraseUnreachableBlocks, erased " << erasedDeadBlocks + << " dead blocks"; + + return success(erasedDeadBlocks > 0); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp index ee5c642..1382550 100644 --- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp @@ -13,18 +13,40 @@ #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/Visitors.h" #include "mlir/Rewrite/PatternApplicator.h" -#include "llvm/Support/Debug.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "walk-rewriter" namespace mlir { +// Find all reachable blocks in the region and add them to the visitedBlocks +// set. +static void findReachableBlocks(Region ®ion, + DenseSet<Block *> &reachableBlocks) { + Block *entryBlock = ®ion.front(); + reachableBlocks.insert(entryBlock); + // Traverse the CFG and add all reachable blocks to the blockList. + SmallVector<Block *> worklist({entryBlock}); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + Operation *terminator = &block->back(); + for (Block *successor : terminator->getSuccessors()) { + if (reachableBlocks.contains(successor)) + continue; + worklist.push_back(successor); + reachableBlocks.insert(successor); + } + } +} + namespace { struct WalkAndApplyPatternsAction final : tracing::ActionImpl<WalkAndApplyPatternsAction> { @@ -88,20 +110,104 @@ void walkAndApplyPatterns(Operation *op, PatternApplicator applicator(patterns); applicator.applyDefaultCostModel(); + // Iterator on all reachable operations in the region. + // Also keep track if we visited the nested regions of the current op + // already to drive the post-order traversal. + struct RegionReachableOpIterator { + RegionReachableOpIterator(Region *region) : region(region) { + regionIt = region->begin(); + if (regionIt != region->end()) + blockIt = regionIt->begin(); + if (!llvm::hasSingleElement(*region)) + findReachableBlocks(*region, reachableBlocks); + } + // Advance the iterator to the next reachable operation. + void advance() { + assert(regionIt != region->end()); + hasVisitedRegions = false; + if (blockIt == regionIt->end()) { + ++regionIt; + while (regionIt != region->end() && + !reachableBlocks.contains(&*regionIt)) + ++regionIt; + if (regionIt != region->end()) + blockIt = regionIt->begin(); + return; + } + ++blockIt; + if (blockIt != regionIt->end()) { + LDBG() << "Incrementing block iterator, next op: " + << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions()); + } + } + // The region we're iterating over. + Region *region; + // The Block currently being iterated over. + Region::iterator regionIt; + // The Operation currently being iterated over. + Block::iterator blockIt; + // The set of blocks that are reachable in the current region. + DenseSet<Block *> reachableBlocks; + // Whether we've visited the nested regions of the current op already. + bool hasVisitedRegions = false; + }; + + // Worklist of regions to visit to drive the post-order traversal. + SmallVector<RegionReachableOpIterator> worklist; + + LDBG() << "Starting walk-based pattern rewrite driver"; ctx->executeAction<WalkAndApplyPatternsAction>( [&] { + // Perform a post-order traversal of the regions, visiting each + // reachable operation. for (Region ®ion : op->getRegions()) { - region.walk([&](Operation *visitedOp) { - LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print( - llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n";); + assert(worklist.empty()); + if (region.empty()) + continue; + + // Prime the worklist with the entry block of this region. + worklist.push_back({®ion}); + while (!worklist.empty()) { + RegionReachableOpIterator &it = worklist.back(); + if (it.regionIt == it.region->end()) { + // We're done with this region. + worklist.pop_back(); + continue; + } + if (it.blockIt == it.regionIt->end()) { + // We're done with this block. + it.advance(); + continue; + } + Operation *op = &*it.blockIt; + // If we haven't visited the nested regions of this op yet, + // enqueue them. + if (!it.hasVisitedRegions) { + it.hasVisitedRegions = true; + for (Region &nestedRegion : llvm::reverse(op->getRegions())) { + if (nestedRegion.empty()) + continue; + worklist.push_back({&nestedRegion}); + } + } + // If we're not at the back of the worklist, we've enqueued some + // nested region for processing. We'll come back to this op later + // (post-order) + if (&it != &worklist.back()) + continue; + + // Preemptively increment the iterator, in case the current op + // would be erased. + it.advance(); + + LDBG() << "Visiting op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - erasedListener.visitedOp = visitedOp; + erasedListener.visitedOp = op; #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) { - LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";); - } - }); + if (succeeded(applicator.matchAndRewrite(op, rewriter))) + LDBG() << "\tOp matched and rewritten"; + } } }, {op}); |