diff options
Diffstat (limited to 'mlir/lib')
202 files changed, 13767 insertions, 3259 deletions
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index 70b56ca..a93e605 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -180,23 +180,20 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( return; } - /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() - /// on a LoopLikeInterface return the lower/upper bound for that result if - /// possible. - auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound, - Type boundType, Block *block, bool getUpper) { + /// Given a lower bound, upper bound, or step from a LoopLikeInterface return + /// the lower/upper bound for that result if possible. + auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType, + Block *block, bool getUpper) { unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); - if (loopBound.has_value()) { - if (auto attr = dyn_cast<Attribute>(*loopBound)) { - if (auto bound = dyn_cast_or_null<IntegerAttr>(attr)) - return bound.getValue(); - } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) { - const IntegerValueRangeLattice *lattice = - getLatticeElementFor(getProgramPointBefore(block), value); - if (lattice != nullptr && !lattice->getValue().isUninitialized()) - return getUpper ? lattice->getValue().getValue().smax() - : lattice->getValue().getValue().smin(); - } + if (auto attr = dyn_cast<Attribute>(loopBound)) { + if (auto bound = dyn_cast<IntegerAttr>(attr)) + return bound.getValue(); + } else if (auto value = llvm::dyn_cast<Value>(loopBound)) { + const IntegerValueRangeLattice *lattice = + getLatticeElementFor(getProgramPointBefore(block), value); + if (lattice != nullptr && !lattice->getValue().isUninitialized()) + return getUpper ? lattice->getValue().getValue().smax() + : lattice->getValue().getValue().smin(); } // Given the results of getConstant{Lower,Upper}Bound() // or getConstantStep() on a LoopLikeInterface return the lower/upper @@ -207,38 +204,43 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( // Infer bounds for loop arguments that have static bounds if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) { - std::optional<Value> iv = loop.getSingleInductionVar(); - if (!iv) { + std::optional<llvm::SmallVector<Value>> maybeIvs = + loop.getLoopInductionVars(); + if (!maybeIvs) { return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( op, successor, argLattices, firstIndex); } - Block *block = iv->getParentBlock(); - std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound(); - std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound(); - std::optional<OpFoldResult> step = loop.getSingleStep(); - APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block, - /*getUpper=*/false); - APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block, - /*getUpper=*/true); - // Assume positivity for uniscoverable steps by way of getUpper = true. - APInt stepVal = - getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true); - - if (stepVal.isNegative()) { - std::swap(min, max); - } else { - // Correct the upper bound by subtracting 1 so that it becomes a <= - // bound, because loops do not generally include their upper bound. - max -= 1; - } + // This shouldn't be returning nullopt if there are indunction variables. + SmallVector<OpFoldResult> lowerBounds = *loop.getLoopLowerBounds(); + SmallVector<OpFoldResult> upperBounds = *loop.getLoopUpperBounds(); + SmallVector<OpFoldResult> steps = *loop.getLoopSteps(); + for (auto [iv, lowerBound, upperBound, step] : + llvm::zip_equal(*maybeIvs, lowerBounds, upperBounds, steps)) { + Block *block = iv.getParentBlock(); + APInt min = getLoopBoundFromFold(lowerBound, iv.getType(), block, + /*getUpper=*/false); + APInt max = getLoopBoundFromFold(upperBound, iv.getType(), block, + /*getUpper=*/true); + // Assume positivity for uniscoverable steps by way of getUpper = true. + APInt stepVal = + getLoopBoundFromFold(step, iv.getType(), block, /*getUpper=*/true); + + if (stepVal.isNegative()) { + std::swap(min, max); + } else { + // Correct the upper bound by subtracting 1 so that it becomes a <= + // bound, because loops do not generally include their upper bound. + max -= 1; + } - // If we infer the lower bound to be larger than the upper bound, the - // resulting range is meaningless and should not be used in further - // inferences. - if (max.sge(min)) { - IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); - auto ivRange = ConstantIntRanges::fromSigned(min, max); - propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); + // If we infer the lower bound to be larger than the upper bound, the + // resulting range is meaningless and should not be used in further + // inferences. + if (max.sge(min)) { + IntegerValueRangeLattice *ivEntry = getLatticeElement(iv); + auto ivRange = ConstantIntRanges::fromSigned(min, max); + propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); + } } return; } diff --git a/mlir/lib/Analysis/Presburger/Barvinok.cpp b/mlir/lib/Analysis/Presburger/Barvinok.cpp index 75d592e..c31b277 100644 --- a/mlir/lib/Analysis/Presburger/Barvinok.cpp +++ b/mlir/lib/Analysis/Presburger/Barvinok.cpp @@ -178,13 +178,13 @@ mlir::presburger::detail::solveParametricEquations(FracMatrix equations) { for (unsigned i = 0; i < d; ++i) { // First ensure that the diagonal element is nonzero, by swapping // it with a row that is non-zero at column i. - if (equations(i, i) != 0) - continue; - for (unsigned j = i + 1; j < d; ++j) { - if (equations(j, i) == 0) - continue; - equations.swapRows(j, i); - break; + if (equations(i, i) == 0) { + for (unsigned j = i + 1; j < d; ++j) { + if (equations(j, i) == 0) + continue; + equations.swapRows(j, i); + break; + } } Fraction diagElement = equations(i, i); diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 812043d..26197ce 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -21,6 +21,7 @@ #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Analysis/Presburger/Utils.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallBitVector.h" @@ -442,6 +443,14 @@ void IntegerRelation::removeInequality(unsigned pos) { inequalities.removeRow(pos); } +void IntegerRelation::removeConstraint(unsigned pos) { + if (pos >= getNumInequalities()) { + removeEquality(pos - getNumInequalities()); + } else { + removeInequality(pos); + } +} + void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) { if (start >= end) return; @@ -1112,15 +1121,29 @@ unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart, return posLimit - posStart; } +static std::optional<unsigned> +findEqualityWithNonZeroAfterRow(IntegerRelation &rel, unsigned fromRow, + unsigned colIdx) { + assert(fromRow < rel.getNumEqualities() && colIdx < rel.getNumCols() && + "position out of bounds"); + for (unsigned rowIdx = fromRow, e = rel.getNumEqualities(); rowIdx < e; + ++rowIdx) { + if (rel.atEq(rowIdx, colIdx) != 0) + return rowIdx; + } + return std::nullopt; +} + bool IntegerRelation::gaussianEliminate() { gcdTightenInequalities(); unsigned firstVar = 0, vars = getNumVars(); unsigned nowDone, eqs; std::optional<unsigned> pivotRow; for (nowDone = 0, eqs = getNumEqualities(); nowDone < eqs; ++nowDone) { - // Finds the first non-empty column. + // Finds the first non-empty column that we haven't dealt with. for (; firstVar < vars; ++firstVar) { - if ((pivotRow = findConstraintWithNonZeroAt(firstVar, /*isEq=*/true))) + if ((pivotRow = + findEqualityWithNonZeroAfterRow(*this, nowDone, firstVar))) break; } // The matrix has been normalized to row echelon form. @@ -1143,6 +1166,10 @@ bool IntegerRelation::gaussianEliminate() { inequalities.normalizeRow(i); } gcdTightenInequalities(); + + // The column is finished. Tell the next iteration to start at the next + // column. + firstVar++; } // No redundant rows. @@ -1724,12 +1751,64 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize( return minDiff; } +void IntegerRelation::pruneOrthogonalConstraints(unsigned pos) { + llvm::DenseSet<unsigned> relatedCols({pos}), relatedRows; + + // Early exit if constraints is empty. + unsigned numConstraints = getNumConstraints(); + if (numConstraints == 0) + return; + + llvm::SmallVector<unsigned> rowStack, colStack({pos}); + // The following code performs a graph traversal, starting from the target + // variable, to identify all variables(recorded in relatedCols) and + // constraints (recorded in relatedRows) belonging to the same connected + // component. + while (!rowStack.empty() || !colStack.empty()) { + if (!rowStack.empty()) { + unsigned currentRow = rowStack.pop_back_val(); + // Push all variable that accociated to this constraints to relatedCols + // and colStack. + for (unsigned colIndex = 0; colIndex < getNumVars(); ++colIndex) { + if (atConstraint(currentRow, colIndex) != 0 && + relatedCols.insert(colIndex).second) { + colStack.push_back(colIndex); + } + } + } else { + unsigned currentCol = colStack.pop_back_val(); + // Push all constraints that are associated with this variable to related + // rows and the row stack. + for (unsigned rowIndex = 0; rowIndex < numConstraints; ++rowIndex) { + if (atConstraint(rowIndex, currentCol) != 0 && + relatedRows.insert(rowIndex).second) { + rowStack.push_back(rowIndex); + } + } + } + } + + // Prune all constraints not related to target variable. + for (int constraintId = numConstraints - 1; constraintId >= 0; + --constraintId) { + if (!relatedRows.contains(constraintId)) + removeConstraint((unsigned)constraintId); + } +} + template <bool isLower> std::optional<DynamicAPInt> IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { assert(pos < getNumVars() && "invalid position"); // Project to 'pos'. + // Prune orthogonal constraints to reduce unnecessary computations and + // accelerate the bound computation. + pruneOrthogonalConstraints(pos); projectOut(0, pos); + + // After projecting out values, more orthogonal constraints may be exposed. + // Prune these orthogonal constraints again. + pruneOrthogonalConstraints(0); projectOut(1, getNumVars() - 1); // Check if there's an equality equating the '0'^th variable to a constant. int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false); @@ -2265,11 +2344,11 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { newLb[d] = lbFloorDivisor; newUb[d] = -lbFloorDivisor; // Copy over the symbolic part + constant term. - std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimVars()); + llvm::copy(minLb, newLb.begin() + getNumDimVars()); std::transform(newLb.begin() + getNumDimVars(), newLb.end(), newLb.begin() + getNumDimVars(), std::negate<DynamicAPInt>()); - std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimVars()); + llvm::copy(maxUb, newUb.begin() + getNumDimVars()); boundingLbs.emplace_back(newLb); boundingUbs.emplace_back(newUb); diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp index bb60564..83a2c28 100644 --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -255,20 +255,13 @@ void Matrix<T>::fillRow(unsigned row, const T &value) { } // moveColumns is implemented by moving the columns adjacent to the source range -// to their final position. When moving right (i.e. dstPos > srcPos), the range -// of the adjacent columns is [srcPos + num, dstPos + num). When moving left -// (i.e. dstPos < srcPos) the range of the adjacent columns is [dstPos, srcPos). -// First, zeroed out columns are inserted in the final positions of the adjacent -// columns. Then, the adjacent columns are moved to their final positions by -// swapping them with the zeroed columns. Finally, the now zeroed adjacent -// columns are deleted. +// to their final position. template <typename T> void Matrix<T>::moveColumns(unsigned srcPos, unsigned num, unsigned dstPos) { if (num == 0) return; - int offset = dstPos - srcPos; - if (offset == 0) + if (dstPos == srcPos) return; assert(srcPos + num <= getNumColumns() && @@ -276,23 +269,19 @@ void Matrix<T>::moveColumns(unsigned srcPos, unsigned num, unsigned dstPos) { assert(dstPos + num <= getNumColumns() && "move destination range exceeds matrix columns"); - unsigned insertCount = offset > 0 ? offset : -offset; - unsigned finalAdjStart = offset > 0 ? srcPos : srcPos + num; - unsigned curAdjStart = offset > 0 ? srcPos + num : dstPos; - // TODO: This can be done using std::rotate. - // Insert new zero columns in the positions where the adjacent columns are to - // be moved. - insertColumns(finalAdjStart, insertCount); - // Update curAdjStart if insertion of new columns invalidates it. - if (finalAdjStart < curAdjStart) - curAdjStart += insertCount; - - // Swap the adjacent columns with inserted zero columns. - for (unsigned i = 0; i < insertCount; ++i) - swapColumns(finalAdjStart + i, curAdjStart + i); - - // Delete the now redundant zero columns. - removeColumns(curAdjStart, insertCount); + unsigned numRows = getNumRows(); + // std::rotate(start, middle, end) permutes the elements of [start, end] to + // [middle, end) + [start, middle). NOTE: &at(i, srcPos + num) will trigger an + // assert. + if (dstPos > srcPos) { + for (unsigned i = 0; i < numRows; ++i) { + std::rotate(&at(i, srcPos), &at(i, srcPos) + num, &at(i, dstPos) + num); + } + return; + } + for (unsigned i = 0; i < numRows; ++i) { + std::rotate(&at(i, dstPos), &at(i, srcPos), &at(i, srcPos) + num); + } } template <typename T> diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 870a713..05681ce 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -31,8 +31,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) { // StructType //===--------------------------------------------------------------------===// - auto llvmStructType = - mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType); + auto llvmStructType = mlir_type_subclass( + m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID); llvmStructType .def_classmethod( @@ -137,7 +137,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) { // PointerType //===--------------------------------------------------------------------===// - mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType) + mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType, + mlirLLVMPointerTypeGetTypeID) .def_classmethod( "get", [](const nb::object &cls, std::optional<unsigned> addressSpace, diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index 0155023..0b079b4 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -80,6 +80,28 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { "op.", nb::arg("op")); + m.def( + "infer_contraction_dimensions_from_maps", + [](std::vector<MlirAffineMap> indexingMaps) + -> std::optional<MlirLinalgContractionDimensions> { + if (indexingMaps.empty()) + return std::nullopt; + + MlirLinalgContractionDimensions dims = + mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(), + indexingMaps.size()); + + // Detect "empty" result from invalid input or failed inference. + if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) && + mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) { + return std::nullopt; + } + return dims; + }, + "Infers contraction dimensions (batch/m/n/k) from a list of affine " + "maps.", + nb::arg("indexing_maps")); + m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp, "Checks if the given operation is a Linalg convolution operation.", nb::arg("op")); diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 8bb493e..be0785b1 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -75,13 +75,13 @@ NB_MODULE(_mlirExecutionEngine, m) { "__init__", [](PyExecutionEngine &self, MlirModule module, int optLevel, const std::vector<std::string> &sharedLibPaths, - bool enableObjectDump) { + bool enableObjectDump, bool enablePIC) { llvm::SmallVector<MlirStringRef, 4> libPaths; for (const std::string &path : sharedLibPaths) libPaths.push_back({path.c_str(), path.length()}); - MlirExecutionEngine executionEngine = - mlirExecutionEngineCreate(module, optLevel, libPaths.size(), - libPaths.data(), enableObjectDump); + MlirExecutionEngine executionEngine = mlirExecutionEngineCreate( + module, optLevel, libPaths.size(), libPaths.data(), + enableObjectDump, enablePIC); if (mlirExecutionEngineIsNull(executionEngine)) throw std::runtime_error( "Failure while creating the ExecutionEngine."); @@ -89,7 +89,7 @@ NB_MODULE(_mlirExecutionEngine, m) { }, nb::arg("module"), nb::arg("opt_level") = 2, nb::arg("shared_libs") = nb::list(), - nb::arg("enable_object_dump") = true, + nb::arg("enable_object_dump") = true, nb::arg("enable_pic") = false, "Create a new ExecutionEngine instance for the given Module. The " "module must contain only dialects that can be translated to LLVM. " "Perform transformations and code generation at the optimization " diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index cda4fe1..2e0c2b8 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -18,6 +18,7 @@ #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" #include "nanobind/nanobind.h" +#include "nanobind/typing.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -32,33 +33,6 @@ using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; -//------------------------------------------------------------------------------ -// Docstrings (trivial, non-duplicated docstrings are included inline). -//------------------------------------------------------------------------------ - -static const char kContextParseTypeDocstring[] = - R"(Parses the assembly form of a type. - -Returns a Type object or raises an MLIRError if the type cannot be parsed. - -See also: https://mlir.llvm.org/docs/LangRef/#type-system -)"; - -static const char kContextGetCallSiteLocationDocstring[] = - R"(Gets a Location representing a caller and callsite)"; - -static const char kContextGetFileLocationDocstring[] = - R"(Gets a Location representing a file, line and column)"; - -static const char kContextGetFileRangeDocstring[] = - R"(Gets a Location representing a file, line and column range)"; - -static const char kContextGetFusedLocationDocstring[] = - R"(Gets a Location representing a fused location with optional metadata)"; - -static const char kContextGetNameLocationDocString[] = - R"(Gets a Location representing a named location with optional child location)"; - static const char kModuleParseDocstring[] = R"(Parses a module's assembly format from a string. @@ -67,132 +41,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails. See also: https://mlir.llvm.org/docs/LangRef/ )"; -static const char kModuleCAPICreate[] = - R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr). -Note this returns a new object BUT _clear_mlir_module(module) must be called to -prevent double-frees (of the underlying mlir::Module). -)"; - -static const char kOperationCreateDocstring[] = - R"(Creates a new operation. - -Args: - name: Operation name (e.g. "dialect.operation"). - results: Sequence of Type representing op result types. - attributes: Dict of str:Attribute. - successors: List of Block for the operation's successors. - regions: Number of regions to create. - location: A Location object (defaults to resolve from context manager). - ip: An InsertionPoint (defaults to resolve from context manager or set to - False to disable insertion, even with an insertion point set in the - context manager). - infer_type: Whether to infer result types. -Returns: - A new "detached" Operation object. Detached operations can be added - to blocks, which causes them to become "attached." -)"; - -static const char kOperationPrintDocstring[] = - R"(Prints the assembly form of the operation to a file like object. - -Args: - file: The file like object to write to. Defaults to sys.stdout. - binary: Whether to write bytes (True) or str (False). Defaults to False. - large_elements_limit: Whether to elide elements attributes above this - number of elements. Defaults to None (no limit). - large_resource_limit: Whether to elide resource attributes above this - number of characters. Defaults to None (no limit). If large_elements_limit - is set and this is None, the behavior will be to use large_elements_limit - as large_resource_limit. - enable_debug_info: Whether to print debug/location information. Defaults - to False. - pretty_debug_info: Whether to format debug information for easier reading - by a human (warning: the result is unparseable). - print_generic_op_form: Whether to print the generic assembly forms of all - ops. Defaults to False. - use_local_Scope: Whether to print in a way that is more optimized for - multi-threaded access but may not be consistent with how the overall - module prints. - assume_verified: By default, if not printing generic form, the verifier - will be run and if it fails, generic form will be printed with a comment - about failed verification. While a reasonable default for interactive use, - for systematic use, it is often better for the caller to verify explicitly - and report failures in a more robust fashion. Set this to True if doing this - in order to avoid running a redundant verification. If the IR is actually - invalid, behavior is undefined. - skip_regions: Whether to skip printing regions. Defaults to False. -)"; - -static const char kOperationPrintStateDocstring[] = - R"(Prints the assembly form of the operation to a file like object. - -Args: - file: The file like object to write to. Defaults to sys.stdout. - binary: Whether to write bytes (True) or str (False). Defaults to False. - state: AsmState capturing the operation numbering and flags. -)"; - -static const char kOperationGetAsmDocstring[] = - R"(Gets the assembly form of the operation with all options available. - -Args: - binary: Whether to return a bytes (True) or str (False) object. Defaults to - False. - ... others ...: See the print() method for common keyword arguments for - configuring the printout. -Returns: - Either a bytes or str object, depending on the setting of the 'binary' - argument. -)"; - -static const char kOperationPrintBytecodeDocstring[] = - R"(Write the bytecode form of the operation to a file like object. - -Args: - file: The file like object to write to. - desired_version: The version of bytecode to emit. -Returns: - The bytecode writer status. -)"; - -static const char kOperationStrDunderDocstring[] = - R"(Gets the assembly form of the operation with default options. - -If more advanced control over the assembly formatting or I/O options is needed, -use the dedicated print or get_asm method, which supports keyword arguments to -customize behavior. -)"; - static const char kDumpDocstring[] = - R"(Dumps a debug representation of the object to stderr.)"; - -static const char kAppendBlockDocstring[] = - R"(Appends a new block, with argument types as positional args. - -Returns: - The created block. -)"; - -static const char kValueDunderStrDocstring[] = - R"(Returns the string form of the value. - -If the value is a block argument, this is the assembly form of its type and the -position in the argument list. If the value is an operation result, this is -equivalent to printing the operation that produced it. -)"; - -static const char kGetNameAsOperand[] = - R"(Returns the string form of value as an operand (i.e., the ValueID). -)"; - -static const char kValueReplaceAllUsesWithDocstring[] = - R"(Replace all uses of value with the new value, updating anything in -the IR that uses 'self' to use the other value instead. -)"; + "Dumps a debug representation of the object to stderr."; static const char kValueReplaceAllUsesExceptDocstring[] = - R"("Replace all uses of this value with the 'with' value, except for those -in 'exceptions'. 'exceptions' can be either a single operation or a list of + R"(Replace all uses of this value with the `with` value, except for those +in `exceptions`. `exceptions` can be either a single operation or a list of operations. )"; @@ -274,22 +128,26 @@ struct PyGlobalDebugFlag { // Debug flags. nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug") .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, - &PyGlobalDebugFlag::set, "LLVM-wide debug flag") + &PyGlobalDebugFlag::set, "LLVM-wide debug flag.") .def_static( "set_types", [](const std::string &type) { nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugType(type.c_str()); }, - "types"_a, "Sets specific debug types to be produced by LLVM") - .def_static("set_types", [](const std::vector<std::string> &types) { - std::vector<const char *> pointers; - pointers.reserve(types.size()); - for (const std::string &str : types) - pointers.push_back(str.c_str()); - nb::ft_lock_guard lock(mutex); - mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); - }); + "types"_a, "Sets specific debug types to be produced by LLVM.") + .def_static( + "set_types", + [](const std::vector<std::string> &types) { + std::vector<const char *> pointers; + pointers.reserve(types.size()); + for (const std::string &str : types) + pointers.push_back(str.c_str()); + nb::ft_lock_guard lock(mutex); + mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); + }, + "types"_a, + "Sets multiple specific debug types to be produced by LLVM."); } private: @@ -316,12 +174,18 @@ struct PyAttrBuilderMap { static void bind(nb::module_ &m) { nb::class_<PyAttrBuilderMap>(m, "AttrBuilder") - .def_static("contains", &PyAttrBuilderMap::dunderContains) - .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed) + .def_static("contains", &PyAttrBuilderMap::dunderContains, + "attribute_kind"_a, + "Checks whether an attribute builder is registered for the " + "given attribute kind.") + .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed, + "attribute_kind"_a, + "Gets the registered attribute builder for the given " + "attribute kind.") .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed, "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, "Register an attribute builder for building MLIR " - "attributes from python values."); + "attributes from Python values."); } }; @@ -341,8 +205,8 @@ namespace { class PyRegionIterator { public: - PyRegionIterator(PyOperationRef operation) - : operation(std::move(operation)) {} + PyRegionIterator(PyOperationRef operation, int nextIndex) + : operation(std::move(operation)), nextIndex(nextIndex) {} PyRegionIterator &dunderIter() { return *this; } @@ -357,13 +221,15 @@ public: static void bind(nb::module_ &m) { nb::class_<PyRegionIterator>(m, "RegionIterator") - .def("__iter__", &PyRegionIterator::dunderIter) - .def("__next__", &PyRegionIterator::dunderNext); + .def("__iter__", &PyRegionIterator::dunderIter, + "Returns an iterator over the regions in the operation.") + .def("__next__", &PyRegionIterator::dunderNext, + "Returns the next region in the iteration."); } private: PyOperationRef operation; - int nextIndex = 0; + intptr_t nextIndex = 0; }; /// Regions of an op are fixed length and indexed numerically so are represented @@ -382,11 +248,12 @@ public: PyRegionIterator dunderIter() { operation->checkValid(); - return PyRegionIterator(operation); + return PyRegionIterator(operation, startIndex); } static void bindDerived(ClassTy &c) { - c.def("__iter__", &PyRegionList::dunderIter); + c.def("__iter__", &PyRegionList::dunderIter, + "Returns an iterator over the regions in the sequence."); } private: @@ -430,8 +297,10 @@ public: static void bind(nb::module_ &m) { nb::class_<PyBlockIterator>(m, "BlockIterator") - .def("__iter__", &PyBlockIterator::dunderIter) - .def("__next__", &PyBlockIterator::dunderNext); + .def("__iter__", &PyBlockIterator::dunderIter, + "Returns an iterator over the blocks in the operation's region.") + .def("__next__", &PyBlockIterator::dunderNext, + "Returns the next block in the iteration."); } private: @@ -493,10 +362,19 @@ public: static void bind(nb::module_ &m) { nb::class_<PyBlockList>(m, "BlockList") - .def("__getitem__", &PyBlockList::dunderGetItem) - .def("__iter__", &PyBlockList::dunderIter) - .def("__len__", &PyBlockList::dunderLen) - .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, + .def("__getitem__", &PyBlockList::dunderGetItem, + "Returns the block at the specified index.") + .def("__iter__", &PyBlockList::dunderIter, + "Returns an iterator over blocks in the operation's region.") + .def("__len__", &PyBlockList::dunderLen, + "Returns the number of blocks in the operation's region.") + .def("append", &PyBlockList::appendBlock, + R"( + Appends a new block, with argument types as positional args. + + Returns: + The created block. + )", nb::arg("args"), nb::kw_only(), nb::arg("arg_locs") = std::nullopt); } @@ -527,8 +405,10 @@ public: static void bind(nb::module_ &m) { nb::class_<PyOperationIterator>(m, "OperationIterator") - .def("__iter__", &PyOperationIterator::dunderIter) - .def("__next__", &PyOperationIterator::dunderNext); + .def("__iter__", &PyOperationIterator::dunderIter, + "Returns an iterator over the operations in an operation's block.") + .def("__next__", &PyOperationIterator::dunderNext, + "Returns the next operation in the iteration."); } private: @@ -584,9 +464,12 @@ public: static void bind(nb::module_ &m) { nb::class_<PyOperationList>(m, "OperationList") - .def("__getitem__", &PyOperationList::dunderGetItem) - .def("__iter__", &PyOperationList::dunderIter) - .def("__len__", &PyOperationList::dunderLen); + .def("__getitem__", &PyOperationList::dunderGetItem, + "Returns the operation at the specified index.") + .def("__iter__", &PyOperationList::dunderIter, + "Returns an iterator over operations in the list.") + .def("__len__", &PyOperationList::dunderLen, + "Returns the number of operations in the list."); } private: @@ -609,8 +492,10 @@ public: static void bind(nb::module_ &m) { nb::class_<PyOpOperand>(m, "OpOperand") - .def_prop_ro("owner", &PyOpOperand::getOwner) - .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber); + .def_prop_ro("owner", &PyOpOperand::getOwner, + "Returns the operation that owns this operand.") + .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber, + "Returns the operand number in the owning operation."); } private: @@ -634,8 +519,10 @@ public: static void bind(nb::module_ &m) { nb::class_<PyOpOperandIterator>(m, "OpOperandIterator") - .def("__iter__", &PyOpOperandIterator::dunderIter) - .def("__next__", &PyOpOperandIterator::dunderNext); + .def("__iter__", &PyOpOperandIterator::dunderIter, + "Returns an iterator over operands.") + .def("__next__", &PyOpOperandIterator::dunderNext, + "Returns the next operand in the iteration."); } private: @@ -1524,9 +1411,10 @@ nb::object PyOperation::create(std::string_view name, } // Construct the operation. + PyMlirContext::ErrorCapture errors(location.getContext()); MlirOperation operation = mlirOperationCreate(&state); if (!operation.ptr) - throw nb::value_error("Operation creation failed"); + throw MLIRError("Operation creation failed", errors.take()); PyOperationRef created = PyOperation::createDetached(location.getContext(), operation); maybeInsertOperation(created, maybeIp); @@ -1596,7 +1484,11 @@ public: /// Binds the Python module objects to functions of this class. static void bind(nb::module_ &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); + auto cls = ClassTy( + m, DerivedTy::pyClassName, nb::is_generic(), + nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])") + .str() + .c_str())); cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value")); cls.def_static( "isinstance", @@ -1626,16 +1518,21 @@ public: static void bindDerived(ClassTy &c) { c.def_prop_ro( - "owner", [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> { + "owner", + [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> { assert(mlirOperationEqual(self.getParentOperation()->get(), mlirOpResultGetOwner(self.get())) && "expected the owner of the value in Python to match that in " "the IR"); return self.getParentOperation().getObject(); - }); - c.def_prop_ro("result_number", [](PyOpResult &self) { - return mlirOpResultGetResultNumber(self.get()); - }); + }, + "Returns the operation that produces this result."); + c.def_prop_ro( + "result_number", + [](PyOpResult &self) { + return mlirOpResultGetResultNumber(self.get()); + }, + "Returns the position of this result in the operation's result list."); } }; @@ -1671,13 +1568,18 @@ public: operation(std::move(operation)) {} static void bindDerived(ClassTy &c) { - c.def_prop_ro("types", [](PyOpResultList &self) { - return getValueTypes(self, self.operation->getContext()); - }); - c.def_prop_ro("owner", - [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> { - return self.operation->createOpView(); - }); + c.def_prop_ro( + "types", + [](PyOpResultList &self) { + return getValueTypes(self, self.operation->getContext()); + }, + "Returns a list of types for all results in this result list."); + c.def_prop_ro( + "owner", + [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> { + return self.operation->createOpView(); + }, + "Returns the operation that owns this result list."); } PyOperationRef &getOperation() { return operation; } @@ -2427,19 +2329,31 @@ public: using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_prop_ro("owner", [](PyBlockArgument &self) { - return PyBlock(self.getParentOperation(), - mlirBlockArgumentGetOwner(self.get())); - }); - c.def_prop_ro("arg_number", [](PyBlockArgument &self) { - return mlirBlockArgumentGetArgNumber(self.get()); - }); + c.def_prop_ro( + "owner", + [](PyBlockArgument &self) { + return PyBlock(self.getParentOperation(), + mlirBlockArgumentGetOwner(self.get())); + }, + "Returns the block that owns this argument."); + c.def_prop_ro( + "arg_number", + [](PyBlockArgument &self) { + return mlirBlockArgumentGetArgNumber(self.get()); + }, + "Returns the position of this argument in the block's argument list."); c.def( "set_type", [](PyBlockArgument &self, PyType type) { return mlirBlockArgumentSetType(self.get(), type); }, - nb::arg("type")); + nb::arg("type"), "Sets the type of this block argument."); + c.def( + "set_location", + [](PyBlockArgument &self, PyLocation loc) { + return mlirBlockArgumentSetLocation(self.get(), loc); + }, + nb::arg("loc"), "Sets the location of this block argument."); } }; @@ -2462,9 +2376,12 @@ public: operation(std::move(operation)), block(block) {} static void bindDerived(ClassTy &c) { - c.def_prop_ro("types", [](PyBlockArgumentList &self) { - return getValueTypes(self, self.operation->getContext()); - }); + c.def_prop_ro( + "types", + [](PyBlockArgumentList &self) { + return getValueTypes(self, self.operation->getContext()); + }, + "Returns a list of types for all arguments in this argument list."); } private: @@ -2516,7 +2433,9 @@ public: } static void bindDerived(ClassTy &c) { - c.def("__setitem__", &PyOpOperandList::dunderSetItem); + c.def("__setitem__", &PyOpOperandList::dunderSetItem, nb::arg("index"), + nb::arg("value"), + "Sets the operand at the specified index to a new value."); } private: @@ -2571,7 +2490,8 @@ public: } static void bindDerived(ClassTy &c) { - c.def("__setitem__", &PyOpSuccessors::dunderSetItem); + c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nb::arg("index"), + nb::arg("block"), "Sets the successor block at the specified index."); } private: @@ -2743,55 +2663,70 @@ public: static void bind(nb::module_ &m) { nb::class_<PyOpAttributeMap>(m, "OpAttributeMap") - .def("__contains__", &PyOpAttributeMap::dunderContains) - .def("__len__", &PyOpAttributeMap::dunderLen) - .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) - .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) - .def("__setitem__", &PyOpAttributeMap::dunderSetItem) - .def("__delitem__", &PyOpAttributeMap::dunderDelItem) - .def("__iter__", - [](PyOpAttributeMap &self) { - nb::list keys; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef name, MlirAttribute) { - keys.append(nb::str(name.data, name.length)); - }); - return nb::iter(keys); - }) - .def("keys", - [](PyOpAttributeMap &self) { - nb::list out; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef name, MlirAttribute) { - out.append(nb::str(name.data, name.length)); - }); - return out; - }) - .def("values", - [](PyOpAttributeMap &self) { - nb::list out; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef, MlirAttribute attr) { - out.append(PyAttribute(self.operation->getContext(), attr) - .maybeDownCast()); - }); - return out; - }) - .def("items", [](PyOpAttributeMap &self) { - nb::list out; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef name, MlirAttribute attr) { - out.append(nb::make_tuple( - nb::str(name.data, name.length), - PyAttribute(self.operation->getContext(), attr) - .maybeDownCast())); - }); - return out; - }); + .def("__contains__", &PyOpAttributeMap::dunderContains, nb::arg("name"), + "Checks if an attribute with the given name exists in the map.") + .def("__len__", &PyOpAttributeMap::dunderLen, + "Returns the number of attributes in the map.") + .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed, + nb::arg("name"), "Gets an attribute by name.") + .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed, + nb::arg("index"), "Gets a named attribute by index.") + .def("__setitem__", &PyOpAttributeMap::dunderSetItem, nb::arg("name"), + nb::arg("attr"), "Sets an attribute with the given name.") + .def("__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg("name"), + "Deletes an attribute with the given name.") + .def( + "__iter__", + [](PyOpAttributeMap &self) { + nb::list keys; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + keys.append(nb::str(name.data, name.length)); + }); + return nb::iter(keys); + }, + "Iterates over attribute names.") + .def( + "keys", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + out.append(nb::str(name.data, name.length)); + }); + return out; + }, + "Returns a list of attribute names.") + .def( + "values", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef, MlirAttribute attr) { + out.append(PyAttribute(self.operation->getContext(), attr) + .maybeDownCast()); + }); + return out; + }, + "Returns a list of attribute values.") + .def( + "items", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute attr) { + out.append(nb::make_tuple( + nb::str(name.data, name.length), + PyAttribute(self.operation->getContext(), attr) + .maybeDownCast())); + }); + return out; + }, + "Returns a list of `(name, attribute)` tuples."); } private: @@ -2979,62 +2914,103 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Mapping of Diagnostics. //---------------------------------------------------------------------------- nb::class_<PyDiagnostic>(m, "Diagnostic") - .def_prop_ro("severity", &PyDiagnostic::getSeverity) - .def_prop_ro("location", &PyDiagnostic::getLocation) - .def_prop_ro("message", &PyDiagnostic::getMessage) - .def_prop_ro("notes", &PyDiagnostic::getNotes) - .def("__str__", [](PyDiagnostic &self) -> nb::str { - if (!self.isValid()) - return nb::str("<Invalid Diagnostic>"); - return self.getMessage(); - }); + .def_prop_ro("severity", &PyDiagnostic::getSeverity, + "Returns the severity of the diagnostic.") + .def_prop_ro("location", &PyDiagnostic::getLocation, + "Returns the location associated with the diagnostic.") + .def_prop_ro("message", &PyDiagnostic::getMessage, + "Returns the message text of the diagnostic.") + .def_prop_ro("notes", &PyDiagnostic::getNotes, + "Returns a tuple of attached note diagnostics.") + .def( + "__str__", + [](PyDiagnostic &self) -> nb::str { + if (!self.isValid()) + return nb::str("<Invalid Diagnostic>"); + return self.getMessage(); + }, + "Returns the diagnostic message as a string."); nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo") - .def("__init__", - [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { - new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); - }) - .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity) - .def_ro("location", &PyDiagnostic::DiagnosticInfo::location) - .def_ro("message", &PyDiagnostic::DiagnosticInfo::message) - .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes) - .def("__str__", - [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); + .def( + "__init__", + [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { + new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); + }, + "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.") + .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity, + "The severity level of the diagnostic.") + .def_ro("location", &PyDiagnostic::DiagnosticInfo::location, + "The location associated with the diagnostic.") + .def_ro("message", &PyDiagnostic::DiagnosticInfo::message, + "The message text of the diagnostic.") + .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes, + "List of attached note diagnostics.") + .def( + "__str__", + [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }, + "Returns the diagnostic message as a string."); nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler") - .def("detach", &PyDiagnosticHandler::detach) - .def_prop_ro("attached", &PyDiagnosticHandler::isAttached) - .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError) - .def("__enter__", &PyDiagnosticHandler::contextEnter) + .def("detach", &PyDiagnosticHandler::detach, + "Detaches the diagnostic handler from the context.") + .def_prop_ro("attached", &PyDiagnosticHandler::isAttached, + "Returns True if the handler is attached to a context.") + .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError, + "Returns True if an error was encountered during diagnostic " + "handling.") + .def("__enter__", &PyDiagnosticHandler::contextEnter, + "Enters the diagnostic handler as a context manager.") .def("__exit__", &PyDiagnosticHandler::contextExit, nb::arg("exc_type").none(), nb::arg("exc_value").none(), - nb::arg("traceback").none()); + nb::arg("traceback").none(), + "Exits the diagnostic handler context manager."); // Expose DefaultThreadPool to python nb::class_<PyThreadPool>(m, "ThreadPool") - .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); }) - .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency) - .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr); + .def( + "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); }, + "Creates a new thread pool with default concurrency.") + .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency, + "Returns the maximum number of threads in the pool.") + .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr, + "Returns the raw pointer to the LLVM thread pool as a string."); nb::class_<PyMlirContext>(m, "Context") - .def("__init__", - [](PyMlirContext &self) { - MlirContext context = mlirContextCreateWithThreading(false); - new (&self) PyMlirContext(context); - }) - .def_static("_get_live_count", &PyMlirContext::getLiveCount) - .def("_get_context_again", - [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> { - PyMlirContextRef ref = PyMlirContext::forContext(self.get()); - return ref.releaseObject(); - }) - .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) + .def( + "__init__", + [](PyMlirContext &self) { + MlirContext context = mlirContextCreateWithThreading(false); + new (&self) PyMlirContext(context); + }, + R"( + Creates a new MLIR context. + + The context is the top-level container for all MLIR objects. It owns the storage + for types, attributes, locations, and other core IR objects. A context can be + configured to allow or disallow unregistered dialects and can have dialects + loaded on-demand.)") + .def_static("_get_live_count", &PyMlirContext::getLiveCount, + "Gets the number of live Context objects.") + .def( + "_get_context_again", + [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> { + PyMlirContextRef ref = PyMlirContext::forContext(self.get()); + return ref.releaseObject(); + }, + "Gets another reference to the same context.") + .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount, + "Gets the number of live modules owned by this context.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule, + "Gets a capsule wrapping the MlirContext.") .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, - &PyMlirContext::createFromCapsule) - .def("__enter__", &PyMlirContext::contextEnter) + &PyMlirContext::createFromCapsule, + "Creates a Context from a capsule wrapping MlirContext.") + .def("__enter__", &PyMlirContext::contextEnter, + "Enters the context as a context manager.") .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), - nb::arg("exc_value").none(), nb::arg("traceback").none()) + nb::arg("exc_value").none(), nb::arg("traceback").none(), + "Exits the context manager.") .def_prop_ro_static( "current", [](nb::object & /*class*/) @@ -3045,14 +3021,15 @@ void mlir::python::populateIRCore(nb::module_ &m) { return nb::cast(context); }, nb::sig("def current(/) -> Context | None"), - "Gets the Context bound to the current thread or raises ValueError") + "Gets the Context bound to the current thread or returns None if no " + "context is set.") .def_prop_ro( "dialects", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, - "Gets a container for accessing dialects by name") + "Gets a container for accessing dialects by name.") .def_prop_ro( "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, - "Alias for 'dialect'") + "Alias for `dialects`.") .def( "get_dialect_descriptor", [=](PyMlirContext &self, std::string &name) { @@ -3065,7 +3042,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyDialectDescriptor(self.getRef(), dialect); }, nb::arg("dialect_name"), - "Gets or loads a dialect by name, returning its descriptor object") + "Gets or loads a dialect by name, returning its descriptor object.") .def_prop_rw( "allow_unregistered_dialects", [](PyMlirContext &self) -> bool { @@ -3073,67 +3050,110 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); - }) + }, + "Controls whether unregistered dialects are allowed in this context.") .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, nb::arg("callback"), - "Attaches a diagnostic handler that will receive callbacks") + "Attaches a diagnostic handler that will receive callbacks.") .def( "enable_multithreading", [](PyMlirContext &self, bool enable) { mlirContextEnableMultithreading(self.get(), enable); }, - nb::arg("enable")) - .def("set_thread_pool", - [](PyMlirContext &self, PyThreadPool &pool) { - // we should disable multi-threading first before setting - // new thread pool otherwise the assert in - // MLIRContext::setThreadPool will be raised. - mlirContextEnableMultithreading(self.get(), false); - mlirContextSetThreadPool(self.get(), pool.get()); - }) - .def("get_num_threads", - [](PyMlirContext &self) { - return mlirContextGetNumThreads(self.get()); - }) - .def("_mlir_thread_pool_ptr", - [](PyMlirContext &self) { - MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get()); - std::stringstream ss; - ss << pool.ptr; - return ss.str(); - }) + nb::arg("enable"), + R"( + Enables or disables multi-threading support in the context. + + Args: + enable: Whether to enable (True) or disable (False) multi-threading. + )") + .def( + "set_thread_pool", + [](PyMlirContext &self, PyThreadPool &pool) { + // we should disable multi-threading first before setting + // new thread pool otherwise the assert in + // MLIRContext::setThreadPool will be raised. + mlirContextEnableMultithreading(self.get(), false); + mlirContextSetThreadPool(self.get(), pool.get()); + }, + R"( + Sets a custom thread pool for the context to use. + + Args: + pool: A ThreadPool object to use for parallel operations. + + Note: + Multi-threading is automatically disabled before setting the thread pool.)") + .def( + "get_num_threads", + [](PyMlirContext &self) { + return mlirContextGetNumThreads(self.get()); + }, + "Gets the number of threads in the context's thread pool.") + .def( + "_mlir_thread_pool_ptr", + [](PyMlirContext &self) { + MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get()); + std::stringstream ss; + ss << pool.ptr; + return ss.str(); + }, + "Gets the raw pointer to the LLVM thread pool as a string.") .def( "is_registered_operation", [](PyMlirContext &self, std::string &name) { return mlirContextIsRegisteredOperation( self.get(), MlirStringRef{name.data(), name.size()}); }, - nb::arg("operation_name")) + nb::arg("operation_name"), + R"( + Checks whether an operation with the given name is registered. + + Args: + operation_name: The fully qualified name of the operation (e.g., `arith.addf`). + + Returns: + True if the operation is registered, False otherwise.)") .def( "append_dialect_registry", [](PyMlirContext &self, PyDialectRegistry ®istry) { mlirContextAppendDialectRegistry(self.get(), registry); }, - nb::arg("registry")) + nb::arg("registry"), + R"( + Appends the contents of a dialect registry to the context. + + Args: + registry: A DialectRegistry containing dialects to append.)") .def_prop_rw("emit_error_diagnostics", &PyMlirContext::getEmitErrorDiagnostics, &PyMlirContext::setEmitErrorDiagnostics, - "Emit error diagnostics to diagnostic handlers. By default " - "error diagnostics are captured and reported through " - "MLIRError exceptions.") - .def("load_all_available_dialects", [](PyMlirContext &self) { - mlirContextLoadAllAvailableDialects(self.get()); - }); + R"( + Controls whether error diagnostics are emitted to diagnostic handlers. + + By default, error diagnostics are captured and reported through MLIRError exceptions.)") + .def( + "load_all_available_dialects", + [](PyMlirContext &self) { + mlirContextLoadAllAvailableDialects(self.get()); + }, + R"( + Loads all dialects available in the registry into the context. + + This eagerly loads all dialects that have been registered, making them + immediately available for use.)"); //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor //---------------------------------------------------------------------------- nb::class_<PyDialectDescriptor>(m, "DialectDescriptor") - .def_prop_ro("namespace", - [](PyDialectDescriptor &self) { - MlirStringRef ns = mlirDialectGetNamespace(self.get()); - return nb::str(ns.data, ns.length); - }) + .def_prop_ro( + "namespace", + [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + return nb::str(ns.data, ns.length); + }, + "Returns the namespace of the dialect.") .def( "__repr__", [](PyDialectDescriptor &self) { @@ -3143,35 +3163,43 @@ void mlir::python::populateIRCore(nb::module_ &m) { repr.append(">"); return repr; }, - nb::sig("def __repr__(self) -> str")); + nb::sig("def __repr__(self) -> str"), + "Returns a string representation of the dialect descriptor."); //---------------------------------------------------------------------------- // Mapping of PyDialects //---------------------------------------------------------------------------- nb::class_<PyDialects>(m, "Dialects") - .def("__getitem__", - [=](PyDialects &self, std::string keyName) { - MlirDialect dialect = - self.getDialectForKey(keyName, /*attrError=*/false); - nb::object descriptor = - nb::cast(PyDialectDescriptor{self.getContext(), dialect}); - return createCustomDialectWrapper(keyName, std::move(descriptor)); - }) - .def("__getattr__", [=](PyDialects &self, std::string attrName) { - MlirDialect dialect = - self.getDialectForKey(attrName, /*attrError=*/true); - nb::object descriptor = - nb::cast(PyDialectDescriptor{self.getContext(), dialect}); - return createCustomDialectWrapper(attrName, std::move(descriptor)); - }); + .def( + "__getitem__", + [=](PyDialects &self, std::string keyName) { + MlirDialect dialect = + self.getDialectForKey(keyName, /*attrError=*/false); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(keyName, std::move(descriptor)); + }, + "Gets a dialect by name using subscript notation.") + .def( + "__getattr__", + [=](PyDialects &self, std::string attrName) { + MlirDialect dialect = + self.getDialectForKey(attrName, /*attrError=*/true); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(attrName, std::move(descriptor)); + }, + "Gets a dialect by name using attribute notation."); //---------------------------------------------------------------------------- // Mapping of PyDialect //---------------------------------------------------------------------------- nb::class_<PyDialect>(m, "Dialect") - .def(nb::init<nb::object>(), nb::arg("descriptor")) - .def_prop_ro("descriptor", - [](PyDialect &self) { return self.getDescriptor(); }) + .def(nb::init<nb::object>(), nb::arg("descriptor"), + "Creates a Dialect from a DialectDescriptor.") + .def_prop_ro( + "descriptor", [](PyDialect &self) { return self.getDescriptor(); }, + "Returns the DialectDescriptor for this dialect.") .def( "__repr__", [](const nb::object &self) { @@ -3181,31 +3209,43 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::str(" (class ") + clazz.attr("__module__") + nb::str(".") + clazz.attr("__name__") + nb::str(")>"); }, - nb::sig("def __repr__(self) -> str")); + nb::sig("def __repr__(self) -> str"), + "Returns a string representation of the dialect."); //---------------------------------------------------------------------------- // Mapping of PyDialectRegistry //---------------------------------------------------------------------------- nb::class_<PyDialectRegistry>(m, "DialectRegistry") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule, + "Gets a capsule wrapping the MlirDialectRegistry.") .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, - &PyDialectRegistry::createFromCapsule) - .def(nb::init<>()); + &PyDialectRegistry::createFromCapsule, + "Creates a DialectRegistry from a capsule wrapping " + "`MlirDialectRegistry`.") + .def(nb::init<>(), "Creates a new empty dialect registry."); //---------------------------------------------------------------------------- // Mapping of Location //---------------------------------------------------------------------------- nb::class_<PyLocation>(m, "Location") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) - .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) - .def("__enter__", &PyLocation::contextEnter) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule, + "Gets a capsule wrapping the MlirLocation.") + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule, + "Creates a Location from a capsule wrapping MlirLocation.") + .def("__enter__", &PyLocation::contextEnter, + "Enters the location as a context manager.") .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), - nb::arg("exc_value").none(), nb::arg("traceback").none()) - .def("__eq__", - [](PyLocation &self, PyLocation &other) -> bool { - return mlirLocationEqual(self, other); - }) - .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) + nb::arg("exc_value").none(), nb::arg("traceback").none(), + "Exits the location context manager.") + .def( + "__eq__", + [](PyLocation &self, PyLocation &other) -> bool { + return mlirLocationEqual(self, other); + }, + "Compares two locations for equality.") + .def( + "__eq__", [](PyLocation &self, nb::object other) { return false; }, + "Compares location with non-location object (always returns False).") .def_prop_ro_static( "current", [](nb::object & /*class*/) -> std::optional<PyLocation *> { @@ -3217,7 +3257,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { // clang-format off nb::sig("def current(/) -> Location | None"), // clang-format on - "Gets the Location bound to the current thread or raises ValueError") + "Gets the Location bound to the current thread or raises ValueError.") .def_static( "unknown", [](DefaultingPyMlirContext context) { @@ -3225,13 +3265,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirLocationUnknownGet(context->get())); }, nb::arg("context") = nb::none(), - "Gets a Location representing an unknown location") + "Gets a Location representing an unknown location.") .def_static( "callsite", [](PyLocation callee, const std::vector<PyLocation> &frames, DefaultingPyMlirContext context) { if (frames.empty()) - throw nb::value_error("No caller frames provided"); + throw nb::value_error("No caller frames provided."); MlirLocation caller = frames.back().get(); for (const PyLocation &frame : llvm::reverse(llvm::ArrayRef(frames).drop_back())) @@ -3240,18 +3280,23 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirLocationCallSiteGet(callee.get(), caller)); }, nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(), - kContextGetCallSiteLocationDocstring) - .def("is_a_callsite", mlirLocationIsACallSite) - .def_prop_ro("callee", - [](PyLocation &self) { - return PyLocation(self.getContext(), - mlirLocationCallSiteGetCallee(self)); - }) - .def_prop_ro("caller", - [](PyLocation &self) { - return PyLocation(self.getContext(), - mlirLocationCallSiteGetCaller(self)); - }) + "Gets a Location representing a caller and callsite.") + .def("is_a_callsite", mlirLocationIsACallSite, + "Returns True if this location is a CallSiteLoc.") + .def_prop_ro( + "callee", + [](PyLocation &self) { + return PyLocation(self.getContext(), + mlirLocationCallSiteGetCallee(self)); + }, + "Gets the callee location from a CallSiteLoc.") + .def_prop_ro( + "caller", + [](PyLocation &self) { + return PyLocation(self.getContext(), + mlirLocationCallSiteGetCaller(self)); + }, + "Gets the caller location from a CallSiteLoc.") .def_static( "file", [](std::string filename, int line, int col, @@ -3262,7 +3307,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { context->get(), toMlirStringRef(filename), line, col)); }, nb::arg("filename"), nb::arg("line"), nb::arg("col"), - nb::arg("context") = nb::none(), kContextGetFileLocationDocstring) + nb::arg("context") = nb::none(), + "Gets a Location representing a file, line and column.") .def_static( "file", [](std::string filename, int startLine, int startCol, int endLine, @@ -3274,17 +3320,25 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"), nb::arg("end_line"), nb::arg("end_col"), - nb::arg("context") = nb::none(), kContextGetFileRangeDocstring) - .def("is_a_file", mlirLocationIsAFileLineColRange) - .def_prop_ro("filename", - [](MlirLocation loc) { - return mlirIdentifierStr( - mlirLocationFileLineColRangeGetFilename(loc)); - }) - .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine) - .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn) - .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine) - .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn) + nb::arg("context") = nb::none(), + "Gets a Location representing a file, line and column range.") + .def("is_a_file", mlirLocationIsAFileLineColRange, + "Returns True if this location is a FileLineColLoc.") + .def_prop_ro( + "filename", + [](MlirLocation loc) { + return mlirIdentifierStr( + mlirLocationFileLineColRangeGetFilename(loc)); + }, + "Gets the filename from a FileLineColLoc.") + .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine, + "Gets the start line number from a `FileLineColLoc`.") + .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn, + "Gets the start column number from a `FileLineColLoc`.") + .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine, + "Gets the end line number from a `FileLineColLoc`.") + .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn, + "Gets the end column number from a `FileLineColLoc`.") .def_static( "fused", [](const std::vector<PyLocation> &pyLocations, @@ -3300,8 +3354,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyLocation(context->getRef(), location); }, nb::arg("locations"), nb::arg("metadata") = nb::none(), - nb::arg("context") = nb::none(), kContextGetFusedLocationDocstring) - .def("is_a_fused", mlirLocationIsAFused) + nb::arg("context") = nb::none(), + "Gets a Location representing a fused location with optional " + "metadata.") + .def("is_a_fused", mlirLocationIsAFused, + "Returns True if this location is a `FusedLoc`.") .def_prop_ro( "locations", [](PyLocation &self) { @@ -3314,7 +3371,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { for (unsigned i = 0; i < numLocations; ++i) pyLocations.emplace_back(self.getContext(), locations[i]); return pyLocations; - }) + }, + "Gets the list of locations from a `FusedLoc`.") .def_static( "name", [](std::string name, std::optional<PyLocation> childLoc, @@ -3327,17 +3385,24 @@ void mlir::python::populateIRCore(nb::module_ &m) { : mlirLocationUnknownGet(context->get()))); }, nb::arg("name"), nb::arg("childLoc") = nb::none(), - nb::arg("context") = nb::none(), kContextGetNameLocationDocString) - .def("is_a_name", mlirLocationIsAName) - .def_prop_ro("name_str", - [](MlirLocation loc) { - return mlirIdentifierStr(mlirLocationNameGetName(loc)); - }) - .def_prop_ro("child_loc", - [](PyLocation &self) { - return PyLocation(self.getContext(), - mlirLocationNameGetChildLoc(self)); - }) + nb::arg("context") = nb::none(), + "Gets a Location representing a named location with optional child " + "location.") + .def("is_a_name", mlirLocationIsAName, + "Returns True if this location is a `NameLoc`.") + .def_prop_ro( + "name_str", + [](MlirLocation loc) { + return mlirIdentifierStr(mlirLocationNameGetName(loc)); + }, + "Gets the name string from a `NameLoc`.") + .def_prop_ro( + "child_loc", + [](PyLocation &self) { + return PyLocation(self.getContext(), + mlirLocationNameGetChildLoc(self)); + }, + "Gets the child location from a `NameLoc`.") .def_static( "from_attr", [](PyAttribute &attribute, DefaultingPyMlirContext context) { @@ -3345,41 +3410,59 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirLocationFromAttribute(attribute)); }, nb::arg("attribute"), nb::arg("context") = nb::none(), - "Gets a Location from a LocationAttr") + "Gets a Location from a `LocationAttr`.") .def_prop_ro( "context", [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> { return self.getContext().getObject(); }, - "Context that owns the Location") + "Context that owns the `Location`.") .def_prop_ro( "attr", [](PyLocation &self) { return PyAttribute(self.getContext(), mlirLocationGetAttribute(self)); }, - "Get the underlying LocationAttr") + "Get the underlying `LocationAttr`.") .def( "emit_error", [](PyLocation &self, std::string message) { mlirEmitError(self, message.c_str()); }, - nb::arg("message"), "Emits an error at this location") - .def("__repr__", [](PyLocation &self) { - PyPrintAccumulator printAccum; - mlirLocationPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }); + nb::arg("message"), + R"( + Emits an error diagnostic at this location. + + Args: + message: The error message to emit.)") + .def( + "__repr__", + [](PyLocation &self) { + PyPrintAccumulator printAccum; + mlirLocationPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + "Returns the assembly representation of the location."); //---------------------------------------------------------------------------- // Mapping of Module //---------------------------------------------------------------------------- nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable()) - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule, + "Gets a capsule wrapping the MlirModule.") .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule, - kModuleCAPICreate) - .def("_clear_mlir_module", &PyModule::clearMlirModule) + R"( + Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`). + + This returns a new object **BUT** `_clear_mlir_module(module)` must be called to + prevent double-frees (of the underlying `mlir::Module`).)") + .def("_clear_mlir_module", &PyModule::clearMlirModule, + R"( + Clears the internal MLIR module reference. + + This is used internally to prevent double-free when ownership is transferred + via the C API capsule mechanism. Not intended for normal use.)") .def_static( "parse", [](const std::string &moduleAsm, DefaultingPyMlirContext context) @@ -3427,13 +3510,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { MlirModule module = mlirModuleCreateEmpty(pyLoc.get()); return PyModule::forModule(module).releaseObject(); }, - nb::arg("loc") = nb::none(), "Creates an empty module") + nb::arg("loc") = nb::none(), "Creates an empty module.") .def_prop_ro( "context", [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> { return self.getContext().getObject(); }, - "Context that created the Module") + "Context that created the `Module`.") .def_prop_ro( "operation", [](PyModule &self) -> nb::typed<nb::object, PyOperation> { @@ -3442,7 +3525,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { self.getRef().releaseObject()) .releaseObject(); }, - "Accesses the module as an operation") + "Accesses the module as an operation.") .def_prop_ro( "body", [](PyModule &self) { @@ -3452,7 +3535,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); return returnBlock; }, - "Return the block for this module") + "Return the block for this module.") .def( "dump", [](PyModule &self) { @@ -3465,39 +3548,59 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Defer to the operation's __str__. return self.attr("operation").attr("__str__")(); }, - nb::sig("def __str__(self) -> str"), kOperationStrDunderDocstring) + nb::sig("def __str__(self) -> str"), + R"( + Gets the assembly form of the operation with default options. + + If more advanced control over the assembly formatting or I/O options is needed, + use the dedicated print or get_asm method, which supports keyword arguments to + customize behavior. + )") .def( "__eq__", [](PyModule &self, PyModule &other) { return mlirModuleEqual(self.get(), other.get()); }, - "other"_a) - .def("__hash__", - [](PyModule &self) { return mlirModuleHashValue(self.get()); }); + "other"_a, "Compares two modules for equality.") + .def( + "__hash__", + [](PyModule &self) { return mlirModuleHashValue(self.get()); }, + "Returns the hash value of the module."); //---------------------------------------------------------------------------- // Mapping of Operation. //---------------------------------------------------------------------------- nb::class_<PyOperationBase>(m, "_OperationBase") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, - [](PyOperationBase &self) { - return self.getOperation().getCapsule(); - }) - .def("__eq__", - [](PyOperationBase &self, PyOperationBase &other) { - return mlirOperationEqual(self.getOperation().get(), - other.getOperation().get()); - }) - .def("__eq__", - [](PyOperationBase &self, nb::object other) { return false; }) - .def("__hash__", - [](PyOperationBase &self) { - return mlirOperationHashValue(self.getOperation().get()); - }) - .def_prop_ro("attributes", - [](PyOperationBase &self) { - return PyOpAttributeMap(self.getOperation().getRef()); - }) + .def_prop_ro( + MLIR_PYTHON_CAPI_PTR_ATTR, + [](PyOperationBase &self) { + return self.getOperation().getCapsule(); + }, + "Gets a capsule wrapping the `MlirOperation`.") + .def( + "__eq__", + [](PyOperationBase &self, PyOperationBase &other) { + return mlirOperationEqual(self.getOperation().get(), + other.getOperation().get()); + }, + "Compares two operations for equality.") + .def( + "__eq__", + [](PyOperationBase &self, nb::object other) { return false; }, + "Compares operation with non-operation object (always returns " + "False).") + .def( + "__hash__", + [](PyOperationBase &self) { + return mlirOperationHashValue(self.getOperation().get()); + }, + "Returns the hash value of the operation.") + .def_prop_ro( + "attributes", + [](PyOperationBase &self) { + return PyOpAttributeMap(self.getOperation().getRef()); + }, + "Returns a dictionary-like map of operation attributes.") .def_prop_ro( "context", [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> { @@ -3505,22 +3608,28 @@ void mlir::python::populateIRCore(nb::module_ &m) { concreteOperation.checkValid(); return concreteOperation.getContext().getObject(); }, - "Context that owns the Operation") - .def_prop_ro("name", - [](PyOperationBase &self) { - auto &concreteOperation = self.getOperation(); - concreteOperation.checkValid(); - MlirOperation operation = concreteOperation.get(); - return mlirIdentifierStr(mlirOperationGetName(operation)); - }) - .def_prop_ro("operands", - [](PyOperationBase &self) { - return PyOpOperandList(self.getOperation().getRef()); - }) - .def_prop_ro("regions", - [](PyOperationBase &self) { - return PyRegionList(self.getOperation().getRef()); - }) + "Context that owns the operation.") + .def_prop_ro( + "name", + [](PyOperationBase &self) { + auto &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + MlirOperation operation = concreteOperation.get(); + return mlirIdentifierStr(mlirOperationGetName(operation)); + }, + "Returns the fully qualified name of the operation.") + .def_prop_ro( + "operands", + [](PyOperationBase &self) { + return PyOpOperandList(self.getOperation().getRef()); + }, + "Returns the list of operation operands.") + .def_prop_ro( + "regions", + [](PyOperationBase &self) { + return PyRegionList(self.getOperation().getRef()); + }, + "Returns the list of operation regions.") .def_prop_ro( "results", [](PyOperationBase &self) { @@ -3551,14 +3660,16 @@ void mlir::python::populateIRCore(nb::module_ &m) { "defined or derived from."), nb::for_setter("Sets the source location the operation was defined " "or derived from.")) - .def_prop_ro("parent", - [](PyOperationBase &self) - -> std::optional<nb::typed<nb::object, PyOperation>> { - auto parent = self.getOperation().getParentOperation(); - if (parent) - return parent->getObject(); - return {}; - }) + .def_prop_ro( + "parent", + [](PyOperationBase &self) + -> std::optional<nb::typed<nb::object, PyOperation>> { + auto parent = self.getOperation().getParentOperation(); + if (parent) + return parent->getObject(); + return {}; + }, + "Returns the parent operation, or `None` if at top level.") .def( "__str__", [](PyOperationBase &self) { @@ -3579,7 +3690,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::overload_cast<PyAsmState &, nb::object, bool>( &PyOperationBase::print), nb::arg("state"), nb::arg("file") = nb::none(), - nb::arg("binary") = false, kOperationPrintStateDocstring) + nb::arg("binary") = false, + R"( + Prints the assembly form of the operation to a file like object. + + Args: + state: `AsmState` capturing the operation numbering and flags. + file: Optional file like object to write to. Defaults to sys.stdout. + binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)") .def("print", nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>, bool, bool, bool, bool, bool, bool, nb::object, @@ -3594,10 +3712,47 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("use_name_loc_as_prefix") = false, nb::arg("assume_verified") = false, nb::arg("file") = nb::none(), nb::arg("binary") = false, nb::arg("skip_regions") = false, - kOperationPrintDocstring) + R"( + Prints the assembly form of the operation to a file like object. + + Args: + large_elements_limit: Whether to elide elements attributes above this + number of elements. Defaults to None (no limit). + large_resource_limit: Whether to elide resource attributes above this + number of characters. Defaults to None (no limit). If large_elements_limit + is set and this is None, the behavior will be to use large_elements_limit + as large_resource_limit. + enable_debug_info: Whether to print debug/location information. Defaults + to False. + pretty_debug_info: Whether to format debug information for easier reading + by a human (warning: the result is unparseable). Defaults to False. + print_generic_op_form: Whether to print the generic assembly forms of all + ops. Defaults to False. + use_local_scope: Whether to print in a way that is more optimized for + multi-threaded access but may not be consistent with how the overall + module prints. + use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as + prefixes for the SSA identifiers. Defaults to False. + assume_verified: By default, if not printing generic form, the verifier + will be run and if it fails, generic form will be printed with a comment + about failed verification. While a reasonable default for interactive use, + for systematic use, it is often better for the caller to verify explicitly + and report failures in a more robust fashion. Set this to True if doing this + in order to avoid running a redundant verification. If the IR is actually + invalid, behavior is undefined. + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + skip_regions: Whether to skip printing regions. Defaults to False.)") .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"), nb::arg("desired_version") = nb::none(), - kOperationPrintBytecodeDocstring) + R"( + Write the bytecode form of the operation to a file like object. + + Args: + file: The file like object to write to. + desired_version: Optional version of bytecode to emit. + Returns: + The bytecode writer status.)") .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. nb::arg("binary") = false, @@ -3609,7 +3764,17 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("use_local_scope") = false, nb::arg("use_name_loc_as_prefix") = false, nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, - kOperationGetAsmDocstring) + R"( + Gets the assembly form of the operation with all options available. + + Args: + binary: Whether to return a bytes (True) or str (False) object. Defaults to + False. + ... others ...: See the print() method for common keyword arguments for + configuring the printout. + Returns: + Either a bytes or str object, depending on the setting of the `binary` + argument.)") .def("verify", &PyOperationBase::verify, "Verify the operation. Raises MLIRError if verification fails, and " "returns true otherwise.") @@ -3621,18 +3786,31 @@ void mlir::python::populateIRCore(nb::module_ &m) { "block.") .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, nb::arg("other"), - "Given an operation 'other' that is within the same parent block, " - "return" - "whether the current operation is before 'other' in the operation " - "list" - "of the parent block.") + R"( + Checks if this operation is before another in the same block. + + Args: + other: Another operation in the same parent block. + + Returns: + True if this operation is before `other` in the operation list of the parent block.)") .def( "clone", [](PyOperationBase &self, const nb::object &ip) -> nb::typed<nb::object, PyOperation> { return self.getOperation().clone(ip); }, - nb::arg("ip") = nb::none()) + nb::arg("ip") = nb::none(), + R"( + Creates a deep copy of the operation. + + Args: + ip: Optional insertion point where the cloned operation should be inserted. + If None, the current insertion point is used. If False, the operation + remains detached. + + Returns: + A new Operation that is a clone of this operation.)") .def( "detach_from_parent", [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> { @@ -3653,13 +3831,24 @@ void mlir::python::populateIRCore(nb::module_ &m) { return operation.isAttached(); }, "Reports if the operation is attached to its parent block.") - .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) + .def( + "erase", [](PyOperationBase &self) { self.getOperation().erase(); }, + R"( + Erases the operation and frees its memory. + + Note: + After erasing, any Python references to the operation become invalid.)") .def("walk", &PyOperationBase::walk, nb::arg("callback"), nb::arg("walk_order") = MlirWalkPostOrder, // clang-format off - nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None") + nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"), // clang-format on - ); + R"( + Walks the operation tree with a callback function. + + Args: + callback: A callable that takes an Operation and returns a WalkResult. + walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)"); nb::class_<PyOperation, PyOperationBase>(m, "Operation") .def_static( @@ -3692,7 +3881,22 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(), nb::arg("successors") = nb::none(), nb::arg("regions") = 0, nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(), - nb::arg("infer_type") = false, kOperationCreateDocstring) + nb::arg("infer_type") = false, + R"( + Creates a new operation. + + Args: + name: Operation name (e.g. `dialect.operation`). + results: Optional sequence of Type representing op result types. + operands: Optional operands of the operation. + attributes: Optional Dict of {str: Attribute}. + successors: Optional List of Block for the operation's successors. + regions: Number of regions to create (default = 0). + location: Optional Location object (defaults to resolve from context manager). + ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager). + infer_type: Whether to infer result types (default = False). + Returns: + A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)") .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, @@ -3705,18 +3909,30 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("context") = nb::none(), "Parses an operation. Supports both text assembly format and binary " "bytecode format.") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule, + "Gets a capsule wrapping the MlirOperation.") .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, - &PyOperation::createFromCapsule) - .def_prop_ro("operation", - [](nb::object self) -> nb::typed<nb::object, PyOperation> { - return self; - }) - .def_prop_ro("opview", - [](PyOperation &self) -> nb::typed<nb::object, PyOpView> { - return self.createOpView(); - }) - .def_prop_ro("block", &PyOperation::getBlock) + &PyOperation::createFromCapsule, + "Creates an Operation from a capsule wrapping MlirOperation.") + .def_prop_ro( + "operation", + [](nb::object self) -> nb::typed<nb::object, PyOperation> { + return self; + }, + "Returns self (the operation).") + .def_prop_ro( + "opview", + [](PyOperation &self) -> nb::typed<nb::object, PyOpView> { + return self.createOpView(); + }, + R"( + Returns an OpView of this operation. + + Note: + If the operation has a registered and loaded dialect then this OpView will + be concrete wrapper class.)") + .def_prop_ro("block", &PyOperation::getBlock, + "Returns the block containing this operation.") .def_prop_ro( "successors", [](PyOperationBase &self) { @@ -3830,7 +4046,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("cls"), nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "", nb::arg("context") = nb::none(), - "Parses a specific, generated OpView based on class level attributes"); + "Parses a specific, generated OpView based on class level attributes."); //---------------------------------------------------------------------------- // Mapping of PyRegion. @@ -3856,17 +4072,22 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyBlockIterator(self.getParentOperation(), firstBlock); }, "Iterates over blocks in the region.") - .def("__eq__", - [](PyRegion &self, PyRegion &other) { - return self.get().ptr == other.get().ptr; - }) - .def("__eq__", [](PyRegion &self, nb::object &other) { return false; }); + .def( + "__eq__", + [](PyRegion &self, PyRegion &other) { + return self.get().ptr == other.get().ptr; + }, + "Compares two regions for pointer equality.") + .def( + "__eq__", [](PyRegion &self, nb::object &other) { return false; }, + "Compares region with non-region object (always returns False)."); //---------------------------------------------------------------------------- // Mapping of PyBlock. //---------------------------------------------------------------------------- nb::class_<PyBlock>(m, "Block") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule, + "Gets a capsule wrapping the MlirBlock.") .def_prop_ro( "owner", [](PyBlock &self) -> nb::typed<nb::object, PyOpView> { @@ -3893,14 +4114,26 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirBlockAddArgument(self.get(), type, loc)); }, "type"_a, "loc"_a, - "Append an argument of the specified type to the block and returns " - "the newly added argument.") + R"( + Appends an argument of the specified type to the block. + + Args: + type: The type of the argument to add. + loc: The source location for the argument. + + Returns: + The newly added block argument.)") .def( "erase_argument", [](PyBlock &self, unsigned index) { return mlirBlockEraseArgument(self.get(), index); }, - "Erase the argument at 'index' and remove it from the argument list.") + nb::arg("index"), + R"( + Erases the argument at the specified index. + + Args: + index: The index of the argument to erase.)") .def_prop_ro( "operations", [](PyBlock &self) { @@ -3928,7 +4161,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirBlockDetach(b); mlirRegionAppendOwnedBlock(region.get(), b); }, - "Append this block to a region, transferring ownership if necessary") + nb::arg("region"), + R"( + Appends this block to a region. + + Transfers ownership if the block is currently owned by another region. + + Args: + region: The region to append the block to.)") .def( "create_before", [](PyBlock &self, const nb::args &pyArgTypes, @@ -3969,15 +4209,21 @@ void mlir::python::populateIRCore(nb::module_ &m) { firstOperation); }, "Iterates over operations in the block.") - .def("__eq__", - [](PyBlock &self, PyBlock &other) { - return self.get().ptr == other.get().ptr; - }) - .def("__eq__", [](PyBlock &self, nb::object &other) { return false; }) - .def("__hash__", - [](PyBlock &self) { - return static_cast<size_t>(llvm::hash_value(self.get().ptr)); - }) + .def( + "__eq__", + [](PyBlock &self, PyBlock &other) { + return self.get().ptr == other.get().ptr; + }, + "Compares two blocks for pointer equality.") + .def( + "__eq__", [](PyBlock &self, nb::object &other) { return false; }, + "Compares block with non-block object (always returns False).") + .def( + "__hash__", + [](PyBlock &self) { + return static_cast<size_t>(llvm::hash_value(self.get().ptr)); + }, + "Returns the hash value of the block.") .def( "__str__", [](PyBlock &self) { @@ -4000,8 +4246,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { self.getParentOperation().getObject()); }, nb::arg("operation"), - "Appends an operation to this block. If the operation is currently " - "in another block, it will be moved.") + R"( + Appends an operation to this block. + + If the operation is currently in another block, it will be moved. + + Args: + operation: The operation to append to the block.)") .def_prop_ro( "successors", [](PyBlock &self) { @@ -4022,10 +4273,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::class_<PyInsertionPoint>(m, "InsertionPoint") .def(nb::init<PyBlock &>(), nb::arg("block"), "Inserts after the last operation but still inside the block.") - .def("__enter__", &PyInsertionPoint::contextEnter) + .def("__enter__", &PyInsertionPoint::contextEnter, + "Enters the insertion point as a context manager.") .def("__exit__", &PyInsertionPoint::contextExit, nb::arg("exc_type").none(), nb::arg("exc_value").none(), - nb::arg("traceback").none()) + nb::arg("traceback").none(), + "Exits the insertion point context manager.") .def_prop_ro_static( "current", [](nb::object & /*class*/) { @@ -4036,20 +4289,50 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::sig("def current(/) -> InsertionPoint"), "Gets the InsertionPoint bound to the current thread or raises " - "ValueError if none has been set") + "ValueError if none has been set.") .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"), "Inserts before a referenced operation.") .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, - nb::arg("block"), "Inserts at the beginning of the block.") + nb::arg("block"), + R"( + Creates an insertion point at the beginning of a block. + + Args: + block: The block at whose beginning operations should be inserted. + + Returns: + An InsertionPoint at the block's beginning.)") .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, - nb::arg("block"), "Inserts before the block terminator.") + nb::arg("block"), + R"( + Creates an insertion point before a block's terminator. + + Args: + block: The block whose terminator to insert before. + + Returns: + An InsertionPoint before the terminator. + + Raises: + ValueError: If the block has no terminator.)") .def_static("after", &PyInsertionPoint::after, nb::arg("operation"), - "Inserts after the operation.") + R"( + Creates an insertion point immediately after an operation. + + Args: + operation: The operation after which to insert. + + Returns: + An InsertionPoint after the operation.)") .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), - "Inserts an operation.") + R"( + Inserts an operation at this insertion point. + + Args: + operation: The operation to insert.)") .def_prop_ro( "block", [](PyInsertionPoint &self) { return self.getBlock(); }, - "Returns the block that this InsertionPoint points to.") + "Returns the block that this `InsertionPoint` points to.") .def_prop_ro( "ref_operation", [](PyInsertionPoint &self) @@ -4061,7 +4344,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "The reference operation before which new operations are " "inserted, or None if the insertion point is at the end of " - "the block"); + "the block."); //---------------------------------------------------------------------------- // Mapping of PyAttribute. @@ -4070,10 +4353,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Delegate to the PyAttribute copy constructor, which will also lifetime // extend the backing context which owns the MlirAttribute. .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"), - "Casts the passed attribute to the generic Attribute") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) - .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, - &PyAttribute::createFromCapsule) + "Casts the passed attribute to the generic `Attribute`.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule, + "Gets a capsule wrapping the MlirAttribute.") + .def_static( + MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule, + "Creates an Attribute from a capsule wrapping `MlirAttribute`.") .def_static( "parse", [](const std::string &attrSpec, DefaultingPyMlirContext context) @@ -4086,33 +4371,49 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyAttribute(context.get()->getRef(), attr).maybeDownCast(); }, nb::arg("asm"), nb::arg("context") = nb::none(), - "Parses an attribute from an assembly form. Raises an MLIRError on " + "Parses an attribute from an assembly form. Raises an `MLIRError` on " "failure.") .def_prop_ro( "context", [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> { return self.getContext().getObject(); }, - "Context that owns the Attribute") - .def_prop_ro("type", - [](PyAttribute &self) -> nb::typed<nb::object, PyType> { - return PyType(self.getContext(), - mlirAttributeGetType(self)) - .maybeDownCast(); - }) + "Context that owns the `Attribute`.") + .def_prop_ro( + "type", + [](PyAttribute &self) -> nb::typed<nb::object, PyType> { + return PyType(self.getContext(), mlirAttributeGetType(self)) + .maybeDownCast(); + }, + "Returns the type of the `Attribute`.") .def( "get_named", [](PyAttribute &self, std::string name) { return PyNamedAttribute(self, std::move(name)); }, - nb::keep_alive<0, 1>(), "Binds a name to the attribute") - .def("__eq__", - [](PyAttribute &self, PyAttribute &other) { return self == other; }) - .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; }) - .def("__hash__", - [](PyAttribute &self) { - return static_cast<size_t>(llvm::hash_value(self.get().ptr)); - }) + nb::keep_alive<0, 1>(), + R"( + Binds a name to the attribute, creating a `NamedAttribute`. + + Args: + name: The name to bind to the `Attribute`. + + Returns: + A `NamedAttribute` with the given name and this attribute.)") + .def( + "__eq__", + [](PyAttribute &self, PyAttribute &other) { return self == other; }, + "Compares two attributes for equality.") + .def( + "__eq__", [](PyAttribute &self, nb::object &other) { return false; }, + "Compares attribute with non-attribute object (always returns " + "False).") + .def( + "__hash__", + [](PyAttribute &self) { + return static_cast<size_t>(llvm::hash_value(self.get().ptr)); + }, + "Returns the hash value of the attribute.") .def( "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, kDumpDocstring) @@ -4125,61 +4426,69 @@ void mlir::python::populateIRCore(nb::module_ &m) { return printAccum.join(); }, "Returns the assembly form of the Attribute.") - .def("__repr__", - [](PyAttribute &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, attribute values are generally considered useful and - // are printed. This may need to be re-evaluated if debug dumps end - // up being excessive. - PyPrintAccumulator printAccum; - printAccum.parts.append("Attribute("); - mlirAttributePrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_prop_ro("typeid", - [](PyAttribute &self) { - MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); - assert(!mlirTypeIDIsNull(mlirTypeID) && - "mlirTypeID was expected to be non-null."); - return PyTypeID(mlirTypeID); - }) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> { - return self.maybeDownCast(); - }); + .def( + "__repr__", + [](PyAttribute &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, attribute values are generally considered useful and + // are printed. This may need to be re-evaluated if debug dumps end + // up being excessive. + PyPrintAccumulator printAccum; + printAccum.parts.append("Attribute("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }, + "Returns a string representation of the attribute.") + .def_prop_ro( + "typeid", + [](PyAttribute &self) { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + return PyTypeID(mlirTypeID); + }, + "Returns the `TypeID` of the attribute.") + .def( + MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> { + return self.maybeDownCast(); + }, + "Downcasts the attribute to a more specific attribute if possible."); //---------------------------------------------------------------------------- // Mapping of PyNamedAttribute //---------------------------------------------------------------------------- nb::class_<PyNamedAttribute>(m, "NamedAttribute") - .def("__repr__", - [](PyNamedAttribute &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("NamedAttribute("); - printAccum.parts.append( - nb::str(mlirIdentifierStr(self.namedAttr.name).data, - mlirIdentifierStr(self.namedAttr.name).length)); - printAccum.parts.append("="); - mlirAttributePrint(self.namedAttr.attribute, - printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) + .def( + "__repr__", + [](PyNamedAttribute &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("NamedAttribute("); + printAccum.parts.append( + nb::str(mlirIdentifierStr(self.namedAttr.name).data, + mlirIdentifierStr(self.namedAttr.name).length)); + printAccum.parts.append("="); + mlirAttributePrint(self.namedAttr.attribute, + printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }, + "Returns a string representation of the named attribute.") .def_prop_ro( "name", [](PyNamedAttribute &self) { return mlirIdentifierStr(self.namedAttr.name); }, - "The name of the NamedAttribute binding") + "The name of the `NamedAttribute` binding.") .def_prop_ro( "attr", [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"), - "The underlying generic attribute of the NamedAttribute binding"); + "The underlying generic attribute of the `NamedAttribute` binding."); //---------------------------------------------------------------------------- // Mapping of PyType. @@ -4188,9 +4497,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Delegate to the PyType copy constructor, which will also lifetime // extend the backing context which owns the MlirType. .def(nb::init<PyType &>(), nb::arg("cast_from_type"), - "Casts the passed type to the generic Type") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) - .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) + "Casts the passed type to the generic `Type`.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule, + "Gets a capsule wrapping the `MlirType`.") + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule, + "Creates a Type from a capsule wrapping `MlirType`.") .def_static( "parse", [](std::string typeSpec, @@ -4203,21 +4514,31 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyType(context.get()->getRef(), type).maybeDownCast(); }, nb::arg("asm"), nb::arg("context") = nb::none(), - kContextParseTypeDocstring) + R"( + Parses the assembly form of a type. + + Returns a Type object or raises an `MLIRError` if the type cannot be parsed. + + See also: https://mlir.llvm.org/docs/LangRef/#type-system)") .def_prop_ro( "context", [](PyType &self) -> nb::typed<nb::object, PyMlirContext> { return self.getContext().getObject(); }, - "Context that owns the Type") - .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) + "Context that owns the `Type`.") + .def( + "__eq__", [](PyType &self, PyType &other) { return self == other; }, + "Compares two types for equality.") .def( "__eq__", [](PyType &self, nb::object &other) { return false; }, - nb::arg("other").none()) - .def("__hash__", - [](PyType &self) { - return static_cast<size_t>(llvm::hash_value(self.get().ptr)); - }) + nb::arg("other").none(), + "Compares type with non-type object (always returns False).") + .def( + "__hash__", + [](PyType &self) { + return static_cast<size_t>(llvm::hash_value(self.get().ptr)); + }, + "Returns the hash value of the `Type`.") .def( "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) .def( @@ -4228,60 +4549,84 @@ void mlir::python::populateIRCore(nb::module_ &m) { printAccum.getUserData()); return printAccum.join(); }, - "Returns the assembly form of the type.") - .def("__repr__", - [](PyType &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, types are an exception as they typically have compact - // assembly forms and printing them is useful. - PyPrintAccumulator printAccum; - printAccum.parts.append("Type("); - mlirTypePrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](PyType &self) -> nb::typed<nb::object, PyType> { - return self.maybeDownCast(); - }) - .def_prop_ro("typeid", [](PyType &self) { - MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); - if (!mlirTypeIDIsNull(mlirTypeID)) - return PyTypeID(mlirTypeID); - auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self))); - throw nb::value_error( - (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); - }); + "Returns the assembly form of the `Type`.") + .def( + "__repr__", + [](PyType &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, types are an exception as they typically have compact + // assembly forms and printing them is useful. + PyPrintAccumulator printAccum; + printAccum.parts.append("Type("); + mlirTypePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }, + "Returns a string representation of the `Type`.") + .def( + MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyType &self) -> nb::typed<nb::object, PyType> { + return self.maybeDownCast(); + }, + "Downcasts the Type to a more specific `Type` if possible.") + .def_prop_ro( + "typeid", + [](PyType &self) { + MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); + if (!mlirTypeIDIsNull(mlirTypeID)) + return PyTypeID(mlirTypeID); + auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self))); + throw nb::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); + }, + "Returns the `TypeID` of the `Type`, or raises `ValueError` if " + "`Type` has no " + "`TypeID`."); //---------------------------------------------------------------------------- // Mapping of PyTypeID. //---------------------------------------------------------------------------- nb::class_<PyTypeID>(m, "TypeID") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) - .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule, + "Gets a capsule wrapping the `MlirTypeID`.") + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule, + "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.") // Note, this tests whether the underlying TypeIDs are the same, // not whether the wrapper MlirTypeIDs are the same, nor whether // the Python objects are the same (i.e., PyTypeID is a value type). - .def("__eq__", - [](PyTypeID &self, PyTypeID &other) { return self == other; }) - .def("__eq__", - [](PyTypeID &self, const nb::object &other) { return false; }) + .def( + "__eq__", + [](PyTypeID &self, PyTypeID &other) { return self == other; }, + "Compares two `TypeID`s for equality.") + .def( + "__eq__", + [](PyTypeID &self, const nb::object &other) { return false; }, + "Compares TypeID with non-TypeID object (always returns False).") // Note, this gives the hash value of the underlying TypeID, not the // hash value of the Python object, nor the hash value of the // MlirTypeID wrapper. - .def("__hash__", [](PyTypeID &self) { - return static_cast<size_t>(mlirTypeIDHashValue(self)); - }); + .def( + "__hash__", + [](PyTypeID &self) { + return static_cast<size_t>(mlirTypeIDHashValue(self)); + }, + "Returns the hash value of the `TypeID`."); //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- - nb::class_<PyValue>(m, "Value") - .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value")) - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) - .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) + m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type")); + + nb::class_<PyValue>(m, "Value", nb::is_generic(), + nb::sig("class Value(Generic[_T])")) + .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"), + "Creates a Value reference from another `Value`.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule, + "Gets a capsule wrapping the `MlirValue`.") + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule, + "Creates a `Value` from a capsule wrapping `MlirValue`.") .def_prop_ro( "context", [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> { @@ -4312,23 +4657,30 @@ void mlir::python::populateIRCore(nb::module_ &m) { assert(false && "Value must be a block argument or an op result"); return nb::none(); }, - // clang-format off - nb::sig("def owner(self) -> Operation | Block | None")) - // clang-format on - .def_prop_ro("uses", - [](PyValue &self) { - return PyOpOperandIterator( - mlirValueGetFirstUse(self.get())); - }) - .def("__eq__", - [](PyValue &self, PyValue &other) { - return self.get().ptr == other.get().ptr; - }) - .def("__eq__", [](PyValue &self, nb::object other) { return false; }) - .def("__hash__", - [](PyValue &self) { - return static_cast<size_t>(llvm::hash_value(self.get().ptr)); - }) + "Returns the owner of the value (`Operation` for results, `Block` " + "for " + "arguments).") + .def_prop_ro( + "uses", + [](PyValue &self) { + return PyOpOperandIterator(mlirValueGetFirstUse(self.get())); + }, + "Returns an iterator over uses of this value.") + .def( + "__eq__", + [](PyValue &self, PyValue &other) { + return self.get().ptr == other.get().ptr; + }, + "Compares two values for pointer equality.") + .def( + "__eq__", [](PyValue &self, nb::object other) { return false; }, + "Compares value with non-value object (always returns False).") + .def( + "__hash__", + [](PyValue &self) { + return static_cast<size_t>(llvm::hash_value(self.get().ptr)); + }, + "Returns the hash value of the value.") .def( "__str__", [](PyValue &self) { @@ -4339,7 +4691,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { printAccum.parts.append(")"); return printAccum.join(); }, - kValueDunderStrDocstring) + R"( + Returns the string form of the value. + + If the value is a block argument, this is the assembly form of its type and the + position in the argument list. If the value is an operation result, this is + equivalent to printing the operation that produced it. + )") .def( "get_name", [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) { @@ -4359,7 +4717,16 @@ void mlir::python::populateIRCore(nb::module_ &m) { return printAccum.join(); }, nb::arg("use_local_scope") = false, - nb::arg("use_name_loc_as_prefix") = false) + nb::arg("use_name_loc_as_prefix") = false, + R"( + Returns the string form of value as an operand. + + Args: + use_local_scope: Whether to use local scope for naming. + use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix. + + Returns: + The value's name as it appears in IR (e.g., `%0`, `%arg0`).)") .def( "get_name", [](PyValue &self, PyAsmState &state) { @@ -4370,25 +4737,30 @@ void mlir::python::populateIRCore(nb::module_ &m) { printAccum.getUserData()); return printAccum.join(); }, - nb::arg("state"), kGetNameAsOperand) - .def_prop_ro("type", - [](PyValue &self) -> nb::typed<nb::object, PyType> { - return PyType(self.getParentOperation()->getContext(), - mlirValueGetType(self.get())) - .maybeDownCast(); - }) + nb::arg("state"), + "Returns the string form of value as an operand (i.e., the ValueID).") + .def_prop_ro( + "type", + [](PyValue &self) -> nb::typed<nb::object, PyType> { + return PyType(self.getParentOperation()->getContext(), + mlirValueGetType(self.get())) + .maybeDownCast(); + }, + "Returns the type of the value.") .def( "set_type", [](PyValue &self, const PyType &type) { - return mlirValueSetType(self.get(), type); + mlirValueSetType(self.get(), type); }, - nb::arg("type")) + nb::arg("type"), "Sets the type of the value.", + nb::sig("def set_type(self, type: _T)")) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { mlirValueReplaceAllUsesOfWith(self.get(), with.get()); }, - kValueReplaceAllUsesWithDocstring) + "Replace all uses of value with the new value, updating anything in " + "the IR that uses `self` to use the other value instead.") .def( "replace_all_uses_except", [](PyValue &self, PyValue &with, PyOperation &exception) { @@ -4434,10 +4806,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("with_"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](PyValue &self) -> nb::typed<nb::object, PyValue> { - return self.maybeDownCast(); - }) + .def( + MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyValue &self) -> nb::typed<nb::object, PyValue> { + return self.maybeDownCast(); + }, + "Downcasts the `Value` to a more specific kind if possible.") .def_prop_ro( "location", [](MlirValue self) { @@ -4445,7 +4819,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyMlirContext::forContext(mlirValueGetContext(self)), mlirValueGetLocation(self)); }, - "Returns the source location the value"); + "Returns the source location of the value."); PyBlockArgument::bind(m); PyOpResult::bind(m); @@ -4453,43 +4827,105 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::class_<PyAsmState>(m, "AsmState") .def(nb::init<PyValue &, bool>(), nb::arg("value"), - nb::arg("use_local_scope") = false) + nb::arg("use_local_scope") = false, + R"( + Creates an `AsmState` for consistent SSA value naming. + + Args: + value: The value to create state for. + use_local_scope: Whether to use local scope for naming.)") .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"), - nb::arg("use_local_scope") = false); + nb::arg("use_local_scope") = false, + R"( + Creates an AsmState for consistent SSA value naming. + + Args: + op: The operation to create state for. + use_local_scope: Whether to use local scope for naming.)"); //---------------------------------------------------------------------------- // Mapping of SymbolTable. //---------------------------------------------------------------------------- nb::class_<PySymbolTable>(m, "SymbolTable") - .def(nb::init<PyOperationBase &>()) - .def("__getitem__", - [](PySymbolTable &self, - const std::string &name) -> nb::typed<nb::object, PyOpView> { - return self.dunderGetItem(name); - }) - .def("insert", &PySymbolTable::insert, nb::arg("operation")) - .def("erase", &PySymbolTable::erase, nb::arg("operation")) - .def("__delitem__", &PySymbolTable::dunderDel) - .def("__contains__", - [](PySymbolTable &table, const std::string &name) { - return !mlirOperationIsNull(mlirSymbolTableLookup( - table, mlirStringRefCreate(name.data(), name.length()))); - }) + .def(nb::init<PyOperationBase &>(), + R"( + Creates a symbol table for an operation. + + Args: + operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`). + + Raises: + TypeError: If the operation is not a symbol table.)") + .def( + "__getitem__", + [](PySymbolTable &self, + const std::string &name) -> nb::typed<nb::object, PyOpView> { + return self.dunderGetItem(name); + }, + R"( + Looks up a symbol by name in the symbol table. + + Args: + name: The name of the symbol to look up. + + Returns: + The operation defining the symbol. + + Raises: + KeyError: If the symbol is not found.)") + .def("insert", &PySymbolTable::insert, nb::arg("operation"), + R"( + Inserts a symbol operation into the symbol table. + + Args: + operation: An operation with a symbol name to insert. + + Returns: + The symbol name attribute of the inserted operation. + + Raises: + ValueError: If the operation does not have a symbol name.)") + .def("erase", &PySymbolTable::erase, nb::arg("operation"), + R"( + Erases a symbol operation from the symbol table. + + Args: + operation: The symbol operation to erase. + + Note: + The operation is also erased from the IR and invalidated.)") + .def("__delitem__", &PySymbolTable::dunderDel, + "Deletes a symbol by name from the symbol table.") + .def( + "__contains__", + [](PySymbolTable &table, const std::string &name) { + return !mlirOperationIsNull(mlirSymbolTableLookup( + table, mlirStringRefCreate(name.data(), name.length()))); + }, + "Checks if a symbol with the given name exists in the table.") // Static helpers. .def_static("set_symbol_name", &PySymbolTable::setSymbolName, - nb::arg("symbol"), nb::arg("name")) + nb::arg("symbol"), nb::arg("name"), + "Sets the symbol name for a symbol operation.") .def_static("get_symbol_name", &PySymbolTable::getSymbolName, - nb::arg("symbol")) + nb::arg("symbol"), + "Gets the symbol name from a symbol operation.") .def_static("get_visibility", &PySymbolTable::getVisibility, - nb::arg("symbol")) + nb::arg("symbol"), + "Gets the visibility attribute of a symbol operation.") .def_static("set_visibility", &PySymbolTable::setVisibility, - nb::arg("symbol"), nb::arg("visibility")) + nb::arg("symbol"), nb::arg("visibility"), + "Sets the visibility attribute of a symbol operation.") .def_static("replace_all_symbol_uses", &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"), - nb::arg("new_symbol"), nb::arg("from_op")) + nb::arg("new_symbol"), nb::arg("from_op"), + "Replaces all uses of a symbol with a new symbol name within " + "the given operation.") .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, nb::arg("from_op"), nb::arg("all_sym_uses_visible"), - nb::arg("callback")); + nb::arg("callback"), + "Walks symbol tables starting from an operation with a " + "callback function."); // Container bindings. PyBlockArgumentList::bind(m); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index a14f09f..ba767ad 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -24,6 +24,8 @@ using namespace mlir::python; NB_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; + m.attr("T") = nb::type_var("T"); + m.attr("U") = nb::type_var("U"); nb::class_<PyGlobals>(m, "_Globals") .def_prop_rw("dialect_search_modules", @@ -102,6 +104,10 @@ NB_MODULE(_mlir, m) { return opClass; }); }, + // clang-format off + nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) " + "-> typing.Callable[[type[T]], type[T]]"), + // clang-format on "dialect_class"_a, nb::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); @@ -114,6 +120,10 @@ NB_MODULE(_mlir, m) { return typeCaster; }); }, + // clang-format off + nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), + // clang-format on "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( @@ -126,6 +136,10 @@ NB_MODULE(_mlir, m) { return valueCaster; }); }, + // clang-format off + nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), + // clang-format on "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values."); diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h index 64ea4329..aea195f 100644 --- a/mlir/lib/Bindings/Python/NanobindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -19,6 +19,7 @@ #include "llvm/Support/raw_ostream.h" #include <string> +#include <typeinfo> #include <variant> template <> @@ -344,7 +345,16 @@ public: /// Binds the indexing and length methods in the Python class. static void bind(nanobind::module_ &m) { - auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName) + const std::type_info &elemTy = typeid(ElementTy); + PyObject *elemTyInfo = nanobind::detail::nb_type_lookup(&elemTy); + assert(elemTyInfo && + "expected nb_type_lookup to succeed for Sliceable elemTy"); + nanobind::handle elemTyName = nanobind::detail::nb_type_name(elemTyInfo); + std::string sig = std::string("class ") + Derived::pyClassName + + "(collections.abc.Sequence[" + + nanobind::cast<std::string>(elemTyName) + "])"; + auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName, + nanobind::sig(sig.c_str())) .def("__add__", &Sliceable::dunderAdd); Derived::bindDerived(clazz); @@ -395,7 +405,6 @@ public: /// Hook for derived classes willing to bind more methods. static void bindDerived(ClassTy &) {} -private: intptr_t startIndex; intptr_t length; intptr_t step; diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 1659437..0ac5fc5 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -27,6 +27,7 @@ #include <cstddef> #include <cstdint> +#include <deque> #include <list> #include <memory> #include <numeric> @@ -830,6 +831,23 @@ namespace { /// This class provides support for reading attribute and type entries from the /// bytecode. Attribute and Type entries are read lazily on demand, so we use /// this reader to manage when to actually parse them from the bytecode. +/// +/// The parsing of attributes & types are generally recursive, this can lead to +/// stack overflows for deeply nested structures, so we track a few extra pieces +/// of information to avoid this: +/// +/// - `depth`: The current depth while parsing nested attributes. We defer on +/// parsing deeply nested attributes to avoid potential stack overflows. The +/// deferred parsing is achieved by reporting a failure when parsing a nested +/// attribute/type and registering the index of the encountered attribute/type +/// in the deferred parsing worklist. Hence, a failure with deffered entry +/// does not constitute a failure, it also requires that folks return on +/// first failure rather than attempting additional parses. +/// - `deferredWorklist`: A list of attribute/type indices that we could not +/// parse due to hitting the depth limit. The worklist is used to capture the +/// indices of attributes/types that need to be parsed/reparsed when we hit +/// the depth limit. This enables moving the tracking of what needs to be +/// parsed to the heap. class AttrTypeReader { /// This class represents a single attribute or type entry. template <typename T> @@ -863,12 +881,34 @@ public: ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData); + LogicalResult readAttribute(uint64_t index, Attribute &result, + uint64_t depth = 0) { + return readEntry(attributes, index, result, "attribute", depth); + } + + LogicalResult readType(uint64_t index, Type &result, uint64_t depth = 0) { + return readEntry(types, index, result, "type", depth); + } + /// Resolve the attribute or type at the given index. Returns nullptr on /// failure. - Attribute resolveAttribute(size_t index) { - return resolveEntry(attributes, index, "Attribute"); + Attribute resolveAttribute(size_t index, uint64_t depth = 0) { + return resolveEntry(attributes, index, "Attribute", depth); + } + Type resolveType(size_t index, uint64_t depth = 0) { + return resolveEntry(types, index, "Type", depth); + } + + Attribute getAttributeOrSentinel(size_t index) { + if (index >= attributes.size()) + return nullptr; + return attributes[index].entry; + } + Type getTypeOrSentinel(size_t index) { + if (index >= types.size()) + return nullptr; + return types[index].entry; } - Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } /// Parse a reference to an attribute or type using the given reader. LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { @@ -909,23 +949,33 @@ public: llvm::getTypeName<T>(), ", but got: ", baseResult); } + /// Add an index to the deferred worklist for re-parsing. + void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); } + private: /// Resolve the given entry at `index`. template <typename T> - T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, - StringRef entryType); + T resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index, + StringRef entryType, uint64_t depth = 0); - /// Parse an entry using the given reader that was encoded using the textual - /// assembly format. + /// Read the entry at the given index, returning failure if the entry is not + /// yet resolved. template <typename T> - LogicalResult parseAsmEntry(T &result, EncodingReader &reader, - StringRef entryType); + LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index, + T &result, StringRef entryType, uint64_t depth); /// Parse an entry using the given reader that was encoded using a custom /// bytecode format. template <typename T> LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, - StringRef entryType); + StringRef entryType, uint64_t index, + uint64_t depth); + + /// Parse an entry using the given reader that was encoded using the textual + /// assembly format. + template <typename T> + LogicalResult parseAsmEntry(T &result, EncodingReader &reader, + StringRef entryType); /// The string section reader used to resolve string references when parsing /// custom encoded attribute/type entries. @@ -951,6 +1001,10 @@ private: /// Reference to the parser configuration. const ParserConfig &parserConfig; + + /// Worklist for deferred attribute/type parsing. This is used to handle + /// deeply nested structures like CallSiteLoc iteratively. + std::vector<uint64_t> deferredWorklist; }; class DialectReader : public DialectBytecodeReader { @@ -959,10 +1013,11 @@ public: const StringSectionReader &stringReader, const ResourceSectionReader &resourceReader, const llvm::StringMap<BytecodeDialect *> &dialectsMap, - EncodingReader &reader, uint64_t &bytecodeVersion) + EncodingReader &reader, uint64_t &bytecodeVersion, + uint64_t depth = 0) : attrTypeReader(attrTypeReader), stringReader(stringReader), resourceReader(resourceReader), dialectsMap(dialectsMap), - reader(reader), bytecodeVersion(bytecodeVersion) {} + reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {} InFlightDiagnostic emitError(const Twine &msg) const override { return reader.emitError(msg); @@ -998,14 +1053,40 @@ public: // IR //===--------------------------------------------------------------------===// + /// The maximum depth to eagerly parse nested attributes/types before + /// deferring. + static constexpr uint64_t maxAttrTypeDepth = 5; + LogicalResult readAttribute(Attribute &result) override { - return attrTypeReader.parseAttribute(reader, result); + uint64_t index; + if (failed(reader.parseVarInt(index))) + return failure(); + if (depth > maxAttrTypeDepth) { + if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) { + result = attr; + return success(); + } + attrTypeReader.addDeferredParsing(index); + return failure(); + } + return attrTypeReader.readAttribute(index, result, depth + 1); } LogicalResult readOptionalAttribute(Attribute &result) override { return attrTypeReader.parseOptionalAttribute(reader, result); } LogicalResult readType(Type &result) override { - return attrTypeReader.parseType(reader, result); + uint64_t index; + if (failed(reader.parseVarInt(index))) + return failure(); + if (depth > maxAttrTypeDepth) { + if (Type type = attrTypeReader.getTypeOrSentinel(index)) { + result = type; + return success(); + } + attrTypeReader.addDeferredParsing(index); + return failure(); + } + return attrTypeReader.readType(index, result, depth + 1); } FailureOr<AsmDialectResourceHandle> readResourceHandle() override { @@ -1095,6 +1176,7 @@ private: const llvm::StringMap<BytecodeDialect *> &dialectsMap; EncodingReader &reader; uint64_t &bytecodeVersion; + uint64_t depth; }; /// Wraps the properties section and handles reading properties out of it. @@ -1238,69 +1320,112 @@ LogicalResult AttrTypeReader::initialize( } template <typename T> -T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, - StringRef entryType) { +T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, + uint64_t index, StringRef entryType, + uint64_t depth) { if (index >= entries.size()) { emitError(fileLoc) << "invalid " << entryType << " index: " << index; return {}; } - // If the entry has already been resolved, there is nothing left to do. - Entry<T> &entry = entries[index]; - if (entry.entry) - return entry.entry; + // Fast path: Try direct parsing without worklist overhead. This handles the + // common case where there are no deferred dependencies. + assert(deferredWorklist.empty()); + T result; + if (succeeded(readEntry(entries, index, result, entryType, depth))) { + assert(deferredWorklist.empty()); + return result; + } + if (deferredWorklist.empty()) { + // Failed with no deferred entries is error. + return T(); + } - // Parse the entry. - EncodingReader reader(entry.data, fileLoc); + // Slow path: Use worklist to handle deferred dependencies. Use a deque to + // iteratively resolve entries with dependencies. + // - Pop from front to process + // - Push new dependencies to front (depth-first) + // - Move failed entries to back (retry after dependencies) + std::deque<size_t> worklist; + llvm::DenseSet<size_t> inWorklist; - // Parse based on how the entry was encoded. - if (entry.hasCustomEncoding) { - if (failed(parseCustomEntry(entry, reader, entryType))) - return T(); - } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { - return T(); + // Add the original index and any dependencies from the fast path attempt. + worklist.push_back(index); + inWorklist.insert(index); + for (uint64_t idx : llvm::reverse(deferredWorklist)) { + if (inWorklist.insert(idx).second) + worklist.push_front(idx); } - if (!reader.empty()) { - reader.emitError("unexpected trailing bytes after " + entryType + " entry"); - return T(); + while (!worklist.empty()) { + size_t currentIndex = worklist.front(); + worklist.pop_front(); + + // Clear the deferred worklist before parsing to capture any new entries. + deferredWorklist.clear(); + + T result; + if (succeeded(readEntry(entries, currentIndex, result, entryType, depth))) { + inWorklist.erase(currentIndex); + continue; + } + + if (deferredWorklist.empty()) { + // Parsing failed with no deferred entries which implies an error. + return T(); + } + + // Move this entry to the back to retry after dependencies. + worklist.push_back(currentIndex); + + // Add dependencies to the front (in reverse so they maintain order). + for (uint64_t idx : llvm::reverse(deferredWorklist)) { + if (inWorklist.insert(idx).second) + worklist.push_front(idx); + } + deferredWorklist.clear(); } - return entry.entry; + return entries[index].entry; } template <typename T> -LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, - StringRef entryType) { - StringRef asmStr; - if (failed(reader.parseNullTerminatedString(asmStr))) - return failure(); +LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries, + uint64_t index, T &result, + StringRef entryType, uint64_t depth) { + if (index >= entries.size()) + return emitError(fileLoc) << "invalid " << entryType << " index: " << index; - // Invoke the MLIR assembly parser to parse the entry text. - size_t numRead = 0; - MLIRContext *context = fileLoc->getContext(); - if constexpr (std::is_same_v<T, Type>) - result = - ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); - else - result = ::parseAttribute(asmStr, context, Type(), &numRead, - /*isKnownNullTerminated=*/true); - if (!result) + // If the entry has already been resolved, return it. + Entry<T> &entry = entries[index]; + if (entry.entry) { + result = entry.entry; + return success(); + } + + // If the entry hasn't been resolved, try to parse it. + EncodingReader reader(entry.data, fileLoc); + LogicalResult parseResult = + entry.hasCustomEncoding + ? parseCustomEntry(entry, reader, entryType, index, depth) + : parseAsmEntry(entry.entry, reader, entryType); + if (failed(parseResult)) return failure(); - // Ensure there weren't dangling characters after the entry. - if (numRead != asmStr.size()) { - return reader.emitError("trailing characters found after ", entryType, - " assembly format: ", asmStr.drop_front(numRead)); - } + if (!reader.empty()) + return reader.emitError("unexpected trailing bytes after " + entryType + + " entry"); + + result = entry.entry; return success(); } template <typename T> LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, EncodingReader &reader, - StringRef entryType) { + StringRef entryType, + uint64_t index, uint64_t depth) { DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, - reader, bytecodeVersion); + reader, bytecodeVersion, depth); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); @@ -1350,6 +1475,33 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, return success(!!entry.entry); } +template <typename T> +LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, + StringRef entryType) { + StringRef asmStr; + if (failed(reader.parseNullTerminatedString(asmStr))) + return failure(); + + // Invoke the MLIR assembly parser to parse the entry text. + size_t numRead = 0; + MLIRContext *context = fileLoc->getContext(); + if constexpr (std::is_same_v<T, Type>) + result = + ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); + else + result = ::parseAttribute(asmStr, context, Type(), &numRead, + /*isKnownNullTerminated=*/true); + if (!result) + return failure(); + + // Ensure there weren't dangling characters after the entry. + if (numRead != asmStr.size()) { + return reader.emitError("trailing characters found after ", entryType, + " assembly format: ", asmStr.drop_front(numRead)); + } + return success(); +} + //===----------------------------------------------------------------------===// // Bytecode Reader //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index eaad8a8..bf23176 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -27,6 +27,10 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) { return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace)); } +MlirTypeID mlirLLVMPointerTypeGetTypeID() { + return wrap(LLVM::LLVMPointerType::getTypeID()); +} + bool mlirTypeIsALLVMPointerType(MlirType type) { return isa<LLVM::LLVMPointerType>(unwrap(type)); } @@ -73,6 +77,10 @@ bool mlirTypeIsALLVMStructType(MlirType type) { return isa<LLVM::LLVMStructType>(unwrap(type)); } +MlirTypeID mlirLLVMStructTypeGetTypeID() { + return wrap(LLVM::LLVMStructType::getTypeID()); +} + bool mlirLLVMStructTypeIsLiteral(MlirType type) { return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified(); } @@ -159,9 +167,8 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations, return wrap(DIExpressionAttr::get( unwrap(ctx), - llvm::map_to_vector( - unwrapList(nOperations, operations, attrStorage), - [](Attribute a) { return cast<DIExpressionElemAttr>(a); }))); + llvm::map_to_vector(unwrapList(nOperations, operations, attrStorage), + llvm::CastTo<DIExpressionElemAttr>))); } MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) { @@ -202,7 +209,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( cast<DIExpressionAttr>(unwrap(allocated)), cast<DIExpressionAttr>(unwrap(associated)), llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), - [](Attribute a) { return cast<DINodeAttr>(a); }))); + llvm::CastTo<DINodeAttr>))); } MlirAttribute mlirLLVMDIDerivedTypeAttrGet( @@ -308,7 +315,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, return wrap(DISubroutineTypeAttr::get( unwrap(ctx), callingConvention, llvm::map_to_vector(unwrapList(nTypes, types, attrStorage), - [](Attribute a) { return cast<DITypeAttr>(a); }))); + llvm::CastTo<DITypeAttr>))); } MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId) { @@ -338,10 +345,10 @@ MlirAttribute mlirLLVMDISubprogramAttrGet( cast<DISubroutineTypeAttr>(unwrap(type)), llvm::map_to_vector( unwrapList(nRetainedNodes, retainedNodes, nodesStorage), - [](Attribute a) { return cast<DINodeAttr>(a); }), + llvm::CastTo<DINodeAttr>), llvm::map_to_vector( unwrapList(nAnnotations, annotations, annotationsStorage), - [](Attribute a) { return cast<DINodeAttr>(a); }))); + llvm::CastTo<DINodeAttr>))); } MlirAttribute mlirLLVMDISubprogramAttrGetScope(MlirAttribute diSubprogram) { @@ -398,7 +405,7 @@ MlirAttribute mlirLLVMDIImportedEntityAttrGet( cast<DINodeAttr>(unwrap(entity)), cast<DIFileAttr>(unwrap(file)), line, cast<StringAttr>(unwrap(name)), llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), - [](Attribute a) { return cast<DINodeAttr>(a); }))); + llvm::CastTo<DINodeAttr>))); } MlirAttribute mlirLLVMDIAnnotationAttrGet(MlirContext ctx, MlirAttribute name, diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 5c2a65d..75c811a 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/Linalg.h" +#include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -62,9 +63,8 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { const linalg::ContractionDimensions &contractionDims = *maybeDims; MLIRContext *ctx = linalgOp.getContext(); - auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute { - return wrap( - DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals))); + auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute { + return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals))); }; result.batch = toAttr(contractionDims.batch); @@ -75,6 +75,38 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { return result; } +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, + size_t numMaps) { + MlirLinalgContractionDimensions result{}; + if (!indexingMaps || numMaps == 0) + return result; + + SmallVector<AffineMap, 3> maps; + maps.reserve(numMaps); + for (size_t i = 0; i < numMaps; ++i) { + maps.push_back(unwrap(indexingMaps[i])); + } + + FailureOr<linalg::ContractionDimensions> maybeDims = + linalg::inferContractionDims(maps); + if (failed(maybeDims)) + return result; + + MLIRContext *ctx = maps[0].getContext(); + + auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute { + return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals))); + }; + + result.batch = toAttr(maybeDims->batch); + result.m = toAttr(maybeDims->m); + result.n = toAttr(maybeDims->n); + result.k = toAttr(maybeDims->k); + + return result; +} + MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) { auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op)); if (!linalgOp) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 2dbb993..81d86ad 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -22,7 +22,7 @@ using namespace mlir; extern "C" MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, const MlirStringRef *sharedLibPaths, - bool enableObjectDump) { + bool enableObjectDump, bool enablePIC) { static bool initOnce = [] { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmParser(); // needed for inline_asm @@ -38,12 +38,17 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); if (!tmBuilderOrError) { - llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; + llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host " + "because: \n"; + consumeError(tmBuilderOrError.takeError()); return MlirExecutionEngine{nullptr}; } + if (enablePIC) + tmBuilderOrError->setRelocationModel(llvm::Reloc::PIC_); auto tmOrError = tmBuilderOrError->createTargetMachine(); if (!tmOrError) { - llvm::errs() << "Failed to create a TargetMachine for the host\n"; + llvm::errs() << "Failed to create a TargetMachine for the host because: \n"; + consumeError(tmOrError.takeError()); return MlirExecutionEngine{nullptr}; } @@ -60,8 +65,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, jitOptions.jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(optLevel); jitOptions.sharedLibPaths = libPaths; jitOptions.enableObjectDump = enableObjectDump; - auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions); + auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions, + std::move(tmOrError.get())); if (!jitOrError) { + llvm::errs() << "Failed to create an ExecutionEngine because: \n"; consumeError(jitOrError.takeError()); return MlirExecutionEngine{nullptr}; } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index f5f4ed3..e2e236a 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -536,7 +536,7 @@ MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, if (failed(memrefType.getStridesAndOffset(strides_, *offset))) return mlirLogicalResultFailure(); - (void)std::copy(strides_.begin(), strides_.end(), strides); + (void)llvm::copy(strides_, strides); return mlirLogicalResultSuccess(); } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 1881865..ffcbed8 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -1129,6 +1129,11 @@ void mlirBlockArgumentSetType(MlirValue value, MlirType type) { blockArg.setType(unwrap(type)); } +void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc) { + if (auto blockArg = llvm::dyn_cast<BlockArgument>(unwrap(value))) + blockArg.setLoc(unwrap(loc)); +} + MlirOperation mlirOpResultGetOwner(MlirValue value) { return wrap(llvm::dyn_cast<OpResult>(unwrap(value)).getOwner()); } diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 3a307a0..7584b17 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -16,8 +16,10 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" @@ -42,6 +44,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8); constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); constexpr Chipset kGfx942 = Chipset(9, 4, 2); constexpr Chipset kGfx950 = Chipset(9, 5, 0); +constexpr Chipset kGfx1250 = Chipset(12, 5, 0); /// Convert an unsigned number `val` to i32. static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, @@ -79,12 +82,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter, return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value); } -static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, - bool value) { - Type llvmI1 = rewriter.getI1Type(); - return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value); -} - /// Returns the linear index used to access an element in the memref. static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, @@ -509,10 +506,16 @@ struct MemoryCounterWaitOpLowering if (std::optional<int> exp = adaptor.getExp()) ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); + if (std::optional<int> tensor = adaptor.getTensor()) + ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor); + rewriter.eraseOp(op); return success(); } + if (adaptor.getTensor()) + return op.emitOpError("unsupported chipset"); + auto getVal = [](Attribute attr) -> unsigned { if (attr) return cast<IntegerAttr>(attr).getInt(); @@ -684,12 +687,11 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, /// intrinsics having been defined before the AMD backend supported bfloat. We /// similarly need to pack 8-bit float types into integers as if they were i8 /// (which they are for the backend's purposes). -static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, - Location loc, - const TypeConverter *typeConverter, - bool isUnsigned, Value llvmInput, - Value mlirInput, - SmallVector<Value, 4> &operands) { +static void wmmaPushInputOperand( + ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, + Value mlirInput, SmallVectorImpl<Value> &operands, + SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) { Type inputType = llvmInput.getType(); auto vectorType = dyn_cast<VectorType>(inputType); if (!vectorType) { @@ -697,10 +699,6 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, return; } Type elemType = vectorType.getElementType(); - - if (elemType.isBF16()) - llvmInput = LLVM::BitcastOp::create( - rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); return; @@ -719,8 +717,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, } else if (elemType.isSignedInteger()) { localIsUnsigned = false; } - Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); - operands.push_back(sign); + attrs.push_back( + NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned))); } int64_t numBits = @@ -751,18 +749,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, - bool clamp, SmallVector<Value, 4> &operands) { + bool clamp, SmallVectorImpl<Value> &operands, + SmallVectorImpl<NamedAttribute> &attrs) { Type inputType = output.getType(); auto vectorType = dyn_cast<VectorType>(inputType); Type elemType = vectorType.getElementType(); - if (elemType.isBF16()) - output = LLVM::BitcastOp::create( - rewriter, loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { - operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); + attrs.push_back( + NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset))); } else if (elemType.isInteger(32)) { - operands.push_back(createI1Constant(rewriter, loc, clamp)); + attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp))); } } @@ -1160,7 +1157,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, k, isRDNA3); // Handle gfx1250. - if (chipset == Chipset{12, 5, 0}) + if (chipset == kGfx1250) return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, elemDestType, k); @@ -1311,11 +1308,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { if (chipset.majorVersion != 11 && chipset.majorVersion != 12) return op->emitOpError("WMMA only supported on gfx11 and gfx12"); - // The WMMA operations represent vectors of bf16s as vectors of i16s, so we - // need to bitcast bfloats to i16 and then bitcast them back. + bool isGFX1250 = chipset >= kGfx1250; + + // The WMMA operations represent vectors of bf16s as vectors of i16s + // (except on gfx1250), so we need to bitcast bfloats to i16 and then + // bitcast them back. + auto aType = cast<VectorType>(adaptor.getSourceA().getType()); + auto bType = cast<VectorType>(adaptor.getSourceB().getType()); + auto destCType = cast<VectorType>(adaptor.getDestC().getType()); + bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250; + bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250; + bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250; + bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250; VectorType rawOutType = outType; - if (outType.getElementType().isBF16()) + if (castOutToI16) rawOutType = outType.clone(rewriter.getI16Type()); + Value a = adaptor.getSourceA(); + if (castAToI16) + a = LLVM::BitcastOp::create(rewriter, loc, + aType.clone(rewriter.getI16Type()), a); + Value b = adaptor.getSourceB(); + if (castBToI16) + b = LLVM::BitcastOp::create(rewriter, loc, + bType.clone(rewriter.getI16Type()), b); + Value destC = adaptor.getDestC(); + if (castDestCToI16) + destC = LLVM::BitcastOp::create( + rewriter, loc, destCType.clone(rewriter.getI16Type()), destC); std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); @@ -1325,18 +1344,20 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0) return op.emitOpError("subwordOffset not supported on gfx12+"); - OperationState loweredOp(loc, *maybeIntrinsic); - loweredOp.addTypes(rawOutType); - SmallVector<Value, 4> operands; - wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), - adaptor.getSourceA(), op.getSourceA(), operands); - wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), - adaptor.getSourceB(), op.getSourceB(), operands); - wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), - op.getSubwordOffset(), op.getClamp(), operands); + SmallVector<NamedAttribute, 4> attrs; + wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a, + op.getSourceA(), operands, attrs, "signA"); + wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b, + op.getSourceB(), operands, attrs, "signB"); + wmmaPushOutputOperand(rewriter, loc, typeConverter, destC, + op.getSubwordOffset(), op.getClamp(), operands, + attrs); + OperationState loweredOp(loc, *maybeIntrinsic); + loweredOp.addTypes(rawOutType); loweredOp.addOperands(operands); + loweredOp.addAttributes(attrs); Operation *lowered = rewriter.create(loweredOp); Operation *maybeCastBack = lowered; @@ -1492,6 +1513,20 @@ struct ExtPackedFp8OpLowering final ConversionPatternRewriter &rewriter) const override; }; +struct ScaledExtPackedMatrixOpLowering final + : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> { + ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter, + Chipset chipset) + : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter), + chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(ScaledExtPackedMatrixOp op, + ScaledExtPackedMatrixOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + struct PackedTrunc2xFp8OpLowering final : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> { PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter, @@ -1600,6 +1635,173 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( return success(); } +int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf, + int32_t firstScaleByte) { + // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f* + // operations, the attributes blockSize, sourceType, scaleWaveHalf, and + // firstScaleByte are merged into a single attribute scaleSel. This is how + // those values are merged together. (Note: scaleWaveHalf isn't a high-level + // attribute but is derifed from firstScaleLane). + assert(llvm::is_contained({16, 32}, blockSize)); + assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth)); + + const bool isFp8 = bitWidth == 8; + const bool isBlock16 = blockSize == 16; + + if (!isFp8) { + int32_t bit0 = isBlock16; + assert(llvm::is_contained({0, 1, 2}, firstScaleByte)); + int32_t bit1 = (firstScaleByte == 2) << 1; + assert(llvm::is_contained({0, 1}, scaleWaveHalf)); + int32_t bit2 = scaleWaveHalf << 2; + return bit2 | bit1 | bit0; + } + + int32_t bit0 = isBlock16; + // firstScaleByte is guaranteed to be defined by two bits. + assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte)); + int32_t bits2and1 = firstScaleByte << 1; + assert(llvm::is_contained({0, 1}, scaleWaveHalf)); + int32_t bit3 = scaleWaveHalf << 3; + int32_t bits = bit3 | bits2and1 | bit0; + // These are invalid cases. + assert(!llvm::is_contained( + {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits)); + return bits; +} + +static std::optional<StringRef> +scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) { + using fp4 = Float4E2M1FNType; + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + using fp6 = Float6E2M3FNType; + using bf6 = Float6E3M2FNType; + if (isa<fp4>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName(); + return std::nullopt; + } + if (isa<fp8>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName(); + return std::nullopt; + } + if (isa<bf8>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName(); + return std::nullopt; + } + if (isa<fp6>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName(); + return std::nullopt; + } + if (isa<bf6>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName(); + return std::nullopt; + } + llvm_unreachable("invalid combination of element types for packed conversion " + "instructions"); +} + +LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite( + ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + using fp4 = Float4E2M1FNType; + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + using fp6 = Float6E2M3FNType; + using bf6 = Float6E3M2FNType; + Location loc = op.getLoc(); + if (chipset != kGfx1250) { + return rewriter.notifyMatchFailure( + loc, + "Scaled fp packed conversion instructions are not available on target " + "architecture and their emulation is not implemented"); + } + // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that + // is being selected. + int32_t scaleWaveHalf = op.getFirstScaleLane() / 16; + int32_t firstScaleByte = op.getFirstScaleByte(); + int32_t blockSize = op.getBlockSize(); + auto sourceType = cast<VectorType>(op.getSource().getType()); + auto srcElemType = cast<FloatType>(sourceType.getElementType()); + unsigned bitWidth = srcElemType.getWidth(); + + auto targetType = cast<VectorType>(op.getResult().getType()); + auto destElemType = cast<FloatType>(targetType.getElementType()); + + IntegerType i32 = rewriter.getI32Type(); + Value source = adaptor.getSource(); + Type llvmResultType = typeConverter->convertType(op.getResult().getType()); + Type packedType = nullptr; + if (isa<fp4>(srcElemType)) { + packedType = i32; + packedType = getTypeConverter()->convertType(packedType); + } else if (isa<fp8, bf8>(srcElemType)) { + packedType = VectorType::get(2, i32); + packedType = getTypeConverter()->convertType(packedType); + } else if (isa<fp6, bf6>(srcElemType)) { + packedType = VectorType::get(3, i32); + packedType = getTypeConverter()->convertType(packedType); + } else { + llvm_unreachable("invalid element type for packed scaled ext"); + } + + if (!packedType || !llvmResultType) { + return rewriter.notifyMatchFailure(op, "type conversion failed"); + } + + std::optional<StringRef> maybeIntrinsic = + scaledExtPacked816ToIntrinsic(srcElemType, destElemType); + if (!maybeIntrinsic.has_value()) + return op.emitOpError( + "no intrinsic matching packed scaled conversion on the given chipset"); + + int32_t scaleSel = + getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte); + Value castedScale = + LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale()); + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); + + OperationState loweredOp(loc, *maybeIntrinsic); + loweredOp.addTypes({llvmResultType}); + loweredOp.addOperands({castedSource, castedScale}); + + SmallVector<NamedAttribute, 1> attrs; + attrs.push_back( + NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel))); + + loweredOp.addAttributes(attrs); + Operation *lowered = rewriter.create(loweredOp); + rewriter.replaceOp(op, lowered); + + return success(); +} + LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -2073,6 +2275,441 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> { } }; +struct AMDGPUMakeDmaBaseLowering + : public ConvertOpToLLVMPattern<MakeDmaBaseOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset < kGfx1250) + return op->emitOpError("make_dma_base is only supported on gfx1250"); + + Location loc = op.getLoc(); + + ValueRange ldsIndices = adaptor.getLdsIndices(); + Value lds = adaptor.getLds(); + auto ldsMemRefType = cast<MemRefType>(op.getLds().getType()); + + Value ldsPtr = + getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices); + + ValueRange globalIndices = adaptor.getGlobalIndices(); + Value global = adaptor.getGlobal(); + auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType()); + + Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType, + global, globalIndices); + + Type i32 = rewriter.getI32Type(); + Type i64 = rewriter.getI64Type(); + + Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr); + Value castForGlobalAddr = + LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr); + + Value lowHalf = + LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr); + + Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr, + createI64Constant(rewriter, loc, 32)); + + Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift); + + Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1); + Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask); + + Value typeField = createI32Constant(rewriter, loc, 2 << 30); + Value highHalfPlusType = + LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField); + + Value c0 = createI32Constant(rewriter, loc, 0); + Value c1 = createI32Constant(rewriter, loc, 1); + Value c2 = createI32Constant(rewriter, loc, 2); + Value c3 = createI32Constant(rewriter, loc, 3); + + Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32)); + assert(v4i32 && "expected type conversion to succeed"); + Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32); + result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0); + result = LLVM::InsertElementOp::create(rewriter, loc, result, + castForLdsAddr, c1); + result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2); + result = LLVM::InsertElementOp::create(rewriter, loc, result, + highHalfPlusType, c3); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct AMDGPUMakeDmaDescriptorLowering + : public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AMDGPUMakeDmaDescriptorLowering(const LLVMTypeConverter &converter, + Chipset chipset) + : ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter), + chipset(chipset) {} + Chipset chipset; + + Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); } + + Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc, + Value accumulator, Value value, int64_t shift) const { + shift = shift % 32; + Value shiftAmount; + if (shift != 0) { + shiftAmount = createI32Constant(rewriter, loc, shift % 32); + value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount); + } + + if (matchPattern(accumulator, mlir::m_Zero())) + return value; + + return LLVM::OrOp::create(rewriter, loc, accumulator, value); + } + + Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0) const { + Value mask = op.getWorkgroupMask(); + if (!mask) + return sgpr0; + + Type i32 = rewriter.getI32Type(); + Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask); + return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0); + } + + Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + // Compute data_size. + unsigned elementTypeWidthInBits = op.getElementTypeWidth(); + assert( + llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidthInBits) && + "expected type width to be 8, 16, 32, or 64."); + int64_t dataSize = llvm::Log2_32(elementTypeWidthInBits / 8); + Value size = createI32Constant(rewriter, loc, dataSize); + return setValueAtOffset(rewriter, loc, sgpr0, size, 16); + } + + Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr; + if (!atomic_barrier_enable) + return sgpr0; + + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18); + } + + Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool iterate_enable = adaptor.getGlobalIncrement() != nullptr; + if (!iterate_enable) + return sgpr0; + + // TODO: In future PR, add other required fields for iteration. + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19); + } + + Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool pad_enable = op.getPadAmount() != nullptr; + if (!pad_enable) + return sgpr0; + + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20); + } + + Value setEarlyTimeout(MakeDmaDescriptorOp op, OpAdaptor adaptorm, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + if (!op.getWorkgroupMask()) + return sgpr0; + + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21); + } + + Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool pad_enable = op.getPadAmount() != nullptr; + if (!pad_enable) + return sgpr0; + + IntegerType i32 = rewriter.getI32Type(); + Value padInterval = adaptor.getPadInterval(); + // pre-condition: padInterval can be a power of two between 2 and 256. + padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32, + padInterval, false); + padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]); + // post-condition: padInterval can be a value between 0 and 7. + return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22); + } + + Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool pad_enable = op.getPadAmount() != nullptr; + if (!pad_enable) + return sgpr0; + + Value padAmount = adaptor.getPadAmount(); + // pre-condition: padAmount is a value between 1-128. + padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]); + // post-condition: padAmount is a value between 0-127. + return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25); + } + + Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Location loc, Value sgpr1, + ArrayRef<Value> consts) const { + bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr; + if (!atomic_barrier_enable) + return sgpr1; + + Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress(); + auto barrierAddressTy = + cast<MemRefType>(op.getAtomicBarrierAddress().getType()); + ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices(); + atomicBarrierAddress = + getStridedElementPtr(rewriter, loc, barrierAddressTy, + atomicBarrierAddress, atomicBarrierIndices); + IntegerType i32 = rewriter.getI32Type(); + // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies + // that the 3 LSBs are zero. + atomicBarrierAddress = + LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress); + atomicBarrierAddress = + LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]); + Value mask = createI32Constant(rewriter, loc, 0xFFFF); + atomicBarrierAddress = + LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask); + return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32); + } + + std::pair<Value, Value> setTensorDim0(MakeDmaDescriptorOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Location loc, Value sgpr1, Value sgpr2, + ArrayRef<Value> consts) const { + SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes(); + OpFoldResult tensorDim0OpFoldResult = mixedGlobalSizes.back(); + Value tensorDim0; + if (auto attr = dyn_cast<Attribute>(tensorDim0OpFoldResult)) + tensorDim0 = + createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tensorDim0 = cast<Value>(tensorDim0OpFoldResult); + + Value c16 = createI32Constant(rewriter, loc, 16); + Value tensorDim0High = LLVM::LShrOp::create(rewriter, loc, tensorDim0, c16); + sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDim0, 48); + sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim0High, 48 + 16); + return {sgpr1, sgpr2}; + } + + std::pair<Value, Value> setTensorDim1(MakeDmaDescriptorOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Location loc, Value sgpr2, Value sgpr3, + ArrayRef<Value> consts) const { + // TODO: Generalize to setTensorDimX. + SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes(); + OpFoldResult tensorDim1OpFoldResult = *(mixedGlobalSizes.rbegin() + 1); + Value tensorDim1; + if (auto attr = dyn_cast<Attribute>(tensorDim1OpFoldResult)) + tensorDim1 = + createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tensorDim1 = cast<Value>(tensorDim1OpFoldResult); + + Value c16 = createI32Constant(rewriter, loc, 16); + Value tensorDim1High = LLVM::LShrOp::create(rewriter, loc, tensorDim1, c16); + sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim1, 80); + sgpr3 = setValueAtOffset(rewriter, loc, sgpr3, tensorDim1High, 80 + 16); + return {sgpr2, sgpr3}; + } + + Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr, ArrayRef<Value> consts, size_t dimX, + int64_t offset) const { + SmallVector<OpFoldResult> mixedSharedSizes = op.getMixedSharedSizes(); + + if (mixedSharedSizes.size() <= dimX) + return sgpr; + + OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX); + Value tileDimX; + if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) + tileDimX = + createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tileDimX = cast<Value>(tileDimXOpFoldResult); + + return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset); + } + + Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr3, ArrayRef<Value> consts) const { + return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112); + } + + Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr4, ArrayRef<Value> consts) const { + return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128); + } + + Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr4, ArrayRef<Value> consts) const { + return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144); + } + + std::pair<Value, Value> + setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgprY, Value sgprZ, ArrayRef<Value> consts, + size_t dimX, int64_t offset) const { + SmallVector<OpFoldResult> mixedGlobalStrides = op.getMixedGlobalStrides(); + + if (mixedGlobalStrides.size() <= dimX) + return {sgprY, sgprZ}; + + OpFoldResult tensorDimXStrideOpFoldResult = + *(mixedGlobalStrides.rbegin() + dimX); + Value tensorDimXStride; + if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult)) + tensorDimXStride = + createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult); + + constexpr int64_t first48bits = (1ll << 48) - 1; + Value mask = createI64Constant(rewriter, loc, first48bits); + tensorDimXStride = + LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride); + IntegerType i32 = rewriter.getI32Type(); + Value tensorDimXStrideLow = + LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride); + + int64_t shift = (offset % 32) == 0 ? 32 : offset % 32; + Value shiftVal = createI64Constant(rewriter, loc, shift); + Value tensorDimXStrideHigh = + LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal); + tensorDimXStrideHigh = + LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh); + + sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset); + sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh, + offset + shift); + return {sgprY, sgprZ}; + } + + std::pair<Value, Value> + setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const { + return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts, + 0, 160); + } + + std::pair<Value, Value> + setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const { + return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts, + 1, 208); + } + + Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef<Value> consts) const { + Value sgprs[8]; + for (int64_t i = 0; i < 8; i++) { + sgprs[i] = consts[0]; + } + + sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]); + sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts); + + sgprs[1] = + setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts); + std::tie(sgprs[1], sgprs[2]) = + setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts); + std::tie(sgprs[2], sgprs[3]) = + setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts); + + sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts); + sgprs[4] = setTileDim1(op, adaptor, rewriter, loc, sgprs[4], consts); + sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts); + std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride( + op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts); + std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride( + op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts); + + IntegerType i32 = rewriter.getI32Type(); + Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32)); + assert(v8i32 && "expected type conversion to succeed"); + Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32); + + for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) { + dgroup1 = + LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant); + } + + return dgroup1; + } + + LogicalResult + matchAndRewrite(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset < kGfx1250) + return op->emitOpError( + "make_dma_descriptor is only supported on gfx1250"); + + if (op.getRank() > 2) + return op->emitOpError("unimplemented"); + + Location loc = op.getLoc(); + + IntegerType i32 = rewriter.getI32Type(); + [[maybe_unused]] Type v4i32 = + this->typeConverter->convertType(VectorType::get(4, i32)); + assert(v4i32 && "expected type conversion to succeed"); + + SmallVector<Value> consts; + for (int64_t i = 0; i < 8; i++) + consts.push_back(createI32Constant(rewriter, loc, i)); + + Value dgroup0 = this->getDGroup0(adaptor); + Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts); + + SmallVector<Value> results = {dgroup0, dgroup1}; + rewriter.replaceOpWithMultiple(op, {results}); + return success(); + } +}; + struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> { using Base::Base; @@ -2087,6 +2724,11 @@ struct ConvertAMDGPUToROCDLPass RewritePatternSet patterns(ctx); LLVMTypeConverter converter(ctx); + converter.addConversion([&](TDMBaseType type) -> Type { + Type i32 = IntegerType::get(type.getContext(), 32); + return converter.convertType(VectorType::get(4, i32)); + }); + populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); LLVMConversionTarget target(getContext()); target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); @@ -2122,25 +2764,27 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, Chipset chipset) { populateAMDGPUMemorySpaceAttributeConversions(converter); - patterns - .add<FatRawBufferCastLowering, - RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>, - RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>, - RawBufferOpLowering<RawBufferAtomicFaddOp, - ROCDL::RawPtrBufferAtomicFaddOp>, - RawBufferOpLowering<RawBufferAtomicFmaxOp, - ROCDL::RawPtrBufferAtomicFmaxOp>, - RawBufferOpLowering<RawBufferAtomicSmaxOp, - ROCDL::RawPtrBufferAtomicSmaxOp>, - RawBufferOpLowering<RawBufferAtomicUminOp, - ROCDL::RawPtrBufferAtomicUminOp>, - RawBufferOpLowering<RawBufferAtomicCmpswapOp, - ROCDL::RawPtrBufferAtomicCmpSwap>, - AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, - SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, - WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, - PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, - PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, - TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset); + patterns.add< + FatRawBufferCastLowering, + RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>, + RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>, + RawBufferOpLowering<RawBufferAtomicFaddOp, + ROCDL::RawPtrBufferAtomicFaddOp>, + RawBufferOpLowering<RawBufferAtomicFmaxOp, + ROCDL::RawPtrBufferAtomicFmaxOp>, + RawBufferOpLowering<RawBufferAtomicSmaxOp, + ROCDL::RawPtrBufferAtomicSmaxOp>, + RawBufferOpLowering<RawBufferAtomicUminOp, + ROCDL::RawPtrBufferAtomicUminOp>, + RawBufferOpLowering<RawBufferAtomicCmpswapOp, + ROCDL::RawPtrBufferAtomicCmpSwap>, + AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, + SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, + WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering, + ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, + PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, + GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering, + AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter, + chipset); patterns.add<AMDGPUSwizzleBitModeLowering>(converter); } diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp new file mode 100644 index 0000000..79816fc --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -0,0 +1,665 @@ +//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Utils/Utils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::func; + +static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, + StringRef name, FunctionType funcT, bool setPrivate, + SymbolTableCollection *symbolTables = nullptr) { + OpBuilder::InsertionGuard g(b); + assert(!symTable->getRegion(0).empty() && "expected non-empty region"); + b.setInsertionPointToStart(&symTable->getRegion(0).front()); + FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT); + if (setPrivate) + funcOp.setPrivate(); + if (symbolTables) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable); + symbolTable.insert(funcOp, symTable->getRegion(0).front().begin()); + } + return funcOp; +} + +/// Helper function to look up or create the symbol for a runtime library +/// function with the given parameter types. Returns an int64_t, unless a +/// different result type is specified. +static FailureOr<FuncOp> +lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable, + StringRef name, TypeRange paramTypes, + SymbolTableCollection *symbolTables = nullptr, + Type resultType = {}) { + if (!resultType) + resultType = IntegerType::get(symTable->getContext(), 64); + std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); + auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType}); + FailureOr<FuncOp> func = + lookupFnDecl(symTable, funcName, funcT, symbolTables); + // Failed due to type mismatch. + if (failed(func)) + return func; + // Successfully matched existing decl. + if (*func) + return *func; + + return createFnDecl(b, symTable, funcName, funcT, + /*setPrivate=*/true, symbolTables); +} + +/// Helper function to look up or create the symbol for a runtime library +/// function for a binary arithmetic operation. +/// +/// Parameter 1: APFloat semantics +/// Parameter 2: Left-hand side operand +/// Parameter 3: Right-hand side operand +/// +/// This function will return a failure if the function is found but has an +/// unexpected signature. +/// +static FailureOr<FuncOp> +lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, + SymbolTableCollection *symbolTables = nullptr) { + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type}, + symbolTables); +} + +static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) { + int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + return arith::ConstantOp::create(b, loc, b.getI32Type(), + b.getIntegerAttr(b.getI32Type(), sem)); +} + +/// Given two operands of vector type and vector result type (with the same +/// shape), call the given function for each pair of scalar operands and +/// package the result into a vector. If the given operands and result type are +/// not vectors, call the function directly. The second operand is optional. +template <typename Fn, typename... Values> +static Value forEachScalarValue(RewriterBase &rewriter, Location loc, + Value operand1, Value operand2, Type resultType, + Fn fn) { + auto vecTy1 = dyn_cast<VectorType>(operand1.getType()); + if (operand2) { + // Sanity check: Operand types must match. + assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) && + "expected same vector types"); + } + if (!vecTy1) { + // Not a vector. Call the function directly. + return fn(operand1, operand2, resultType); + } + + // Prepare scalar operands. + ResultRange sclars1 = + vector::ToElementsOp::create(rewriter, loc, operand1)->getResults(); + SmallVector<Value> scalars2; + if (!operand2) { + // No second operand. Create a vector of empty values. + scalars2.assign(vecTy1.getNumElements(), Value()); + } else { + llvm::append_range( + scalars2, + vector::ToElementsOp::create(rewriter, loc, operand2)->getResults()); + } + + // Call the function for each pair of scalar operands. + auto resultVecType = cast<VectorType>(resultType); + SmallVector<Value> results; + for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) { + Value result = fn(scalar1, scalar2, resultVecType.getElementType()); + results.push_back(result); + } + + // Package the results into a vector. + return vector::FromElementsOp::create( + rewriter, loc, + vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()), + results); +} + +/// Check preconditions for the conversion: +/// 1. All operands / results must be integers or floats (or vectors thereof). +/// 2. The bitwidth of the operands / results must be <= 64. +static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) { + for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) { + Type type = value.getType(); + if (auto vecTy = dyn_cast<VectorType>(type)) { + type = vecTy.getElementType(); + } + if (!type.isIntOrFloat()) { + return rewriter.notifyMatchFailure( + op, "only integers and floats (or vectors thereof) are supported"); + } + if (type.getIntOrFloatBitWidth() > 64) + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); + } + return success(); +} + +/// Rewrite a binary arithmetic operation to an APFloat function call. +template <typename OpTy> +struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> { + BinaryArithOpToAPFloatConversion(MLIRContext *context, + const char *APFloatName, + SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable), + APFloatName(APFloatName) {}; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + FailureOr<FuncOp> fn = + lookupOrCreateBinaryFn(rewriter, symTable, APFloatName); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getLhs(), op.getRhs(), op.getType(), + [&](Value lhs, Value rhs, Type resultType) { + // Cast operands to 64-bit integers. + auto floatTy = cast<FloatType>(resultType); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + auto int64Type = rewriter.getI64Type(); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, lhs)); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, rhs)); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, lhsBits, rhsBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, + resultOp->getResult(0)); + return arith::BitcastOp::create(rewriter, loc, floatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; + const char *APFloatName; +}; + +template <typename OpTy> +struct FpToFpConversion final : OpRewritePattern<OpTy> { + FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = lookupOrCreateApFloatFn( + rewriter, symTable, "convert", {i32Type, i32Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto inFloatTy = cast<FloatType>(operand1.getType()); + auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, inIntWType, operand1)); + + // Call APFloat function. + Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy); + auto outFloatTy = cast<FloatType>(resultType); + Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy); + std::array<Value, 3> params = {inSemValue, outSemValue, operandBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth()); + Value truncatedBits = arith::TruncIOp::create( + rewriter, loc, outIntWType, resultOp->getResult(0)); + return arith::BitcastOp::create(rewriter, loc, outFloatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; +}; + +template <typename OpTy> +struct FpToIntConversion final : OpRewritePattern<OpTy> { + FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable, + bool isUnsigned, PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable), + isUnsigned(isUnsigned) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int", + {i32Type, i32Type, i1Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto inFloatTy = cast<FloatType>(operand1.getType()); + auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, inIntWType, operand1)); + + // Call APFloat function. + Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy); + auto outIntTy = cast<IntegerType>(resultType); + Value outWidthValue = arith::ConstantOp::create( + rewriter, loc, i32Type, + rewriter.getIntegerAttr(i32Type, outIntTy.getWidth())); + Value isUnsignedValue = arith::ConstantOp::create( + rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, isUnsigned)); + SmallVector<Value> params = {inSemValue, outWidthValue, + isUnsignedValue, operandBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + return arith::TruncIOp::create(rewriter, loc, outIntTy, + resultOp->getResult(0)); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; + bool isUnsigned; +}; + +template <typename OpTy> +struct IntToFpConversion final : OpRewritePattern<OpTy> { + IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable, + bool isUnsigned, PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable), + isUnsigned(isUnsigned) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int", + {i32Type, i32Type, i1Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto inIntTy = cast<IntegerType>(operand1.getType()); + Value operandBits = operand1; + if (operandBits.getType().getIntOrFloatBitWidth() < 64) { + if (isUnsigned) { + operandBits = + arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits); + } else { + operandBits = + arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits); + } + } + + // Call APFloat function. + auto outFloatTy = cast<FloatType>(resultType); + Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy); + Value inWidthValue = arith::ConstantOp::create( + rewriter, loc, i32Type, + rewriter.getIntegerAttr(i32Type, inIntTy.getWidth())); + Value isUnsignedValue = arith::ConstantOp::create( + rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, isUnsigned)); + SmallVector<Value> params = {outSemValue, inWidthValue, + isUnsignedValue, operandBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth()); + Value truncatedBits = arith::TruncIOp::create( + rewriter, loc, outIntWType, resultOp->getResult(0)); + return arith::BitcastOp::create(rewriter, loc, outFloatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; + bool isUnsigned; +}; + +struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> { + CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(arith::CmpFOp op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i8Type = IntegerType::get(symTable->getContext(), 8); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "compare", + {i32Type, i64Type, i64Type}, nullptr, i8Type); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getLhs(), op.getRhs(), op.getType(), + [&](Value lhs, Value rhs, Type resultType) { + // Cast operands to 64-bit integers. + auto floatTy = cast<FloatType>(lhs.getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, lhs)); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, rhs)); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, lhsBits, rhsBits}; + Value comparisonResult = + func::CallOp::create(rewriter, loc, TypeRange(i8Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Generate an i1 SSA value that is "true" if the comparison result + // matches the given `val`. + auto checkResult = [&](llvm::APFloat::cmpResult val) { + return arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, comparisonResult, + arith::ConstantOp::create( + rewriter, loc, i8Type, + rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val))) + .getResult()); + }; + // Generate an i1 SSA value that is "true" if the comparison result + // matches any of the given `vals`. + std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> + checkResults = [&](ArrayRef<llvm::APFloat::cmpResult> vals) { + Value first = checkResult(vals.front()); + if (vals.size() == 1) + return first; + Value rest = checkResults(vals.drop_front()); + return arith::OrIOp::create(rewriter, loc, first, rest) + .getResult(); + }; + + // This switch-case statement was taken from arith::applyCmpPredicate. + Value result; + switch (op.getPredicate()) { + case arith::CmpFPredicate::AlwaysFalse: + result = + arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 0)) + .getResult(); + break; + case arith::CmpFPredicate::OEQ: + result = checkResult(llvm::APFloat::cmpEqual); + break; + case arith::CmpFPredicate::OGT: + result = checkResult(llvm::APFloat::cmpGreaterThan); + break; + case arith::CmpFPredicate::OGE: + result = checkResults( + {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::OLT: + result = checkResult(llvm::APFloat::cmpLessThan); + break; + case arith::CmpFPredicate::OLE: + result = checkResults( + {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ONE: + // Not cmpUnordered and not cmpUnordered. + result = checkResults( + {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::ORD: + // Not cmpUnordered. + result = checkResults({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UEQ: + result = checkResults( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UGT: + result = checkResults( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::UGE: + result = checkResults({llvm::APFloat::cmpUnordered, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ULT: + result = checkResults( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan}); + break; + case arith::CmpFPredicate::ULE: + result = checkResults({llvm::APFloat::cmpUnordered, + llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UNE: + // Not cmpEqual. + result = checkResults({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpUnordered}); + break; + case arith::CmpFPredicate::UNO: + result = checkResult(llvm::APFloat::cmpUnordered); + break; + case arith::CmpFPredicate::AlwaysTrue: + result = + arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 1)) + .getResult(); + break; + } + return result; + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; +}; + +struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> { + NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(arith::NegFOp op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto floatTy = cast<FloatType>(operand1.getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, operand1)); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, operandBits}; + Value negatedBits = + func::CallOp::create(rewriter, loc, TypeRange(i64Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Truncate result to the original width. + Value truncatedBits = + arith::TruncIOp::create(rewriter, loc, intWType, negatedBits); + return arith::BitcastOp::create(rewriter, loc, floatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; +}; + +namespace { +struct ArithToAPFloatConversionPass final + : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> { + using Base::Base; + + void runOnOperation() override; +}; + +void ArithToAPFloatConversionPass::runOnOperation() { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context, "add", + getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>( + context, "subtract", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>( + context, "multiply", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>( + context, "divide", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>( + context, "remainder", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>( + context, "minnum", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>( + context, "maxnum", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>( + context, "minimum", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>( + context, "maximum", getOperation()); + patterns + .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>, + CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>( + context, getOperation()); + patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(), + /*isUnsigned=*/false); + patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(), + /*isUnsigned=*/true); + patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(), + /*isUnsigned=*/false); + patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(), + /*isUnsigned=*/true); + LogicalResult result = success(); + ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { + if (diag.getSeverity() == DiagnosticSeverity::Error) { + result = failure(); + } + // NB: if you don't return failure, no other diag handlers will fire (see + // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit). + return failure(); + }); + walkAndApplyPatterns(getOperation(), std::move(patterns)); + if (failed(result)) + return signalPassFailure(); +} +} // namespace diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt new file mode 100644 index 0000000..31fce7a --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRArithToAPFloat + ArithToAPFloat.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRArithTransforms + MLIRFuncDialect + MLIRFuncUtils + MLIRVectorDialect + ) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index b609990..220826d 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" @@ -280,6 +281,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), op->getAttrs(), + /*propAttr=*/Attribute{}, *getTypeConverter(), rewriter); } @@ -481,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, LogicalResult CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(), + op.getLhs().getType())) + return rewriter.notifyMatchFailure(op, "unsupported floating point type"); + Type operandType = adaptor.getLhs().getType(); Type resultType = op.getResult().getType(); LLVM::FastmathFlags fmf = diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index bebf1b8..613dc6d 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard) add_subdirectory(AMDGPUToROCDL) add_subdirectory(ArithCommon) add_subdirectory(ArithToAMDGPU) +add_subdirectory(ArithToAPFloat) add_subdirectory(ArithToArmSME) add_subdirectory(ArithToEmitC) add_subdirectory(ArithToLLVM) diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index 86d02e6..6a0c211 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -96,7 +96,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> { ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite( op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), - op->getAttrs(), *getTypeConverter(), rewriter); + op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(), + rewriter); } }; diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 93fe2ed..2220f61 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -374,9 +374,12 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp( // Create a memory effect attribute corresponding to readnone. if (funcOp->hasAttr(readnoneAttrName)) { auto memoryAttr = LLVM::MemoryEffectsAttr::get( - rewriter.getContext(), - {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef, - LLVM::ModRefInfo::NoModRef}); + rewriter.getContext(), {/*other=*/LLVM::ModRefInfo::NoModRef, + /*argMem=*/LLVM::ModRefInfo::NoModRef, + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef}); newFuncOp.setMemoryEffectsAttr(memoryAttr); } diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index 425594b..f143a9e 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -66,7 +66,10 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef; auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/noModRef, - /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef, + /*errnoMem=*/noModRef, + /*targetMem0=*/noModRef, + /*targetMem1=*/noModRef); func.setMemoryEffectsAttr(memAttr); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index d64c4d6..5848489 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -419,7 +419,10 @@ struct LowerGpuOpsToNVVMOpsPass final if (this->hasRedux) populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns); configureGpuToNVVMConversionLegality(target); - if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed( + applyPartialConversion(m, target, std::move(llvmPatterns), config))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 99c059c..6254de8 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" using namespace mlir; @@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) { if (type.getElementType().isF32()) return type.getOperand() == "COp" ? NVVM::MMATypes::f32 : NVVM::MMATypes::tf32; - + if (type.getElementType().isF64()) + return NVVM::MMATypes::f64; if (type.getElementType().isSignedInteger(8)) return NVVM::MMATypes::s8; if (type.getElementType().isUnsignedInteger(8)) @@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering // then passed on to the intrinsic call. Emit llvm ops to extract individual // values form lowered memrefs. SmallVector<Value> unpackedOps; - auto unpackOp = [&](Value operand) { + // f64 a and b fragments are not structs but scalars. + if (!isa<LLVM::LLVMStructType>(operand.getType())) { + unpackedOps.push_back(operand); + return; + } + // every other type is lowered to an LLVM struct, extract the values. auto structType = cast<LLVM::LLVMStructType>(operand.getType()); for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i); @@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering return failure(); Location loc = subgroupMmaConstantOp.getLoc(); Value cst = adaptor.getOperands()[0]; - LLVM::LLVMStructType type = convertMMAToLLVMType( + Type type = convertMMAToLLVMType( cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType())); + // If the element is not a struct, it means it's a scalar f64. + auto structType = dyn_cast<LLVM::LLVMStructType>(type); + if (!structType) { + rewriter.replaceOp(subgroupMmaConstantOp, cst); + return success(); + } // If the element type is a vector create a vector from the operand. - if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) { + if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) { Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType); for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { Value idx = LLVM::ConstantOp::create(rewriter, loc, @@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering } cst = vecCst; } - Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type); - for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType); + for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) { matrixStruct = LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i); } @@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering return failure(); Location loc = subgroupMmaElementwiseOp.getLoc(); size_t numOperands = adaptor.getOperands().size(); - LLVM::LLVMStructType destType = convertMMAToLLVMType( + Type destType = convertMMAToLLVMType( cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType())); - Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType); - for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { + + // If the element is not a struct, it means it's a scalar f64. + LLVM::LLVMStructType structDestTy = + dyn_cast<LLVM::LLVMStructType>(destType); + if (!structDestTy) { + SmallVector<Value> operands; + for (auto operand : adaptor.getOperands()) { + operands.push_back(operand); + } + Value element = createScalarOp( + rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands); + rewriter.replaceOp(subgroupMmaElementwiseOp, element); + return success(); + } + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy); + for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) { SmallVector<Value> extractedOperands; for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { extractedOperands.push_back(LLVM::ExtractValueOp::create( @@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering } // namespace /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. -LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { +Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { NVVM::MMAFrag frag = convertOperand(type.getOperand()); NVVM::MMATypes eltType = getElementType(type); auto nRow = type.getShape()[0]; auto nCol = type.getShape()[1]; std::pair<Type, unsigned> typeInfo = NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext()); + // Special handling for f64 a and b fragments + Type f64Ty = Float64Type::get(type.getContext()); + if (typeInfo.first == f64Ty && typeInfo.second == 1) { + return f64Ty; + } return LLVM::LLVMStructType::getLiteral( type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); } diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp index bc2f2f2..d4b4c46 100644 --- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp +++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp @@ -107,16 +107,16 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); - Type n_type = n.getType(); + Type nType = n.getType(); Value m = adaptor.getRhs(); // Define the constants - Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 0)); - Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 1)); - Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, -1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 0)); + Value posOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 1)); + Value negOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, -1)); // Compute `x`. Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero); @@ -157,14 +157,14 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); - Type n_type = n.getType(); + Type nType = n.getType(); Value m = adaptor.getRhs(); // Define the constants - Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 0)); - Value one = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 0)); + Value one = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 1)); // Compute the non-zero result. Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one); @@ -193,16 +193,16 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); - Type n_type = n.getType(); + Type nType = n.getType(); Value m = adaptor.getRhs(); // Define the constants - Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 0)); - Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 1)); - Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, -1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 0)); + Value posOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 1)); + Value negOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, -1)); // Compute `x`. Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 48a0319..f28a6cc 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -296,19 +296,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Detail methods //===----------------------------------------------------------------------===// -void LLVM::detail::setNativeProperties(Operation *op, - IntegerOverflowFlags overflowFlags) { - if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) - iface.setOverflowFlags(overflowFlags); -} - /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. LogicalResult LLVM::detail::oneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags) { + ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); SmallVector<Type> resultTypes; @@ -320,11 +314,10 @@ LogicalResult LLVM::detail::oneToOneRewrite( } // Create the operation through state since we don't know its C++ type. - Operation *newOp = - rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - resultTypes, targetAttrs); - - setNativeProperties(newOp, overflowFlags); + OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), operands, + resultTypes, targetAttrs); + state.propertiesAttr = propertiesAttr; + Operation *newOp = rewriter.create(state); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index e7dd0b5..e5969c2 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -105,9 +105,9 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( LogicalResult LLVM::detail::vectorOneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags) { + ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. @@ -116,18 +116,38 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( auto llvmNDVectorTy = operands[0].getType(); if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) - return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, - rewriter, overflowFlags); - - auto callback = [op, targetOp, targetAttrs, overflowFlags, + return oneToOneRewrite(op, targetOp, operands, targetAttrs, propertiesAttr, + typeConverter, rewriter); + auto callback = [op, targetOp, targetAttrs, propertiesAttr, &rewriter](Type llvm1DVectorTy, ValueRange operands) { - Operation *newOp = - rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), - operands, llvm1DVectorTy, targetAttrs); - LLVM::detail::setNativeProperties(newOp, overflowFlags); + OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), + operands, llvm1DVectorTy, targetAttrs); + state.propertiesAttr = propertiesAttr; + Operation *newOp = rewriter.create(state); return newOp->getResult(0); }; return handleMultidimensionalVectors(op, operands, typeConverter, callback, rewriter); } + +/// Return the given type if it's a floating point type. If the given type is +/// a vector type, return its element type if it's a floating point type. +static FloatType getFloatingPointType(Type type) { + if (auto floatType = dyn_cast<FloatType>(type)) + return floatType; + if (auto vecType = dyn_cast<VectorType>(type)) + return dyn_cast<FloatType>(vecType.getElementType()); + return nullptr; +} + +bool LLVM::detail::isUnsupportedFloatingPointType( + const TypeConverter &typeConverter, Type type) { + FloatType floatType = getFloatingPointType(type); + if (!floatType) + return false; + Type convertedType = typeConverter.convertType(floatType); + if (!convertedType) + return true; + return !isa<FloatType>(convertedType); +} diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index 16ef11a..59a16df 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -93,13 +93,13 @@ public: /// Different MPI implementations have different communicator types. /// Using i64 as a portable, intermediate type. /// Appropriate cast needs to take place before calling MPI functions. - virtual Value getCommWorld(const Location loc, + virtual Value getCommWorld(Location loc, ConversionPatternRewriter &rewriter) = 0; /// Type converter provides i64 type for communicator type. /// Converts to native type, which might be ptr or int or whatever. - virtual Value castComm(const Location loc, - ConversionPatternRewriter &rewriter, Value comm) = 0; + virtual Value castComm(Location loc, ConversionPatternRewriter &rewriter, + Value comm) = 0; /// Get the MPI_STATUS_IGNORE value (typically a pointer type). virtual intptr_t getStatusIgnore() = 0; @@ -109,13 +109,12 @@ public: /// Gets or creates an MPI datatype as a value which corresponds to the given /// type. - virtual Value getDataType(const Location loc, - ConversionPatternRewriter &rewriter, Type type) = 0; + virtual Value getDataType(Location loc, ConversionPatternRewriter &rewriter, + Type type) = 0; /// Gets or creates an MPI_Op value which corresponds to the given /// enum value. - virtual Value getMPIOp(const Location loc, - ConversionPatternRewriter &rewriter, + virtual Value getMPIOp(Location loc, ConversionPatternRewriter &rewriter, mpi::MPI_ReductionOpEnum opAttr) = 0; }; diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 11f866c..0a382d8 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -122,7 +122,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, return totalSizeBytes.getResult(); } -static emitc::ApplyOp +static emitc::AddressOfOp createPointerFromEmitcArray(Location loc, OpBuilder &builder, TypedValue<emitc::ArrayType> arrayValue) { @@ -133,9 +133,9 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder, llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex); emitc::SubscriptOp subPtr = emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices)); - emitc::ApplyOp ptr = emitc::ApplyOp::create( + emitc::AddressOfOp ptr = emitc::AddressOfOp::create( builder, loc, emitc::PointerType::get(arrayType.getElementType()), - builder.getStringAttr("&"), subPtr); + subPtr); return ptr; } @@ -225,12 +225,12 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> { auto srcArrayValue = cast<TypedValue<emitc::ArrayType>>(operands.getSource()); - emitc::ApplyOp srcPtr = + emitc::AddressOfOp srcPtr = createPointerFromEmitcArray(loc, rewriter, srcArrayValue); auto targetArrayValue = cast<TypedValue<emitc::ArrayType>>(operands.getTarget()); - emitc::ApplyOp targetPtr = + emitc::AddressOfOp targetPtr = createPointerFromEmitcArray(loc, rewriter, targetArrayValue); emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create( @@ -319,8 +319,8 @@ struct ConvertGetGlobal final emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create( rewriter, op.getLoc(), lvalueType, operands.getNameAttr()); emitc::PointerType pointerType = emitc::PointerType::get(resultTy); - rewriter.replaceOpWithNewOp<emitc::ApplyOp>( - op, pointerType, rewriter.getStringAttr("&"), globalLValue); + rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType, + globalLValue); return success(); } rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy, diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 9348d3c1..64a7f56 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -922,15 +922,12 @@ struct NVGPUMBarrierArriveExpectTxLowering getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Value txcount = truncToI32(b, adaptor.getTxcount()); - - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>( - op, barrier, txcount, adaptor.getPredicate()); - return success(); - } - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>( - op, barrier, txcount, adaptor.getPredicate()); + op, Type{}, // return-value is optional and is void by default + barrier, txcount, // barrier and txcount + NVVM::MemScopeKind::CTA, // default scope is CTA + false, // relaxed-semantics is false + adaptor.getPredicate()); return success(); } }; @@ -949,13 +946,6 @@ struct NVGPUMBarrierTryWaitParityLowering Value ticks = truncToI32(b, adaptor.getTicks()); Value phase = LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity()); - - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>( - op, barrier, phase, ticks); - return success(); - } - rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier, phase, ticks); return success(); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 021e31a..7fdc23a 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -66,6 +66,9 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> { for (NamedAttribute attr : op->getAttrs()) { if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) { Type convertedType = converter->convertType(typeAttr.getValue()); + if (!convertedType) + return rewriter.notifyMatchFailure( + op, "failed to convert type in attribute"); convertedAttrs.emplace_back(attr.getName(), TypeAttr::get(convertedType)); } else { diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 37cfc9f..03842cc 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -36,6 +36,7 @@ namespace { struct SCFToControlFlowPass : public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> { + using Base::Base; void runOnOperation() override; }; @@ -736,7 +737,9 @@ void SCFToControlFlowPass::runOnOperation() { target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns), + config))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 76a822b..309121f 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -453,10 +453,24 @@ static LogicalResult processParallelLoop( 1, 2, rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) + rewriter.getAffineSymbolExpr(1)); + // Map through cloningMap first so we use values valid at the launch + // scope, then ensure they are launch-independent (or cloned constants). + Value mappedStep = cloningMap.lookupOrDefault(step); + Value mappedLowerBound = cloningMap.lookupOrDefault(lowerBound); + + mappedStep = ensureLaunchIndependent(mappedStep); + mappedLowerBound = ensureLaunchIndependent(mappedLowerBound); + + // If either cannot be made available above the launch, fail gracefully. + if (!mappedStep || !mappedLowerBound) { + return rewriter.notifyMatchFailure( + parallelOp, "lower bound / step must be constant or defined above " + "the gpu.launch"); + } + newIndex = AffineApplyOp::create( rewriter, loc, annotation.getMap().compose(lowerAndStep), - ValueRange{operand, ensureLaunchIndependent(step), - ensureLaunchIndependent(lowerBound)}); + ValueRange{operand, mappedStep, mappedLowerBound}); // If there was also a bound, insert that, too. // TODO: Check that we do not assign bounds twice. if (annotation.getBound()) { diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 460595b..6423d49 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -188,7 +188,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable, OpBuilder::InsertionGuard guard(builder); Type type = reduce.getOperands()[reductionIndex].getType(); auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(), - "__scf_reduction", type); + "__scf_reduction", type, + /*byref_element_type=*/{}); symbolTable.insert(decl); builder.createBlock(&decl.getInitializerRegion(), diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 50fca56..02b61bd 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1520,20 +1520,12 @@ public: if (!dstType) return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); - Location loc = tanOp.getLoc(); - Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand()); - Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand()); - rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); + rewriter.replaceOpWithNewOp<LLVM::TanOp>(tanOp, dstType, + adaptor.getOperands()); return success(); } }; -/// Convert `spirv.Tanh` to -/// -/// exp(2x) - 1 -/// ----------- -/// exp(2x) + 1 -/// class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> { public: using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion; @@ -1546,18 +1538,8 @@ public: if (!dstType) return rewriter.notifyMatchFailure(tanhOp, "type conversion failed"); - Location loc = tanhOp.getLoc(); - Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); - Value multiplied = - LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand()); - Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied); - Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); - Value numerator = - LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one); - Value denominator = - LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one); - rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, - denominator); + rewriter.replaceOpWithNewOp<LLVM::TanhOp>(tanhOp, dstType, + adaptor.getOperands()); return success(); } }; diff --git a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp index 9921a06..feb0489 100644 --- a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp +++ b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp @@ -23,8 +23,11 @@ namespace mlir { using namespace mlir; -namespace { +//===----------------------------------------------------------------------===// +// PoisonOpLowering +//===----------------------------------------------------------------------===// +namespace { struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -32,13 +35,8 @@ struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> { matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; - } // namespace -//===----------------------------------------------------------------------===// -// PoisonOpLowering -//===----------------------------------------------------------------------===// - LogicalResult PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -61,6 +59,29 @@ PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, } //===----------------------------------------------------------------------===// +// UnreachableOpLowering +//===----------------------------------------------------------------------===// + +namespace { +struct UnreachableOpLowering + : public ConvertOpToLLVMPattern<ub::UnreachableOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(ub::UnreachableOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace +LogicalResult + +UnreachableOpLowering::matchAndRewrite( + ub::UnreachableOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp<LLVM::UnreachableOp>(op); + return success(); +} + +//===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -93,7 +114,7 @@ struct UBToLLVMConversionPass void mlir::ub::populateUBToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add<PoisonOpLowering>(converter); + patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp index 244d214..3831387 100644 --- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp +++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp @@ -40,6 +40,17 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> { } }; +struct UnreachableOpLowering final : OpConversionPattern<ub::UnreachableOp> { + using Base::Base; + + LogicalResult + matchAndRewrite(ub::UnreachableOp op, OpAdaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<spirv::UnreachableOp>(op); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -75,5 +86,6 @@ struct UBToSPIRVConversionPass final void mlir::ub::populateUBToSPIRVConversionPatterns( const SPIRVTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add<PoisonOpLowering>(converter, patterns.getContext()); + patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter, + patterns.getContext()); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 69a317ec..05d541f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -345,7 +345,8 @@ public: matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = scatter->getLoc(); - MemRefType memRefType = scatter.getMemRefType(); + auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType()); + assert(memRefType && "The base should be bufferized"); if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) return rewriter.notifyMatchFailure(scatter, "memref type not supported"); @@ -1654,6 +1655,20 @@ private: return failure(); } } + } else if (auto floatTy = dyn_cast<FloatType>(printType)) { + // Print other floating-point types using the APFloat runtime library. + int32_t sem = + llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + Value semValue = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); + Value floatBits = + LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value); + printer = + LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables); + emitCall(rewriter, loc, printer.value(), + ValueRange({semValue, floatBits})); + return success(); } else { return failure(); } diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 1b4d1a4..079e1e2 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -519,8 +519,13 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { return lowerToScatteredLoadOp(readOp, rewriter); } - // Perform common data transfer checks. VectorType vecTy = readOp.getVectorType(); + + // Lower using load.gather in 1D case + if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim()) + return lowerToScatteredLoadOp(readOp, rewriter); + + // Perform common data transfer checks. if (failed(storeLoadPreconditions(rewriter, readOp, vecTy))) return failure(); @@ -562,7 +567,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices, /*packed=*/nullptr, transposeAttr, /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); rewriter.replaceOp(readOp, loadOp); return success(); @@ -616,7 +622,8 @@ struct TransferWriteLowering auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, indices, /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); rewriter.replaceOp(writeOp, storeOp); return success(); @@ -720,7 +727,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> { xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices, /*packed=*/nullptr, /*transpose=*/nullptr, /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); rewriter.replaceOp(loadOp, loadNdOp); return success(); @@ -758,7 +766,8 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> { auto storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices, /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); rewriter.replaceOp(storeOp, storeNdOp); diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index de552ce..0ecb50e 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16}; // Offsets to individual fields of the 8xi32 layout nd tensor descriptor. enum class NdTdescOffset : uint32_t { - BasePtr = 0, // Base pointer (i64) - BaseShapeW = 2, // Base shape width (i32) - BaseShapeH = 3, // Base shape height (i32) - TensorOffsetW = 4, // Tensor offset W (i32) - TensorOffsetH = 5 // Tensor offset H (i32) + BasePtr = 0, // Base pointer (i64) + BaseShapeW = 2, // Base shape width (i32) + BaseShapeH = 3, // Base shape height (i32) + BasePitch = 4, // Base pitch (i32) }; static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { @@ -151,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint, } } +// +// Note: +// Block operations for tile of sub byte element types are handled by +// emulating with larger element types. +// Tensor descriptor are keep intact and only ops consuming them are +// emulated +// + class CreateNdDescToXeVMPattern : public OpConversionPattern<xegpu::CreateNdDescOp> { using OpConversionPattern::OpConversionPattern; @@ -179,16 +186,12 @@ class CreateNdDescToXeVMPattern Value baseAddr; Value baseShapeW; Value baseShapeH; - Value offsetW; - Value offsetH; // Source can be a memref or a pointer (ui64, ui32, i64 or i32). SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes(); + SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); - if (rank != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D shape."); - auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -197,10 +200,20 @@ class CreateNdDescToXeVMPattern if (!sourceMemrefTy.hasRank()) { return rewriter.notifyMatchFailure(op, "Expected ranked Memref."); } - baseAddr = - memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); + // Access adaptor after failure check to avoid rolling back generated code + // for materialization cast. + baseAddr = adaptor.getSource(); } else { baseAddr = adaptor.getSource(); + if (baseAddr.getType() != i64Ty) { + // Pointer type may be i32. Cast to i64 if needed. + baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); + } + } + // 1D tensor descriptor is just the base address. + if (rank == 1) { + rewriter.replaceOp(op, baseAddr); + return success(); } // Utility for creating offset values from op fold result. auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec, @@ -209,19 +222,11 @@ class CreateNdDescToXeVMPattern val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); return val; }; - // Offsets are not supported (0 is used). - offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); - offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); // Get shape values from op fold results. baseShapeW = createOffset(mixedSizes, 1); baseShapeH = createOffset(mixedSizes, 0); - if (sourceMemrefTy) { - // Cast index to i64. - baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); - } else if (baseAddr.getType() != i64Ty) { - // Pointer type may be i32. Cast to i64 if needed. - baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); - } + // Get pitch value from op fold results. + Value basePitch = createOffset(mixedStrides, 0); // Populate payload. Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); @@ -235,12 +240,9 @@ class CreateNdDescToXeVMPattern payload = vector::InsertOp::create(rewriter, loc, baseShapeH, payload, static_cast<int>(NdTdescOffset::BaseShapeH)); - payload = vector::InsertOp::create( - rewriter, loc, offsetW, payload, - static_cast<int>(NdTdescOffset::TensorOffsetW)); - payload = vector::InsertOp::create( - rewriter, loc, offsetH, payload, - static_cast<int>(NdTdescOffset::TensorOffsetH)); + payload = + vector::InsertOp::create(rewriter, loc, basePitch, payload, + static_cast<int>(NdTdescOffset::BasePitch)); rewriter.replaceOp(op, payload); return success(); } @@ -257,108 +259,240 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { ConversionPatternRewriter &rewriter) const override { auto mixedOffsets = op.getMixedOffsets(); int64_t opOffsetsSize = mixedOffsets.size(); - if (opOffsetsSize != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdesc = adaptor.getTensorDesc(); auto tdescTy = op.getTensorDescType(); - if (tdescTy.getRank() != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); + auto tileRank = tdescTy.getRank(); + if (opOffsetsSize != tileRank) + return rewriter.notifyMatchFailure( + op, "Expected offset rank to match descriptor rank."); auto elemType = tdescTy.getElementType(); auto elemBitSize = elemType.getIntOrFloatBitWidth(); - if (elemBitSize % 8 != 0) + bool isSubByte = elemBitSize < 8; + uint64_t wScaleFactor = 1; + + if (!isSubByte && (elemBitSize % 8 != 0)) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); + auto tileW = tdescTy.getDimSize(tileRank - 1); + // For sub byte types, only 4bits are currently supported. + if (isSubByte) { + if (elemBitSize != 4) + return rewriter.notifyMatchFailure( + op, "Only sub byte types of 4bits are supported."); + if (tileRank != 2) + return rewriter.notifyMatchFailure( + op, "Sub byte types are only supported for 2D tensor descriptors."); + auto subByteFactor = 8 / elemBitSize; + auto tileH = tdescTy.getDimSize(0); + // Handle special case for packed load. + if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) { + if (op.getPacked().value_or(false)) { + // packed load is implemented as packed loads of 8bit elements. + if (tileH == systolicDepth * 4 && + tileW == executionSize * subByteFactor) { + // Usage case for loading as Matrix B with pack request. + // source is assumed to pre-packed into 8bit elements + // Emulate with 8bit loads with pack request. + // scaled_tileW = executionSize + elemType = rewriter.getIntegerType(8); + tileW = executionSize; + wScaleFactor = subByteFactor; + } + } + } + // If not handled by packed load case above, handle other cases. + if (wScaleFactor == 1) { + auto sub16BitFactor = subByteFactor * 2; + if (tileW == executionSize * sub16BitFactor) { + // Usage case for loading as Matrix A operand + // Emulate with 16bit loads/stores. + // scaled_tileW = executionSize + elemType = rewriter.getIntegerType(16); + tileW = executionSize; + wScaleFactor = sub16BitFactor; + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported tile shape for sub byte types."); + } + } + // recompute element bit size for emulation. + elemBitSize = elemType.getIntOrFloatBitWidth(); + } - VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); - Value payLoadAsI64 = - vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); - Value basePtr = vector::ExtractOp::create( - rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr)); - Value baseShapeW = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); - Value baseShapeH = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); - // Offsets are provided by the op. - // convert them to i32. - Value offsetW = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); - offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetW); - Value offsetH = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); - offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetH); // Get address space from tensor descriptor memory space. auto ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); - // Convert base pointer (i64) to LLVM pointer type. - Value basePtrLLVM = - LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); - // Compute element byte size and surface width in bytes. - Value elemByteSize = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); - Value surfaceW = - arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); - - // Get tile sizes and vblocks from the tensor descriptor type. - auto tileW = tdescTy.getDimSize(1); - auto tileH = tdescTy.getDimSize(0); - int32_t vblocks = tdescTy.getArrayLength(); - if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { - Value src = adaptor.getValue(); - // If store value is a scalar, get value from op instead of adaptor. - // Adaptor might have optimized away single element vector - if (src.getType().isIntOrFloat()) { - src = op.getValue(); + if (tileRank == 2) { + // Compute element byte size. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); + VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); + Value payLoadAsI64 = + vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); + Value basePtr = + vector::ExtractOp::create(rewriter, loc, payLoadAsI64, + static_cast<int>(NdTdescOffset::BasePtr)); + Value baseShapeW = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); + Value baseShapeH = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); + Value basePitch = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch)); + // Offsets are provided by the op. + // convert them to i32. + Value offsetW = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); + offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetW); + Value offsetH = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetH); + // Convert base pointer (i64) to LLVM pointer type. + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); + // FIXME: width or pitch is not the same as baseShapeW it should be the + // stride of the second to last dimension in row major layout. + // Compute width in bytes. + Value baseShapeWInBytes = + arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); + // Compute pitch in bytes. + Value basePitchBytes = + arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize); + + if (wScaleFactor > 1) { + // Scale offsetW, baseShapeWInBytes for sub byte emulation. + // Note: tileW is already scaled above. + Value wScaleFactorValLog2 = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor)); + baseShapeWInBytes = arith::ShRSIOp::create( + rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2); + basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes, + wScaleFactorValLog2); + offsetW = + arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2); } - VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); - if (!srcVecTy) - return rewriter.notifyMatchFailure( - op, "Expected store value to be a vector type."); - // Get flat vector type of integer type with matching element bit size. - VectorType newSrcVecTy = - encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); - if (srcVecTy != newSrcVecTy) - src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); - auto storeCacheControl = - translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - xevm::BlockStore2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, src, - xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); - rewriter.eraseOp(op); - } else { - auto loadCacheControl = - translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) { - xevm::BlockPrefetch2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, vblocks, - xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + // Get tile height from the tensor descriptor type. + auto tileH = tdescTy.getDimSize(0); + // Get vblocks from the tensor descriptor type. + int32_t vblocks = tdescTy.getArrayLength(); + if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + xevm::BlockStore2dOp::create( + rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, + basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); rewriter.eraseOp(op); } else { - VectorType dstVecTy = cast<VectorType>(op.getValue().getType()); - const bool vnni = op.getPacked().value_or(false); - auto transposeValue = op.getTranspose(); - bool transpose = - transposeValue.has_value() && transposeValue.value()[0] == 1; - VectorType loadedTy = encodeVectorTypeTo( - dstVecTy, vnni ? rewriter.getI32Type() - : rewriter.getIntegerType(elemBitSize)); - - Value resultFlatVec = xevm::BlockLoad2dOp::create( - rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, - surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, - transpose, vnni, + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) { + xevm::BlockPrefetch2dOp::create( + rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, + basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, + vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + rewriter.eraseOp(op); + } else { + VectorType dstVecTy = cast<VectorType>(op.getValue().getType()); + const bool vnni = op.getPacked().value_or(false); + auto transposeValue = op.getTranspose(); + bool transpose = + transposeValue.has_value() && transposeValue.value()[0] == 1; + VectorType loadedTy = encodeVectorTypeTo( + dstVecTy, vnni ? rewriter.getI32Type() + : rewriter.getIntegerType(elemBitSize)); + + Value resultFlatVec = xevm::BlockLoad2dOp::create( + rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes, + baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW, + tileH, vblocks, transpose, vnni, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + resultFlatVec = vector::BitCastOp::create( + rewriter, loc, + encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), + resultFlatVec); + rewriter.replaceOp(op, resultFlatVec); + } + } + } else { + // 1D tensor descriptor. + // `tdesc` represents base address as i64 + // Offset in number of elements, need to multiply by element byte size. + // Compute byte offset. + // byteOffset = offset * elementByteSize + Value offset = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offset = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI64Type(), offset); + // Compute element byte size. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI64Type(), elemBitSize / 8); + Value byteOffset = + rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize); + // Final address = basePtr + byteOffset + Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>( + loc, tdesc, + getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(), + byteOffset)); + // Convert base pointer (i64) to LLVM pointer type. + Value finalPtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64); + if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>( + op, finalPtrLLVM, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); + } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) { + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + VectorType resTy = cast<VectorType>(op.getValue().getType()); + VectorType loadedTy = + encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize)); + Value load = xevm::BlockLoadOp::create( + rewriter, loc, loadedTy, finalPtrLLVM, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); - resultFlatVec = vector::BitCastOp::create( - rewriter, loc, - encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), - resultFlatVec); - rewriter.replaceOp(op, resultFlatVec); + if (loadedTy != resTy) + load = vector::BitCastOp::create(rewriter, loc, resTy, load); + rewriter.replaceOp(op, load); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported operation: xegpu.prefetch_nd with tensor " + "descriptor rank == 1"); } } return success(); @@ -511,9 +645,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { } }; -// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions -// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than -// 32 bits will be converted to 32 bits. class CreateMemDescOpPattern final : public OpConversionPattern<xegpu::CreateMemDescOp> { public: @@ -522,16 +653,7 @@ public: matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resTy = op.getMemDesc(); - - // Create the result MemRefType with the same shape, element type, and - // memory space - auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy); - - Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); - auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, - op.getSource(), zero, ValueRange()); - rewriter.replaceOp(op, viewOp); + rewriter.replaceOp(op, adaptor.getSource()); return success(); } }; @@ -551,19 +673,27 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); - Value basePtrStruct = adaptor.getMemDesc(); + Value baseAddr32 = adaptor.getMemDesc(); Value mdescVal = op.getMemDesc(); // Load result or Store value Type can be vector or scalar. - Value data; - if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) - data = op.getResult(); - else - data = adaptor.getData(); - VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); + Type dataTy; + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { + Type resType = op.getResult().getType(); + // Some transforms may leave unit dimension in the 2D vector, adaptors do + // not catch it for results. + if (auto vecType = dyn_cast<VectorType>(resType)) { + assert(llvm::count_if(vecType.getShape(), + [](int64_t d) { return d != 1; }) <= 1 && + "Expected either 1D vector or nD with unit dimensions"); + resType = VectorType::get({vecType.getNumElements()}, + vecType.getElementType()); + } + dataTy = resType; + } else + dataTy = adaptor.getData().getType(); + VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy); if (!valOrResVecTy) - valOrResVecTy = VectorType::get(1, data.getType()); - if (valOrResVecTy.getShape().size() != 1) - return rewriter.notifyMatchFailure(op, "Expected 1D data vector."); + valOrResVecTy = VectorType::get(1, dataTy); int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth(); @@ -579,21 +709,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType()); - Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( - rewriter, loc, basePtrStruct); - - // Convert base pointer (ptr) to i32 - Value basePtrI32 = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI32Type(), basePtrLLVM); - Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); linearOffset = arith::IndexCastUIOp::create( rewriter, loc, rewriter.getI32Type(), linearOffset); - basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset, - elemByteSize); + Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, + linearOffset, elemByteSize); // convert base pointer (i32) to LLVM pointer type - basePtrLLVM = + Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); if (op.getSubgroupBlockIoAttr()) { @@ -929,20 +1052,22 @@ struct ConvertXeGPUToXeVMPass return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { + // Scattered descriptors are not supported in XeVM lowering. if (type.isScattered()) + return {}; + if (type.getRank() == 1) return IntegerType::get(&getContext(), 64); auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); - // Convert MemDescType into flattened MemRefType for SLM + // Convert MemDescType into i32 for SLM typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { - Type elemTy = type.getElementType(); - int numElems = type.getNumElements(); - return MemRefType::get(numElems, elemTy, AffineMap(), 3); + return IntegerType::get(&getContext(), 32); }); typeConverter.addConversion([&](MemRefType type) -> Type { - // Convert MemRefType to i64 type. + if (type.getMemorySpaceAsInt() == 3) + return IntegerType::get(&getContext(), 32); return IntegerType::get(&getContext(), 64); }); @@ -1059,6 +1184,7 @@ struct ConvertXeGPUToXeVMPass }; typeConverter.addSourceMaterialization( singleElementVectorMaterializationCast); + typeConverter.addSourceMaterialization(vectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index f276984..20a420d 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -290,7 +290,7 @@ static LLVM::CallOp createDeviceFunctionCall( ArrayRef<Type> argTypes, ArrayRef<Value> args, mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs, LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) { - auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); + auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); assert(moduleOp && "Expecting module"); Location loc = op->getLoc(); @@ -401,7 +401,10 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> { auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::NoModRef, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef); auto funcAttrs = convergentNoUnwindWillReturnAttrs; funcAttrs.memEffectsAttr = memAttr; Value result = @@ -450,7 +453,10 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> { auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::Ref, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef); funcAttr.memEffectsAttr = memAttr; LLVM::CallOp call = createDeviceFunctionCall( @@ -556,7 +562,10 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::Ref, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef); funcAttr = noUnwindAttrs; funcAttr.memEffectsAttr = memAttr; } else { @@ -798,7 +807,10 @@ class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> { constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/noModRef, - /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef, + /*errnoMem=*/noModRef, + /*targetMem0=*/noModRef, + /*targetMem1=*/noModRef); call.setMemoryEffectsAttr(memAttr); rewriter.replaceOp(op, call); return success(); @@ -836,7 +848,10 @@ class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> { constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/noModRef, - /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef, + /*errnoMem=*/noModRef, + /*targetMem0=*/noModRef, + /*targetMem1=*/noModRef); call.setMemoryEffectsAttr(memAttr); rewriter.replaceOp(op, call); return success(); diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index df955fc..b7a665b 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -55,6 +55,10 @@ void AMDGPUDialect::initialize() { #define GET_OP_LIST #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" + >(); addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" @@ -339,19 +343,45 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// -// ScaledExtPacked816Op +// ScaledExtPackedMatrixOp //===----------------------------------------------------------------------===// -LogicalResult ScaledExtPacked816Op::verify() { +LogicalResult ScaledExtPackedMatrixOp::verify() { int blockSize = getBlockSize(); - assert((blockSize == 16 || blockSize == 32) && "invalid block size"); + assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size"); + int firstScaleByte = getFirstScaleByte(); - if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) { - return emitOpError( - "blockSize of 16 can only have firstScaleByte be 0 or 1."); - } - if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) { - return emitOpError( - "blockSize of 32 can only have firstScaleByte be 0 or 2."); + int firstScaleLane = getFirstScaleLane(); + auto sourceType = cast<VectorType>(getSource().getType()); + Type elementType = sourceType.getElementType(); + auto floatType = cast<FloatType>(elementType); + unsigned bitWidth = floatType.getWidth(); + + assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth)); + + const bool is_fp8 = bitWidth == 8; + const bool is_block_16 = blockSize == 16; + + if (!is_fp8) { + if (is_block_16) { + if (!llvm::is_contained({0, 1}, firstScaleByte)) { + return emitOpError("blockSize of 16 can only have firstScaleByte be 0 " + "or 1 for f4 and f6."); + } + } else { + if (!llvm::is_contained({0, 2}, firstScaleByte)) { + return emitOpError("blockSize of 32 can only have firstScaleByte be 0 " + "or 2 for f4 and f6."); + } + } + } else { + if (is_block_16) { + bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) || + ((firstScaleLane == 16) && (firstScaleByte == 2)); + if (!is_valid) { + return emitOpError("blockSize of 16 can only have (firstScaleLane, " + "firstScaleByte) be (0, 0) or (16, 2) for f8."); + } + } } return success(); @@ -567,6 +597,53 @@ LogicalResult PermlaneSwapOp::verify() { } //===----------------------------------------------------------------------===// +// MemoryCounterWaitOp +//===----------------------------------------------------------------------===// + +namespace { +/// Fuse adjacent memory counter wait ops, taking the minimum value of the +/// counters. +struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> { + using Base::Base; + + LogicalResult matchAndRewrite(MemoryCounterWaitOp op, + PatternRewriter &rewriter) const override { + auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode()); + if (!next) + return failure(); + + auto setters = {&MemoryCounterWaitOp::setLoad, + &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs, + &MemoryCounterWaitOp::setExp, + &MemoryCounterWaitOp::setTensor}; + auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(), + op.getTensor()}; + auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(), + next.getExp(), next.getTensor()}; + rewriter.modifyOpInPlace(op, [&] { + for (auto [setter, lhs, rhs] : + llvm::zip_equal(setters, lhsVals, rhsVals)) { + if (lhs && rhs) { + (op.*setter)(std::min(*lhs, *rhs)); + } else if (lhs) { + (op.*setter)(*lhs); + } else if (rhs) { + (op.*setter)(*rhs); + } + } + }); + rewriter.eraseOp(next); + return success(); + } +}; +} // namespace + +void MemoryCounterWaitOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add<FuseMemoryCounterWaitOp>(context); +} + +//===----------------------------------------------------------------------===// // GatherToLDSOp //===----------------------------------------------------------------------===// @@ -662,19 +739,123 @@ LogicalResult TransposeLoadOp::verify() { }; auto validNumElems = kValidLoadSizeMap.find(elementTypeSize); - if (validNumElems == kValidLoadSizeMap.end()) { + if (validNumElems == kValidLoadSizeMap.end()) return emitOpError("Unsupported element type size for transpose load: ") << elementTypeSize << " bits"; - } - if (numElements != validNumElems->second) { + + if (numElements != validNumElems->second) return emitOpError( "Transferring type size mismatch: expected num of elements: ") << validNumElems->second; + + return success(); +} + +//===----------------------------------------------------------------------===// +// MakeDmaBaseOp +//===----------------------------------------------------------------------===// + +LogicalResult MakeDmaBaseOp::verify() { + + auto ldsType = cast<MemRefType>(getLds().getType()); + auto globalType = cast<MemRefType>(getGlobal().getType()); + if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace())) + return emitOpError( + "lds memref must have workgroup address space attribute."); + if (!hasGlobalMemorySpace(globalType.getMemorySpace())) + return emitOpError( + "global memref must have global address space attribute."); + + Type elementType = ldsType.getElementType(); + unsigned width = elementType.getIntOrFloatBitWidth(); + + if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width)) + return emitOpError( + "element type must be 1, 2, 4, or 8 bytes long but type was ") + << width << " bits long."; + + return success(); +} + +//===----------------------------------------------------------------------===// +// MakeDmaDescriptorOp +//===----------------------------------------------------------------------===// + +LogicalResult MakeDmaDescriptorOp::verify() { + ArrayRef<int64_t> globalStaticStrides = getGlobalStaticStrides(); + + if (globalStaticStrides.empty()) + return emitOpError("strides must not be empty."); + if (globalStaticStrides.back() != 1) + return emitOpError("strides for the innermost dimension must be 1."); + + ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes(); + size_t rank = globalStaticSizes.size(); + if (rank > 5) + return emitOpError("tensor and tile must be at most of rank 5."); + if (rank != globalStaticStrides.size()) + return emitOpError("strides and sizes must have same rank."); + + ArrayRef<int64_t> sharedStaticSizes = getSharedStaticSizes(); + if (rank != sharedStaticSizes.size()) + return emitOpError("tensor must have same rank as tile."); + + unsigned elementTypeWidth = getElementTypeWidth(); + if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidth)) + return emitOpError( + "element type width must be 1, 2, 4 or 8 bytes, but was ") + << elementTypeWidth << " bits long"; + + if (Value atomicBarrierAddress = getAtomicBarrierAddress()) { + auto atomicBarrierAddressType = + cast<MemRefType>(atomicBarrierAddress.getType()); + bool barrierInLDS = + hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace()); + if (!barrierInLDS) + return emitOpError("atomic barrier address must be in LDS."); } + if (getEarlyTimeout() && !getWorkgroupMask()) + return emitOpError( + "early timeout does not apply when workgroup_mask is not set."); return success(); } +OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) { + SmallVector<OpFoldResult> mixedGlobalSizes(getMixedGlobalSizes()); + SmallVector<OpFoldResult> mixedGlobalStrides(getMixedGlobalStrides()); + SmallVector<OpFoldResult> mixedSharedSizes(getMixedSharedSizes()); + + if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true)) && + failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true)) && + failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true))) + return nullptr; + + SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides, + dynamicSharedSizes; + SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides, + staticSharedSizes; + + dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes, + staticGlobalSizes); + setGlobalStaticSizes(staticGlobalSizes); + getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes); + + dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides, + staticGlobalStrides); + setGlobalStaticStrides(staticGlobalStrides); + getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides); + + dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes, + staticSharedSizes); + setSharedStaticSizes(staticSharedSizes); + getSharedDynamicSizesMutable().assign(dynamicSharedSizes); + return getResult(); +} + //===----------------------------------------------------------------------===// // ScaledMFMAOp //===----------------------------------------------------------------------===// @@ -813,5 +994,8 @@ void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results, #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index f15c63c..89ef51f 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -33,19 +33,18 @@ using namespace mlir::amdgpu; /// This pattern supports lowering of: `vector.maskedload` to `vector.load` /// and `arith.select` if the memref is in buffer address space. -static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, - vector::MaskedLoadOp maskedOp) { - auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType()); +static LogicalResult hasBufferAddressSpace(Type type) { + auto memRefType = dyn_cast<MemRefType>(type); if (!memRefType) - return rewriter.notifyMatchFailure(maskedOp, "not a memref source"); + return failure(); Attribute addrSpace = memRefType.getMemorySpace(); if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace)) - return rewriter.notifyMatchFailure(maskedOp, "no address space"); + return failure(); if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() != amdgpu::AddressSpace::FatRawBuffer) - return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space"); + return failure(); return success(); } @@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> { LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp, PatternRewriter &rewriter) const override { if (maskedOp->hasAttr(kMaskedloadNeedsMask)) - return failure(); + return rewriter.notifyMatchFailure(maskedOp, "already rewritten"); - if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) { - return failure(); + if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) { + return rewriter.notifyMatchFailure( + maskedOp, "isn't a load from a fat buffer resource"); } // Check if this is either a full inbounds load or an empty, oob load. If @@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, PatternRewriter &rewriter) const override { + if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType()))) + return rewriter.notifyMatchFailure( + loadOp, "buffer loads are handled by a more specialized pattern"); + FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask()); if (failed(maybeCond)) { - return failure(); + return rewriter.notifyMatchFailure(loadOp, + "isn't loading a broadcasted scalar"); } Value cond = maybeCond.value(); @@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, PatternRewriter &rewriter) const override { + // A condition-free implementation of fully masked stores requires + // 1) an accessor for the num_records field on buffer resources/fat pointers + // 2) knowledge that said field will always be set accurately - that is, + // that writes to x < num_records of offset wouldn't trap, which is + // something a pattern user would need to assert or we'd need to prove. + // + // Therefore, conditional stores to buffers still go down this path at + // present. + FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask()); if (failed(maybeCond)) { return failure(); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 0c35921..c6addfb 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -5421,7 +5421,7 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final return rewriter.notifyMatchFailure(op, "no unit basis entries to replace"); - if (newIndices.size() == 0) { + if (newIndices.empty()) { rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0); return success(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp index c942c02..b04e2d6 100644 --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -82,7 +82,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { ArrayRef<int64_t> oldShape = oldMemRefType.getShape(); SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank()); newShape[0] = 2; - std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); + llvm::copy(oldShape, newShape.begin() + 1); return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({}); }; diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 4743941..8f1249e 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1711,6 +1711,12 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) { outermost.getBody()->getOperations().splice( Block::iterator(secondOutermostLoop.getOperation()), innermost.getBody()->getOperations()); + for (auto [iter, init] : + llvm::zip_equal(secondOutermostLoop.getRegionIterArgs(), + secondOutermostLoop.getInits())) { + iter.replaceAllUsesWith(init); + iter.dropAllUses(); + } secondOutermostLoop.erase(); return success(); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index de3efc9f..e256915 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -389,8 +389,8 @@ def TruncIExtUIToExtUI : // trunci(shrsi(x, c)) -> trunci(shrui(x, c)) def TruncIShrSIToTrunciShrUI : Pat<(Arith_TruncIOp:$tr - (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow), - (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow), + (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0), $exact), $overflow), + (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)), $exact), $overflow), [(TruncationMatchesShiftAmount $x, $tr, $c0)]>; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index adeb50b..c4e81e5 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -35,7 +35,7 @@ static Value createConst(Location loc, Type type, int value, } /// Create a float constant. -static Value createFloatConst(Location loc, Type type, APFloat value, +static Value createFloatConst(Location loc, Type type, const APFloat &value, PatternRewriter &rewriter) { auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast<ShapedType>(type)) { diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp index 39e398b..cb7c3d7 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp @@ -150,7 +150,7 @@ public: rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask()); } - auto extOp = op.getLhs().getDefiningOp(); + auto *extOp = op.getLhs().getDefiningOp(); arm_sme::CombiningKind kind = op.getKind(); if (kind == arm_sme::CombiningKind::Add) { @@ -311,8 +311,8 @@ public: rhsMask = packInputs(rhs0Mask, rhs1Mask); } - auto lhsExtOp = op.getLhs().getDefiningOp(); - auto rhsExtOp = op.getRhs().getDefiningOp(); + auto *lhsExtOp = op.getLhs().getDefiningOp(); + auto *rhsExtOp = op.getRhs().getDefiningOp(); arm_sme::CombiningKind kind = op.getKind(); if (kind == arm_sme::CombiningKind::Add) { diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index e0cf353..9b11270 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -680,16 +680,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { return false; } -// bufferization.to_buffer is not allowed to change the rank. -static void ensureToBufferOpIsValid(Value tensor, Type memrefType) { -#ifndef NDEBUG - auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType()); - assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() == - rankedTensorType.getRank()) && - "to_buffer would be invalid: mismatching ranks"); -#endif -} - FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state) { @@ -708,7 +698,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state); if (failed(bufferType)) return failure(); - ensureToBufferOpIsValid(value, *bufferType); + return bufferization::ToBufferOp::create(rewriter, value.getLoc(), *bufferType, value) .getResult(); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index d6c3cd6..bd177ba 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -54,9 +54,6 @@ struct BuiltinTensorExternalModel mlir::LogicalResult verifyCompatibleBufferType( mlir::Type tensor, BufferLikeType bufferType, llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const { - assert(isa<TensorType>(tensor) && "expected tensor type"); - assert(isa<BaseMemRefType>(bufferType) && "expected memref type"); - auto tensorType = cast<ShapedType>(tensor); auto memrefType = cast<ShapedType>(bufferType); diff --git a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp index 51feec7..f8eb45c 100644 --- a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp +++ b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp @@ -17,6 +17,10 @@ // Pipeline implementation. //===----------------------------------------------------------------------===// +void mlir::bufferization::buildBufferDeallocationPipeline(OpPassManager &pm) { + buildBufferDeallocationPipeline(pm, BufferDeallocationPipelineOptions()); +} + void mlir::bufferization::buildBufferDeallocationPipeline( OpPassManager &pm, const BufferDeallocationPipelineOptions &options) { memref::ExpandReallocPassOptions expandAllocPassOptions{ @@ -44,5 +48,7 @@ void mlir::bufferization::registerBufferizationPipelines() { "The default pipeline for automatically inserting deallocation " "operations after one-shot bufferization. Deallocation operations " "(except `memref.realloc`) may not be present already.", - buildBufferDeallocationPipeline); + [](OpPassManager &pm, const BufferDeallocationPipelineOptions &options) { + buildBufferDeallocationPipeline(pm, options); + }); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index 1784964..677c0ba 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/SubsetOpInterface.h" +#include "mlir/Transforms/RegionUtils.h" namespace mlir { namespace bufferization { @@ -105,8 +106,13 @@ Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter, // this replacement. Operation *insertionPoint = findValidInsertionPoint(emptyTensorOp, user, neededValues); - if (!insertionPoint) - return {}; + if (!insertionPoint) { + // If no already suitable insertion point was found, attempt to move all + // needed values before the user. + if (failed(moveValueDefinitions(rewriter, neededValues, user))) + return {}; + insertionPoint = user; + } rewriter.setInsertionPoint(insertionPoint); Value replacement = diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index 9ccbfd3..5dfe3e6 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -497,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, // terminates. All of them must be equivalent subsets. SetVector<Value> backwardSlice = state.findValueInReverseUseDefChain(opOperand, matchingSubset); - return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset)); + return llvm::all_of(backwardSlice, matchingSubset); } /// Return "true" if the given "read" and potentially conflicting "write" are diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt index 58551bb..05a787f 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt @@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect MLIRControlFlowInterfaces MLIRIR MLIRSideEffectInterfaces + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index f1da1a1..d2078d8 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -445,6 +446,37 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { return success(replaced); } }; + +/// If the destination block of a conditional branch contains only +/// ub.unreachable, unconditionally branch to the other destination. +struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> { + using OpRewritePattern<CondBranchOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // If the "true" destination is unreachable, branch to the "false" + // destination. + Block *trueDest = condbr.getTrueDest(); + Block *falseDest = condbr.getFalseDest(); + if (llvm::hasSingleElement(*trueDest) && + isa<ub::UnreachableOp>(trueDest->getTerminator())) { + rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest, + condbr.getFalseOperands()); + return success(); + } + + // If the "false" destination is unreachable, branch to the "true" + // destination. + if (llvm::hasSingleElement(*falseDest) && + isa<ub::UnreachableOp>(falseDest->getTerminator())) { + rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, + condbr.getTrueOperands()); + return success(); + } + + return failure(); + } +}; } // namespace void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -452,7 +484,7 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, SimplifyCondBranchIdenticalSuccessors, SimplifyCondBranchFromCondBranchOnSameCondition, - CondBranchTruthPropagation>(context); + CondBranchTruthPropagation, DropUnreachableCondBranch>(context); } SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index d478220..b0566dd 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -226,6 +226,21 @@ FailureOr<SmallVector<ReplacementItem>> parseFormatString( } //===----------------------------------------------------------------------===// +// AddressOfOp +//===----------------------------------------------------------------------===// + +LogicalResult AddressOfOp::verify() { + emitc::LValueType referenceType = getReference().getType(); + emitc::PointerType resultType = getResult().getType(); + + if (referenceType.getValueType() != resultType.getPointee()) + return emitOpError("requires result to be a pointer to the type " + "referenced by operand"); + + return success(); +} + +//===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -380,6 +395,20 @@ LogicalResult emitc::ConstantOp::verify() { OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } //===----------------------------------------------------------------------===// +// DereferenceOp +//===----------------------------------------------------------------------===// + +LogicalResult DereferenceOp::verify() { + emitc::PointerType pointerType = getPointer().getType(); + + if (pointerType.getPointee() != getResult().getType().getValueType()) + return emitOpError("requires result to be an lvalue of the type " + "pointed to by operand"); + + return success(); +} + +//===----------------------------------------------------------------------===// // ExpressionOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp index b4cb093..d6dfd02 100644 --- a/mlir/lib/Dialect/Func/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp @@ -254,3 +254,28 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp, return std::make_pair(*newFuncOpOrFailure, newCallOp); } + +FailureOr<func::FuncOp> +func::lookupFnDecl(SymbolOpInterface symTable, StringRef name, + FunctionType funcT, SymbolTableCollection *symbolTables) { + FuncOp func; + if (symbolTables) { + func = symbolTables->lookupSymbolIn<FuncOp>( + symTable, StringAttr::get(symTable->getContext(), name)); + } else { + func = llvm::dyn_cast_or_null<FuncOp>( + SymbolTable::lookupSymbolIn(symTable, name)); + } + + if (!func) + return func; + + mlir::FunctionType foundFuncT = func.getFunctionType(); + // Assert the signature of the found function is same as expected + if (funcT != foundFuncT) { + return func.emitError("matched function '") + << name << "' but with different type: " << foundFuncT + << " (expected " << funcT << ")"; + } + return func; +} diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 6c6d8d2..61a630a 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } bool MMAMatrixType::isValidElementType(Type elementType) { - return elementType.isF16() || elementType.isF32() || + return elementType.isF16() || elementType.isF32() || elementType.isF64() || elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) || elementType.isInteger(32); } @@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, if (!MMAMatrixType::isValidElementType(elementType)) return emitError() - << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32"; + << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64"; return success(); } diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt index ec68acf..85b7b1ce 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_dialect_library(MLIRGPUPipelines MLIRNVVMToLLVM MLIRReconcileUnrealizedCasts MLIRSCFToControlFlow + MLIRVectorToLLVMPass MLIRVectorToSCF MLIRXeGPUTransforms MLIRXeGPUToXeVM diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp index 2c3e466..5462cdd 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp @@ -72,6 +72,7 @@ void buildGpuPassPipeline(OpPassManager &pm, ConvertGpuOpsToNVVMOpsOptions opt; opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv; opt.indexBitwidth = options.indexBitWidth; + opt.allowPatternRollback = options.allowPatternRollback; pm.addNestedPass<gpu::GPUModuleOp>(createConvertGpuOpsToNVVMOps(opt)); pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp index 1a1485b..38313dc 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp @@ -63,13 +63,20 @@ void buildGPUPassPipeline(OpPassManager &pm, if (options.xegpuOpLevel == "workgroup") { pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute()); pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); + xegpu::XeGPUPropagateLayoutOptions layoutOptions; + layoutOptions.layoutKind = "inst"; + pm.addNestedPass<gpu::GPUModuleOp>( + xegpu::createXeGPUPropagateLayout(layoutOptions)); pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking()); pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); } if (options.xegpuOpLevel == "subgroup" || options.xegpuOpLevel == "workgroup") { - pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout()); + xegpu::XeGPUPropagateLayoutOptions layoutOptions; + layoutOptions.layoutKind = "lane"; + pm.addNestedPass<gpu::GPUModuleOp>( + xegpu::createXeGPUPropagateLayout(layoutOptions)); pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute()); pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); @@ -104,8 +111,11 @@ void buildPostGPUCommonPassPipeline( pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions)); } pm.addPass(createLowerAffinePass()); + pm.addPass(createConvertVectorToLLVMPass()); pm.addPass(createConvertToLLVMPass()); pm.addPass(createReconcileUnrealizedCastsPass()); + pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); + pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); // gpu-module-to-binary { GpuModuleToBinaryPassOptions gpuToModuleBinOptions; diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp index cd13840..70d2e11 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -143,8 +143,8 @@ private: }; /// Erases `executeOp` and returns a clone with additional `results`. -async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, - ValueRange results) { +static async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, + ValueRange results) { // Add values to async.yield op. Operation *yieldOp = executeOp.getBody()->getTerminator(); yieldOp->insertOperands(yieldOp->getNumOperands(), results); diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp index 212ccc9..8d10aac 100644 --- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -169,7 +169,7 @@ LogicalResult getSegmentSizes(Operation *op, StringRef elemName, LogicalResult getOperandSegmentSizes(Operation *op, ArrayRef<Variadicity> variadicities, SmallVectorImpl<int> &segmentSizes) { - return getSegmentSizes(op, "operand", "operand_segment_sizes", + return getSegmentSizes(op, "operand", "operandSegmentSizes", op->getNumOperands(), variadicities, segmentSizes); } @@ -180,7 +180,7 @@ LogicalResult getOperandSegmentSizes(Operation *op, LogicalResult getResultSegmentSizes(Operation *op, ArrayRef<Variadicity> variadicities, SmallVectorImpl<int> &segmentSizes) { - return getSegmentSizes(op, "result", "result_segment_sizes", + return getSegmentSizes(op, "result", "resultSegmentSizes", op->getNumResults(), variadicities, segmentSizes); } diff --git a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp index 183d0e3..887e8e1 100644 --- a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Transforms/InliningUtils.h" using namespace mlir; using namespace mlir::index; @@ -15,10 +16,23 @@ using namespace mlir::index; //===----------------------------------------------------------------------===// // IndexDialect //===----------------------------------------------------------------------===// +namespace { +/// This class defines the interface for handling inlining for index +/// dialect operations. +struct IndexInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + /// All index dialect ops can be inlined. + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } +}; +} // namespace void IndexDialect::initialize() { registerAttributes(); registerOperations(); + addInterfaces<IndexInlinerInterface>(); declarePromisedInterface<ConvertToLLVMPatternInterface, IndexDialect>(); } diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index cc66fac..a73f0c1 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLLVMDialect MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRFunctionInterfaces + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR MLIRMemorySlotInterfaces diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index feaffa3..160b6ae 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16"; static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; +static constexpr llvm::StringRef kPrintApFloat = "printApFloat"; static constexpr llvm::StringRef kPrintString = "printString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; @@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } +FailureOr<LLVM::LLVMFuncOp> +mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { + return lookupOrCreateReservedFn( + b, moduleOp, kPrintApFloat, + {IntegerType::get(moduleOp->getContext(), 32), + IntegerType::get(moduleOp->getContext(), 64)}, + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); +} + static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { return LLVM::LLVMPointerType::get(context); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index b8331e0..9f87e50 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -219,11 +219,16 @@ bool TBAANodeAttr::classof(Attribute attr) { MemoryEffectsAttr MemoryEffectsAttr::get(MLIRContext *context, ArrayRef<ModRefInfo> memInfoArgs) { if (memInfoArgs.empty()) - return MemoryEffectsAttr::get(context, ModRefInfo::ModRef, - ModRefInfo::ModRef, ModRefInfo::ModRef); - if (memInfoArgs.size() == 3) + return MemoryEffectsAttr::get(context, /*other=*/ModRefInfo::ModRef, + /*argMem=*/ModRefInfo::ModRef, + /*inaccessibleMem=*/ModRefInfo::ModRef, + /*errnoMem=*/ModRefInfo::ModRef, + /*targetMem0=*/ModRefInfo::ModRef, + /*targetMem1=*/ModRefInfo::ModRef); + if (memInfoArgs.size() == 6) return MemoryEffectsAttr::get(context, memInfoArgs[0], memInfoArgs[1], - memInfoArgs[2]); + memInfoArgs[2], memInfoArgs[3], + memInfoArgs[4], memInfoArgs[5]); return {}; } @@ -234,6 +239,12 @@ bool MemoryEffectsAttr::isReadWrite() { return false; if (this->getOther() != ModRefInfo::ModRef) return false; + if (this->getErrnoMem() != ModRefInfo::ModRef) + return false; + if (this->getTargetMem0() != ModRefInfo::ModRef) + return false; + if (this->getTargetMem1() != ModRefInfo::ModRef) + return false; return true; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 1bf4a1c..5b81948 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -4224,6 +4224,34 @@ LogicalResult InlineAsmOp::verify() { } //===----------------------------------------------------------------------===// +// UDivOp +//===----------------------------------------------------------------------===// +Speculation::Speculatability UDivOp::getSpeculatability() { + // X / 0 => UB + Value divisor = getRhs(); + if (matchPattern(divisor, m_IntRangeWithoutZeroU())) + return Speculation::Speculatable; + + return Speculation::NotSpeculatable; +} + +//===----------------------------------------------------------------------===// +// SDivOp +//===----------------------------------------------------------------------===// +Speculation::Speculatability SDivOp::getSpeculatability() { + // This function conservatively assumes that all signed division by -1 are + // not speculatable. + // X / 0 => UB + // INT_MIN / -1 => UB + Value divisor = getRhs(); + if (matchPattern(divisor, m_IntRangeWithoutZeroS()) && + matchPattern(divisor, m_IntRangeWithoutNegOneS())) + return Speculation::Speculatable; + + return Speculation::NotSpeculatable; +} + +//===----------------------------------------------------------------------===// // LLVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index ce93d18..5dc4fa2 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -667,6 +667,7 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries, static constexpr llvm::StringRef kSpirvPrefix = "spirv."; static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount"; +static constexpr llvm::StringRef kAMDGCNNamedBarrier = "amdgcn.named.barrier"; bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { // See llvm/lib/IR/Type.cpp for reference. @@ -676,6 +677,9 @@ bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { properties |= (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal); + if (getExtTypeName() == kAMDGCNNamedBarrier) + properties |= LLVMTargetExtType::CanBeGlobal; + return (properties & prop) == prop; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index d43f881..5ce56e6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/NVVMIntrinsicUtils.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/NVPTXAddrSpace.h" @@ -48,6 +49,47 @@ using namespace NVVM; static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic; //===----------------------------------------------------------------------===// +// Helper/Utility methods +//===----------------------------------------------------------------------===// + +static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) { + auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType()); + return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS); +} + +static bool isPtrInGenericSpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic); +} + +static bool isPtrInSharedCTASpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared); +} + +static bool isPtrInSharedClusterSpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster); +} + +static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder, + llvm::Value *ptr, + NVVMMemorySpace targetAS) { + unsigned AS = static_cast<unsigned>(targetAS); + return builder.CreateAddrSpaceCast( + ptr, llvm::PointerType::get(builder.getContext(), AS)); +} + +// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM +static llvm::nvvm::CTAGroupKind +getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) { + switch (ctaGroup) { + case NVVM::CTAGroupKind::CTA_1: + return llvm::nvvm::CTAGroupKind::CG_1; + case NVVM::CTAGroupKind::CTA_2: + return llvm::nvvm::CTAGroupKind::CG_2; + } + llvm_unreachable("unsupported cta_group value"); +} + +//===----------------------------------------------------------------------===// // Verifier methods //===----------------------------------------------------------------------===// @@ -199,6 +241,83 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() { return success(); } +LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() { + bool isSharedCTA = isPtrInSharedCTASpace(getDstMem()); + if (isSharedCTA && getMulticastMask()) + return emitError("Multicast is not supported with shared::cta mode."); + + return success(); +} + +static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, + NVVM::MemScopeKind scope, + Value retVal = nullptr) { + if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER) + return op->emitError("mbarrier scope must be either CTA or Cluster"); + + bool isSharedCluster = isPtrInSharedClusterSpace(addr); + bool hasRetValue = static_cast<bool>(retVal); + if (isSharedCluster && hasRetValue) + return op->emitError( + "mbarrier in shared_cluster space cannot return any value"); + + return success(); +} + +LogicalResult MBarrierArriveOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveDropOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveExpectTxOp::verify() { + // The inline-ptx version of this Op does not support all features. + // With predicate, this Op lowers to inline-ptx. So, verify and + // error-out if there are unsupported features. + if (getPredicate()) { + if (getScope() != NVVM::MemScopeKind::CTA) + return emitError("mbarrier scope must be CTA when using predicate"); + + if (isPtrInSharedClusterSpace(getAddr())) + return emitError("mbarrier in shared_cluster space is not supported when " + "using predicate"); + + if (getRes()) + return emitError("return-value is not supported when using predicate"); + + if (getRelaxed() == true) + return emitError("mbarrier with relaxed semantics is not supported when " + "using predicate"); + } + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveDropExpectTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierExpectTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierCompleteTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierTestWaitOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierTryWaitOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + LogicalResult ConvertFloatToTF32Op::verify() { using RndMode = NVVM::FPRoundingMode; switch (getRnd()) { @@ -365,22 +484,71 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() { return success(); } +LogicalResult PermuteOp::verify() { + using Mode = NVVM::PermuteMode; + bool hasHi = static_cast<bool>(getHi()); + + switch (getMode()) { + case Mode::DEFAULT: + case Mode::F4E: + case Mode::B4E: + if (!hasHi) + return emitError("mode '") + << stringifyPermuteMode(getMode()) << "' requires 'hi' operand."; + break; + case Mode::RC8: + case Mode::ECL: + case Mode::ECR: + case Mode::RC16: + if (hasHi) + return emitError("mode '") << stringifyPermuteMode(getMode()) + << "' does not accept 'hi' operand."; + break; + } + + return success(); +} + //===----------------------------------------------------------------------===// // Stochastic Rounding Conversion Ops //===----------------------------------------------------------------------===// -LogicalResult ConvertF32x2ToF16x2Op::verify() { - if (getRnd() != FPRoundingMode::RS) - return emitOpError("Only RS rounding mode is supported for " - "conversions from f32x2 to f16x2."); +static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, + FPRoundingMode rnd, + bool hasRandomBits, + Operation *op) { + static constexpr FPRoundingMode validRndModes[] = { + FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS}; + + if (!llvm::is_contained(validRndModes, rnd)) { + return op->emitOpError( + "Only RN, RZ, and RS rounding modes are supported for " + "conversions from f32x2 to ") + << dstType << "."; + } + + if (rnd == FPRoundingMode::RS) { + if (!hasRandomBits) { + return op->emitOpError("random_bits is required for RS rounding mode."); + } + } else { + if (hasRandomBits) { + return op->emitOpError( + "random_bits not supported for RN and RZ rounding modes."); + } + } + return success(); } +LogicalResult ConvertF32x2ToF16x2Op::verify() { + return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(), + getRandomBits() ? true : false, *this); +} + LogicalResult ConvertF32x2ToBF16x2Op::verify() { - if (getRnd() != FPRoundingMode::RS) - return emitOpError("Only RS rounding mode is supported for " - "conversions from f32x2 to bf16x2."); - return success(); + return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(), + getRandomBits() ? true : false, *this); } LogicalResult ConvertF32x4ToF8x4Op::verify() { @@ -919,6 +1087,482 @@ LogicalResult MmaOp::verify() { return success(); } +MMATypes MmaSpOp::accumPtxType() { + std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType( + getODSOperands(2).getTypes().front(), /*isAccumulator=*/true); + assert(val.has_value() && "accumulator PTX type should always be inferrable"); + return val.value(); +} + +MMATypes MmaSpOp::resultPtxType() { + std::optional<mlir::NVVM::MMATypes> val = + MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true); + assert(val.has_value() && "result PTX type should always be inferrable"); + return val.value(); +} + +mlir::NVVM::IDArgPair +MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MmaSpOp>(op); + + // Get operands + llvm::SmallVector<llvm::Value *> args; + for (mlir::Value v : thisOp.getOperands()) + args.push_back(mt.lookupValue(v)); + + // Get intrinsic ID using the existing getIntrinsicID method + auto intId = MmaSpOp::getIntrinsicID( + thisOp.getShape().getM(), thisOp.getShape().getN(), + thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(), + thisOp.getOrderedMetadata(), thisOp.getKind(), + *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(), + thisOp.accumPtxType(), thisOp.resultPtxType()); + + return {intId, args}; +} + +void MmaSpOp::print(OpAsmPrinter &p) { + SmallVector<Type, 4> regTypes; + struct OperandFragment { + StringRef operandName; + StringRef ptxTypeAttr; + SmallVector<Value, 4> regs; + explicit OperandFragment(StringRef name, StringRef ptxTypeName) + : operandName(name), ptxTypeAttr(ptxTypeName) {} + }; + + std::array<OperandFragment, 5> frags{ + OperandFragment("A", getMultiplicandAPtxTypeAttrName()), + OperandFragment("B", getMultiplicandBPtxTypeAttrName()), + OperandFragment("C", ""), OperandFragment("sparseMetadata", ""), + OperandFragment("selector", "")}; + SmallVector<StringRef, 4> ignoreAttrNames{ + mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()}; + + // Handle variadic operands A, B, C + for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) { + auto &frag = frags[fragIdx]; + auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); + for (auto operandIdx = varOperandSpec.first; + operandIdx < varOperandSpec.first + varOperandSpec.second; + operandIdx++) { + frag.regs.push_back(this->getOperand(operandIdx)); + if (operandIdx == varOperandSpec.first) { + regTypes.push_back(this->getOperand(operandIdx).getType()); + } + } + std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType( + regTypes.back(), /*isAccumulator=*/fragIdx >= 2); + if (inferredType) + ignoreAttrNames.push_back(frag.ptxTypeAttr); + } + + // Handle sparse metadata and selector (single operands) + frags[3].regs.push_back(getSparseMetadata()); + frags[4].regs.push_back(getSparsitySelector()); + + auto printMmaSpOperand = [&](const OperandFragment &frag) -> void { + p << " " << frag.operandName; + p << "["; + p.printOperands(frag.regs); + p << "]"; + }; + + for (const auto &frag : frags) + printMmaSpOperand(frag); + + p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames); + p << " : "; + p << "("; + for (int i = 0; i < 3; ++i) { + p << regTypes[i]; + if (i < 2) + p << ", "; + } + p << ") -> " << getResult().getType(); +} + +void MmaSpOp::build( + OpBuilder &builder, OperationState &result, Type resultType, + ValueRange operandA, ValueRange operandB, ValueRange operandC, + Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape, + std::optional<MMAIntOverflow> intOverflow, + std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) { + + assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); + MLIRContext *ctx = builder.getContext(); + result.addAttribute( + "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2])); + + result.addOperands(operandA); + result.addOperands(operandB); + result.addOperands(operandC); + result.addOperands(sparseMetadata); + result.addOperands(sparsitySelector); + + if (multiplicandPtxTypes) { + result.addAttribute("multiplicandAPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); + result.addAttribute("multiplicandBPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); + } else { + if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false)) + result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); + if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false)) + result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); + } + + if (intOverflow.has_value()) + result.addAttribute("intOverflowBehavior", + MMAIntOverflowAttr::get(ctx, *intOverflow)); + + result.addTypes(resultType); + result.addAttribute( + MmaSpOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()), + static_cast<int32_t>(operandB.size()), + static_cast<int32_t>(operandC.size()), 1, + 1})); // sparseMetadata and sparsitySelector +} + +ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) { + struct OperandFragment { + std::optional<MMATypes> elemtype; + SmallVector<OpAsmParser::UnresolvedOperand, 4> regs; + SmallVector<Type> regTypes; + }; + + Builder &builder = parser.getBuilder(); + std::array<OperandFragment, 6> frags; // A, B, C, sparseMetadata, selector + + NamedAttrList namedAttributes; + + // A helper to parse the operand segments. + auto parseMmaSpOperand = [&](StringRef operandName, + OperandFragment &frag) -> LogicalResult { + if (parser.parseKeyword(operandName).failed()) + return failure(); + if (parser + .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) + .failed()) + return failure(); + return success(); + }; + + // Parse the operand segments. + if (parseMmaSpOperand("A", frags[0]).failed()) + return failure(); + if (parseMmaSpOperand("B", frags[1]).failed()) + return failure(); + if (parseMmaSpOperand("C", frags[2]).failed()) + return failure(); + if (parseMmaSpOperand("sparseMetadata", frags[3]).failed()) + return failure(); + if (parseMmaSpOperand("selector", frags[4]).failed()) + return failure(); + + if (parser.parseOptionalAttrDict(namedAttributes).failed()) + return failure(); + + // Parse the type specification and resolve operands. + SmallVector<Type, 3> operandTypes; + if (failed(parser.parseColon())) + return failure(); + if (failed(parser.parseLParen())) + return failure(); + if (failed(parser.parseTypeList(operandTypes))) + return failure(); + if (failed(parser.parseRParen())) + return failure(); + if (operandTypes.size() != 3) + return parser.emitError( + parser.getNameLoc(), + "expected one type for each operand segment but got " + + Twine(operandTypes.size()) + " types"); + for (const auto &iter : llvm::enumerate(operandTypes)) { + auto &frag = frags[iter.index()]; + frag.regTypes.resize(frag.regs.size(), iter.value()); + if (failed(parser.resolveOperands(frag.regs, frag.regTypes, + parser.getNameLoc(), result.operands))) + return failure(); + frag.elemtype = + MmaOp::inferOperandMMAType(frag.regTypes[0], + /*isAccumulator*/ iter.index() >= 2); + } + + Type resultType; + if (parser.parseArrow() || parser.parseType(resultType)) + return failure(); + frags[5].elemtype = + MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true); + + // Resolve sparse metadata and selector (assume i32 type) + Type i32Type = builder.getIntegerType(32); + if (parser + .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(), + result.operands) + .failed()) + return failure(); + if (parser + .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(), + result.operands) + .failed()) + return failure(); + + std::array<StringRef, 2> names{"multiplicandAPtxType", + "multiplicandBPtxType"}; + for (unsigned idx = 0; idx < names.size(); idx++) { + const auto &frag = frags[idx]; + std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]); + if (!frag.elemtype.has_value() && !attr.has_value()) { + return parser.emitError( + parser.getNameLoc(), + "attribute " + names[idx] + + " is not provided explicitly and cannot be inferred"); + } + if (!attr.has_value()) + result.addAttribute( + names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); + } + + result.addTypes(resultType); + if (!namedAttributes.empty()) + result.addAttributes(namedAttributes); + result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast<int32_t>(frags[0].regs.size()), + static_cast<int32_t>(frags[1].regs.size()), + static_cast<int32_t>(frags[2].regs.size()), + 1, // sparseMetadata + 1 // sparsitySelector + })); + return success(); +} + +LogicalResult MmaSpOp::verify() { + MLIRContext *context = getContext(); + auto f16Ty = Float16Type::get(context); + auto i32Ty = IntegerType::get(context, 32); + auto f16x2Ty = VectorType::get(2, f16Ty); + auto f32Ty = Float32Type::get(context); + auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + + auto s32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); + auto f32x8StructTy = + LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty)); + auto f16x2x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); + auto f32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); + auto s32x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); + + std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), + getShapeAttr().getK()}; + + // These variables define the set of allowed data types for matrices A, B, C, + // and result. + using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>; + using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>; + AllowedShapes allowedShapes; + AllowedTypes expectedA; + AllowedTypes expectedB; + AllowedTypes expectedC; + SmallVector<Type> expectedResult; + + // When M = 16, we just need to calculate the number of 8xk tiles, where + // k is a factor that depends on the data type. + if (mmaShape[0] == 16) { + int64_t kFactor; + Type multiplicandFragType; + switch (*getMultiplicandAPtxType()) { + case MMATypes::tf32: + kFactor = 4; + multiplicandFragType = i32Ty; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty})); + // Sparse MMA supports m16n8k8 and m16n8k16 for tf32 + allowedShapes.push_back({16, 8, 8}); + allowedShapes.push_back({16, 8, 16}); + break; + case MMATypes::bf16: + kFactor = 8; + multiplicandFragType = i32Ty; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty})); + // Sparse MMA supports m16n8k16 and m16n8k32 for bf16 + allowedShapes.push_back({16, 8, 16}); + allowedShapes.push_back({16, 8, 32}); + break; + case MMATypes::f16: + kFactor = 8; + multiplicandFragType = f16x2Ty; + expectedResult.push_back(f16x2x2StructTy); + expectedResult.push_back(f32x4StructTy); + // Sparse MMA supports m16n8k16 and m16n8k32 for f16 + allowedShapes.push_back({16, 8, 16}); + allowedShapes.push_back({16, 8, 32}); + break; + case MMATypes::s4: + case MMATypes::u4: + kFactor = 32; + // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4 + allowedShapes.push_back({16, 8, 64}); + allowedShapes.push_back({16, 8, 128}); + break; + case MMATypes::s8: + case MMATypes::u8: + kFactor = 16; + // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8 + allowedShapes.push_back({16, 8, 32}); + allowedShapes.push_back({16, 8, 64}); + break; + case MMATypes::e4m3: + case MMATypes::e5m2: + case MMATypes::e3m2: + case MMATypes::e2m3: + case MMATypes::e2m1: + kFactor = 32; + multiplicandFragType = i32Ty; + expectedResult.push_back(f16x2x2StructTy); + expectedResult.push_back(f32x4StructTy); + // Sparse MMA supports m16n8k64 for FP8 types + allowedShapes.push_back({16, 8, 64}); + break; + default: + return emitError("invalid shape or multiplicand type: " + + stringifyEnum(getMultiplicandAPtxType().value())); + } + + if (isIntegerPtxType(getMultiplicandAPtxType().value())) { + expectedResult.push_back(s32x4StructTy); + expectedC.emplace_back(4, i32Ty); + multiplicandFragType = i32Ty; + } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 && + *getMultiplicandAPtxType() <= MMATypes::e2m1) { + // FP8 types + expectedC.emplace_back(2, f16x2Ty); + expectedC.emplace_back(4, f32Ty); + } else { + expectedC.emplace_back(2, f16x2Ty); + expectedC.emplace_back(4, f32Ty); + } + + // For sparse MMA, A operand is compressed (2:4 sparsity means half the + // elements) + int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2; + int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); + expectedA.emplace_back(unitA, multiplicandFragType); + expectedB.emplace_back(unitB, multiplicandFragType); + + if (resultPtxType() != accumPtxType()) + return emitOpError("ctype does not match dtype"); + } + + // In the M=8 case, there is only 1 possible case per data type. + if (mmaShape[0] == 8) { + if (*getMultiplicandAPtxType() == MMATypes::f16) { + expectedA.emplace_back(2, f16x2Ty); + expectedB.emplace_back(2, f16x2Ty); + expectedResult.push_back(f16x2x4StructTy); + expectedResult.push_back(f32x8StructTy); + expectedC.emplace_back(4, f16x2Ty); + expectedC.emplace_back(8, f32Ty); + allowedShapes.push_back({8, 8, 4}); + } + if (*getMultiplicandAPtxType() == MMATypes::f64) { + Type f64Ty = Float64Type::get(context); + expectedA.emplace_back(1, f64Ty); + expectedB.emplace_back(1, f64Ty); + expectedC.emplace_back(2, f64Ty); + expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( + context, SmallVector<Type>(2, f64Ty))); + allowedShapes.push_back({8, 8, 4}); + } + if (isIntegerPtxType(getMultiplicandAPtxType().value())) { + expectedA.push_back({i32Ty}); + expectedB.push_back({i32Ty}); + expectedC.push_back({i32Ty, i32Ty}); + expectedResult.push_back(s32x2StructTy); + if (isInt4PtxType(getMultiplicandAPtxType().value())) + allowedShapes.push_back({8, 8, 32}); + if (isInt8PtxType(getMultiplicandAPtxType().value())) + allowedShapes.push_back({8, 8, 16}); + } + } + + std::string errorMessage; + llvm::raw_string_ostream errorStream(errorMessage); + + // Check that we matched an existing shape/dtype combination. + if (expectedA.empty() || expectedB.empty() || expectedC.empty() || + !llvm::is_contained(allowedShapes, mmaShape)) { + errorStream << "unimplemented variant for MMA shape <"; + llvm::interleaveComma(mmaShape, errorStream); + errorStream << ">"; + return emitOpError(errorMessage); + } + + // Verify the operand types for segments of A, B, and C operands. + std::array<StringRef, 3> operandNames{"A", "B", "C"}; + for (const auto &iter : llvm::enumerate( + SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) { + auto spec = this->getODSOperandIndexAndLength(iter.index()); + SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first, + operand_type_begin() + spec.first + + spec.second); + bool match = llvm::is_contained(iter.value(), operandTySeg); + + if (!match) { + errorStream << "Could not match types for the " + << operandNames[iter.index()] + << " operands; expected one of "; + for (const auto &x : iter.value()) { + errorStream << x.size() << "x" << x[0] << " "; + } + errorStream << "but got "; + llvm::interleaveComma(operandTySeg, errorStream); + return emitOpError(errorMessage); + } + } + + // Check the result type + if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { + return expectedResultType == getResult().getType(); + })) { + errorStream + << "Could not match allowed types for the result; expected one of "; + llvm::interleaveComma(expectedResult, errorStream); + errorStream << " but got " << getResult().getType(); + return emitOpError(errorMessage); + } + + // Ensure int4/int8 MMA variants specify the accum overflow behavior + // attribute. + if (isInt4PtxType(*getMultiplicandAPtxType()) || + isInt8PtxType(*getMultiplicandAPtxType())) { + if (!getIntOverflowBehavior()) + return emitOpError("op requires " + + getIntOverflowBehaviorAttrName().strref() + + " attribute"); + } + + // Validate sparse metadata type (should be i32) + if (!getSparseMetadata().getType().isInteger(32)) { + return emitOpError() << "sparse metadata must be i32 type"; + } + + // Validate sparsity selector type (should be i32) + if (!getSparsitySelector().getType().isInteger(32)) { + return emitOpError() << "sparsity selector must be i32 type"; + } + + return success(); +} + LogicalResult ShflOp::verify() { auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); @@ -1454,6 +2098,13 @@ bool NVVM::WgmmaMmaAsyncOp::getAsmValues( return true; // Has manual mapping } +LogicalResult NVVM::FenceSyncRestrictOp::verify() { + if (getOrder() != NVVM::MemOrderKind::ACQUIRE && + getOrder() != NVVM::MemOrderKind::RELEASE) + return emitOpError("only acquire and release semantics are supported"); + return success(); +} + LogicalResult NVVM::FenceProxyOp::verify() { if (getKind() == NVVM::ProxyKind::TENSORMAP) return emitOpError() << "tensormap proxy is not a supported proxy kind"; @@ -1476,7 +2127,6 @@ LogicalResult NVVM::FenceProxyAcquireOp::verify() { if (getToProxy() != NVVM::ProxyKind::TENSORMAP) return emitOpError("uni-directional proxies only support tensormap " "for to_proxy attribute"); - return success(); } @@ -1488,7 +2138,19 @@ LogicalResult NVVM::FenceProxyReleaseOp::verify() { if (getToProxy() != NVVM::ProxyKind::TENSORMAP) return emitOpError("uni-directional proxies only support tensormap " "for to_proxy attribute"); + return success(); +} + +LogicalResult NVVM::FenceProxySyncRestrictOp::verify() { + if (getOrder() != NVVM::MemOrderKind::ACQUIRE && + getOrder() != NVVM::MemOrderKind::RELEASE) + return emitOpError("only acquire and release semantics are supported"); + + if (getFromProxy() != NVVM::ProxyKind::GENERIC) + return emitOpError("only generic is support for from_proxy attribute"); + if (getToProxy() != NVVM::ProxyKind::async) + return emitOpError("only async is supported for to_proxy attribute"); return success(); } @@ -1504,6 +2166,15 @@ LogicalResult NVVM::BarrierOp::verify() { if (getNumberOfThreads() && !getBarrierId()) return emitOpError( "barrier id is missing, it should be set between 0 to 15"); + + if (getBarrierId() && (getReductionOp() || getReductionPredicate())) + return emitOpError("reduction are only available when id is 0"); + + if ((getReductionOp() && !getReductionPredicate()) || + (!getReductionOp() && getReductionPredicate())) + return emitOpError("reduction predicate and reduction operation must be " + "specified together"); + return success(); } @@ -1741,24 +2412,68 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op, //===----------------------------------------------------------------------===// std::string NVVM::MBarrierInitOp::getPtx() { - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace(); - return (addressSpace == NVVMMemorySpace::Shared) - ? std::string("mbarrier.init.shared.b64 [%0], %1;") - : std::string("mbarrier.init.b64 [%0], %1;"); + bool isShared = isPtrInSharedCTASpace(getAddr()); + return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;") + : std::string("mbarrier.init.b64 [%0], %1;"); +} + +std::string NVVM::MBarrierArriveExpectTxOp::getPtx() { + bool isShared = isPtrInSharedCTASpace(getAddr()); + return isShared + ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;") + : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); +} + +std::string NVVM::MBarrierTryWaitParityOp::getPtx() { + bool isShared = isPtrInSharedCTASpace(getAddr()); + llvm::StringRef space = isShared ? ".shared" : ""; + + return llvm::formatv("{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}", + space); } //===----------------------------------------------------------------------===// // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// -static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) { - auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType()); - return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS); -} +mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::BarrierOp>(op); + llvm::Value *barrierId = thisOp.getBarrierId() + ? mt.lookupValue(thisOp.getBarrierId()) + : builder.getInt32(0); + llvm::Intrinsic::ID id; + llvm::SmallVector<llvm::Value *> args; + if (thisOp.getNumberOfThreads()) { + id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count; + args.push_back(barrierId); + args.push_back(mt.lookupValue(thisOp.getNumberOfThreads())); + } else if (thisOp.getReductionOp()) { + switch (*thisOp.getReductionOp()) { + case NVVM::BarrierReduction::AND: + id = llvm::Intrinsic::nvvm_barrier0_and; + break; + case NVVM::BarrierReduction::OR: + id = llvm::Intrinsic::nvvm_barrier0_or; + break; + case NVVM::BarrierReduction::POPC: + id = llvm::Intrinsic::nvvm_barrier0_popc; + break; + } + args.push_back(mt.lookupValue(thisOp.getReductionPredicate())); + } else { + id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all; + args.push_back(barrierId); + } -static bool isPtrInSharedCTASpace(mlir::Value ptr) { - return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared); + return {id, std::move(args)}; } mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( @@ -1787,15 +2502,213 @@ mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs( return {id, {mt.lookupValue(thisOp.getAddr())}}; } +mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster}; + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getTxcount())); + + return {IDs[index], std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster}; + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getTxcount())); + + return {IDs[index], std::move(args)}; +} + mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::MBarrierArriveOp>(op); - bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); - llvm::Intrinsic::ID id = isShared - ? llvm::Intrinsic::nvvm_mbarrier_arrive_shared - : llvm::Intrinsic::nvvm_mbarrier_arrive; - return {id, {mt.lookupValue(thisOp.getAddr())}}; + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster}; + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // When count is not explicitly specified, the default is 1. + llvm::LLVMContext &ctx = mt.getLLVMContext(); + bool hasCount = static_cast<bool>(thisOp.getCount()); + llvm::Value *count = + hasCount ? mt.lookupValue(thisOp.getCount()) + : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1); + + return {id, {mbar, count}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster}; + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // When count is not explicitly specified, the default is 1. + llvm::LLVMContext &ctx = mt.getLLVMContext(); + bool hasCount = static_cast<bool>(thisOp.getCount()); + llvm::Value *count = + hasCount ? mt.lookupValue(thisOp.getCount()) + : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1); + + return {id, {mbar, count}}; +} + +bool MBarrierArriveExpectTxOp::getAsmValues( + RewriterBase &rewriter, + llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> + &asmValues) { + // Add all the operands but not the attrs to the asmValues list. + // The attrs here are used to generate the right variants for + // intrinsics-lowering. So, we ignore them while generating inline-PTX. + for (auto val : getOperands()) + asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read}); + + return false; +} + +mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, txcount}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, txcount}}; } mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs( @@ -1813,17 +2726,100 @@ mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs( return {id, std::move(args)}; } -mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs( +mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { - auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op); + auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op); bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); - llvm::Intrinsic::ID id = isShared - ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared - : llvm::Intrinsic::nvvm_mbarrier_test_wait; + llvm::Intrinsic::ID id = + isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete; // Fill the Intrinsic Args llvm::SmallVector<llvm::Value *> args; args.push_back(mt.lookupValue(thisOp.getAddr())); - args.push_back(mt.lookupValue(thisOp.getState())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op); + bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: isPhaseParity + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, input}}; +} + +mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op); + bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + bool hasTicks = static_cast<bool>(thisOp.getTicks()); + // bit-0: isPhaseParity + // bit-1: Scope + // bit-2: hasTicks + size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) | + (isPhaseParity ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the mbarrier pointer + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mbar); + args.push_back(mt.lookupValue(thisOp.getStateOrPhase())); + if (hasTicks) + args.push_back(mt.lookupValue(thisOp.getTicks())); return {id, std::move(args)}; } @@ -1914,11 +2910,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( args.push_back(mt.lookupValue(thisOp.getSrcMem())); args.push_back(mt.lookupValue(thisOp.getSize())); - // Multicast mask, if available. + // Multicast mask for shared::cluster only, if available. mlir::Value multicastMask = thisOp.getMulticastMask(); const bool hasMulticastMask = static_cast<bool>(multicastMask); - llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); - args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused); + const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem()); + if (!isSharedCTA) { + llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); + args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) + : i16Unused); + } // Cache hint, if available. mlir::Value cacheHint = thisOp.getL2CacheHint(); @@ -1927,11 +2927,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); // Flag arguments for multicast and cachehint. - args.push_back(builder.getInt1(hasMulticastMask)); + if (!isSharedCTA) + args.push_back(builder.getInt1(hasMulticastMask)); args.push_back(builder.getInt1(hasCacheHint)); llvm::Intrinsic::ID id = - llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; + isSharedCTA + ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta + : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; return {id, std::move(args)}; } @@ -2646,30 +3649,100 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \ }() -llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() { - bool hasRelu = getRelu(); - bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); +NVVM::IDArgPair +ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + static constexpr llvm::Intrinsic::ID rndRNIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rn, + llvm::Intrinsic::nvvm_ff2f16x2_rn_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRZIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rz, + llvm::Intrinsic::nvvm_ff2f16x2_rz_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRSIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rs, + llvm::Intrinsic::nvvm_ff2f16x2_rs_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite, + }; + + unsigned hasRelu = op.getRelu() ? 1 : 0; + unsigned hasSatFinite = + (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; + // idx: bit-0 - relu + // bit-1 - satfinite + unsigned idx = (hasSatFinite << 1) | hasRelu; - if (hasRelu && hasSatFinite) - return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite; - if (hasRelu) - return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu; - if (hasSatFinite) - return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite; - return llvm::Intrinsic::nvvm_ff2f16x2_rs; + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getSrcHi())); + args.push_back(mt.lookupValue(op.getSrcLo())); + if (op.getRandomBits()) + args.push_back(mt.lookupValue(op.getRandomBits())); + + switch (op.getRnd()) { + case FPRoundingMode::RN: + return {rndRNIds[idx], std::move(args)}; + case FPRoundingMode::RZ: + return {rndRZIds[idx], std::move(args)}; + case FPRoundingMode::RS: + return {rndRSIds[idx], std::move(args)}; + default: + llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op"); + } } -llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() { - bool hasRelu = getRelu(); - bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); +NVVM::IDArgPair +ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + static constexpr llvm::Intrinsic::ID rndRNIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rn, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRZIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rz, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRSIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rs, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite, + }; - if (hasRelu && hasSatFinite) - return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite; - if (hasRelu) - return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu; - if (hasSatFinite) - return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite; - return llvm::Intrinsic::nvvm_ff2bf16x2_rs; + unsigned hasRelu = op.getRelu() ? 1 : 0; + unsigned hasSatFinite = + (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; + // idx: bit-0 - relu + // bit-1 - satfinite + unsigned idx = (hasSatFinite << 1) | hasRelu; + + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getSrcHi())); + args.push_back(mt.lookupValue(op.getSrcLo())); + if (op.getRandomBits()) + args.push_back(mt.lookupValue(op.getRandomBits())); + + switch (op.getRnd()) { + case FPRoundingMode::RN: + return {rndRNIds[idx], std::move(args)}; + case FPRoundingMode::RZ: + return {rndRZIds[idx], std::move(args)}; + case FPRoundingMode::RS: + return {rndRSIds[idx], std::move(args)}; + default: + llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op"); + } } llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() { @@ -3010,6 +4083,630 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs( return {intrinsicID, args}; } +mlir::NVVM::IDArgPair +PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::PermuteOp>(op); + NVVM::PermuteMode mode = thisOp.getMode(); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e, + llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8, + llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr, + llvm::Intrinsic::nvvm_prmt_rc16}; + + unsigned modeIndex = static_cast<unsigned>(mode); + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getLo())); + + // Only first 3 modes (Default, f4e, b4e) need the hi operand. + if (modeIndex < 3) + args.push_back(mt.lookupValue(thisOp.getHi())); + + args.push_back(mt.lookupValue(thisOp.getSelector())); + + return {IDs[modeIndex], args}; +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair +Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + const bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + + using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>; + using CtaGroupArray = std::array<EnableAShiftArray, 2>; + using IsATensorArray = std::array<CtaGroupArray, 2>; + using HasScaleInputDArray = std::array<IsATensorArray, 2>; + using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>; + + // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift] + static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = { + { // without diable output lane + {{// without scale input D + {{ + // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift, + }}}, + }}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift, + }}}}}}}, + // with disable output lane + {{ // without scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2, + notIntrinsic}}}, + {{// cg1 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift, + }, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift, + }}}}}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2, + notIntrinsic}}}, + // tensor + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift}, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift, + }}}}}}}}}; + + llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD()); + bool hasScaleInputD = ScaleInputD != nullptr; + + llvm::Value *DisableOutputLane = + mt.lookupValue(thisOp.getDisableOutputLane()); + bool hasDisableOutputLane = DisableOutputLane != nullptr; + + const unsigned ctaGroup = + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())); + + llvm::Intrinsic::ID ID = + tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor] + [ctaGroup - 1][thisOp.getAShift()]; + + assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp."); + + if (hasScaleInputD) + args.push_back(ScaleInputD); + + if (hasDisableOutputLane) + args.push_back(DisableOutputLane); + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + + if (!hasDisableOutputLane) + args.push_back(builder.getInt32(ctaGroup)); + + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +static LogicalResult +verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, + NVVM::CTAGroupKind ctaGroup, bool hasAShift, + NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) { + + if (disableOutputLane) { + mlir::VectorType disableOutputLaneType = + cast<mlir::VectorType>(disableOutputLane.getType()); + if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 && + disableOutputLaneType.getNumElements() != 4) || + (ctaGroup == NVVM::CTAGroupKind::CTA_2 && + disableOutputLaneType.getNumElements() != 8)) + return emitError(loc) << "Disable Output Lane of length " + << disableOutputLaneType.getNumElements() + << " is incompatible with CtaGroupAttr"; + } + + if (hasAShift && !isATensor) + return emitError( + loc, "A-shift can be applied only when matrix A is in tensor memory"); + + if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL || + collectorOp == Tcgen05MMACollectorOp::USE)) + return emitError( + loc, "Cannot use collector buffer operation fill or use with ashift"); + + return success(); +} + +LogicalResult Tcgen05MMAOp::verify() { + return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()), + getDisableOutputLane(), getCtaGroup(), getAShift(), + getCollectorOp(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.sp functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + + using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>; + using CtaGroupArray = std::array<EnableAShiftArray, 2>; + using IsATensorArray = std::array<CtaGroupArray, 2>; + using HasScaleInputDArray = std::array<IsATensorArray, 2>; + using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>; + + // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift] + static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = { + { // without diable output lane + {{// without scale input D + {{ + // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift, + }}}, + }}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d, + notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d, + notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift, + }}}}}}}, + // with disable output lane + {{ // without scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2, + notIntrinsic}}}, + {{// cg1 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift, + }, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift, + }}}}}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2, + notIntrinsic}}}, + // tensor + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift}, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift, + }}}}}}}}}; + + llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD()); + bool hasScaleInputD = ScaleInputD != nullptr; + + llvm::Value *DisableOutputLane = + mt.lookupValue(thisOp.getDisableOutputLane()); + bool hasDisableOutputLane = DisableOutputLane != nullptr; + + unsigned ctaGroup = + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())); + + llvm::Intrinsic::ID ID = + tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor] + [ctaGroup - 1][thisOp.getAShift()]; + + assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp."); + + if (hasScaleInputD) + args.push_back(ScaleInputD); + + if (hasDisableOutputLane) + args.push_back(DisableOutputLane); + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + + if (!hasDisableOutputLane) + args.push_back(builder.getInt32(ctaGroup)); + + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +LogicalResult Tcgen05MMASparseOp::verify() { + return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()), + getDisableOutputLane(), getCtaGroup(), getAShift(), + getCollectorOp(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.block_scale functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getScaleA())); + args.push_back(mt.lookupValue(thisOp.getScaleB())); + args.push_back(builder.getInt32( + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + auto kind = thisOp.getKind(); + auto blockScale = thisOp.getBlockScale(); + llvm::Intrinsic::ID ID = [&]() { + if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor + ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale + : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32; + + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16; + } + } + llvm_unreachable("Invalid tcgen05.mma.block_scale attributes"); + }(); + + return {ID, args}; +} + +static LogicalResult +verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, + NVVM::Tcgen05MMABlockScaleKind kind, + NVVM::Tcgen05MMABlockScale blockScale, + Location loc) { + + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT && + kind == Tcgen05MMABlockScaleKind::MXF4NVF4) + return emitError(loc, "mxf4nvf4 requires block scale attribute"); + + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 && + kind != Tcgen05MMABlockScaleKind::MXF4NVF4) + return emitError(loc, + llvm::formatv("{} kind does not support block16 attribute", + stringifyEnum(kind))); + + return success(); +} + +LogicalResult Tcgen05MMABlockScaleOp::verify() { + return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(), + getBlockScale(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.sp.block_scale functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + args.push_back(mt.lookupValue(thisOp.getScaleA())); + args.push_back(mt.lookupValue(thisOp.getScaleB())); + args.push_back(builder.getInt32( + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + auto kind = thisOp.getKind(); + auto blockScale = thisOp.getBlockScale(); + llvm::Intrinsic::ID ID = [&]() { + if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32; + + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16; + } + } + llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes"); + }(); + + return {ID, args}; +} + +LogicalResult Tcgen05MMASparseBlockScaleOp::verify() { + return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(), + getBlockScale(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.ws functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + + mlir::Value ZeroColMask = thisOp.getZeroColMask(); + llvm::Intrinsic::ID ID = notIntrinsic; + if (ZeroColMask) { + args.push_back(mt.lookupValue(ZeroColMask)); + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask; + } else + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared; + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.ws.sp functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + + mlir::Value ZeroColMask = thisOp.getZeroColMask(); + llvm::Intrinsic::ID ID = notIntrinsic; + if (ZeroColMask) { + args.push_back(mt.lookupValue(ZeroColMask)); + ID = isATensor + ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask; + } else + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared; + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// @@ -3213,16 +4910,20 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) { "Minimum NVVM target SM version is sm_20"); } - gpuModuleOp->walk([&](Operation *op) { - if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) { - const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion(); - if (!requirement.isCompatibleWith(targetSMVersion)) { - op->emitOpError() << "is not supported on " << getChip(); - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); + if (gpuModuleOp + ->walk([&](Operation *op) { + if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) { + const NVVMCheckSMVersion requirement = + reqOp.getRequiredMinSMVersion(); + if (!requirement.isCompatibleWith(targetSMVersion)) { + op->emitOpError() << "is not supported on " << getChip(); + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }) + .wasInterrupted()) + return failure(); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp index 67573c4..12dd225 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp @@ -109,8 +109,12 @@ static Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr, return FusedLoc::get(context, {loc}, lexicalBlockFileAttr); } +/// Adds DILexicalBlockFileAttr for operations with CallSiteLoc and operations +/// from different files than their containing function. static void setLexicalBlockFileAttr(Operation *op) { - if (auto callSiteLoc = dyn_cast<CallSiteLoc>(op->getLoc())) { + Location opLoc = op->getLoc(); + + if (auto callSiteLoc = dyn_cast<CallSiteLoc>(opLoc)) { auto callerLoc = callSiteLoc.getCaller(); auto calleeLoc = callSiteLoc.getCallee(); LLVM::DIScopeAttr scopeAttr; @@ -122,6 +126,45 @@ static void setLexicalBlockFileAttr(Operation *op) { op->setLoc( CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc)); } + + return; + } + + auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); + if (!funcOp) + return; + + FileLineColLoc opFileLoc = extractFileLoc(opLoc); + if (!opFileLoc) + return; + + FileLineColLoc funcFileLoc = extractFileLoc(funcOp.getLoc()); + if (!funcFileLoc) + return; + + StringRef opFile = opFileLoc.getFilename().getValue(); + StringRef funcFile = funcFileLoc.getFilename().getValue(); + + // Handle cross-file operations: add DILexicalBlockFileAttr when the + // operation's source file differs from its containing function. + if (opFile != funcFile) { + auto funcOpLoc = llvm::dyn_cast_if_present<FusedLoc>(funcOp.getLoc()); + if (!funcOpLoc) + return; + auto scopeAttr = dyn_cast<LLVM::DISubprogramAttr>(funcOpLoc.getMetadata()); + if (!scopeAttr) + return; + + auto *context = op->getContext(); + LLVM::DIFileAttr opFileAttr = + LLVM::DIFileAttr::get(context, llvm::sys::path::filename(opFile), + llvm::sys::path::parent_path(opFile)); + + LLVM::DILexicalBlockFileAttr lexicalBlockFileAttr = + LLVM::DILexicalBlockFileAttr::get(context, scopeAttr, opFileAttr, 0); + + Location newLoc = FusedLoc::get(context, {opLoc}, lexicalBlockFileAttr); + op->setLoc(newLoc); } } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index dcc1ef9..b4b1347 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { // FillOpInterface implementation //===----------------------------------------------------------------------===// +namespace { enum class MatchFillResult { Success = 0, NotLinalgOp, WrongNumOperands, - NotScalarInput + NotScalarInput, + TypeMismatch }; +} // namespace static MatchFillResult isFillInterfaceImpl(Operation *op) { auto linalgOp = dyn_cast<linalg::LinalgOp>(op); @@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) { if (!linalgOp.isScalar(value)) return MatchFillResult::NotScalarInput; + // Check that the scalar input type matches the output element type. + OpOperand *output = linalgOp.getDpsInitOperand(0); + Type scalarType = value->get().getType(); + Type outputElementType = getElementTypeOrSelf(output->get().getType()); + if (scalarType != outputElementType) + return MatchFillResult::TypeMismatch; + return MatchFillResult::Success; } LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { - auto res = isFillInterfaceImpl(op); + MatchFillResult res = isFillInterfaceImpl(op); if (res == MatchFillResult::NotLinalgOp) return op->emitError("expected a LinalgOp"); if (res == MatchFillResult::WrongNumOperands) return op->emitError("expected op with 1 input and 1 output"); if (res == MatchFillResult::NotScalarInput) return op->emitError("expected op with scalar input"); + if (res == MatchFillResult::TypeMismatch) { + auto linalgOp = cast<linalg::LinalgOp>(op); + Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType(); + Type outputElementType = + getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType()); + return op->emitOpError("expected fill value type (") + << scalarType << ") to match output element type (" + << outputElementType << ")"; + } return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 3dc45ed..33ec79b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1338,8 +1338,6 @@ Speculation::Speculatability GenericOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); } -LogicalResult GenericOp::verify() { return success(); } - namespace { /// Remove linalg operations that are just copying the values from inputs to @@ -2091,7 +2089,7 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor, return failure(); // Single dimension transpose. - if (getPermutation().size() == 0) { + if (getPermutation().empty()) { result.push_back(getInput()); return success(); } @@ -4885,13 +4883,6 @@ void ElementwiseOp::print(OpAsmPrinter &p) { elidedAttrs); } -LogicalResult ElementwiseOp::verify() { - // All necessary checks are done either by - // - EnumAttr (e.g. unknown operation kind) - // - verifyStructuredOpInterface (incorrect map, sizes). - return success(); -} - /// Implements the block region builder for the ElementwiseOp. This is called by /// 'fillStructuredOpRegion'. void ElementwiseOp::regionBuilder( diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index aa82063..b8c1bad 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -176,7 +176,8 @@ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults( if (auto attr = dyn_cast<Attribute>(paramOrHandle)) { reified.push_back(cast<IntegerAttr>(attr).getInt()); continue; - } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) { + } + if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) { ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle)); if (params.size() != 1) return transformOp.emitSilenceableError() << "expected a single param"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 22690da..9e6c1e6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); auto rankReducedType = cast<RankedTensorType>( tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - reassociation->size(), sliceOp.getSourceType(), offsets, sizes, - strides)); + reassociation->size(), sliceOp.getSourceType(), sizes)); Location loc = sliceOp.getLoc(); Value newSlice = tensor::ExtractSliceOp::create( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 05fc7cb..421ab5e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1038,6 +1038,62 @@ private: ControlFusionFn controlFoldingReshapes; }; +/// Carries information about a padded dimension. +struct PadDimInfo { + // The resulting shape after padding each dimension. + SmallVector<int64_t> paddedShape; + + // Low and high padding amounts for each dimension. + SmallVector<OpFoldResult> lowPad; + SmallVector<OpFoldResult> highPad; +}; + +/// Computes the expanded padding information for the given pad operation based +/// on the provided expanded shape and reassociation indices. Returns a list of +/// PadDimInfo containing the low and high padding amounts and the padded +/// size for each dimension, or failure if the expansion is not possible. +static FailureOr<PadDimInfo> +computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape, + ArrayRef<ReassociationIndices> reassociations, + PatternRewriter &rewriter) { + // If the padding value depends on the index values of the pad operation, + // then it may not be valid to expand the dimensions, since it will change + // the index values on which the padding value depends. This is not currently + // supported by the pad expansion patterns, but it could be implemented + // similarly to the expansion of linalg.generic ops with linalg.index ops in + // the body, as is done in `updateExpandedGenericOpRegion`. + if (!padOp.getConstantPaddingValue()) + return failure(); + + // Expanded dimensions cannot have padding because the resulting padding may + // not be representable by a tensor.pad op. There are some special cases where + // it is possible (like expanding unit dims), but supporting these cases is + // NYI, so disallow it for now. + ArrayRef<int64_t> low = padOp.getStaticLow(); + ArrayRef<int64_t> high = padOp.getStaticHigh(); + for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { + if (reInd.size() != 1 && (l != 0 || h != 0)) + return failure(); + } + + SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad()); + SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad()); + ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape(); + PadDimInfo padDimInfo; + padDimInfo.paddedShape.assign(expandedShape); + padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0)); + padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0)); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() == 1) { + padDimInfo.paddedShape[reInd[0]] = paddedShape[idx]; + padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx]; + padDimInfo.highPad[reInd[0]] = mixedHighPad[idx]; + } + } + + return padDimInfo; +} + class FoldPadWithProducerReshapeOpByExpansion : public OpRewritePattern<tensor::PadOp> { public: @@ -1053,46 +1109,96 @@ public: padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); if (!reshapeOp) return failure(); - if (!reshapeOp->hasOneUse()) - return failure(); if (!controlFoldingReshapes(&padOp.getSourceMutable())) { return rewriter.notifyMatchFailure(padOp, "fusion blocked by control function"); } - ArrayRef<int64_t> low = padOp.getStaticLow(); - ArrayRef<int64_t> high = padOp.getStaticHigh(); + RankedTensorType expandedType = reshapeOp.getSrcType(); SmallVector<ReassociationIndices> reassociations = reshapeOp.getReassociationIndices(); + FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding( + padOp, expandedType.getShape(), reassociations, rewriter); + if (failed(maybeExpandedPadding)) + return failure(); + PadDimInfo &expandedPadding = maybeExpandedPadding.value(); - for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { - if (reInd.size() != 1 && (l != 0 || h != 0)) - return failure(); + Location loc = padOp->getLoc(); + RankedTensorType expandedPaddedType = + padOp.getResultType().clone(expandedPadding.paddedShape); + + auto newPadOp = tensor::PadOp::create( + rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), + expandedPadding.lowPad, expandedPadding.highPad, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( + padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + +class FoldReshapeWithProducerPadOpByExpansion + : public OpRewritePattern<tensor::ExpandShapeOp> { +public: + FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>(); + if (!padOp) + return failure(); + + if (!controlFoldingReshapes(&expandOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(expandOp, + "fusion blocked by control function"); } - SmallVector<OpFoldResult> newLow, newHigh; - RankedTensorType expandedType = reshapeOp.getSrcType(); - RankedTensorType paddedType = padOp.getResultType(); - SmallVector<int64_t> expandedPaddedShape(expandedType.getShape()); + RankedTensorType expandedType = expandOp.getResultType(); + SmallVector<ReassociationIndices> reassociations = + expandOp.getReassociationIndices(); + FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding( + padOp, expandedType.getShape(), reassociations, rewriter); + if (failed(maybeExpandedPadding)) + return failure(); + PadDimInfo &expandedPadding = maybeExpandedPadding.value(); + + Location loc = expandOp->getLoc(); + SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape(); + SmallVector<int64_t> newExpandedShape(expandedType.getShape()); + rewriter.setInsertionPointAfterValue(padOp.getSource()); + SmallVector<OpFoldResult> padSrcSizes = + tensor::getMixedSizes(rewriter, loc, padOp.getSource()); for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + // We know that any reassociation with multiple dims is not padded because + // of the requirements of computeExpandedPadding. if (reInd.size() == 1) { - expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx]; - } - for (size_t i = 0; i < reInd.size(); ++i) { - newLow.push_back(padOp.getMixedLowPad()[idx]); - newHigh.push_back(padOp.getMixedHighPad()[idx]); + newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx); + newExpandedSizes[reInd[0]] = padSrcSizes[idx]; } } - - Location loc = padOp->getLoc(); - RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape); + RankedTensorType newExpandedType = expandedType.clone(newExpandedShape); + auto newExpandOp = tensor::ExpandShapeOp::create( + rewriter, loc, newExpandedType, padOp.getSource(), reassociations, + newExpandedSizes); + RankedTensorType expandedPaddedType = + padOp.getResultType().clone(expandedPadding.paddedShape); + rewriter.setInsertionPoint(expandOp); auto newPadOp = tensor::PadOp::create( - rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + rewriter, loc, expandedPaddedType, newExpandOp.getResult(), + expandedPadding.lowPad, expandedPadding.highPad, padOp.getConstantPaddingValue(), padOp.getNofold()); - rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( - padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + rewriter.replaceOp(expandOp, newPadOp.getResult()); return success(); } @@ -1921,6 +2027,62 @@ private: ControlFusionFn controlFoldingReshapes; }; +/// Computes the collapsed padding information for the given pad operation based +/// on the provided collapsed shape and reassociation indices. Returns a +/// PadDimInfo containing the low and high padding amounts and the collapsed +/// shape for each dimension, or failure if the collapse is not possible. +static FailureOr<PadDimInfo> +computeCollapsedPadding(tensor::PadOp padOp, + ArrayRef<ReassociationIndices> reassociations, + PatternRewriter &rewriter) { + // If the padding value depends on the index values of the pad operation, + // then it may not be valid to collapse the dimensions, since it will change + // the index values on which the padding value depends. This is not currently + // supported by the pad collapsing patterns, but it could be implemented + // similarly to the collapsing of linalg.generic ops with linalg.index ops in + // the body, as is done in `generateCollapsedIndexingRegion`. + if (!padOp.getConstantPaddingValue()) + return failure(); + + // Collapsed dimensions cannot have padding because this can produce strided + // padding that isn't representable by a tensor.pad op. There are some special + // cases where it is possible (like collapsing unit dims), but supporting + // these cases is NYI, so disallow it for now. + ArrayRef<int64_t> low = padOp.getStaticLow(); + ArrayRef<int64_t> high = padOp.getStaticHigh(); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + for (int64_t dim : reInd) { + if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1) + return failure(); + } + } + + // Initialize padding values for collapsed tensors with zeros + ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape(); + PadDimInfo padDimInfo; + padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0)); + padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0)); + + // Update padding for dimensions that are not being collapsed, and compute + // the collapsed padded shape. + SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad()); + SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad()); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() == 1) { + padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]]; + padDimInfo.highPad[idx] = mixedHighPad[reInd[0]]; + } + SaturatedInteger collapsedSize = SaturatedInteger::wrap(1); + for (int64_t dim : reInd) { + collapsedSize = + collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]); + } + padDimInfo.paddedShape.push_back(collapsedSize.asInteger()); + } + + return padDimInfo; +} + class FoldPadWithProducerReshapeOpByCollapsing : public OpRewritePattern<tensor::PadOp> { public: @@ -1936,57 +2098,40 @@ public: padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); if (!reshapeOp) return failure(); - if (!reshapeOp->hasOneUse()) - return failure(); if (!controlFoldingReshapes(&padOp.getSourceMutable())) { return rewriter.notifyMatchFailure(padOp, "fusion blocked by control function"); } - ArrayRef<int64_t> low = padOp.getStaticLow(); - ArrayRef<int64_t> high = padOp.getStaticHigh(); SmallVector<ReassociationIndices> reassociations = reshapeOp.getReassociationIndices(); + FailureOr<PadDimInfo> maybeCollapsedPadding = + computeCollapsedPadding(padOp, reassociations, rewriter); + if (failed(maybeCollapsedPadding)) + return failure(); + PadDimInfo &collapsedPadding = maybeCollapsedPadding.value(); - for (auto reInd : reassociations) { - if (reInd.size() == 1) - continue; - if (llvm::any_of(reInd, [&](int64_t ind) { - return low[ind] != 0 || high[ind] != 0; - })) { - return failure(); - } - } - - SmallVector<OpFoldResult> newLow, newHigh; - RankedTensorType collapsedType = reshapeOp.getSrcType(); - RankedTensorType paddedType = padOp.getResultType(); - SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape()); - SmallVector<OpFoldResult> expandedPaddedSizes( - getMixedValues(reshapeOp.getStaticOutputShape(), - reshapeOp.getOutputShape(), rewriter)); + SmallVector<OpFoldResult> expandedPaddedSizes = + reshapeOp.getMixedOutputShape(); AffineExpr d0, d1, d2; bindDims(rewriter.getContext(), d0, d1, d2); auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2}); Location loc = reshapeOp->getLoc(); - for (auto [idx, reInd] : llvm::enumerate(reassociations)) { - OpFoldResult l = padOp.getMixedLowPad()[reInd[0]]; - OpFoldResult h = padOp.getMixedHighPad()[reInd[0]]; + for (auto [reInd, l, h] : + llvm::zip_equal(reassociations, collapsedPadding.lowPad, + collapsedPadding.highPad)) { if (reInd.size() == 1) { - collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]]; - OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply( + expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply( rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]}); - expandedPaddedSizes[reInd[0]] = paddedSize; } - newLow.push_back(l); - newHigh.push_back(h); } RankedTensorType collapsedPaddedType = - paddedType.clone(collapsedPaddedShape); + padOp.getType().clone(collapsedPadding.paddedShape); auto newPadOp = tensor::PadOp::create( - rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), + collapsedPadding.lowPad, collapsedPadding.highPad, padOp.getConstantPaddingValue(), padOp.getNofold()); rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( @@ -2000,6 +2145,52 @@ private: ControlFusionFn controlFoldingReshapes; }; +class FoldReshapeWithProducerPadOpByCollapsing + : public OpRewritePattern<tensor::CollapseShapeOp> { +public: + FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>(); + if (!padOp) + return failure(); + + if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(padOp, + "fusion blocked by control function"); + } + + SmallVector<ReassociationIndices> reassociations = + reshapeOp.getReassociationIndices(); + RankedTensorType collapsedPaddedType = reshapeOp.getResultType(); + FailureOr<PadDimInfo> maybeCollapsedPadding = + computeCollapsedPadding(padOp, reassociations, rewriter); + if (failed(maybeCollapsedPadding)) + return failure(); + PadDimInfo &collapsedPadding = maybeCollapsedPadding.value(); + + Location loc = reshapeOp->getLoc(); + auto newCollapseOp = tensor::CollapseShapeOp::create( + rewriter, loc, padOp.getSource(), reassociations); + + auto newPadOp = tensor::PadOp::create( + rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(), + collapsedPadding.lowPad, collapsedPadding.highPad, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOp(reshapeOp, newPadOp.getResult()); + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + /// Pattern to collapse dimensions. template <typename LinalgType> class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> { @@ -2239,6 +2430,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( controlFoldingReshapes); patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(), controlFoldingReshapes); + patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(), + controlFoldingReshapes); patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), controlFoldingReshapes); } @@ -2250,6 +2443,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( controlFoldingReshapes); patterns.add<FoldPadWithProducerReshapeOpByCollapsing>( patterns.getContext(), controlFoldingReshapes); + patterns.add<FoldReshapeWithProducerPadOpByCollapsing>( + patterns.getContext(), controlFoldingReshapes); patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(), controlFoldingReshapes); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp index 9974ccd..cbd6357 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -200,10 +200,10 @@ static void populateOpPayload( SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands(); updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos); - SmallVector<OpOperand *> origOutputOperands = llvm::to_vector(llvm::map_range( - genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); - SmallVector<OpOperand *> newOutputOperands = llvm::to_vector(llvm::map_range( - newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); + SmallVector<OpOperand *> origOutputOperands = + llvm::to_vector(llvm::make_pointer_range(genericOp.getDpsInitsMutable())); + SmallVector<OpOperand *> newOutputOperands = + llvm::to_vector(llvm::make_pointer_range(newOp.getDpsInitsMutable())); updateReplacements(origOutputOperands, newOutputOperands, origOutsToNewOutsPos); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 9436f1c..161d978 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -913,8 +913,7 @@ static Value replaceByPackingResult(RewriterBase &rewriter, llvm_unreachable("loop independence prerequisite not met"); // offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0]. - std::copy(loopIterationCounts.begin(), loopIterationCounts.end(), - offsets.begin()); + llvm::copy(loopIterationCounts, offsets.begin()); hoistedPackedTensor = scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front()) ->getResult(0); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 40fc0d6..c2485a0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,6 +237,69 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp); } +/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy` +/// with `dilations` and `strides`. +template <typename ConvOpTy> +static FailureOr<LinalgOp> +specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, + ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) { + SmallVector<Value> inputs = genericOp.getDpsInputs(); + ValueRange outputs = genericOp.getDpsInits(); + SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics() + ? TypeRange(ValueRange(outputs)) + : TypeRange{}; + LinalgOp namedOp; + // Ops with no dilations and no strides. + if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> || + std::is_same_v<ConvOpTy, linalg::Conv2DOp> || + std::is_same_v<ConvOpTy, linalg::Conv3DOp>) { + namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes, + inputs, outputs); + } else { + Attribute stridesAttr = rewriter.getI64TensorAttr(strides); + Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); + namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>( + genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + } + return namedOp; +} + +/// Converts linalg.generic to named linalg.*conv/pooling* where possible. +static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector<int64_t> dilations, strides; +#define CONV_OP_SPECIALIZER(ConvOpTy) \ + if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \ + return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \ + strides); \ + // ----------------------------- + // Convolution ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::Conv1DOp); + CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp); + CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp); + CONV_OP_SPECIALIZER(linalg::Conv2DOp); + CONV_OP_SPECIALIZER(linalg::Conv3DOp); + // ----------------------------- + // Depthwise Convolution ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp); + // ----------------------------- + // Pooling ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp); +#undef CONV_OP_SPECIALIZER + return failure(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -316,6 +379,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter, if (isaContractionOpInterface(genericOp)) { return specializeLinalgContractions(rewriter, genericOp); } + + // Convolution - e.g. *conv/pooling* + if (isaConvolutionOpInterface(genericOp)) { + return specializeLinalgConvolutions(rewriter, genericOp); + } return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 705d6f2..8e14ef4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -452,8 +452,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); - if (!shapeSizesToLoopsMap) - return failure(); + assert(shapeSizesToLoopsMap && "invalid linalgOp with null ShapesToLoopsMap"); auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges( b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 8a0440b..50a84ac 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -167,7 +167,7 @@ struct LinalgOpTilingInterface llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) { auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr); if (!dimExpr) - continue; + return failure(); unsigned position = dimExpr.getPosition(); auto it = mappedOffsets.find(position); if (it != mappedOffsets.end()) { @@ -357,6 +357,32 @@ struct LinalgOpTilingInterface /// Inline the op payload and store the result. return inlinePayload(builder, linalgOp, ivs, indexedValues); } + + bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes) const { + // The verifier gives all the necessary requirements for consumer fusion. + return true; + } + + bool isOpFusableWithProducerSlices( + Operation *op, ArrayRef<unsigned> operandNumbers, + ArrayRef<SmallVector<OpFoldResult>> allOffsets, + ArrayRef<SmallVector<OpFoldResult>> allSizes) const { + + auto linalgOp = cast<LinalgOp>(op); + SmallVector<AffineMap> indexingMaps = + llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) { + OpOperand &opOperand = linalgOp->getOpOperand(operandNumber); + return linalgOp.getMatchingIndexingMap(&opOperand); + }); + // Check that offsets/sizes are consistent across all operands. + OpBuilder b(op); + SmallVector<OpFoldResult> mappedOffsets, mappedSizes; + return succeeded(getMappedOffsetAndSize(linalgOp, b, indexingMaps, + allOffsets, allSizes, mappedOffsets, + mappedSizes)); + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 027268c..67e2b9f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1167,12 +1167,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( "this is not supported ATM!"); } - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); - Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); int64_t srcRank = packOp.getSourceRank(); - int64_t destRank = packOp.getDestRank(); // 1. Get the input that is going to be packed. If the input requires padding, // add a padding operation and return that as the input. @@ -1262,14 +1259,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( writeSizes.push_back(tileSizeOfr); } - // TODO: Add a constructor for tensor.insert_slice that doesn't require - // strides nor offsets. - SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); - SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); - auto insert = tensor::InsertSliceOp::create( - rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), - writeOffsets, writeSizes, writeStrides); + rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes); // 4. Replace tensor.packOp with tensor.insert_slice created above rewriter.replaceOp(packOp, insert.getResult()); @@ -1279,7 +1270,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const { - int64_t srcRank = unpackOp.getSourceRank(); int64_t destRank = unpackOp.getDestRank(); ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape(); ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); @@ -1296,7 +1286,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( Value source = unpackOp.getSource(); DenseMap<int64_t, OpFoldResult> dimAndTileMapping = unpackOp.getDimAndTileMapping(); - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of @@ -1307,9 +1296,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( // outer-tiled-dims being all 1), this will be // [ outer-untiled-dims, tile-sizes ] SmallVector<OpFoldResult> extractSliceSizes; - // The offset and strides attributes for ExtractSliceOp. - SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr); - SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr); // Shape for EmptyOp that's used as the init value for TransposeOp below. // This should be: @@ -1364,8 +1350,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( Type elemType = unpackOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); Value innerTile = tensor::ExtractSliceOp::create( - rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets, - extractSliceSizes, extractSliceStrides); + rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes); // 2. Transpose the tile to match the outer corresponding tile order. SmallVector<int64_t> perm = getPackUnpackRankReducedPerm( @@ -1381,9 +1366,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( // 3. Handle in-complete tiles if needed. It truncates trailing data from the // transposed tile. - int numLoops = shapeForEmptyOp.size(); - SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr); - SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr); SmallVector<OpFoldResult> tileSizes; ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape(); for (auto i : llvm::seq<unsigned>(0, destRank)) { @@ -1393,13 +1375,11 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( } auto partialTile = - tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0], - tileOffsets, tileSizes, tileStrides); + tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(), + transposedOp.getResult()[0], tileSizes); // 4. Insert the result to the destination tensor. SmallVector<OpFoldResult> writeSizes; - SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); - SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); for (int i = 0, idx = 0; i < destRank; ++i) { if (dimAndTileMapping.count(i) || destShape[i] != 1) writeSizes.push_back(tileSizes[idx++]); @@ -1407,8 +1387,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( writeSizes.push_back(oneIdxAttr); } auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile, - unpackOp.getDest(), writeOffsets, - writeSizes, writeStrides); + unpackOp.getDest(), writeSizes); rewriter.replaceOp(unpackOp, insert.getResult()); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 19d2d85..bb3bccd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -746,12 +746,12 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, auto vectorType = state.getCanonicalVecType( getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap); + SmallVector<Value> indices(linalgOp.getRank(outputOperand), + arith::ConstantIndexOp::create(rewriter, loc, 0)); + Operation *write; if (vectorType.getRank() > 0) { AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap)); - SmallVector<Value> indices( - linalgOp.getRank(outputOperand), - arith::ConstantIndexOp::create(rewriter, loc, 0)); value = broadcastIfNeeded(rewriter, value, vectorType); assert(value.getType() == vectorType && "Incorrect type"); write = vector::TransferWriteOp::create( @@ -762,7 +762,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, value = vector::BroadcastOp::create(rewriter, loc, vectorType, value); assert(value.getType() == vectorType && "Incorrect type"); write = vector::TransferWriteOp::create(rewriter, loc, value, - outputOperand->get(), ValueRange{}); + outputOperand->get(), indices); } write = state.maskOperation(rewriter, write, linalgOp, opOperandMap); @@ -1890,9 +1890,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, // Create masked TransferReadOp. auto maskedRead = vector::createReadOrMaskedRead( - rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue, - useInBoundsInsteadOfMasking, - /*inputScalableVecSizes=*/{}); + rewriter, loc, packOp.getSource(), readVecType, padValue, + useInBoundsInsteadOfMasking); // Create ShapeCastOp. auto shapeCastOp = vector::ShapeCastOp::create( @@ -1977,9 +1976,12 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, } // -- Generate the read operation -- + VectorType readVecType = + VectorType::get(readVectorSizes, unpackTensorType.getElementType(), + readScalableVectorFlags); Value readResult = vector::createReadOrMaskedRead( - rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt, - useInBoundsInsteadOfMasking, readScalableVectorFlags); + rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt, + useInBoundsInsteadOfMasking); // -- Generate the transpose operation -- PackingMetadata packMetadata; @@ -2025,9 +2027,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, .reifyResultShapes(rewriter, reifiedReturnShapes); (void)status; // prevent unused variable warning on non-assert builds assert(succeeded(status) && "failed to reify result shapes"); + auto readType = VectorType::get(inputVectorSizes, padValue.getType()); auto maskedRead = vector::createReadOrMaskedRead( - rewriter, loc, padOp.getSource(), inputVectorSizes, padValue, - /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{}); + rewriter, loc, padOp.getSource(), readType, padValue, + /*useInBoundsInsteadOfMasking=*/false); // Create Xfer write Op Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0], @@ -2222,9 +2225,9 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); Value read = mlir::vector::createReadOrMaskedRead( - rewriter, loc, opOperand.get(), readType.getShape(), + rewriter, loc, opOperand.get(), readType, /*padding=*/arith::getZeroConstant(rewriter, loc, elemType), - /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims()); + /*useInBoundsInsteadOfMasking=*/false); vecOperands.push_back(read); } @@ -3165,9 +3168,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, SmallVector<Value> readIndices( vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0)); Value read = mlir::vector::createReadOrMaskedRead( - rewriter, loc, source, vecType.getShape(), padValue, - /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(), - /*inputScalableVecSizes=*/{}); + rewriter, loc, source, vecType, padValue, + /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty()); // Create write auto writeIndices = diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 6eeb206..01e6e1e 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -235,6 +235,731 @@ bool isReductionIterator(utils::IteratorType iteratorType) { return iteratorType == utils::IteratorType::reduction; } +//===----------------------------------------------------------------------===// +// Convolution matcher utilities +//===----------------------------------------------------------------------===// + +/// Returns the BlockArgument that leads to `val`, if any. Traverses optional +/// ext* ops. +static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { + BlockArgument blockArg = dyn_cast<BlockArgument>(val); + if ((blockArg)) + return blockArg; + + Operation *defOp = val.getDefiningOp(); + if (!dyn_cast_if_present<arith::ExtFOp>(defOp) && + !dyn_cast_if_present<arith::ExtSIOp>(defOp) && + !dyn_cast_if_present<arith::ExtUIOp>(defOp)) { + return nullptr; + } + return dyn_cast<BlockArgument>(defOp->getOperand(0)); +} + +/// Utility to match block body for convolution ops. +/// The body is thus expected to yield :- +/// %out + (%lhs * %rhs) +/// where: %lhs, %rhs and %out are block arguments and +/// %lhs and %rhs can have optional upcast operation. +static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) { + Operation *addOp = yieldVal.getDefiningOp(); + if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp)) + return false; + + Operation *mulOp = addOp->getOperand(1).getDefiningOp(); + if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp)) + return false; + + BlockArgument lhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0)); + BlockArgument rhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1)); + BlockArgument outBlockArg = + getBlockArgumentWithOptionalExtOps(addOp->getOperand(0)); + if (!lhsBlockArg || !rhsBlockArg || !outBlockArg || + lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body || + outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 || + rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2) + return false; + return true; +} + +/// Utility to match block body for linalg.pool* ops. +template <typename... OpTypes> +static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { + Operation *defOp = yieldVal.getDefiningOp(); + if (!(isa_and_present<OpTypes>(defOp) || ...)) + return false; + + BlockArgument lhsArg = + getBlockArgumentWithOptionalExtOps(defOp->getOperand(0)); + BlockArgument rhsArg = + getBlockArgumentWithOptionalExtOps(defOp->getOperand(1)); + if (!lhsArg || !rhsArg || lhsArg.getOwner() != body || + rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 || + rhsArg.getArgNumber() != 0) + return false; + return true; +} + +static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, + body); +} + +// max_unsigned ops should not allow float data type. +// TODO(#164800): Retire OPDSL logic. +static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, + body); +} + +static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, + body); +} + +// min_unsigned ops should not allow float data type. +// TODO(#164800): Retire OPDSL logic. +static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, + body); +} + +static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body); +} + +static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, + uint32_t dimIndex) { + auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue(); + if (dimIndex < affineMap.getNumResults()) + return affineMap.getResult(dimIndex); + return nullptr; +} + +/// Check if `expr` is either: +/// - a dimension expr alone (implying multiplication by 1), or +/// - a multiplication of dimension expr by any positive constant != 1 +/// In both cases we will capture the dimension expression into `dim` and +/// return the constant multiplier. Returns -1 in case of a match failure. +static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim) { + if ((dim = dyn_cast<AffineDimExpr>(expr))) + return 1; + + auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return -1; + + AffineExpr lhs = mulExpr.getLHS(); + AffineExpr rhs = mulExpr.getRHS(); + + AffineConstantExpr cst = nullptr; + if (((dim = dyn_cast<AffineDimExpr>(lhs)) && + (cst = dyn_cast<AffineConstantExpr>(rhs))) || + ((dim = dyn_cast<AffineDimExpr>(rhs)) && + (cst = dyn_cast<AffineConstantExpr>(lhs)))) + return cst.getValue(); + return -1; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following +/// commutatively:- +/// indexingMaps[0].getResult(iDim) == +/// indexingMaps[1].getResult(fDim) * <c0> + +/// indexingMaps[n-1].getResult(oDim) * <c1> +/// where, +/// - c0 and c1 can be any constant, +/// - n is the size of the indexingMaps' array, +/// - 0, 1 and n-1 are input, filter and output map indices respectively, +/// - iDim, fDim and oDim are the input, filter and output dimension +/// indices in their respective indexing maps +/// Example: +/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) +/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)> +/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +/// +/// Here, +/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3 +/// Therefore, +/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride) +/// would return true and update dilation = 3 and stride = 2 +static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, + unsigned fDim, unsigned oDim, + int64_t &dilation, int64_t &stride) { + unsigned inputMapIdx = 0, filterMapIdx = 1, + outputMapIdx = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim); + auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) + return false; + + AffineExpr dim0, dim1; + int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0); + int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1); + + if (c0 == -1 || c1 == -1) + return false; + // Pattern matched with dims and constants extracted. + AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim); + if (dim0 == fExpr && dim1 == oExpr) { + dilation = c0; + stride = c1; + return true; + } + if (dim1 == fExpr && dim0 == oExpr) { + dilation = c1; + stride = c0; + return true; + } + return false; +} + +/// Returns true if the given indexing maps matches with the expected indexing +/// maps. +static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected, + ArrayAttr indexingMaps, MLIRContext *context) { + SmallVector<AffineMap, 4> expectedIndexingMaps = + AffineMap::inferFromExprList(mapListExpected, context); + return indexingMaps == + ArrayAttr::get( + context, llvm::to_vector<4>(llvm::map_range( + expectedIndexingMaps, [&](AffineMap m) -> Attribute { + return AffineMapAttr::get(m); + }))); +} + +/// Enum representing pooling operation types used by ConvMatcherBuilder. +enum class PoolingType { + None, + MaxSigned, + MaxUnsigned, + MinSigned, + MinUnsigned, + Sum +}; + +/// Helper class for building convolution op matchers with minimal boilerplate. +/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well +/// as Pooling ops. +/// +/// Usage: Create an instance with the op, spatial rank, and output pointers for +/// extracted dilations/strides. Then chain matchStride() calls for each spatial +/// dimension, followed by matchMaps() to verify indexing maps, and finally +/// matchBody() to verify the operation body pattern. +/// +/// The `matched` flag starts as `true` and is set to `false` if any match step +/// fails. This allows chaining multiple match calls; once any match fails, all +/// subsequent calls become no-ops and the final result is `false`. +/// +/// The `dilations` and `strides` pointers are output parameters that get +/// populated with the extracted dilation and stride values from the operation's +/// indexing maps during matchStride() calls. These values are initially set to +/// 1 for each spatial dimension and updated as patterns are matched. +class ConvMatcherBuilder { + LinalgOp op; + MLIRContext *ctx; + SmallVector<int64_t> *dilations, *strides; + ArrayAttr indexingMaps; + PoolingType poolingType; + bool matched = true; + +public: + ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d, + SmallVector<int64_t> *s, + PoolingType poolingType = PoolingType::None) + : op(op), ctx(op->getContext()), dilations(d), strides(s), + indexingMaps(op.getIndexingMaps()), poolingType(poolingType) { + *dilations = SmallVector<int64_t>(spatialRank, 1); + *strides = SmallVector<int64_t>(spatialRank, 1); + } + + /// Get affine dimension expression for dimension `i`. + AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); } + + /// Build strided expression: base * stride[idx] + kernel * dilation[idx]. + AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) { + return base * (*strides)[idx] + kernel * (*dilations)[idx]; + } + + /// Match stride/dilation pattern for a spatial dimension. + /// Returns *this for method chaining. + ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim, + unsigned idx) { + if (matched) { + matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim, + (*dilations)[idx], (*strides)[idx]); + } + return *this; + } + + /// Match expected indexing maps layout. Returns *this for method chaining. + ConvMatcherBuilder &matchMaps(ArrayRef<ArrayRef<AffineExpr>> maps) { + if (matched) + matched &= convLayoutMatches(maps, indexingMaps, ctx); + return *this; + } + + /// Match body pattern. This should be called last. + bool matchBody() { + if (!matched) + return false; + Block *body = op.getBlock(); + auto yieldOp = cast<linalg::YieldOp>(body->getTerminator()); + switch (poolingType) { + case PoolingType::None: + return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body); + case PoolingType::MaxSigned: + return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MaxUnsigned: + return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MinSigned: + return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MinUnsigned: + return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::Sum: + return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body); + } + return false; + } +}; + +//===----------------------------------------------------------------------===// +// Matchers for specific convolution operation. +//===----------------------------------------------------------------------===// + +// #inputMap = affine_map<(W, w) -> (W + w)> +// #filterMap = affine_map<(W, w) -> (w)> +// #outputMap = affine_map<(W, w) -> (W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op, + SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv1DOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr W = m.dim(0); + AffineExpr w = m.dim(1); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .matchMaps({/*inputMap=*/{m.strided(W, w, 0)}, + /*filterMap=*/{w}, + /*outputMap=*/{W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)> +// #filterMap = affine_map<(N, W, F, w, c) -> (w, c, F)> +// #outputMap = affine_map<(N, W, F, w, c) -> (N, W, F)> +template <> +bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv1DNwcWcfOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr F = m.dim(2); + AffineExpr w = m.dim(3); + AffineExpr c = m.dim(4); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c}, + /*filterMap=*/{w, c, F}, + /*outputMap=*/{N, W, F}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)> +// #filterMap = affine_map<(N, F, W, c, w) -> (F, c, w)> +// #outputMap = affine_map<(N, F, W, c, w) -> (N, F, W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv1DNcwFcwOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr F = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr c = m.dim(3); + AffineExpr w = m.dim(4); + + return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)}, + /*filterMap=*/{F, c, w}, + /*outputMap=*/{N, F, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)> +// #filterMap = affine_map<(H, W, h, w) -> (h, w)> +// #outputMap = affine_map<(H, W, h, w) -> (H, W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op, + SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv2DOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr H = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr h = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1) + .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{h, w}, + /*outputMap=*/{H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)> +// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)> +// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op, + SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv3DOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr D = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr d = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1) + .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2) + .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2)}, + /*filterMap=*/{d, h, w}, + /*outputMap=*/{D, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)> +// #filterMap = affine_map<(N, W, C, w) -> (C, w)> +// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv1DNcwCwOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)}, + /*filterMap=*/{C, w}, + /*outputMap=*/{N, C, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)> +// #filterMap = affine_map<(N, W, C, w) -> (w, C)> +// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv1DNwcWcOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w, C}, + /*outputMap=*/{N, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)> +// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)> +// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr CM = m.dim(3); + AffineExpr w = m.dim(4); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w, C, CM}, + /*outputMap=*/{N, W, C, CM}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv2DNchwChwOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{C, h, w}, + /*outputMap=*/{N, C, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D + d, H + h, W + w, C)> +// #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (d, h, w, C, CM)> +// #outputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D, H, W, C, CM)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr D = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr CM = m.dim(4); + AffineExpr d = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + AffineExpr C = m.dim(8); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2) + .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2), C}, + /*filterMap=*/{d, h, w, C, CM}, + /*outputMap=*/{N, D, H, W, C, CM}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMaxOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MaxSigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMinOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MinSigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcSumOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::Sum); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MaxUnsigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMinUnsignedOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MinUnsigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index 1382c7ac..d358362 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRMemorySlotInterfaces MLIRShapedOpInterfaces MLIRSideEffectInterfaces + MLIRUBDialect MLIRValueBoundsOpInterface MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp index 6ff63df..a1e3f10 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index dfa2e4e..5404238 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -61,15 +62,8 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape, // Interfaces for AllocaOp //===----------------------------------------------------------------------===// -static bool isSupportedElementType(Type type) { - return llvm::isa<MemRefType>(type) || - OpBuilder(type.getContext()).getZeroAttr(type); -} - SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { MemRefType type = getType(); - if (!isSupportedElementType(type.getElementType())) - return {}; if (!type.hasStaticShape()) return {}; // Make sure the memref contains only a single element. @@ -81,16 +75,7 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - assert(isSupportedElementType(slot.elemType)); - // TODO: support more types. - return TypeSwitch<Type, Value>(slot.elemType) - .Case([&](MemRefType t) { - return memref::AllocaOp::create(builder, getLoc(), t); - }) - .Default([&](Type t) { - return arith::ConstantOp::create(builder, getLoc(), t, - builder.getZeroAttr(t)); - }); + return ub::PoisonOp::create(builder, getLoc(), slot.elemType); } std::optional<PromotableAllocationOpInterface> diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1c21a2f..1035d7c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1074,13 +1074,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) { return subview.getDynamicSize(sourceIndex); } - if (auto sizeInterface = - dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { - assert(sizeInterface.isDynamicSize(unsignedIndex) && - "Expected dynamic subview size"); - return sizeInterface.getDynamicSize(unsignedIndex); - } - // dim(memrefcast) -> dim if (succeeded(foldMemRefCast(*this))) return getResult(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index bd02516..c9352e8 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -959,7 +959,11 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp PatternRewriter &rewriter) const override { auto viewLikeOp = extractOp.getSource().getDefiningOp<ViewLikeOpInterface>(); - if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest()) + // ViewLikeOpInterface by itself doesn't guarantee to preserve the base + // pointer in general and `memref.view` is one such example, so just check + // for a few specific cases. + if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() || + !isa<memref::SubViewOp, memref::ReinterpretCastOp>(viewLikeOp)) return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); rewriter.modifyOpInPlace(extractOp, [&]() { extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 214410f..3667fdb 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -347,28 +347,55 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices, isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation())))) return failure(); - llvm::TypeSwitch<Operation *, void>(loadOp) + + return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp) .Case([&](affine::AffineLoadOp op) { rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( loadOp, expandShapeOp.getViewSource(), sourceIndices); + return success(); }) .Case([&](memref::LoadOp op) { rewriter.replaceOpWithNewOp<memref::LoadOp>( loadOp, expandShapeOp.getViewSource(), sourceIndices, op.getNontemporal()); + return success(); }) .Case([&](vector::LoadOp op) { rewriter.replaceOpWithNewOp<vector::LoadOp>( op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, op.getNontemporal()); + return success(); }) .Case([&](vector::MaskedLoadOp op) { rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, op.getMask(), op.getPassThru()); + return success(); + }) + .Case([&](vector::TransferReadOp op) { + // We only support minor identity maps in the permutation attribute. + if (!op.getPermutationMap().isMinorIdentity()) + return failure(); + + // We only support the case where the source of the expand shape has + // rank greater than or equal to the vector rank. + const int64_t sourceRank = sourceIndices.size(); + const int64_t vectorRank = op.getVectorType().getRank(); + if (sourceRank < vectorRank) + return failure(); + + // We need to construct a new minor identity map since we will have lost + // some dimensions in folding away the expand shape. + auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank, + op.getContext()); + + rewriter.replaceOpWithNewOp<vector::TransferReadOp>( + op, op.getVectorType(), expandShapeOp.getViewSource(), + sourceIndices, minorIdMap, op.getPadding(), op.getMask(), + op.getInBounds()); + return success(); }) .DefaultUnreachable("unexpected operation"); - return success(); } template <typename OpTy> @@ -659,6 +686,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { LoadOpOfExpandShapeOpFolder<memref::LoadOp>, LoadOpOfExpandShapeOpFolder<vector::LoadOp>, LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>, + LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>, StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>, StoreOpOfExpandShapeOpFolder<memref::StoreOp>, StoreOpOfExpandShapeOpFolder<vector::StoreOp>, diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 6a81a15..c498c8a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> { if (!dimIndex) return failure(); - ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed(reifyResultShapes(rewriter, dimValue.getOwner(), - reifiedResultShapes))) + FailureOr<OpFoldResult> replacement = reifyDimOfResult( + rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex); + if (failed(replacement)) return failure(); - unsigned resultNumber = dimValue.getResultNumber(); - // Do not apply pattern if the IR is invalid (dim out of bounds). - if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size()) - return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds"); - Value replacement = getValueOrCreateConstantIndexOp( - rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]); - rewriter.replaceOp(dimOp, replacement); + // Check if the OpFoldResult is empty (unreifiable dimension). + if (!replacement.value()) + return failure(); + Value replacementVal = getValueOrCreateConstantIndexOp( + rewriter, dimOp.getLoc(), replacement.value()); + rewriter.replaceOp(dimOp, replacementVal); return success(); } }; @@ -166,12 +165,14 @@ namespace { struct ResolveRankedShapeTypeResultDimsPass final : public memref::impl::ResolveRankedShapeTypeResultDimsPassBase< ResolveRankedShapeTypeResultDimsPass> { + using Base::Base; void runOnOperation() override; }; struct ResolveShapedTypeResultDimsPass final : public memref::impl::ResolveShapedTypeResultDimsPassBase< ResolveShapedTypeResultDimsPass> { + using Base::Base; void runOnOperation() override; }; @@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns( void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + auto result = applyPatternsGreedily(getOperation(), std::move(patterns)); + if (errorOnPatternIterationLimit && failed(result)) { + getOperation()->emitOpError( + "dim operation resolution hit pattern iteration limit"); return signalPassFailure(); + } } void ResolveShapedTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + auto result = applyPatternsGreedily(getOperation(), std::move(patterns)); + if (errorOnPatternIterationLimit && failed(result)) { + getOperation()->emitOpError( + "dim operation resolution hit pattern iteration limit"); return signalPassFailure(); + } } diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 14152c5..e5cc41e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -268,61 +268,82 @@ struct SubViewOpInterface MemRefType sourceType = subView.getSource().getType(); // For each dimension, assert that: - // 0 <= offset < dim_size - // 0 <= offset + (size - 1) * stride < dim_size + // For empty slices (size == 0) : 0 <= offset <= dim_size + // For non-empty slices (size > 0): 0 <= offset < dim_size + // 0 <= offset + (size - 1) * stride + // dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); + auto metadataOp = ExtractStridedMetadataOp::create(builder, loc, subView.getSource()); + for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) { - // Reset insertion point to before the operation for each dimension + // Reset insertion point to before the operation for each dimension. builder.setInsertionPoint(subView); + Value offset = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedOffsets()[i]); Value size = getValueOrCreateConstantIndexOp(builder, loc, subView.getMixedSizes()[i]); Value stride = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedStrides()[i]); - - // Verify that offset is in-bounds. Value dimSize = metadataOp.getSizes()[i]; - Value offsetInBounds = - generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create(builder, loc, offsetInBounds, + + // Verify that offset is in-bounds (conditional on slice size). + Value sizeIsZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, size, zero); + auto offsetCheckIf = scf::IfOp::create( + builder, loc, sizeIsZero, + [&](OpBuilder &b, Location loc) { + // For empty slices, offset can be at the boundary: 0 <= offset <= + // dimSize. + Value offsetGEZero = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sge, offset, zero); + Value offsetLEDimSize = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sle, offset, dimSize); + Value emptyOffsetValid = + arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize); + scf::YieldOp::create(b, loc, emptyOffsetValid); + }, + [&](OpBuilder &b, Location loc) { + // For non-empty slices, offset must be a valid index: 0 <= offset + // dimSize. + Value offsetInBounds = + generateInBoundsCheck(b, loc, offset, zero, dimSize); + scf::YieldOp::create(b, loc, offsetInBounds); + }); + + Value offsetCondition = offsetCheckIf.getResult(0); + cf::AssertOp::create(builder, loc, offsetCondition, generateErrorMessage(op, "offset " + std::to_string(i) + " is out-of-bounds")); - // Only verify if size > 0 + // Verify that the slice endpoint is in-bounds (only for non-empty + // slices). Value sizeIsNonZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::sgt, size, zero); + auto ifOp = scf::IfOp::create( + builder, loc, sizeIsNonZero, + [&](OpBuilder &b, Location loc) { + // Verify that slice does not run out-of-bounds. + Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one); + Value sizeMinusOneTimesStride = + arith::MulIOp::create(b, loc, sizeMinusOne, stride); + Value lastPos = + arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride); + Value lastPosInBounds = + generateInBoundsCheck(b, loc, lastPos, zero, dimSize); + scf::YieldOp::create(b, loc, lastPosInBounds); + }, + [&](OpBuilder &b, Location loc) { + Value trueVal = + arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); + scf::YieldOp::create(b, loc, trueVal); + }); - auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), - sizeIsNonZero, /*withElseRegion=*/true); - - // Populate the "then" region (for size > 0). - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - - // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); - Value sizeMinusOneTimesStride = - arith::MulIOp::create(builder, loc, sizeMinusOne, stride); - Value lastPos = - arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); - Value lastPosInBounds = - generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - - scf::YieldOp::create(builder, loc, lastPosInBounds); - - // Populate the "else" region (for size == 0). - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - Value trueVal = - arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); - scf::YieldOp::create(builder, loc, trueVal); - - builder.setInsertionPointAfter(ifOp); Value finalCondition = ifOp.getResult(0); - cf::AssertOp::create( builder, loc, finalCondition, generateErrorMessage(op, diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 6200366..e548698 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -133,17 +133,20 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, } /// Returns true if all the uses of op are not read/load. -/// There can be SubviewOp users as long as all its users are also +/// There can be view-like-op users as long as all its users are also /// StoreOp/transfer_write. If return true it also fills out the uses, if it /// returns false uses is unchanged. static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) { std::vector<Operation *> opUses; for (OpOperand &use : op->getUses()) { Operation *useOp = use.getOwner(); + // Use escaped the scope + if (useOp->mightHaveTrait<OpTrait::IsTerminator>()) + return false; if (isa<memref::DeallocOp>(useOp) || (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 && !mlir::hasEffect<MemoryEffects::Read>(useOp)) || - (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) { + (isa<ViewLikeOpInterface>(useOp) && resultIsNotRead(useOp, opUses))) { opUses.push_back(useOp); continue; } diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 2a857ed..0d05313 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -675,7 +675,7 @@ MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { - auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn)); + auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn); Type elementType = getElementTypeOrSelf(memref.getType()); auto vt = VectorType::get(vectorShape, elementType); @@ -727,7 +727,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { toStore.push_back(v); }); - return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn)); + return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn); } static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, @@ -792,7 +792,7 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { if (failed(maybeInfo)) return failure(); - MmaSyncInfo info = *maybeInfo; + const MmaSyncInfo &info = *maybeInfo; auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; auto [lhsShape, rhsShape, resShape] = info.vectorShapes; Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef, diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp index 40e769e..1d775fb 100644 --- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp +++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp @@ -41,5 +41,12 @@ InFlightDiagnostic OpenACCSupport::emitNYI(Location loc, const Twine &message) { return mlir::emitError(loc, "not yet implemented: " + message); } +bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) { + if (impl) + return impl->isValidSymbolUse(user, symbol, definingOpPtr); + return acc::isValidSymbolUse(user, symbol, definingOpPtr); +} + } // namespace acc } // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 8c9c137..47f1222 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" @@ -203,12 +204,91 @@ struct MemRefPointerLikeModel return false; } + + mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> srcPtr, + Type valueType) const { + // Load from a memref - only valid for scalar memrefs (rank 0). + // This is because the address computation for memrefs is part of the load + // (and not computed separately), but the API does not have arguments for + // indexing. + auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr); + if (!memrefValue) + return {}; + + auto memrefTy = memrefValue.getType(); + + // Only load from scalar memrefs (rank 0) + if (memrefTy.getRank() != 0) + return {}; + + return memref::LoadOp::create(builder, loc, memrefValue); + } + + bool genStore(Type pointer, OpBuilder &builder, Location loc, + Value valueToStore, TypedValue<PointerLikeType> destPtr) const { + // Store to a memref - only valid for scalar memrefs (rank 0) + // This is because the address computation for memrefs is part of the store + // (and not computed separately), but the API does not have arguments for + // indexing. + auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr); + if (!memrefValue) + return false; + + auto memrefTy = memrefValue.getType(); + + // Only store to scalar memrefs (rank 0) + if (memrefTy.getRank() != 0) + return false; + + memref::StoreOp::create(builder, loc, valueToStore, memrefValue); + return true; + } }; struct LLVMPointerPointerLikeModel : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel, LLVM::LLVMPointerType> { Type getElementType(Type pointer) const { return Type(); } + + mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> srcPtr, + Type valueType) const { + // For LLVM pointers, we need the valueType to determine what to load + if (!valueType) + return {}; + + return LLVM::LoadOp::create(builder, loc, valueType, srcPtr); + } + + bool genStore(Type pointer, OpBuilder &builder, Location loc, + Value valueToStore, TypedValue<PointerLikeType> destPtr) const { + LLVM::StoreOp::create(builder, loc, valueToStore, destPtr); + return true; + } +}; + +struct MemrefAddressOfGlobalModel + : public AddressOfGlobalOpInterface::ExternalModel< + MemrefAddressOfGlobalModel, memref::GetGlobalOp> { + SymbolRefAttr getSymbol(Operation *op) const { + auto getGlobalOp = cast<memref::GetGlobalOp>(op); + return getGlobalOp.getNameAttr(); + } +}; + +struct MemrefGlobalVariableModel + : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel, + memref::GlobalOp> { + bool isConstant(Operation *op) const { + auto globalOp = cast<memref::GlobalOp>(op); + return globalOp.getConstant(); + } + + Region *getInitRegion(Operation *op) const { + // GlobalOp uses attributes for initialization, not regions + return nullptr; + } }; /// Helper function for any of the times we need to modify an ArrayAttr based on @@ -302,6 +382,11 @@ void OpenACCDialect::initialize() { MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext()); LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>( *getContext()); + + // Attach operation interfaces + memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>( + *getContext()); + memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext()); } //===----------------------------------------------------------------------===// @@ -467,6 +552,28 @@ checkValidModifier(Op op, acc::DataClauseModifier validModifiers) { return success(); } +template <typename OpT, typename RecipeOpT> +static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) { + // Mappable types do not need a recipe because it is possible to generate one + // from its API. Reject reductions though because no API is available for them + // at this time. + if (mlir::acc::isMappableType(op.getVar().getType()) && + !std::is_same_v<OpT, acc::ReductionOp>) + return success(); + + mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr(); + if (!operandRecipe) + return op->emitOpError() << "recipe expected for " << operandName; + + auto decl = + SymbolTable::lookupNearestSymbolFrom<RecipeOpT>(op, operandRecipe); + if (!decl) + return op->emitOpError() + << "expected symbol reference " << operandRecipe << " to point to a " + << operandName << " declaration"; + return success(); +} + static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var) { // Either `var` or `varPtr` keyword is required. @@ -573,6 +680,18 @@ static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, } } +static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, + mlir::SymbolRefAttr &recipeAttr) { + if (failed(parser.parseAttribute(recipeAttr))) + return failure(); + return success(); +} + +static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::SymbolRefAttr recipeAttr) { + p << recipeAttr; +} + //===----------------------------------------------------------------------===// // DataBoundsOp //===----------------------------------------------------------------------===// @@ -595,6 +714,9 @@ LogicalResult acc::PrivateOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed( + checkRecipe<acc::PrivateOp, acc::PrivateRecipeOp>(*this, "private"))) + return failure(); return success(); } @@ -609,6 +731,9 @@ LogicalResult acc::FirstprivateOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed(checkRecipe<acc::FirstprivateOp, acc::FirstprivateRecipeOp>( + *this, "firstprivate"))) + return failure(); return success(); } @@ -637,6 +762,9 @@ LogicalResult acc::ReductionOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed(checkRecipe<acc::ReductionOp, acc::ReductionRecipeOp>( + *this, "reduction"))) + return failure(); return success(); } @@ -1322,6 +1450,28 @@ PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, return recipe; } +std::optional<PrivateRecipeOp> +PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, + FirstprivateRecipeOp firstprivRecipe) { + // Create the private.recipe op with the same type as the firstprivate.recipe. + OpBuilder::InsertionGuard guard(builder); + auto varType = firstprivRecipe.getType(); + auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType); + + // Clone the init region + IRMapping mapping; + firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping); + + // Clone destroy region if the firstprivate.recipe has one. + if (!firstprivRecipe.getDestroyRegion().empty()) { + IRMapping mapping; + firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(), + mapping); + } + return recipe; +} + //===----------------------------------------------------------------------===// // FirstprivateRecipeOp //===----------------------------------------------------------------------===// @@ -1432,40 +1582,6 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() { } //===----------------------------------------------------------------------===// -// Custom parser and printer verifier for private clause -//===----------------------------------------------------------------------===// - -static ParseResult parseSymOperandList( - mlir::OpAsmParser &parser, - llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, - llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) { - llvm::SmallVector<SymbolRefAttr> attributes; - if (failed(parser.parseCommaSeparatedList([&]() { - if (parser.parseAttribute(attributes.emplace_back()) || - parser.parseArrow() || - parser.parseOperand(operands.emplace_back()) || - parser.parseColonType(types.emplace_back())) - return failure(); - return success(); - }))) - return failure(); - llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), - attributes.end()); - symbols = ArrayAttr::get(parser.getContext(), arrayAttr); - return success(); -} - -static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, - mlir::OperandRange operands, - mlir::TypeRange types, - std::optional<mlir::ArrayAttr> attributes) { - llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) { - p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " - << std::get<1>(it).getType(); - }); -} - -//===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// @@ -1484,45 +1600,19 @@ static LogicalResult checkDataOperands(Op op, return success(); } -template <typename Op> -static LogicalResult -checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes, - mlir::OperandRange operands, llvm::StringRef operandName, - llvm::StringRef symbolName, bool checkOperandType = true) { - if (!operands.empty()) { - if (!attributes || attributes->size() != operands.size()) - return op->emitOpError() - << "expected as many " << symbolName << " symbol reference as " - << operandName << " operands"; - } else { - if (attributes) - return op->emitOpError() - << "unexpected " << symbolName << " symbol reference"; - return success(); - } - +template <typename OpT, typename RecipeOpT> +static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, + const mlir::ValueRange &operands, + llvm::StringRef operandName) { llvm::DenseSet<Value> set; - for (auto args : llvm::zip(operands, *attributes)) { - mlir::Value operand = std::get<0>(args); - + for (mlir::Value operand : operands) { + if (!mlir::isa<OpT>(operand.getDefiningOp())) + return accConstructOp->emitOpError() + << "expected " << operandName << " as defining op"; if (!set.insert(operand).second) - return op->emitOpError() + return accConstructOp->emitOpError() << operandName << " operand appears more than once"; - - mlir::Type varType = operand.getType(); - auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args)); - auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef); - if (!decl) - return op->emitOpError() - << "expected symbol reference " << symbolRef << " to point to a " - << operandName << " declaration"; - - if (checkOperandType && decl.getType() && decl.getType() != varType) - return op->emitOpError() << "expected " << operandName << " (" << varType - << ") to be the same type as " << operandName - << " declaration (" << decl.getType() << ")"; } - return success(); } @@ -1579,17 +1669,17 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch( } LogicalResult acc::ParallelOp::verify() { - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( @@ -1720,7 +1810,6 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder, mlir::ValueRange gangPrivateOperands, mlir::ValueRange gangFirstPrivateOperands, mlir::ValueRange dataClauseOperands) { - ParallelOp::build( odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr, @@ -1729,9 +1818,8 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder, /*numGangsDeviceType=*/nullptr, numWorkers, /*numWorkersDeviceType=*/nullptr, vectorLength, /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond, - /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr, - gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands, - /*firstprivatizations=*/nullptr, dataClauseOperands, + /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands, + gangFirstPrivateOperands, dataClauseOperands, /*defaultAttr=*/nullptr, /*combined=*/nullptr); } @@ -1808,46 +1896,22 @@ void acc::ParallelOp::addWaitOperands( void acc::ParallelOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::ParallelOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::ParallelOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } static ParseResult parseNumGangs( @@ -2415,17 +2479,17 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { } LogicalResult acc::SerialOp::verify() { - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( @@ -2489,46 +2553,22 @@ void acc::SerialOp::addWaitOperands( void acc::SerialOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::SerialOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::SerialOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } //===----------------------------------------------------------------------===// @@ -2658,6 +2698,27 @@ LogicalResult acc::KernelsOp::verify() { return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands()); } +void acc::KernelsOp::addPrivatization(MLIRContext *context, + mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getPrivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getFirstprivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addReduction(MLIRContext *context, + mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getReductionOperandsMutable().append(op.getResult()); +} + void acc::KernelsOp::addNumWorkersOperand( MLIRContext *context, mlir::Value newValue, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { @@ -2967,19 +3028,21 @@ bool hasDuplicateDeviceTypes( } /// Check for duplicates in the DeviceType array attribute. -LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) { +/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found. +static std::optional<mlir::acc::DeviceType> +checkDeviceTypes(mlir::ArrayAttr deviceTypes) { llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes; if (!deviceTypes) - return success(); + return std::nullopt; for (auto attr : deviceTypes) { auto deviceTypeAttr = mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr); if (!deviceTypeAttr) - return failure(); + return mlir::acc::DeviceType::None; if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second) - return failure(); + return deviceTypeAttr.getValue(); } - return success(); + return std::nullopt; } LogicalResult acc::LoopOp::verify() { @@ -3006,9 +3069,10 @@ LogicalResult acc::LoopOp::verify() { getCollapseDeviceTypeAttr().getValue().size()) return emitOpError() << "collapse attribute count must match collapse" << " device_type count"; - if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr()))) - return emitOpError() - << "duplicate device_type found in collapseDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in collapseDeviceType attribute"; // Check gang if (!getGangOperands().empty()) { @@ -3021,8 +3085,12 @@ LogicalResult acc::LoopOp::verify() { return emitOpError() << "gangOperandsArgType attribute count must match" << " gangOperands count"; } - if (getGangAttr() && failed(checkDeviceTypes(getGangAttr()))) - return emitOpError() << "duplicate device_type found in gang attribute"; + if (getGangAttr()) { + if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in gang attribute"; + } if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getGangOperands(), getGangOperandsSegmentsAttr(), @@ -3030,22 +3098,30 @@ LogicalResult acc::LoopOp::verify() { return failure(); // Check worker - if (failed(checkDeviceTypes(getWorkerAttr()))) - return emitOpError() << "duplicate device_type found in worker attribute"; - if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))) - return emitOpError() << "duplicate device_type found in " - "workerNumOperandsDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in worker attribute"; + if (auto duplicateDeviceType = + checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in workerNumOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(), getWorkerNumOperandsDeviceTypeAttr(), "worker"))) return failure(); // Check vector - if (failed(checkDeviceTypes(getVectorAttr()))) - return emitOpError() << "duplicate device_type found in vector attribute"; - if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))) - return emitOpError() << "duplicate device_type found in " - "vectorOperandsDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in vector attribute"; + if (auto duplicateDeviceType = + checkDeviceTypes(getVectorOperandsDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in vectorOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(), getVectorOperandsDeviceTypeAttr(), "vector"))) @@ -3110,19 +3186,19 @@ LogicalResult acc::LoopOp::verify() { } } - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (getCombined().has_value() && @@ -3556,45 +3632,21 @@ void acc::LoopOp::addGangOperands( void acc::LoopOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::LoopOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } //===----------------------------------------------------------------------===// @@ -4059,7 +4111,8 @@ LogicalResult acc::RoutineOp::verify() { if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1)) return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " - "be present at the same time"; + "be present at the same time for device_type `" + << acc::stringifyDeviceType(dtype) << "`"; } return success(); @@ -4356,6 +4409,100 @@ RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) { return std::nullopt; } +void RoutineOp::addSeq(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addVector(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addWorker(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addGang(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addGang(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + uint64_t val) { + llvm::SmallVector<mlir::Attribute> dimValues; + llvm::SmallVector<mlir::Attribute> deviceTypes; + + if (getGangDimAttr()) + llvm::copy(getGangDimAttr(), std::back_inserter(dimValues)); + if (getGangDimDeviceTypeAttr()) + llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes)); + + assert(dimValues.size() == deviceTypes.size()); + + if (effectiveDeviceTypes.empty()) { + dimValues.push_back( + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val)); + deviceTypes.push_back( + acc::DeviceTypeAttr::get(context, acc::DeviceType::None)); + } else { + for (DeviceType dt : effectiveDeviceTypes) { + dimValues.push_back( + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val)); + deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt)); + } + } + assert(dimValues.size() == deviceTypes.size()); + + setGangDimAttr(mlir::ArrayAttr::get(context, dimValues)); + setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes)); +} + +void RoutineOp::addBindStrName(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + mlir::StringAttr val) { + unsigned before = getBindStrNameDeviceTypeAttr() + ? getBindStrNameDeviceTypeAttr().size() + : 0; + + setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes)); + unsigned after = getBindStrNameDeviceTypeAttr().size(); + + llvm::SmallVector<mlir::Attribute> vals; + if (getBindStrNameAttr()) + llvm::copy(getBindStrNameAttr(), std::back_inserter(vals)); + for (unsigned i = 0; i < after - before; ++i) + vals.push_back(val); + + setBindStrNameAttr(mlir::ArrayAttr::get(context, vals)); +} + +void RoutineOp::addBindIDName(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + mlir::SymbolRefAttr val) { + unsigned before = + getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0; + + setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes)); + unsigned after = getBindIdNameDeviceTypeAttr().size(); + + llvm::SmallVector<mlir::Attribute> vals; + if (getBindIdNameAttr()) + llvm::copy(getBindIdNameAttr(), std::back_inserter(vals)); + for (unsigned i = 0; i < after - before; ++i) + vals.push_back(val); + + setBindIdNameAttr(mlir::ArrayAttr::get(context, vals)); +} + //===----------------------------------------------------------------------===// // InitOp //===----------------------------------------------------------------------===// @@ -4739,3 +4886,12 @@ mlir::acc::getMutableDataOperands(mlir::Operation *accOp) { .Default([&](mlir::Operation *) { return nullptr; })}; return dataOperands; } + +mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) { + auto recipe{ + llvm::TypeSwitch<mlir::Operation *, mlir::SymbolRefAttr>(accOp) + .Case<ACC_DATA_ENTRY_OPS>( + [&](auto entry) { return entry.getRecipeAttr(); }) + .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })}; + return recipe; +} diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp index 91262bd..67cdf10 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp @@ -237,11 +237,6 @@ public: void runOnOperation() override; private: - /// Collects all data clauses that dominate the compute construct. - /// Needed to determine if a variable is already covered by an existing data - /// clause. - SmallVector<Value> getDominatingDataClauses(Operation *computeConstructOp); - /// Looks through the `dominatingDataClauses` to find the original data clause /// op for an alias. Returns nullptr if no original data clause op is found. template <typename OpT> @@ -277,8 +272,7 @@ private: /// Generates recipes for a list of variables. void generateRecipes(ModuleOp &module, OpBuilder &builder, Operation *computeConstructOp, - const SmallVector<Value> &newOperands, - SmallVector<Attribute> &newRecipeSyms); + const SmallVector<Value> &newOperands); }; /// Determines if a variable is a candidate for implicit data mapping. @@ -301,62 +295,6 @@ static bool isCandidateForImplicitData(Value val, Region &accRegion) { return true; } -SmallVector<Value> -ACCImplicitData::getDominatingDataClauses(Operation *computeConstructOp) { - llvm::SmallSetVector<Value, 8> dominatingDataClauses; - - llvm::TypeSwitch<Operation *>(computeConstructOp) - .Case<acc::ParallelOp, acc::KernelsOp, acc::SerialOp>([&](auto op) { - for (auto dataClause : op.getDataClauseOperands()) { - dominatingDataClauses.insert(dataClause); - } - }) - .Default([](Operation *) {}); - - // Collect the data clauses from enclosing data constructs. - Operation *currParentOp = computeConstructOp->getParentOp(); - while (currParentOp) { - if (isa<acc::DataOp>(currParentOp)) { - for (auto dataClause : - dyn_cast<acc::DataOp>(currParentOp).getDataClauseOperands()) { - dominatingDataClauses.insert(dataClause); - } - } - currParentOp = currParentOp->getParentOp(); - } - - // Find the enclosing function/subroutine - auto funcOp = computeConstructOp->getParentOfType<FunctionOpInterface>(); - if (!funcOp) - return dominatingDataClauses.takeVector(); - - // Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that - // dominate and post-dominate the compute construct and add their data - // clauses to the list. - auto &domInfo = this->getAnalysis<DominanceInfo>(); - auto &postDomInfo = this->getAnalysis<PostDominanceInfo>(); - funcOp->walk([&](acc::DeclareEnterOp declareEnterOp) { - if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) { - // Collect all `acc.declare_exit` ops for this token. - SmallVector<acc::DeclareExitOp> exits; - for (auto *user : declareEnterOp.getToken().getUsers()) - if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user)) - exits.push_back(declareExit); - - // Only add clauses if every `acc.declare_exit` op post-dominates the - // compute construct. - if (!exits.empty() && llvm::all_of(exits, [&](acc::DeclareExitOp exitOp) { - return postDomInfo.postDominates(exitOp, computeConstructOp); - })) { - for (auto dataClause : declareEnterOp.getDataClauseOperands()) - dominatingDataClauses.insert(dataClause); - } - } - }); - - return dominatingDataClauses.takeVector(); -} - template <typename OpT> Operation *ACCImplicitData::getOriginalDataClauseOpForAlias( Value var, OpBuilder &builder, OpT computeConstructOp, @@ -453,23 +391,23 @@ ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module, Value var, void ACCImplicitData::generateRecipes(ModuleOp &module, OpBuilder &builder, Operation *computeConstructOp, - const SmallVector<Value> &newOperands, - SmallVector<Attribute> &newRecipeSyms) { + const SmallVector<Value> &newOperands) { auto &accSupport = this->getAnalysis<acc::OpenACCSupport>(); for (auto var : newOperands) { auto loc{var.getLoc()}; - if (isa<acc::PrivateOp>(var.getDefiningOp())) { + if (auto privateOp = dyn_cast<acc::PrivateOp>(var.getDefiningOp())) { auto recipe = generatePrivateRecipe( module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport); if (recipe) - newRecipeSyms.push_back(SymbolRefAttr::get(module->getContext(), - recipe.getSymName().str())); - } else if (isa<acc::FirstprivateOp>(var.getDefiningOp())) { + privateOp.setRecipeAttr( + SymbolRefAttr::get(module->getContext(), recipe.getSymName())); + } else if (auto firstprivateOp = + dyn_cast<acc::FirstprivateOp>(var.getDefiningOp())) { auto recipe = generateFirstprivateRecipe( module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport); if (recipe) - newRecipeSyms.push_back(SymbolRefAttr::get(module->getContext(), - recipe.getSymName().str())); + firstprivateOp.setRecipeAttr(SymbolRefAttr::get( + module->getContext(), recipe.getSymName().str())); } else { accSupport.emitNYI(var.getLoc(), "implicit reduction"); } @@ -570,6 +508,8 @@ Operation *ACCImplicitData::generateDataClauseOpForCandidate( newDataOp = acc::PresentOp::create(builder, loc, var, /*structured=*/true, /*implicit=*/true, accSupport.getVariableName(var)); + newDataOp->setAttr(acc::getFromDefaultClauseAttrName(), + builder.getUnitAttr()); } else { auto copyinOp = acc::CopyinOp::create(builder, loc, var, @@ -611,56 +551,22 @@ static void legalizeValuesInRegion(Region &accRegion, } } -// Adds the private operands and private recipes to the data construct -// operation in a valid way (ensures that the index in the privatizationRecipes -// array matches the position of the private operand). +// Adds the private operands to the compute construct operation. template <typename OpT> -static void -addNewPrivateOperands(OpT &accOp, const SmallVector<Value> &privateOperands, - const SmallVector<Attribute> &privateRecipeSyms) { - assert(privateOperands.size() == privateRecipeSyms.size()); +static void addNewPrivateOperands(OpT &accOp, + const SmallVector<Value> &privateOperands) { if (privateOperands.empty()) return; - SmallVector<Attribute> completePrivateRecipesSyms; - SmallVector<Attribute> completeFirstprivateRecipesSyms; - SmallVector<Value> newPrivateOperands; - SmallVector<Value> newFirstprivateOperands; - - // Collect all of the existing recipes since they are held in an attribute. - // To add to it, we need to create a brand new one. - if (accOp.getPrivatizationRecipes().has_value()) - for (auto privatization : accOp.getPrivatizationRecipesAttr()) - completePrivateRecipesSyms.push_back(privatization); - if (accOp.getFirstprivatizationRecipes().has_value()) - for (auto privatization : accOp.getFirstprivatizationRecipesAttr()) - completeFirstprivateRecipesSyms.push_back(privatization); - - // Now separate between private and firstprivate operands. - for (auto [priv, privateRecipeSym] : - llvm::zip(privateOperands, privateRecipeSyms)) { + for (auto priv : privateOperands) { if (isa<acc::PrivateOp>(priv.getDefiningOp())) { - newPrivateOperands.push_back(priv); - completePrivateRecipesSyms.push_back(privateRecipeSym); + accOp.getPrivateOperandsMutable().append(priv); } else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) { - newFirstprivateOperands.push_back(priv); - completeFirstprivateRecipesSyms.push_back(privateRecipeSym); + accOp.getFirstprivateOperandsMutable().append(priv); } else { - llvm_unreachable("unhandled private operand"); + llvm_unreachable("unhandled reduction operand"); } } - - // Append all of the new private operands to their appropriate list. - accOp.getPrivateOperandsMutable().append(newPrivateOperands); - accOp.getFirstprivateOperandsMutable().append(newFirstprivateOperands); - - // Update the privatizationRecipes attributes to hold all of the new recipes. - if (!completePrivateRecipesSyms.empty()) - accOp.setPrivatizationRecipesAttr( - ArrayAttr::get(accOp.getContext(), completePrivateRecipesSyms)); - if (!completeFirstprivateRecipesSyms.empty()) - accOp.setFirstprivatizationRecipesAttr( - ArrayAttr::get(accOp.getContext(), completeFirstprivateRecipesSyms)); } static Operation *findDataExitOp(Operation *dataEntryOp) { @@ -808,7 +714,10 @@ void ACCImplicitData::generateImplicitDataOps( LLVM_DEBUG(llvm::dbgs() << "== Generating clauses for ==\n" << computeConstructOp << "\n"); } - auto dominatingDataClauses = getDominatingDataClauses(computeConstructOp); + auto &domInfo = this->getAnalysis<DominanceInfo>(); + auto &postDomInfo = this->getAnalysis<PostDominanceInfo>(); + auto dominatingDataClauses = + acc::getDominatingDataClauses(computeConstructOp, domInfo, postDomInfo); for (auto var : candidateVars) { auto newDataClauseOp = generateDataClauseOpForCandidate( var, module, builder, computeConstructOp, dominatingDataClauses, @@ -829,13 +738,11 @@ void ACCImplicitData::generateImplicitDataOps( // of the data clause ops) legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands); - SmallVector<Attribute> newPrivateRecipeSyms; // 5) Generate private recipes which are required for properly attaching // private operands. if constexpr (!std::is_same_v<OpT, acc::KernelsOp> && !std::is_same_v<OpT, acc::KernelEnvironmentOp>) - generateRecipes(module, builder, computeConstructOp, newPrivateOperands, - newPrivateRecipeSyms); + generateRecipes(module, builder, computeConstructOp, newPrivateOperands); // 6) Figure out insertion order for the new data clause operands. SmallVector<Value> sortedDataClauseOperands( @@ -846,15 +753,10 @@ void ACCImplicitData::generateImplicitDataOps( // 7) Generate the data exit operations. generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands, sortedDataClauseOperands); - // 8) Add all of the new operands to the compute construct op. - assert(newPrivateOperands.size() == newPrivateRecipeSyms.size() && - "sizes must match"); if constexpr (!std::is_same_v<OpT, acc::KernelsOp> && !std::is_same_v<OpT, acc::KernelEnvironmentOp>) - addNewPrivateOperands(computeConstructOp, newPrivateOperands, - newPrivateRecipeSyms); - + addNewPrivateOperands(computeConstructOp, newPrivateOperands); computeConstructOp.getDataClauseOperandsMutable().assign( sortedDataClauseOperands); } diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp new file mode 100644 index 0000000..8cab223 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp @@ -0,0 +1,431 @@ +//===- ACCImplicitDeclare.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass applies implicit `acc declare` actions to global variables +// referenced in OpenACC compute regions and routine functions. +// +// Overview: +// --------- +// Global references in an acc regions (for globals not marked with `acc +// declare` by the user) can be handled in one of two ways: +// - Mapped through data clauses +// - Implicitly marked as `acc declare` (this pass) +// +// Thus, the OpenACC specification focuses solely on implicit data mapping rules +// whose implementation is captured in `ACCImplicitData` pass. +// +// However, it is both advantageous and required for certain cases to +// use implicit `acc declare` instead: +// - Any functions that are implicitly marked as `acc routine` through +// `ACCImplicitRoutine` may reference globals. Since data mapping +// is only possible for compute regions, such globals can only be +// made available on device through `acc declare`. +// - Compiler can generate and use globals for cases needed in IR +// representation such as type descriptors or various names needed for +// runtime calls and error reporting - such cases often are introduced +// after a frontend semantic checking is done since it is related to +// implementation detail. Thus, such compiler generated globals would +// not have been visible for a user to mark with `acc declare`. +// - Constant globals such as filename strings or data initialization values +// are values that do not get mutated but are still needed for appropriate +// runtime execution. If a kernel is launched 1000 times, it is not a +// good idea to map such a global 1000 times. Therefore, such globals +// benefit from being marked with `acc declare`. +// +// This pass automatically +// marks global variables with the `acc.declare` attribute when they are +// referenced in OpenACC compute constructs or routine functions and meet +// the criteria noted above, ensuring +// they are properly handled for device execution. +// +// The pass performs two main optimizations: +// +// 1. Hoisting: For non-constant globals referenced in compute regions, the +// pass hoists the address-of operation out of the region when possible, +// allowing them to be implicitly mapped through normal data clause +// mechanisms rather than requiring declare marking. +// +// 2. Declaration: For globals that must be available on the device (constants, +// globals in routines, globals in recipe operations), the pass adds the +// `acc.declare` attribute with the copyin data clause. +// +// Requirements: +// ------------- +// To use this pass in a pipeline, the following requirements must be met: +// +// 1. Operation Interface Implementation: Operations that compute addresses +// of global variables must implement the `acc::AddressOfGlobalOpInterface` +// and those that represent globals must implement the +// `acc::GlobalOpInterface`. Additionally, any operations that indirectly +// access globals must implement the `acc::IndirectGlobalAccessOpInterface`. +// +// 2. Analysis Registration (Optional): If custom behavior is needed for +// determining if a symbol use is valid within GPU regions, the dialect +// should pre-register the `acc::OpenACCSupport` analysis. +// +// Examples: +// --------- +// +// Example 1: Non-constant global in compute region (hoisted) +// +// Before: +// memref.global @g_scalar : memref<f32> = dense<0.0> +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_scalar : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// After: +// memref.global @g_scalar : memref<f32> = dense<0.0> +// func.func @test() { +// %addr = memref.get_global @g_scalar : memref<f32> +// acc.serial { +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// Example 2: Constant global in compute region (declared) +// +// Before: +// memref.global constant @g_const : memref<f32> = dense<1.0> +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_const : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// After: +// memref.global constant @g_const : memref<f32> = dense<1.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_const : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// Example 3: Global in acc routine (declared) +// +// Before: +// memref.global @g_data : memref<f32> = dense<0.0> +// acc.routine @routine_0 func(@device_func) +// func.func @device_func() attributes {acc.routine_info = ...} { +// %addr = memref.get_global @g_data : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// } +// +// After: +// memref.global @g_data : memref<f32> = dense<0.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// acc.routine @routine_0 func(@device_func) +// func.func @device_func() attributes {acc.routine_info = ...} { +// %addr = memref.get_global @g_data : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// } +// +// Example 4: Global in private recipe (declared if recipe is used) +// +// Before: +// memref.global @g_init : memref<f32> = dense<0.0> +// acc.private.recipe @priv_recipe : memref<f32> init { +// ^bb0(%arg0: memref<f32>): +// %alloc = memref.alloc() : memref<f32> +// %global = memref.get_global @g_init : memref<f32> +// %val = memref.load %global[] : memref<f32> +// memref.store %val, %alloc[] : memref<f32> +// acc.yield %alloc : memref<f32> +// } destroy { ... } +// func.func @test() { +// %var = memref.alloc() : memref<f32> +// %priv = acc.private varPtr(%var : memref<f32>) +// recipe(@priv_recipe) -> memref<f32> +// acc.parallel private(%priv : memref<f32>) { ... } +// } +// +// After: +// memref.global @g_init : memref<f32> = dense<0.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// acc.private.recipe @priv_recipe : memref<f32> init { +// ^bb0(%arg0: memref<f32>): +// %alloc = memref.alloc() : memref<f32> +// %global = memref.get_global @g_init : memref<f32> +// %val = memref.load %global[] : memref<f32> +// memref.store %val, %alloc[] : memref<f32> +// acc.yield %alloc : memref<f32> +// } destroy { ... } +// func.func @test() { +// %var = memref.alloc() : memref<f32> +// %priv = acc.private varPtr(%var : memref<f32>) +// recipe(@priv_recipe) -> memref<f32> +// acc.parallel private(%priv : memref<f32>) { ... } +// } +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCIMPLICITDECLARE +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-implicit-declare" + +using namespace mlir; + +namespace { + +using GlobalOpSetT = llvm::SmallSetVector<Operation *, 16>; + +/// Checks whether a use of the requested `globalOp` should be considered +/// for hoisting out of acc region due to avoid `acc declare`ing something +/// that instead should be implicitly mapped. +static bool isGlobalUseCandidateForHoisting(Operation *globalOp, + Operation *user, + SymbolRefAttr symbol, + acc::OpenACCSupport &accSupport) { + // This symbol is valid in GPU region. This means semantics + // would change if moved to host - therefore it is not a candidate. + if (accSupport.isValidSymbolUse(user, symbol)) + return false; + + bool isConstant = false; + bool isFunction = false; + + if (auto globalVarOp = dyn_cast<acc::GlobalVariableOpInterface>(globalOp)) + isConstant = globalVarOp.isConstant(); + + if (isa<FunctionOpInterface>(globalOp)) + isFunction = true; + + // Constants should be kept in device code to ensure they are duplicated. + // Function references should be kept in device code to ensure their device + // addresses are computed. Everything else should be hoisted since we already + // proved they are not valid symbols in GPU region. + return !isConstant && !isFunction; +} + +/// Checks whether it is valid to use acc.declare marking on the global. +bool isValidForAccDeclare(Operation *globalOp) { + // For functions - we use acc.routine marking instead. + return !isa<FunctionOpInterface>(globalOp); +} + +/// Checks whether a recipe operation has meaningful use of its symbol that +/// justifies processing its regions for global references. Returns false if: +/// 1. The recipe has no symbol uses at all, or +/// 2. The only symbol use is the recipe's own symbol definition +template <typename RecipeOpT> +static bool hasRelevantRecipeUse(RecipeOpT &recipeOp, ModuleOp &mod) { + std::optional<SymbolTable::UseRange> symbolUses = recipeOp.getSymbolUses(mod); + + // No recipe symbol uses. + if (!symbolUses.has_value() || symbolUses->empty()) + return false; + + // If more than one use, assume it's used. + auto begin = symbolUses->begin(); + auto end = symbolUses->end(); + if (begin != end && std::next(begin) != end) + return true; + + // If single use, check if the use is the recipe itself. + const SymbolTable::SymbolUse &use = *symbolUses->begin(); + return use.getUser() != recipeOp.getOperation(); +} + +// Hoists addr_of operations for non-constant globals out of OpenACC regions. +// This way - they are implicitly mapped instead of being considered for +// implicit declare. +template <typename AccConstructT> +static void hoistNonConstantDirectUses(AccConstructT accOp, + acc::OpenACCSupport &accSupport) { + accOp.walk([&](acc::AddressOfGlobalOpInterface addrOfOp) { + SymbolRefAttr symRef = addrOfOp.getSymbol(); + if (symRef) { + Operation *globalOp = + SymbolTable::lookupNearestSymbolFrom(addrOfOp, symRef); + if (isGlobalUseCandidateForHoisting(globalOp, addrOfOp, symRef, + accSupport)) { + addrOfOp->moveBefore(accOp); + LLVM_DEBUG( + llvm::dbgs() << "Hoisted:\n\t" << addrOfOp << "\n\tfrom:\n\t"; + accOp->print(llvm::dbgs(), + OpPrintingFlags{}.skipRegions().enableDebugInfo()); + llvm::dbgs() << "\n"); + } + } + }); +} + +// Collects the globals referenced in a device region +static void collectGlobalsFromDeviceRegion(Region ®ion, + GlobalOpSetT &globals, + acc::OpenACCSupport &accSupport, + SymbolTable &symTab) { + region.walk([&](Operation *op) { + // 1) Only consider relevant operations which use symbols + auto addrOfOp = dyn_cast<acc::AddressOfGlobalOpInterface>(op); + if (addrOfOp) { + SymbolRefAttr symRef = addrOfOp.getSymbol(); + // 2) Found an operation which uses the symbol. Next determine if it + // is a candidate for `acc declare`. Some of the criteria considered + // is whether this symbol is not already a device one (either because + // acc declare is already used or this is a CUF global). + Operation *globalOp = nullptr; + bool isCandidate = !accSupport.isValidSymbolUse(op, symRef, &globalOp); + // 3) Add the candidate to the set of globals to be `acc declare`d. + if (isCandidate && globalOp && isValidForAccDeclare(globalOp)) + globals.insert(globalOp); + } else if (auto indirectAccessOp = + dyn_cast<acc::IndirectGlobalAccessOpInterface>(op)) { + // Process operations that indirectly access globals + llvm::SmallVector<SymbolRefAttr> symbols; + indirectAccessOp.getReferencedSymbols(symbols, &symTab); + for (SymbolRefAttr symRef : symbols) + if (Operation *globalOp = symTab.lookup(symRef.getLeafReference())) + if (isValidForAccDeclare(globalOp)) + globals.insert(globalOp); + } + }); +} + +// Adds the declare attribute to the operation `op`. +static void addDeclareAttr(MLIRContext *context, Operation *op, + acc::DataClause clause) { + op->setAttr(acc::getDeclareAttrName(), + acc::DeclareAttr::get(context, + acc::DataClauseAttr::get(context, clause))); +} + +// This pass applies implicit declare actions for globals referenced in +// OpenACC compute and routine regions. +class ACCImplicitDeclare + : public acc::impl::ACCImplicitDeclareBase<ACCImplicitDeclare> { +public: + using ACCImplicitDeclareBase<ACCImplicitDeclare>::ACCImplicitDeclareBase; + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *context = &getContext(); + acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>(); + + // 1) Start off by hoisting any AddressOf operations out of acc region + // for any cases we do not want to `acc declare`. This is because we can + // rely on implicit data mapping in majority of cases without uselessly + // polluting the device globals. + mod.walk([&](Operation *op) { + TypeSwitch<Operation *, void>(op) + .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>( + [&](auto accOp) { + hoistNonConstantDirectUses(accOp, accSupport); + }); + }); + + // 2) Collect global symbols which need to be `acc declare`d. Do it for + // compute regions, acc routine, and existing globals with the declare + // attribute. + SymbolTable symTab(mod); + GlobalOpSetT globalsToAccDeclare; + mod.walk([&](Operation *op) { + TypeSwitch<Operation *, void>(op) + .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>( + [&](auto accOp) { + collectGlobalsFromDeviceRegion( + accOp.getRegion(), globalsToAccDeclare, accSupport, symTab); + }) + .Case<FunctionOpInterface>([&](auto func) { + if ((acc::isAccRoutine(func) || + acc::isSpecializedAccRoutine(func)) && + !func.isExternal()) + collectGlobalsFromDeviceRegion(func.getFunctionBody(), + globalsToAccDeclare, accSupport, + symTab); + }) + .Case<acc::GlobalVariableOpInterface>([&](auto globalVarOp) { + if (globalVarOp->getAttr(acc::getDeclareAttrName())) + if (Region *initRegion = globalVarOp.getInitRegion()) + collectGlobalsFromDeviceRegion(*initRegion, globalsToAccDeclare, + accSupport, symTab); + }) + .Case<acc::PrivateRecipeOp>([&](auto privateRecipe) { + if (hasRelevantRecipeUse(privateRecipe, mod)) { + collectGlobalsFromDeviceRegion(privateRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion(privateRecipe.getDestroyRegion(), + globalsToAccDeclare, accSupport, + symTab); + } + }) + .Case<acc::FirstprivateRecipeOp>([&](auto firstprivateRecipe) { + if (hasRelevantRecipeUse(firstprivateRecipe, mod)) { + collectGlobalsFromDeviceRegion(firstprivateRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion( + firstprivateRecipe.getDestroyRegion(), globalsToAccDeclare, + accSupport, symTab); + collectGlobalsFromDeviceRegion(firstprivateRecipe.getCopyRegion(), + globalsToAccDeclare, accSupport, + symTab); + } + }) + .Case<acc::ReductionRecipeOp>([&](auto reductionRecipe) { + if (hasRelevantRecipeUse(reductionRecipe, mod)) { + collectGlobalsFromDeviceRegion(reductionRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion( + reductionRecipe.getCombinerRegion(), globalsToAccDeclare, + accSupport, symTab); + } + }); + }); + + // 3) Finally, generate the appropriate declare actions needed to ensure + // this is considered for device global. + for (Operation *globalOp : globalsToAccDeclare) { + LLVM_DEBUG( + llvm::dbgs() << "Global is being `acc declare copyin`d: "; + globalOp->print(llvm::dbgs(), + OpPrintingFlags{}.skipRegions().enableDebugInfo()); + llvm::dbgs() << "\n"); + + // Mark it as declare copyin. + addDeclareAttr(context, globalOp, acc::DataClause::acc_copyin); + + // TODO: May need to create the global constructor which does the mapping + // action. It is not yet clear if this is needed yet (since the globals + // might just end up in the GPU image without requiring mapping via + // runtime). + } + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp new file mode 100644 index 0000000..12efaf4 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp @@ -0,0 +1,237 @@ +//===- ACCImplicitRoutine.cpp - OpenACC Implicit Routine Transform -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements the implicit rules described in OpenACC specification +// for `Routine Directive` (OpenACC 3.4 spec, section 2.15.1). +// +// "If no explicit routine directive applies to a procedure whose definition +// appears in the program unit being compiled, then the implementation applies +// an implicit routine directive to that procedure if any of the following +// conditions holds: +// - The procedure is called or its address is accessed in a compute region." +// +// The specification further states: +// "When the implementation applies an implicit routine directive to a +// procedure, it must recursively apply implicit routine directives to other +// procedures for which the above rules specify relevant dependencies. Such +// dependencies can form a cycle, so the implementation must take care to avoid +// infinite recursion." +// +// This pass implements these requirements by: +// 1. Walking through all OpenACC compute constructs and functions already +// marked with `acc routine` in the module and identifying function calls +// within these regions. +// 2. Creating implicit `acc.routine` operations for functions that don't +// already have routine declarations. +// 3. Recursively walking through all existing `acc routine` and creating +// implicit routine operations for function calls within these routines, +// while avoiding infinite recursion through proper tracking. +// +// Requirements: +// ------------- +// To use this pass in a pipeline, the following requirements must be met: +// +// 1. Operation Interface Implementation: Operations that define functions +// or call functions should implement `mlir::FunctionOpInterface` and +// `mlir::CallOpInterface` respectively. +// +// 2. Analysis Registration (Optional): If custom behavior is needed for +// determining if a symbol use is valid within GPU regions, the dialect +// should pre-register the `acc::OpenACCSupport` analysis. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include <queue> + +#define DEBUG_TYPE "acc-implicit-routine" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCIMPLICITROUTINE +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +namespace { + +using namespace mlir; + +class ACCImplicitRoutine + : public acc::impl::ACCImplicitRoutineBase<ACCImplicitRoutine> { +private: + unsigned routineCounter = 0; + static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_"; + + // Count existing routine operations and update counter + void initRoutineCounter(ModuleOp module) { + module.walk([&](acc::RoutineOp routineOp) { routineCounter++; }); + } + + // Check if routine has a default bind clause or a device-type specific bind + // clause. Returns true if `acc routine` has a default bind clause or + // a device-type specific bind clause. + bool isACCRoutineBindDefaultOrDeviceType(acc::RoutineOp op, + acc::DeviceType deviceType) { + // Fast check to avoid device-type specific lookups. + if (!op.getBindIdName() && !op.getBindStrName()) + return false; + return op.getBindNameValue().has_value() || + op.getBindNameValue(deviceType).has_value(); + } + + // Generate a unique name for the routine and create the routine operation + acc::RoutineOp createRoutineOp(OpBuilder &builder, Location loc, + FunctionOpInterface &callee) { + std::string routineName = + (accRoutinePrefix + std::to_string(routineCounter++)).str(); + auto routineOp = acc::RoutineOp::create( + builder, loc, + /* sym_name=*/builder.getStringAttr(routineName), + /* func_name=*/ + mlir::SymbolRefAttr::get(builder.getContext(), + builder.getStringAttr(callee.getName())), + /* bindIdName=*/nullptr, + /* bindStrName=*/nullptr, + /* bindIdNameDeviceType=*/nullptr, + /* bindStrNameDeviceType=*/nullptr, + /* worker=*/nullptr, + /* vector=*/nullptr, + /* seq=*/nullptr, + /* nohost=*/nullptr, + /* implicit=*/builder.getUnitAttr(), + /* gang=*/nullptr, + /* gangDim=*/nullptr, + /* gangDimDeviceType=*/nullptr); + + // Assert that the callee does not already have routine info attribute + assert(!callee->hasAttr(acc::getRoutineInfoAttrName()) && + "function is already associated with a routine"); + + callee->setAttr( + acc::getRoutineInfoAttrName(), + mlir::acc::RoutineInfoAttr::get( + builder.getContext(), + {mlir::SymbolRefAttr::get(builder.getContext(), + builder.getStringAttr(routineName))})); + return routineOp; + } + + // Used to walk through a compute region looking for function calls. + void + implicitRoutineForCallsInComputeRegions(Operation *op, SymbolTable &symTab, + mlir::OpBuilder &builder, + acc::OpenACCSupport &accSupport) { + op->walk([&](CallOpInterface callOp) { + if (!callOp.getCallableForCallee()) + return; + + auto calleeSymbolRef = + dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee()); + // When call is done through ssa value, the callee is not a symbol. + // Skip it because we don't know the call target. + if (!calleeSymbolRef) + return; + + auto callee = symTab.lookup<FunctionOpInterface>( + calleeSymbolRef.getLeafReference().str()); + // If the callee does not exist or is already a valid symbol for GPU + // regions, skip it + + assert(callee && "callee function must be found in symbol table"); + if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef)) + return; + builder.setInsertionPoint(callee); + createRoutineOp(builder, callee.getLoc(), callee); + }); + } + + // Recursively handle calls within a routine operation + void implicitRoutineForCallsInRoutine(acc::RoutineOp routineOp, + mlir::OpBuilder &builder, + acc::OpenACCSupport &accSupport, + acc::DeviceType targetDeviceType) { + // When bind clause is used, it means that the target is different than the + // function to which the `acc routine` is used with. Skip this case to + // avoid implicitly recursively marking calls that would not end up on + // device. + if (isACCRoutineBindDefaultOrDeviceType(routineOp, targetDeviceType)) + return; + + SymbolTable symTab(routineOp->getParentOfType<ModuleOp>()); + std::queue<acc::RoutineOp> routineQueue; + routineQueue.push(routineOp); + while (!routineQueue.empty()) { + auto currentRoutine = routineQueue.front(); + routineQueue.pop(); + auto func = symTab.lookup<FunctionOpInterface>( + currentRoutine.getFuncName().getLeafReference()); + func.walk([&](CallOpInterface callOp) { + if (!callOp.getCallableForCallee()) + return; + + auto calleeSymbolRef = + dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee()); + // When call is done through ssa value, the callee is not a symbol. + // Skip it because we don't know the call target. + if (!calleeSymbolRef) + return; + + auto callee = symTab.lookup<FunctionOpInterface>( + calleeSymbolRef.getLeafReference().str()); + // If the callee does not exist or is already a valid symbol for GPU + // regions, skip it + assert(callee && "callee function must be found in symbol table"); + if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef)) + return; + builder.setInsertionPoint(callee); + auto newRoutineOp = createRoutineOp(builder, callee.getLoc(), callee); + routineQueue.push(newRoutineOp); + }); + } + } + +public: + using ACCImplicitRoutineBase<ACCImplicitRoutine>::ACCImplicitRoutineBase; + + void runOnOperation() override { + auto module = getOperation(); + mlir::OpBuilder builder(module.getContext()); + SymbolTable symTab(module); + initRoutineCounter(module); + + acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>(); + + // Handle compute regions + module.walk([&](Operation *op) { + if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op)) + implicitRoutineForCallsInComputeRegions(op, symTab, builder, + accSupport); + }); + + // Use the device type option from the pass options. + acc::DeviceType targetDeviceType = deviceType; + + // Handle existing routines + module.walk([&](acc::RoutineOp routineOp) { + implicitRoutineForCallsInRoutine(routineOp, builder, accSupport, + targetDeviceType); + }); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp new file mode 100644 index 0000000..f41ce276 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp @@ -0,0 +1,117 @@ +//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass converts acc.serial into acc.parallel with num_gangs(1) +// num_workers(1) vector_length(1). +// +// This transformation simplifies processing of acc regions by unifying the +// handling of serial and parallel constructs. Since an OpenACC serial region +// executes sequentially (like a parallel region with a single gang, worker, and +// vector), this conversion is semantically equivalent while enabling code reuse +// in later compilation stages. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCLEGALIZESERIAL +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-legalize-serial" + +namespace { +using namespace mlir; + +struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> { + using OpRewritePattern<acc::SerialOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::SerialOp serialOp, + PatternRewriter &rewriter) const override { + + const Location loc = serialOp.getLoc(); + + // Create a container holding the constant value of 1 for use as the + // num_gangs, num_workers, and vector_length attributes. + llvm::SmallVector<mlir::Value> numValues; + auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32); + numValues.push_back(value); + + // Since num_gangs is specified as both attributes and values, create a + // segment attribute. + llvm::SmallVector<int32_t> numGangsSegments; + numGangsSegments.push_back(numValues.size()); + auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments); + + // Create a device_type attribute set to `none` which ensures that + // the parallel dimensions specification applies to the default clauses. + llvm::SmallVector<mlir::Attribute> crtDeviceTypes; + auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + rewriter.getContext(), mlir::acc::DeviceType::None); + crtDeviceTypes.push_back(crtDeviceTypeAttr); + auto devTypeAttr = + mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes); + + LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n"); + + // Create a new acc.parallel op with the same operands - except include the + // num_gangs, num_workers, and vector_length attributes. + acc::ParallelOp parOp = acc::ParallelOp::create( + rewriter, loc, serialOp.getAsyncOperands(), + serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(), + serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(), + serialOp.getWaitOperandsDeviceTypeAttr(), + serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues, + gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues, + devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(), + serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(), + serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(), + serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(), + serialOp.getCombinedAttr()); + + parOp.getRegion().takeBody(serialOp.getRegion()); + + LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n"); + rewriter.replaceOp(serialOp, parOp); + + return success(); + } +}; + +class ACCLegalizeSerial + : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> { +public: + using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase; + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + patterns.insert<ACCSerialOpConversion>(context); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt index f8fff59..10a1796 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -1,5 +1,8 @@ add_mlir_dialect_library(MLIROpenACCTransforms ACCImplicitData.cpp + ACCImplicitDeclare.cpp + ACCImplicitRoutine.cpp + ACCLegalizeSerial.cpp LegalizeDataValues.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index fbac28e..7f27b44 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -9,8 +9,13 @@ #include "mlir/Dialect/OpenACC/OpenACCUtils.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/Support/Casting.h" mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { @@ -155,3 +160,109 @@ mlir::Value mlir::acc::getBaseEntity(mlir::Value val) { return val; } + +bool mlir::acc::isValidSymbolUse(mlir::Operation *user, + mlir::SymbolRefAttr symbol, + mlir::Operation **definingOpPtr) { + mlir::Operation *definingOp = + mlir::SymbolTable::lookupNearestSymbolFrom(user, symbol); + + // If there are no defining ops, we have no way to ensure validity because + // we cannot check for any attributes. + if (!definingOp) + return false; + + if (definingOpPtr) + *definingOpPtr = definingOp; + + // Check if the defining op is a recipe (private, reduction, firstprivate). + // Recipes are valid as they get materialized before being offloaded to + // device. They are only instructions for how to materialize. + if (mlir::isa<mlir::acc::PrivateRecipeOp, mlir::acc::ReductionRecipeOp, + mlir::acc::FirstprivateRecipeOp>(definingOp)) + return true; + + // Check if the defining op is a function + if (auto func = + mlir::dyn_cast_if_present<mlir::FunctionOpInterface>(definingOp)) { + // If this symbol is actually an acc routine - then it is expected for it + // to be offloaded - therefore it is valid. + if (func->hasAttr(mlir::acc::getRoutineInfoAttrName())) + return true; + + // If this symbol is a call to an LLVM intrinsic, then it is likely valid. + // Check the following: + // 1. The function is private + // 2. The function has no body + // 3. Name starts with "llvm." + // 4. The function's name is a valid LLVM intrinsic name + if (func.getVisibility() == mlir::SymbolTable::Visibility::Private && + func.getFunctionBody().empty() && func.getName().starts_with("llvm.") && + llvm::Intrinsic::lookupIntrinsicID(func.getName()) != + llvm::Intrinsic::not_intrinsic) + return true; + } + + // A declare attribute is needed for symbol references. + bool hasDeclare = definingOp->hasAttr(mlir::acc::getDeclareAttrName()); + return hasDeclare; +} + +llvm::SmallVector<mlir::Value> +mlir::acc::getDominatingDataClauses(mlir::Operation *computeConstructOp, + mlir::DominanceInfo &domInfo, + mlir::PostDominanceInfo &postDomInfo) { + llvm::SmallSetVector<mlir::Value, 8> dominatingDataClauses; + + llvm::TypeSwitch<mlir::Operation *>(computeConstructOp) + .Case<mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp>( + [&](auto op) { + for (auto dataClause : op.getDataClauseOperands()) { + dominatingDataClauses.insert(dataClause); + } + }) + .Default([](mlir::Operation *) {}); + + // Collect the data clauses from enclosing data constructs. + mlir::Operation *currParentOp = computeConstructOp->getParentOp(); + while (currParentOp) { + if (mlir::isa<mlir::acc::DataOp>(currParentOp)) { + for (auto dataClause : mlir::dyn_cast<mlir::acc::DataOp>(currParentOp) + .getDataClauseOperands()) { + dominatingDataClauses.insert(dataClause); + } + } + currParentOp = currParentOp->getParentOp(); + } + + // Find the enclosing function/subroutine + auto funcOp = + computeConstructOp->getParentOfType<mlir::FunctionOpInterface>(); + if (!funcOp) + return dominatingDataClauses.takeVector(); + + // Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that + // dominate and post-dominate the compute construct and add their data + // clauses to the list. + funcOp->walk([&](mlir::acc::DeclareEnterOp declareEnterOp) { + if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) { + // Collect all `acc.declare_exit` ops for this token. + llvm::SmallVector<mlir::acc::DeclareExitOp> exits; + for (auto *user : declareEnterOp.getToken().getUsers()) + if (auto declareExit = mlir::dyn_cast<mlir::acc::DeclareExitOp>(user)) + exits.push_back(declareExit); + + // Only add clauses if every `acc.declare_exit` op post-dominates the + // compute construct. + if (!exits.empty() && + llvm::all_of(exits, [&](mlir::acc::DeclareExitOp exitOp) { + return postDomInfo.postDominates(exitOp, computeConstructOp); + })) { + for (auto dataClause : declareEnterOp.getDataClauseOperands()) + dominatingDataClauses.insert(dataClause); + } + } + }); + + return dominatingDataClauses.takeVector(); +} diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 1b069c6..103295d 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -617,6 +617,7 @@ parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, break; case ClauseScheduleKind::Auto: case ClauseScheduleKind::Runtime: + case ClauseScheduleKind::Distribute: chunkSize = std::nullopt; } @@ -1817,6 +1818,9 @@ static ParseResult parseMapClause(OpAsmParser &parser, if (mapTypeMod == "ref_ptr_ptee") mapTypeBits |= ClauseMapFlags::ref_ptr_ptee; + if (mapTypeMod == "is_device_ptr") + mapTypeBits |= ClauseMapFlags::is_device_ptr; + return success(); }; @@ -1886,6 +1890,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op, mapTypeStrs.push_back("ref_ptee"); if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee)) mapTypeStrs.push_back("ref_ptr_ptee"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr)) + mapTypeStrs.push_back("is_device_ptr"); if (mapFlags == ClauseMapFlags::none) mapTypeStrs.push_back("none"); @@ -2824,6 +2830,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, ArrayRef<NamedAttribute> attributes) { build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), + /*linear_var_types*/ nullptr, /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr, /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr, /*private_needs_barrier=*/false, @@ -2842,8 +2849,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, WsloopOp::build( builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars, - clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod, - clauses.ordered, clauses.privateVars, + clauses.linearStepVars, clauses.linearVarTypes, clauses.nowait, + clauses.order, clauses.orderMod, clauses.ordered, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), @@ -2888,17 +2895,16 @@ LogicalResult WsloopOp::verifyRegions() { void SimdOp::build(OpBuilder &builder, OperationState &state, const SimdOperands &clauses) { MLIRContext *ctx = builder.getContext(); - // TODO Store clauses in op: linearVars, linearStepVars - SimdOp::build(builder, state, clauses.alignedVars, - makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr, - /*linear_vars=*/{}, /*linear_step_vars=*/{}, - clauses.nontemporalVars, clauses.order, clauses.orderMod, - clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.reductionMod, - clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, - clauses.simdlen); + SimdOp::build( + builder, state, clauses.alignedVars, + makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr, + clauses.linearVars, clauses.linearStepVars, clauses.linearVarTypes, + clauses.nontemporalVars, clauses.order, clauses.orderMod, + clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), + clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, + clauses.simdlen); } LogicalResult SimdOp::verify() { diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt index 423e1c3..b111117 100644 --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -19,5 +19,5 @@ add_mlir_dialect_library(MLIRSCFDialect MLIRSideEffectInterfaces MLIRTensorDialect MLIRValueBoundsOpInterface + MLIRTransformUtils ) - diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 881e256..c4bd31f 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -26,6 +26,7 @@ #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -3687,6 +3688,133 @@ LogicalResult scf::WhileOp::verify() { } namespace { +/// Move a scf.if op that is directly before the scf.condition op in the while +/// before region, and whose condition matches the condition of the +/// scf.condition op, down into the while after region. +/// +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// %res = scf.if %cond -> (...) { +/// use(%additional_used_values) +/// ... // then block +/// scf.yield %then_value +/// } else { +/// scf.yield %else_value +/// } +/// scf.condition(%cond) %res, ... +/// } do { +/// ^bb0(%res_arg, ...): +/// use(%res_arg) +/// ... +/// +/// becomes +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// scf.condition(%cond) %else_value, ..., %additional_used_values +/// } do { +/// ^bb0(%res_arg ..., %additional_args): : +/// use(%additional_args) +/// ... // if then block +/// use(%then_value) +/// ... +struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> { + using OpRewritePattern<scf::WhileOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + auto conditionOp = op.getConditionOp(); + + // Only support ifOp right before the condition at the moment. Relaxing this + // would require to: + // - check that the body does not have side-effects conflicting with + // operations between the if and the condition. + // - check that results of the if operation are only used as arguments to + // the condition. + auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode()); + + // Check that the ifOp is directly before the conditionOp and that it + // matches the condition of the conditionOp. Also ensure that the ifOp has + // no else block with content, as that would complicate the transformation. + // TODO: support else blocks with content. + if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() || + (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty())) + return failure(); + + assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) && + *ifOp->user_begin() == conditionOp)) && + "ifOp has unexpected uses"); + + Location loc = op.getLoc(); + + // Replace uses of ifOp results in the conditionOp with the yielded values + // from the ifOp branches. + for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) { + auto it = llvm::find(ifOp->getResults(), arg); + if (it != ifOp->getResults().end()) { + size_t ifOpIdx = it.getIndex(); + Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx); + Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx); + + rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue); + rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue); + } + } + + // Collect additional used values from before region. + SetVector<Value> additionalUsedValuesSet; + visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) { + if (&op.getBefore() == operand->get().getParentRegion()) + additionalUsedValuesSet.insert(operand->get()); + }); + + // Create new whileOp with additional used values as results. + auto additionalUsedValues = additionalUsedValuesSet.getArrayRef(); + auto additionalValueTypes = llvm::map_to_vector( + additionalUsedValues, [](Value val) { return val.getType(); }); + size_t additionalValueSize = additionalUsedValues.size(); + SmallVector<Type> newResultTypes(op.getResultTypes()); + newResultTypes.append(additionalValueTypes); + + auto newWhileOp = + scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits()); + + rewriter.modifyOpInPlace(newWhileOp, [&] { + newWhileOp.getBefore().takeBody(op.getBefore()); + newWhileOp.getAfter().takeBody(op.getAfter()); + newWhileOp.getAfter().addArguments( + additionalValueTypes, + SmallVector<Location>(additionalValueSize, loc)); + }); + + rewriter.modifyOpInPlace(conditionOp, [&] { + conditionOp.getArgsMutable().append(additionalUsedValues); + }); + + // Replace uses of additional used values inside the ifOp then region with + // the whileOp after region arguments. + rewriter.replaceUsesWithIf( + additionalUsedValues, + newWhileOp.getAfterArguments().take_back(additionalValueSize), + [&](OpOperand &use) { + return ifOp.getThenRegion().isAncestor( + use.getOwner()->getParentRegion()); + }); + + // Inline ifOp then region into new whileOp after region. + rewriter.eraseOp(ifOp.thenYield()); + rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(), + newWhileOp.getAfterBody()->begin()); + rewriter.eraseOp(ifOp); + rewriter.replaceOp(op, + newWhileOp->getResults().drop_back(additionalValueSize)); + return success(); + } +}; + /// Replace uses of the condition within the do block with true, since otherwise /// the block would not be evaluated. /// @@ -4343,7 +4471,7 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { LogicalResult matchAndRewrite(WhileOp loop, PatternRewriter &rewriter) const override { - auto oldBefore = loop.getBeforeBody(); + auto *oldBefore = loop.getBeforeBody(); ConditionOp oldTerm = loop.getConditionOp(); ValueRange beforeArgs = oldBefore->getArguments(); ValueRange termArgs = oldTerm.getArgs(); @@ -4364,7 +4492,7 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { beforeArgs); } - auto oldAfter = loop.getAfterBody(); + auto *oldAfter = loop.getAfterBody(); SmallVector<Type> newResultTypes(beforeArgs.size()); for (auto &&[i, j] : llvm::enumerate(*mapping)) @@ -4373,8 +4501,8 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { auto newLoop = WhileOp::create( rewriter, loop.getLoc(), newResultTypes, loop.getInits(), /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr); - auto newBefore = newLoop.getBeforeBody(); - auto newAfter = newLoop.getAfterBody(); + auto *newBefore = newLoop.getBeforeBody(); + auto *newAfter = newLoop.getAfterBody(); SmallVector<Value> newResults(beforeArgs.size()); SmallVector<Value> newAfterArgs(beforeArgs.size()); @@ -4399,7 +4527,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add<RemoveLoopInvariantArgsFromBeforeBlock, RemoveLoopInvariantValueYielded, WhileConditionTruth, WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, - WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context); + WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 29b770f..009c2c3 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest( for (auto [outerLoop, innerLoop] : llvm::zip_equal(loops.drop_back(), loops.drop_front())) { // Again assume that all the outer loops are scf.for operations. - auto outerForLoop = cast<scf::ForOp>(outerLoop); + auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation()); auto outerLoopYield = cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); SmallVector<Value> newYields = @@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter, return clonedSlices; } -/// Implementation of fusing consumer of a single slice by computing the -/// slice of the consumer in-place for scf loop. -FailureOr<scf::SCFFuseConsumerOfSliceResult> -mlir::scf::tileAndFuseConsumerOfSlices( - RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, - MutableArrayRef<LoopLikeOpInterface> loops) { - if (candidateSlices.empty()) { - return rewriter.notifyMatchFailure( - rewriter.getUnknownLoc(), - "no candidate slices provided for consumer fusion"); - } - // Return if `loops` is empty, return an error for now. Caller is expected - // to handle this case. - if (loops.empty()) { - return rewriter.notifyMatchFailure( - candidateSlices.front(), - "cannot call tile and fuse consumer with an empty loop nest"); - } +static FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, + ArrayRef<OpOperand *> consumerOpOperands, + ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "expected loops to be not empty"); - if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || - llvm::all_of(candidateSlices, - llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + // 1. Check assumption for loop with `reorderOperations` disabled. + if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) { return rewriter.notifyMatchFailure( - candidateSlices.front(), - "candidates slices need to be all `tensor.extract_slice`s or " - "`tensor.parallel_insert_slice`s"); - } - - // 1. Get the consumer of scf.for for the result yielded by - // tensor.insert_slice/parallel_insert_slice. - SmallVector<OpOperand *> consumerOpOperands; - Operation *consumerOp; - { - FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = - getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); - if (failed(maybeConsumerOpOperand)) { - return rewriter.notifyMatchFailure(candidateSlices.front(), - "could not fetch consumer to fuse"); - } - std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); - consumerOp = consumerOpOperands.front()->getOwner(); + loops.front(), "the first user of loop should not dominate any define " + "of consumer operand(s)"); } LoopLikeOpInterface outerMostLoop = loops.front(); LoopLikeOpInterface innerMostLoop = loops.back(); - // Check assumption for loop with `reorderOperations` disabled. - if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { - return rewriter.notifyMatchFailure( - outerMostLoop, "the first user of loop should not dominate any define " - "of consumer operand(s)"); - } - OpBuilder::InsertionGuard g(rewriter); - // 2. Check consumer is not using scf loop's output as init. auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); if (!dstOp) @@ -2428,11 +2391,166 @@ mlir::scf::tileAndFuseConsumerOfSlices( llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum); }); + auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands); return scf::SCFFuseConsumerOfSliceResult{ - std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands), + std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands), std::move(tileAndFuseResult->tiledOps)}; } +/// Implementation of fusing consumer of a single slice by computing the +/// slice of the consumer in-place for scf loop. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumerOfSlices( + RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + if (candidateSlices.empty()) { + return rewriter.notifyMatchFailure( + rewriter.getUnknownLoc(), + "no candidate slices provided for consumer fusion"); + } + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "cannot call tile and fuse consumer with an empty loop nest"); + } + + if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || + llvm::all_of(candidateSlices, + llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "candidates slices need to be all `tensor.extract_slice`s or " + "`tensor.parallel_insert_slice`s"); + } + + // Get the consumer of scf.for for the result yielded by + // tensor.insert_slice/parallel_insert_slice. + FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands = + getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); + if (failed(maybeConsumerOpOperands)) { + return rewriter.notifyMatchFailure(candidateSlices.front(), + "could not fetch consumer to fuse"); + } + Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner(); + + return tileAndFuseConsumerOfSlicesImpl(rewriter, consumerOp, + maybeConsumerOpOperands.value(), + candidateSlices, loops); +} + +/// For a given `result` of a `forallOp` return the +/// `tensor.parallel_insert_slice` op (or combining op) that is used to +/// construct this result. +static std::optional<Operation *> +getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) { + if (result.getOwner() != forallOp) + return std::nullopt; + BlockArgument bbArg = forallOp.getTiedBlockArgument(result); + SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg); + // If the number of combining ops is not 1, then this is unexpected. Return + // nullopt. + if (combiningOps.size() != 1) + return std::nullopt; + return combiningOps[0]; +} + +/// For a given result of the loop nest that is a tiled loop nest, return the +/// insert slice-like op that is used for consumer fusion +static std::optional<Operation *> +getProducingInsertSliceLikeOp(OpResult result, + ArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "Expected loops to be not empty"); + LoopLikeOpInterface outerMostLoop = loops.front(); + if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) { + assert(loops.size() == 1 && + "expected only a single loop when tiling using scf.forall"); + return getProducingParallelInsertSlice(forallOp, result); + } + // Assume that the loop nest is a nested `scf.for` that is created through + // tiling and retrieve the `tensor.insert_slice` operation used to construct + // the result. + while (loops.size() != 1) { + LoopLikeOpInterface loop = loops.front(); + if (result.getOwner() != loop) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + auto innerForResult = + dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber())); + if (!innerForResult) + return std::nullopt; + result = innerForResult; + loops = loops.drop_front(); + } + LoopLikeOpInterface loop = loops.front(); + if (result.getOwner() != loop) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + auto insertSliceOp = yieldOp.getOperand(result.getResultNumber()) + .getDefiningOp<tensor::InsertSliceOp>(); + if (!insertSliceOp) + return std::nullopt; + return insertSliceOp; +} + +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer, + MutableArrayRef<LoopLikeOpInterface> loops) { + if (!isa<TilingInterface>(consumer)) { + return rewriter.notifyMatchFailure( + consumer, "unhandled consumer that does not implement TilingInterface"); + } + + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + consumer, "cannot call tile and fuse consumer with an empty loop nest"); + } + + LoopLikeOpInterface outermostLoop = loops.front(); + + // Collect the operands of the consumer that come from the outermost loop of + // the loop nest. + SmallVector<OpOperand *> consumerFusableOperands; + for (OpOperand &opOperand : consumer->getOpOperands()) { + if (opOperand.get().getDefiningOp() == outermostLoop) { + consumerFusableOperands.push_back(&opOperand); + } + } + + // Nothing to fuse. Just return an empty set. + if (consumerFusableOperands.empty()) { + return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands, + SmallVector<OpOperand *>{}, + SmallVector<Operation *>{}}; + } + + // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices + // for fusion. + SmallVector<Operation *> candidateSlices; + candidateSlices.reserve(consumerFusableOperands.size()); + for (OpOperand *opOperand : consumerFusableOperands) { + std::optional<Operation *> slice = + getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops); + if (!slice) { + return rewriter.notifyMatchFailure( + consumer, + "couldnt find producing insert-slice like operation for operand"); + } + candidateSlices.push_back(slice.value()); + } + return tileAndFuseConsumerOfSlicesImpl( + rewriter, consumer, consumerFusableOperands, candidateSlices, loops); +} + //===----------------------------------------------------------------------===// // lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index f0b46e6..a846d7e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -220,6 +220,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() { } //===----------------------------------------------------------------------===// +// spirv.Switch +//===----------------------------------------------------------------------===// + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + DenseIntElementsAttr literals, BlockRange targets, + ArrayRef<ValueRange> targetOperands) { + build(builder, result, selector, defaultOperands, targetOperands, literals, + defaultTarget, targets); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + ArrayRef<APInt> literals, BlockRange targets, + ArrayRef<ValueRange> targetOperands) { + DenseIntElementsAttr literalsAttr; + if (!literals.empty()) { + ShapedType literalType = VectorType::get( + static_cast<int64_t>(literals.size()), selector.getType()); + literalsAttr = DenseIntElementsAttr::get(literalType, literals); + } + build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr, + targets, targetOperands); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + ArrayRef<int32_t> literals, BlockRange targets, + ArrayRef<ValueRange> targetOperands) { + DenseIntElementsAttr literalsAttr; + if (!literals.empty()) { + ShapedType literalType = VectorType::get( + static_cast<int64_t>(literals.size()), selector.getType()); + literalsAttr = DenseIntElementsAttr::get(literalType, literals); + } + build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr, + targets, targetOperands); +} + +LogicalResult SwitchOp::verify() { + std::optional<DenseIntElementsAttr> literals = getLiterals(); + BlockRange targets = getTargets(); + + if (!literals && targets.empty()) + return success(); + + Type selectorType = getSelector().getType(); + Type literalType = literals->getType().getElementType(); + if (literalType != selectorType) + return emitOpError() << "'selector' type (" << selectorType + << ") should match literals type (" << literalType + << ")"; + + if (literals && literals->size() != static_cast<int64_t>(targets.size())) + return emitOpError() << "number of literals (" << literals->size() + << ") should match number of targets (" + << targets.size() << ")"; + return success(); +} + +SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() + : getTargetOperandsMutable(index - 1)); +} + +Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { + std::optional<DenseIntElementsAttr> literals = getLiterals(); + + if (!literals) + return getDefaultTarget(); + + SuccessorRange targets = getTargets(); + if (auto value = dyn_cast_or_null<IntegerAttr>(operands.front())) { + for (auto [index, literal] : llvm::enumerate(literals->getValues<APInt>())) + if (literal == value.getValue()) + return targets[index]; + return getDefaultTarget(); + } + return nullptr; +} + +//===----------------------------------------------------------------------===// // spirv.mlir.loop //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp index 2f3a28f..8575487 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp @@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, } } +/// Adapted from the cf.switch implementation. +/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? +/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* +static ParseResult parseSwitchOpCases( + OpAsmParser &parser, Type &selectorType, Block *&defaultTarget, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, + SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals, + SmallVectorImpl<Block *> &targets, + SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> + &targetOperands, + SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) { + if (parser.parseKeyword("default") || parser.parseColon() || + parser.parseSuccessor(defaultTarget)) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, + /*allowResultNumber=*/false) || + parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) + return failure(); + } + + SmallVector<APInt> values; + unsigned bitWidth = selectorType.getIntOrFloatBitWidth(); + while (succeeded(parser.parseOptionalComma())) { + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + values.push_back(APInt(bitWidth, value, /*isSigned=*/true)); + + Block *target; + SmallVector<OpAsmParser::UnresolvedOperand> operands; + SmallVector<Type> operandTypes; + if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target))) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parser.parseOperandList(operands, + OpAsmParser::Delimiter::None)) || + failed(parser.parseColonTypeList(operandTypes)) || + failed(parser.parseRParen())) + return failure(); + } + targets.push_back(target); + targetOperands.emplace_back(operands); + targetOperandTypes.emplace_back(operandTypes); + } + + if (!values.empty()) { + ShapedType literalType = + VectorType::get(static_cast<int64_t>(values.size()), selectorType); + literals = DenseIntElementsAttr::get(literalType, values); + } + return success(); +} + +static void +printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType, + Block *defaultTarget, OperandRange defaultOperands, + TypeRange defaultOperandTypes, DenseIntElementsAttr literals, + SuccessorRange targets, OperandRangeRange targetOperands, + const TypeRangeRange &targetOperandTypes) { + p << " default: "; + p.printSuccessorAndUseList(defaultTarget, defaultOperands); + + if (!literals) + return; + + for (auto [index, literal] : llvm::enumerate(literals.getValues<APInt>())) { + p << ','; + p.printNewline(); + p << " "; + p << literal.getLimitedValue(); + p << ": "; + p.printSuccessorAndUseList(targets[index], targetOperands[index]); + } + p.printNewline(); +} + } // namespace mlir::spirv // TablenGen'erated operation definitions. diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index cb9b7f6..f07307f 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -502,6 +502,11 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, << type << " illegal: cannot handle zero-element tensors\n"); return nullptr; } + if (arrayElemCount > std::numeric_limits<unsigned>::max()) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot fit tensor into target type\n"); + return nullptr; + } Type arrayElemType = convertScalarType(targetEnv, options, scalarType); if (!arrayElemType) diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index 645cbff..5941f7d 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -476,38 +476,37 @@ void GridShapeOp::getAsmResultNames( //===----------------------------------------------------------------------===// void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - FlatSymbolRefAttr grid, - ArrayRef<GridAxesAttr> split_axes, - ArrayRef<int64_t> static_halos, - ArrayRef<int64_t> static_offsets) { + FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes, + ArrayRef<int64_t> staticHalos, + ArrayRef<int64_t> staticOffsets) { return build( - b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes), - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); + b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes), + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {}); } void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes, - ArrayRef<int64_t> static_halos, - ArrayRef<int64_t> static_offsets) { + llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes, + ArrayRef<int64_t> staticHalos, + ArrayRef<int64_t> staticOffsets) { return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid), - GridAxesArrayAttr::get(b.getContext(), split_axes), - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), + GridAxesArrayAttr::get(b.getContext(), splitAxes), + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {}); } void ShardingOp::build( ::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes, - ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes, - ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) { + FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes, + ::mlir::ArrayRef<::mlir::OpFoldResult> haloSizes, + ::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets) { mlir::SmallVector<int64_t> staticHalos, staticDims; mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims; - dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos); - dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims); + dispatchIndexOpFoldResults(haloSizes, dynamicHalos, staticHalos); + dispatchIndexOpFoldResults(shardedDimsOffsets, dynamicDims, staticDims); return build( - b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes), + b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes), ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos, ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims); } @@ -576,7 +575,7 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return failure(); } if (mlir::ShapedType::isDynamicShape(grid->getShape()) && - getStaticShardedDimsOffsets().size() > 0) { + !getStaticShardedDimsOffsets().empty()) { return emitError() << "sharded dims offsets are not allowed for " "device grids with dynamic shape."; } @@ -650,14 +649,14 @@ public: if (dynamicOffs.empty() && !staticOffs.empty()) { assert(staticOffs.size() >= 2); auto diff = staticOffs[1] - staticOffs[0]; - bool all_same = staticOffs.size() > 2; + bool allSame = staticOffs.size() > 2; for (auto i = 2u; i < staticOffs.size(); ++i) { if (staticOffs[i] - staticOffs[i - 1] != diff) { - all_same = false; + allSame = false; break; } } - if (all_same) { + if (allSame) { staticOffs.clear(); modified = true; } @@ -749,7 +748,7 @@ bool Sharding::operator==(const Sharding &rhs) const { bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); } -Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {} +Sharding::Sharding(::mlir::FlatSymbolRefAttr grid) : grid(grid) {} Sharding::Sharding(Value rhs) { auto shardingOp = rhs.getDefiningOp<ShardingOp>(); @@ -767,21 +766,20 @@ Sharding::Sharding(Value rhs) { SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets())); } -Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_, - ArrayRef<GridAxesAttr> split_axes_, - ArrayRef<int64_t> static_halo_sizes_, - ArrayRef<int64_t> static_sharded_dims_offsets_, - ArrayRef<Value> dynamic_halo_sizes_, - ArrayRef<Value> dynamic_sharded_dims_offsets_) { - Sharding res(grid_); - if (split_axes_.empty()) { +Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid, + ArrayRef<GridAxesAttr> splitAxes, + ArrayRef<int64_t> staticHaloSizes, + ArrayRef<int64_t> staticShardedDimsOffsets, + ArrayRef<Value> dynamicHaloSizes, + ArrayRef<Value> dynamicShardedDimsOffsets) { + Sharding res(grid); + if (splitAxes.empty()) { return res; } - res.split_axes.resize(split_axes_.size()); - for (auto [i, axis] : llvm::enumerate(split_axes_)) { - res.split_axes[i] = - GridAxesAttr::get(grid_.getContext(), axis.asArrayRef()); + res.split_axes.resize(splitAxes.size()); + for (auto [i, axis] : llvm::enumerate(splitAxes)) { + res.split_axes[i] = GridAxesAttr::get(grid.getContext(), axis.asArrayRef()); } auto clone = [](const auto src, auto &dst) { @@ -789,10 +787,10 @@ Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_, llvm::copy(src, dst.begin()); }; - clone(static_halo_sizes_, res.static_halo_sizes); - clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets); - clone(dynamic_halo_sizes_, res.dynamic_halo_sizes); - clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets); + clone(staticHaloSizes, res.static_halo_sizes); + clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets); + clone(dynamicHaloSizes, res.dynamic_halo_sizes); + clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets); return res; } @@ -809,10 +807,10 @@ void ShardShapeOp::getAsmResultNames( void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<int64_t> dims, - ArrayRef<Value> dims_dyn, ::mlir::Value sharding, + ArrayRef<Value> dimsDyn, ::mlir::Value sharding, ::mlir::ValueRange device) { SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType()); - build(odsBuilder, odsState, resType, dims, dims_dyn, sharding, + build(odsBuilder, odsState, resType, dims, dimsDyn, sharding, SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device); } diff --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp index 3bfbf373..f954131 100644 --- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp @@ -184,7 +184,7 @@ ReshardingRquirementKind getReshardingRquirementKind( for (auto [result, sharding] : llvm::zip_equal(op->getResults(), resultShardings)) { - for (auto user : result.getUsers()) { + for (auto *user : result.getUsers()) { ShardOp shardOp = llvm::dyn_cast<ShardOp>(user); if (!shardOp) { continue; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index ae7eef2..9db9814 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -1365,8 +1365,8 @@ public: arith::SubIOp::create(rewriter, loc, capacity, newSize); Value fillValue = constantZero(rewriter, loc, value.getType()); Value subBuffer = memref::SubViewOp::create( - rewriter, loc, newBuffer, /*offset=*/ValueRange{newSize}, - /*size=*/ValueRange{fillSize}, + rewriter, loc, newBuffer, /*offsets=*/ValueRange{newSize}, + /*sizes=*/ValueRange{fillSize}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); linalg::FillOp::create(rewriter, loc, fillValue, subBuffer); } @@ -1386,8 +1386,8 @@ public: memref::StoreOp::create(rewriter, loc, value, buffer, size); } else { Value subBuffer = memref::SubViewOp::create( - rewriter, loc, buffer, /*offset=*/ValueRange{size}, - /*size=*/ValueRange{n}, + rewriter, loc, buffer, /*offsets=*/ValueRange{size}, + /*sizes=*/ValueRange{n}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); linalg::FillOp::create(rewriter, loc, value, subBuffer); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index febec6d..23436a6 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -132,8 +132,8 @@ static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, SmallVector<Value> scalarArgs(idxs); Value indexVec = idxs.back(); scalarArgs.back() = constantIndex(rewriter, loc, 0); - vector::ScatterOp::create(rewriter, loc, mem, scalarArgs, indexVec, vmask, - rhs); + vector::ScatterOp::create(rewriter, loc, /*resultType=*/nullptr, mem, + scalarArgs, indexVec, vmask, rhs); return; } vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index ffa8b40..9904803 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -80,6 +80,53 @@ inline static bool includesDenseOutput(SortMask mask) { return includesAny(mask, SortMask::kIncludeDenseOutput); } +/// Returns a sparsity rank for loop ordering: lower values indicate +/// dimensions that should be placed in outer loops. +/// 0 = Dense, 1 = Compressed, 2 = Singleton, 3 = Other/Unknown. +static unsigned getLoopSparsityRank(unsigned loop, ArrayRef<Value> allTensors, + ArrayRef<AffineMap> allMaps) { + // Start with highest rank. + unsigned minRank = 3; + + for (auto [tensor, map] : llvm::zip(allTensors, allMaps)) { + // Check if this loop accesses this tensor. + bool loopAccessesTensor = false; + unsigned tensorDim = 0; + for (AffineExpr expr : map.getResults()) { + if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { + if (dimExpr.getPosition() == loop) { + loopAccessesTensor = true; + break; + } + } + tensorDim++; + } + + if (loopAccessesTensor) { + const auto enc = getSparseTensorEncoding(tensor.getType()); + if (!enc) { + // Dense tensor - lowest rank. + return 0; + } else { + // Sparse tensor - check the level type for this dimension. + auto lvlTypes = enc.getLvlTypes(); + if (tensorDim < lvlTypes.size()) { + auto lvlType = lvlTypes[tensorDim]; + if (isDenseLT(lvlType)) { + return 0; // Dense level. + } else if (isCompressedLT(lvlType)) { + minRank = std::min(minRank, 1u); // Compressed level. + } else if (isSingletonLT(lvlType)) { + minRank = std::min(minRank, 2u); // Singleton level. + } + } + } + } + } + + return minRank; +} + AffineMap IterationGraphSorter::topoSort() { // The sorted result will put the first Reduction iterator to the // latest possible position. @@ -107,10 +154,33 @@ AffineMap IterationGraphSorter::topoSort() { case sparse_tensor::LoopOrderingStrategy::kDefault: src = it.back(); break; + case sparse_tensor::LoopOrderingStrategy::kDenseOuter: { + // Prefer dense, then compressed, then singleton dimensions outermost. + // Create combined tensor and map lists for analysis. + SmallVector<Value> allTensors = ins; + allTensors.push_back(out); + SmallVector<AffineMap> allMaps = loop2InsLvl; + allMaps.push_back(loop2OutLvl); + + // Find loop with minimum (lowest) sparsity rank. + unsigned minLoop = it[0]; + unsigned minRank = getLoopSparsityRank(minLoop, allTensors, allMaps); + + for (auto candidateLoop : it) { + unsigned rank = getLoopSparsityRank(candidateLoop, allTensors, allMaps); + if (rank < minRank || (rank == minRank && candidateLoop < minLoop)) { + minLoop = candidateLoop; + minRank = rank; + } + } + src = minLoop; + break; + } } loopOrder.push_back(src); - it.pop_back(); + // Remove the selected loop from the worklist. + it.erase(std::find(it.begin(), it.end(), src)); // Update in-degree, and push 0-degree node into worklist. for (unsigned dst = 0; dst < numLoops; dst++) { if (itGraph[src][dst] && --inDegree[dst] == 0) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h index 3636f3f..46378b9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h @@ -197,7 +197,7 @@ public: // Sets the iterate to the specified position. void seek(ValueRange vals) { assert(vals.size() == cursorValsCnt); - std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin()); + llvm::copy(vals, cursorValsStorageRef.begin()); // Now that the iterator is re-positioned, the coordinate becomes invalid. crd = nullptr; } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp index 4ec13e1..686f6ee 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -77,6 +77,9 @@ namespace { struct ReifyExpandShapeOp : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp, ExpandShapeOp> { + using Base = + ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp, + ExpandShapeOp>; LogicalResult reifyResultShapes(Operation *op, OpBuilder &b, ReifiedRankedShapedTypeDims &reifyResultShapes) const { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 110bfdc..204e9bb 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -551,9 +551,7 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) { assert(!inputTypes.empty() && "cannot concatenate 0 tensors"); auto tensorTypes = - llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) { - return llvm::cast<RankedTensorType>(type); - })); + llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>); int64_t concatRank = tensorTypes[0].getRank(); // The concatenation dim must be in the range [0, rank). @@ -2293,9 +2291,9 @@ void ExtractSliceOp::getAsmResultNames( /// An extract_slice result type can be inferred, when it is not /// rank-reduced, from the source type and the static representation of /// offsets, sizes and strides. Special sentinels encode the dynamic case. -RankedTensorType ExtractSliceOp::inferResultType( - RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets, - ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) { +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType, + ArrayRef<int64_t> staticSizes) { // An extract_slice op may specify only a leading subset of offset/sizes/ // strides in which case we complete with offset=0, sizes from memref type // and strides=1. @@ -2307,11 +2305,12 @@ RankedTensorType ExtractSliceOp::inferResultType( } // TODO: This uses neither offsets nor strides! -RankedTensorType ExtractSliceOp::inferResultType( - RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets, - ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType, + ArrayRef<OpFoldResult> sizes) { SmallVector<int64_t> staticSizes; std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes); + assert(static_cast<int64_t>(staticSizes.size()) == sourceTensorType.getRank() && "unexpected staticSizes not equal to rank of source"); @@ -2329,11 +2328,10 @@ RankedTensorType ExtractSliceOp::inferResultType( /// To disambiguate, this function always drops the first 1 sizes occurrences. RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, - ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, - ArrayRef<int64_t> strides) { + ArrayRef<int64_t> sizes) { // Type inferred in the absence of rank-reducing behavior. auto inferredType = llvm::cast<RankedTensorType>( - inferResultType(sourceRankedTensorType, offsets, sizes, strides)); + inferResultType(sourceRankedTensorType, sizes)); int rankDiff = inferredType.getRank() - desiredResultRank; if (rankDiff > 0) { auto shape = inferredType.getShape(); @@ -2352,16 +2350,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, - ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, - ArrayRef<OpFoldResult> strides) { - SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; - SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + ArrayRef<OpFoldResult> sizes) { + SmallVector<int64_t> staticSizes; + SmallVector<Value> dynamicSizes; dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); return ExtractSliceOp::inferCanonicalRankReducedResultType( - desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes, - staticStrides); + desiredResultRank, sourceRankedTensorType, staticSizes); } /// Build an ExtractSliceOp with mixed static and dynamic entries and custom @@ -2380,8 +2374,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType()); // Structuring implementation this way avoids duplication between builders. if (!resultType) { - resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType( - sourceRankedTensorType, staticOffsets, staticSizes, staticStrides)); + resultType = llvm::cast<RankedTensorType>( + ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes)); } result.addAttributes(attrs); build(b, result, resultType, source, dynamicOffsets, dynamicSizes, @@ -2451,13 +2445,26 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, } } +/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred +/// result type, offsets set to 0 and strides set to 1. +void ExtractSliceOp::build(OpBuilder &b, OperationState &result, + RankedTensorType resultType, Value source, + ArrayRef<OpFoldResult> sizes, + ArrayRef<NamedAttribute> attrs) { + Attribute zeroIdxAttr = b.getIndexAttr(0); + Attribute oneIdxAttr = b.getIndexAttr(1); + SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr); + SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr); + build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs); +} + /// Verifier for ExtractSliceOp. LogicalResult ExtractSliceOp::verify() { RankedTensorType sourceType = getSourceType(); // Verify result type against inferred type. - RankedTensorType expectedType = ExtractSliceOp::inferResultType( - sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides()); + RankedTensorType expectedType = + ExtractSliceOp::inferResultType(sourceType, getMixedSizes()); SliceVerificationResult result = isRankReducedType(expectedType, getType()); if (result != SliceVerificationResult::Success) return produceSliceErrorMsg(result, *this, expectedType); @@ -2697,8 +2704,7 @@ struct SliceReturnTypeCanonicalizer { ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { return ExtractSliceOp::inferCanonicalRankReducedResultType( - op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes, - mixedStrides); + op.getType().getRank(), op.getSourceType(), mixedSizes); } }; @@ -2839,8 +2845,8 @@ static SliceVerificationResult verifyInsertSliceOp( ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) { // insert_slice is the inverse of extract_slice, use the same type // inference. - RankedTensorType expected = ExtractSliceOp::inferResultType( - dstType, staticOffsets, staticSizes, staticStrides); + RankedTensorType expected = + ExtractSliceOp::inferResultType(dstType, staticSizes); if (expectedType) *expectedType = expected; return isRankReducedType(expected, srcType); @@ -2968,7 +2974,7 @@ public: // Create the new op in canonical form. auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType( insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(), - mixedOffsets, mixedSizes, mixedStrides); + mixedSizes); Value toInsert = insertSliceOp.getSource(); if (sourceType != insertSliceOp.getSourceType()) { OpBuilder::InsertionGuard g(rewriter); @@ -3896,6 +3902,18 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set +// to 0, strides set to 1 and inferred result type. +void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, + Value dest, ArrayRef<OpFoldResult> sizes, + ArrayRef<NamedAttribute> attrs) { + Attribute zeroIdxAttr = b.getIndexAttr(0); + Attribute oneIdxAttr = b.getIndexAttr(1); + SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr); + SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr); + build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs); +} + LogicalResult ParallelInsertSliceOp::verify() { if (!isa<InParallelOpInterface>(getOperation()->getParentOp())) return this->emitError("expected InParallelOpInterface parent, got:") diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index c607ece..310e725 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1132,35 +1132,22 @@ struct ConcatOpInterface // Extract the dimension for the concat op uint64_t concatDim = concatOp.getDim(); - bool dynamicConcatDim = false; SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1)); - SmallVector<OpFoldResult> sizes; - - for (const auto &[dimIdx, dimSize] : - llvm::enumerate(tensorType.getShape())) { - if (dimSize == ShapedType::kDynamic) { - auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx); - sizes.push_back(dimOp.getResult()); - if (dimIdx == concatDim) - dynamicConcatDim = true; - } else { - sizes.push_back(rewriter.getIndexAttr(dimSize)); - } - } - - int64_t concatDimOffset = 0; - std::optional<Value> dynamicOffset; - std::optional<Value> dynamicSize; - if (dynamicConcatDim) { - // One or more operands have dynamic size, so we must accumulate the - // offset with arith ops. - dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); - } + SmallVector<OpFoldResult> sizes = + memref::getMixedSizes(rewriter, loc, dstBuffer); + + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + auto sum = [&](OpFoldResult v1, OpFoldResult v2) { + return affine::makeComposedFoldedAffineApply(rewriter, loc, s0 + s1, + {v1, v2}); + }; + OpFoldResult concatDimOffset = rewriter.getIndexAttr(0); for (auto operand : concatOp.getInputs()) { // Get the buffer for the operand. FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state); @@ -1171,18 +1158,10 @@ struct ConcatOpInterface // so the offset on that axis must accumulate through the loop, and the // size must change to the size of the current operand. auto operandTensorType = cast<RankedTensorType>(operand.getType()); - int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim); - - if (dynamicConcatDim) { - offsets[concatDim] = dynamicOffset.value(); - dynamicSize = - memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim) - .getResult(); - sizes[concatDim] = dynamicSize.value(); - } else { - sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize); - offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset); - } + offsets[concatDim] = concatDimOffset; + OpFoldResult concatDimSize = + memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim); + sizes[concatDim] = concatDimSize; // Create a subview of the destination buffer. auto dstMemrefType = cast<MemRefType>(memrefType); @@ -1197,12 +1176,7 @@ struct ConcatOpInterface if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview))) return failure(); - if (dynamicConcatDim) { - dynamicOffset = arith::AddIOp::create( - rewriter, loc, dynamicOffset.value(), dynamicSize.value()); - } else { - concatDimOffset += operandConcatDimSize; - } + concatDimOffset = sum(concatDimOffset, concatDimSize); } replaceOpWithBufferizedValues(rewriter, op, dstBuffer); diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 7ec61c7..a53af98 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract // supported. Moreover, only simple cases where the resulting ExtractSliceOp // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( - srcType, extractSliceOp.getStaticOffsets(), - extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); + srcType, extractSliceOp.getStaticSizes()); if (nonReducingExtractType != resultType) return failure(); @@ -533,8 +532,8 @@ LogicalResult mlir::tensor::getCollapsedExtractSliceInfo( getMixedSizes(b, loc, sliceOp.getSource()); // Helper variables and function for accumulating the size values. - AffineExpr d0, d1, d2; - bindDims(b.getContext(), d0, d1, d2); + AffineExpr d0, d1; + bindDims(b.getContext(), d0, d1); // Multiply two integers. auto mul = [&](OpFoldResult v1, OpFoldResult v2) { auto mulMap = AffineMap::get(2, 0, {d0 * d1}); diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp index 753cb95..d35f458 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp @@ -155,13 +155,15 @@ struct ExtractSliceOpInterface RankedTensorType sourceType = extractSliceOp.getSource().getType(); // For each dimension, assert that: - // 0 <= offset < dim_size - // 0 <= offset + (size - 1) * stride < dim_size + // For empty slices (size == 0) : 0 <= offset <= dim_size + // For non-empty slices (size > 0): 0 <= offset < dim_size + // 0 <= offset + (size - 1) * stride < + // dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) { - // Reset insertion point to before the operation for each dimension + builder.setInsertionPoint(extractSliceOp); Value offset = getValueOrCreateConstantIndexOp( @@ -170,46 +172,63 @@ struct ExtractSliceOpInterface builder, loc, extractSliceOp.getMixedSizes()[i]); Value stride = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedStrides()[i]); - - // Verify that offset is in-bounds. Value dimSize = builder.createOrFold<tensor::DimOp>( loc, extractSliceOp.getSource(), i); - Value offsetInBounds = - generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create(builder, loc, offsetInBounds, + + // Verify that offset is in-bounds (conditional on slice size). + Value sizeIsZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, size, zero); + auto offsetCheckIf = scf::IfOp::create( + builder, loc, sizeIsZero, + [&](OpBuilder &b, Location loc) { + // For empty slices, offset can be at the boundary: 0 <= offset <= + // dimSize. + Value offsetGEZero = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sge, offset, zero); + Value offsetLEDimSize = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sle, offset, dimSize); + Value emptyOffsetValid = + arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize); + scf::YieldOp::create(b, loc, emptyOffsetValid); + }, + [&](OpBuilder &b, Location loc) { + // For non-empty slices, offset must be a valid index: 0 <= offset < + // dimSize. + Value offsetInBounds = + generateInBoundsCheck(b, loc, offset, zero, dimSize); + scf::YieldOp::create(b, loc, offsetInBounds); + }); + + Value offsetCondition = offsetCheckIf.getResult(0); + cf::AssertOp::create(builder, loc, offsetCondition, generateErrorMessage(op, "offset " + std::to_string(i) + " is out-of-bounds")); - // Only verify if size > 0 + // Verify that the slice endpoint is in-bounds (only for non-empty + // slices). Value sizeIsNonZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::sgt, size, zero); + auto ifOp = scf::IfOp::create( + builder, loc, sizeIsNonZero, + [&](OpBuilder &b, Location loc) { + // Verify that slice does not run out-of-bounds. + Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one); + Value sizeMinusOneTimesStride = + arith::MulIOp::create(b, loc, sizeMinusOne, stride); + Value lastPos = + arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride); + Value lastPosInBounds = + generateInBoundsCheck(b, loc, lastPos, zero, dimSize); + scf::YieldOp::create(b, loc, lastPosInBounds); + }, + [&](OpBuilder &b, Location loc) { + Value trueVal = + arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); + scf::YieldOp::create(b, loc, trueVal); + }); - auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), - sizeIsNonZero, /*withElseRegion=*/true); - - // Populate the "then" region (for size > 0). - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - - // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); - Value sizeMinusOneTimesStride = - arith::MulIOp::create(builder, loc, sizeMinusOne, stride); - Value lastPos = - arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); - Value lastPosInBounds = - generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - scf::YieldOp::create(builder, loc, lastPosInBounds); - - // Populate the "else" region (for size == 0). - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - Value trueVal = - arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); - scf::YieldOp::create(builder, loc, trueVal); - - builder.setInsertionPointAfter(ifOp); Value finalCondition = ifOp.getResult(0); - cf::AssertOp::create( builder, loc, finalCondition, generateErrorMessage( diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 293c6af..c420a4c 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> { auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) { - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); } Attribute newMinValAttr, newMaxValAttr; @@ -1485,7 +1486,24 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { return {}; } +static bool +mayRequireBroadcast(ValueTypeRange<mlir::OperandRange> operandTypes) { + const auto isDynamic = [](Type ty) { + const auto shapedTy = llvm::dyn_cast<ShapedType>(ty); + return !shapedTy || !shapedTy.hasStaticShape(); + }; + + return llvm::any_of(operandTypes, isDynamic) || + failed(verifyCompatibleShapes(operandTypes)); +} + OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { + // Select allows operand shapes to be broadcast to the output shape. For + // now, don't support folding when we cannot prove no broadcasting is + // involved. + if (mayRequireBroadcast(getOperandTypes())) + return {}; + if (getOnTrue() == getOnFalse()) return getOnTrue(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 65e0a59..1c175f9ab 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -563,7 +563,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) { static Type getStorageElementTypeOrSelf(Type type) { auto srcType = getElementTypeOrSelf(type); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType)) - srcType = quantType.getStorageType(); + srcType = getStorageElementTypeFromQuantized(quantType); return srcType; } @@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) { bool resultIsFloat = llvm::isa<FloatType>(resultEType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType)) - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType)) - weightEType = quantType.getStorageType(); + weightEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType)) - biasEType = quantType.getStorageType(); + biasEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType)) - resultEType = quantType.getStorageType(); + resultEType = getStorageElementTypeFromQuantized(quantType); if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) { // for now, only enforce bias element type == result element type for @@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() { if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>( outputType.getElementType())) { - if (result.getStorageType() == attrType.getElementType()) + if (getStorageElementTypeFromQuantized(result) == attrType.getElementType()) return success(); } @@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) { llvm::cast<ShapedType>(op.getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType)) - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); auto accType = op.getAccType(); if (inputEType.isInteger(8) && !accType.isInteger(32)) @@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) { llvm::cast<ShapedType>(op.getResult().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType)) - resultEType = quantType.getStorageType(); + resultEType = getStorageElementTypeFromQuantized(quantType); return success(); } @@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() { llvm::cast<ShapedType>(getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) { - inputETy = quantType.getStorageType(); + inputETy = getStorageElementTypeFromQuantized(quantType); } mlir::Type outputETy = llvm::cast<ShapedType>(getOutput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) { - outputETy = quantType.getStorageType(); + outputETy = getStorageElementTypeFromQuantized(quantType); } if (inputETy != outputETy) return emitOpError("input/output element types are incompatible."); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 41b338d..091b481 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaAttachTarget.cpp + TosaArithConstantToConst.cpp TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp @@ -12,6 +13,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaTypeConverters.cpp TosaProfileCompliance.cpp TosaValidation.cpp + TosaNarrowI64ToI32.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms @@ -21,7 +23,9 @@ add_mlir_dialect_library(MLIRTosaTransforms LINK_LIBS PUBLIC MLIRFuncDialect + MLIRFuncTransformOps MLIRPass MLIRTosaDialect MLIRTransformUtils + MLIRFuncTransforms ) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp new file mode 100644 index 0000000..73e1e2b --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp @@ -0,0 +1,111 @@ +//===- TosaArithConstantToConst.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that converts tensor-valued arith.constant ops +// into tosa.const so that TOSA pipelines operate on a uniform constant form. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +// NOTE: TOSA pipelines already lower their constants through shared Arith +// folding passes, so tensor literals often come back as `arith.constant` even +// after the IR is otherwise TOSA-only. Keep this normalization with the rest of +// the TOSA transforms so any client can re-establish a canonical `tosa.const` +// representation without needing a full Arith->TOSA conversion library. + +/// Returns true when `elementType` is natively representable by tosa.const. +static bool isSupportedElementType(Type elementType) { + if (isa<FloatType>(elementType)) + return true; + + if (auto intType = dyn_cast<IntegerType>(elementType)) + return intType.isSignless() || intType.isUnsigned(); + + if (isa<quant::QuantizedType>(elementType)) + return true; + + if (isa<tosa::mxint8Type>(elementType)) + return true; + + return false; +} + +class ArithConstantToTosaConst : public OpRewritePattern<arith::ConstantOp> { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::ConstantOp constOp, + PatternRewriter &rewriter) const override { + // TOSA constant verification requires a ranked, statically shaped tensor. + auto resultType = dyn_cast<RankedTensorType>(constOp.getResult().getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + + if (!isSupportedElementType(resultType.getElementType())) + return failure(); + + Attribute attr = constOp.getValueAttr(); + auto elementsAttr = dyn_cast<ElementsAttr>(attr); + if (!elementsAttr) + return failure(); + + auto attrType = dyn_cast<RankedTensorType>(elementsAttr.getType()); + if (!attrType || !attrType.hasStaticShape()) + return failure(); + if (attrType != resultType) + return failure(); + + auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(), + resultType, elementsAttr); + rewriter.replaceOp(constOp, newConst.getResult()); + return success(); + } +}; + +struct TosaArithConstantToTosaConstPass + : public tosa::impl::TosaArithConstantToTosaConstPassBase< + TosaArithConstantToTosaConstPass> { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<arith::ArithDialect, tosa::TosaDialect>(); + } + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add<ArithConstantToTosaConst>(ctx); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 0bec0da..022476a2 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -33,8 +33,13 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { ShapedType weightType = cast<ShapedType>(weight.getType()); ShapedType resultType = cast<ShapedType>(op.getOutput().getType()); - if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && - resultType.hasStaticShape())) { + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i)) + return failure(); + } + + if (!weightType.hasStaticShape()) { return failure(); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index dc5c51b..8b23fd1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -49,8 +49,13 @@ public: if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) return failure(); - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i)) + return failure(); + } + + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return failure(); int64_t kernelHeight = weightTy.getDimSize(1); @@ -113,8 +118,13 @@ public: if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) return rewriter.notifyMatchFailure(op, "non-one stride found."); - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i)) + return failure(); + } + + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return failure(); int64_t batch = inputTy.getDimSize(0); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp new file mode 100644 index 0000000..ddaf7d8a --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp @@ -0,0 +1,310 @@ +//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass narrows TOSA operations with 64-bit integer tensor types to +// 32-bit integer tensor types. This can be useful for backends that do not +// support the EXT-INT64 extension of TOSA. The pass has two options: +// +// - aggressive-rewrite - If enabled, all TOSA operations are rewritten, +// regardless or whether the narrowing is safe. This option may lead to +// data loss if not used carefully. +// - convert-function-boundaries - If enabled, the pass will convert function +// I/O types as well. Otherwise casts will be inserted at the I/O +// boundaries. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +LogicalResult convertGenericOp(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter, + const TypeConverter *typeConverter) { + // Convert types of results + SmallVector<Type, 4> newResults; + if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults))) + return failure(); + + // Create a new operation state + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + newResults, {}, op->getSuccessors()); + + for (const NamedAttribute &namedAttribute : op->getAttrs()) { + const Attribute attribute = namedAttribute.getValue(); + + // Convert integer attribute type + if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) { + const std::optional<Attribute> convertedAttribute = + typeConverter->convertTypeAttribute(intAttr.getType(), attribute); + state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); + continue; + } + + if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) { + Type type = typeAttr.getValue(); + const std::optional<Attribute> convertedAttribute = + typeConverter->convertTypeAttribute(type, attribute); + if (!convertedAttribute) + return rewriter.notifyMatchFailure(op, + "Failed to convert type attribute."); + state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); + continue; + } + + if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) { + const Type type = denseElementsAttr.getType(); + const std::optional<Attribute> convertedAttribute = + typeConverter->convertTypeAttribute(type, denseElementsAttr); + if (!convertedAttribute) + return rewriter.notifyMatchFailure( + op, "Failed to convert dense elements attribute."); + state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); + continue; + } + + state.addAttribute(namedAttribute.getName(), attribute); + } + + for (Region ®ion : op->getRegions()) { + Region *newRegion = state.addRegion(); + rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); + if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter))) + return failure(); + } + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); +} + +// =========================== +// Aggressive rewrite patterns +// =========================== + +class ConvertGenericOp : public ConversionPattern { +public: + ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + if (!isa<tosa::TosaOp>(op)) + return rewriter.notifyMatchFailure( + op, + "Support for operations other than TOSA has not been implemented."); + + return convertGenericOp(op, operands, rewriter, typeConverter); + } +}; + +// =============================== +// Bounds checked rewrite patterns +// =============================== + +class ConvertArgMaxOpWithBoundsChecking + : public OpConversionPattern<tosa::ArgMaxOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // Output type can be narrowed based on the size of the axis dimension + const int32_t axis = op.getAxis(); + const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType()); + if (!inputType || !inputType.isStaticDim(axis)) + return rewriter.notifyMatchFailure( + op, "Requires a static axis dimension for bounds checking."); + const int64_t axisDim = inputType.getDimSize(axis); + if (axisDim >= std::numeric_limits<int32_t>::max()) + return rewriter.notifyMatchFailure( + op, "Axis dimension is too large to narrow safely."); + + const Type resultType = op.getOutput().getType(); + const Type newResultType = typeConverter->convertType(resultType); + rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType, + adaptor.getInput(), axis); + return success(); + } +}; + +class ConvertCastOpWithBoundsChecking + : public OpConversionPattern<tosa::CastOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType()); + const auto resultType = dyn_cast<ShapedType>(op.getResult().getType()); + if (!inputType || !resultType) + return failure(); + + const auto elementInputIntType = + dyn_cast<IntegerType>(inputType.getElementType()); + const auto elementResultIntType = + dyn_cast<IntegerType>(resultType.getElementType()); + if (elementInputIntType && elementResultIntType && + elementInputIntType.getWidth() > elementResultIntType.getWidth()) + return rewriter.notifyMatchFailure( + op, "Narrowing cast may lead to data loss."); + + rewriter.replaceOpWithNewOp<tosa::CastOp>( + op, typeConverter->convertType(resultType), adaptor.getInput()); + return success(); + } +}; + +template <typename OpTy> +class ConvertTypedOp : public OpConversionPattern<OpTy> { + using OpConversionPattern<OpTy>::OpConversionPattern; + + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + return convertGenericOp(op, adaptor.getOperands(), rewriter, + this->getTypeConverter()); + } +}; + +struct TosaNarrowI64ToI32 + : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> { +public: + explicit TosaNarrowI64ToI32() = default; + explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options) + : TosaNarrowI64ToI32() { + this->aggressiveRewrite = options.aggressiveRewrite; + this->convertFunctionBoundaries = options.convertFunctionBoundaries; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) -> Type { return type; }); + typeConverter.addConversion([](IntegerType type) -> Type { + if (!type.isInteger(64)) + return type; + return IntegerType::get(type.getContext(), 32); + }); + typeConverter.addConversion( + [&typeConverter](RankedTensorType type) -> Type { + const Type elementType = type.getElementType(); + if (!elementType.isInteger(64)) + return type; + return RankedTensorType::get(type.getShape(), + typeConverter.convertType(elementType)); + }); + + const auto materializeCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return Value(); + return tosa::CastOp::create(builder, loc, resultType, inputs.front()); + }; + typeConverter.addSourceMaterialization(materializeCast); + typeConverter.addTargetMaterialization(materializeCast); + + typeConverter.addTypeAttributeConversion( + [](IntegerType type, IntegerAttr attribute) -> Attribute { + const APInt value = attribute.getValue().truncSSat(32); + return IntegerAttr::get(IntegerType::get(type.getContext(), 32), + value); + }); + typeConverter.addTypeAttributeConversion( + [&typeConverter](ShapedType type, + DenseIntElementsAttr attr) -> Attribute { + const ShapedType newType = + cast<ShapedType>(typeConverter.convertType(type)); + const auto oldElementType = cast<IntegerType>(type.getElementType()); + const auto newElementType = + cast<IntegerType>(newType.getElementType()); + if (oldElementType.getWidth() == newElementType.getWidth()) + return attr; + + DenseElementsAttr mapped = + attr.mapValues(newElementType, [&](const APInt &v) { + return v.truncSSat(newElementType.getWidth()); + }); + return mapped; + }); + + ConversionTarget target(*context); + target.addDynamicallyLegalDialect<tosa::TosaDialect>( + [&typeConverter](Operation *op) { + return typeConverter.isLegal(op->getResultTypes()) && + typeConverter.isLegal(op->getOperandTypes()); + }); + if (convertFunctionBoundaries) { + target.addDynamicallyLegalOp<func::FuncOp>( + [&typeConverter](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) { + const FunctionType funcType = + op->getParentOfType<func::FuncOp>().getFunctionType(); + return llvm::equal(op.getOperandTypes(), funcType.getResults()); + }); + } else { + target.addDynamicallyLegalOp<func::FuncOp>( + [](func::FuncOp op) { return true; }); + target.addDynamicallyLegalOp<func::ReturnOp>( + [](func::ReturnOp op) { return true; }); + } + + RewritePatternSet patterns(context); + if (convertFunctionBoundaries) { + populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( + patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + } + if (aggressiveRewrite) { + patterns.add<ConvertGenericOp>(typeConverter, context); + } else { + // Tensor + patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context); + // Data layout + patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context); + // Type conversion + patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context); + // Controlflow + patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context); + } + + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index ac5d620..36e8940 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -70,6 +70,8 @@ namespace { // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. +// If lower=[c], higher=[?, ?, c], [c] reshaped into [1, 1, c]. +// If lower=[?], higher=[?, ?, ?], [?] reshaped into [1, 1, ?]. LogicalResult computeReshapeOutput(ArrayRef<int64_t> higherRankShape, ArrayRef<int64_t> lowerRankShape, @@ -87,7 +89,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape, higherRankDim = higherRankShape[i + rankDiff]; lowerRankDim = lowerRankShape[i]; - if (lowerRankDim != 1 && higherRankDim != 1 && + auto isStaticDimAndNotEqualToOne = [](int64_t dim) { + return dim != 1 && dim != ShapedType::kDynamic; + }; + + if (isStaticDimAndNotEqualToOne(lowerRankDim) && + isStaticDimAndNotEqualToOne(higherRankDim) && lowerRankDim != higherRankDim) return failure(); @@ -216,22 +223,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) { bool mlir::tosa::hasUniqueConstantScatterIndices( ShapedType indicesType, DenseIntElementsAttr indicesAttr) { - llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape(); + const llvm::ArrayRef<int64_t> indicesShape = indicesType.getShape(); const unsigned int indicesRank = indicesShape.size(); const unsigned int lastDimSize = indicesShape[indicesRank - 1]; // check each batch of indices from the flat indicesAttr values // for duplicates - auto const indicesValues = indicesAttr.getValues<int32_t>(); + auto const indicesValues = indicesAttr.getValues<APInt>(); assert( (indicesValues.size() % lastDimSize == 0) && "Constant indices data length should be a multiple of indicesShape[-1]"); - std::vector<uint64_t> indices(lastDimSize); + std::vector<APInt> indices(lastDimSize); for (auto beg = indicesValues.begin(); beg < indicesValues.end(); beg += lastDimSize) { std::copy(beg, beg + lastDimSize, indices.begin()); - std::sort(indices.begin(), indices.end()); + std::sort(indices.begin(), indices.end(), + [](const APInt &a, const APInt &b) { return a.slt(b); }); if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) { // found duplicate values in indices in batch return false; diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp index 02c86a0..c55b13d 100644 --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, maxAttr, quantBits, filterQuantDim, isSigned, narrowRange)); } + +Type mlir::tosa::getStorageElementTypeFromQuantized( + quant::QuantizedType quantType) { + auto quantEty = quantType.getStorageType(); + // StorageType doesn't capture the sign information + // Explicitly create unsigned type if needed + if (!quantType.isSigned()) { + quantEty = IntegerType::get(quantEty.getContext(), + quantEty.getIntOrFloatBitWidth(), + IntegerType::Unsigned); + } + return quantEty; +} diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 062606e..86233b0 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -2062,6 +2062,10 @@ transform::IncludeOp::apply(transform::TransformRewriter &rewriter, DiagnosedSilenceableFailure result = applySequenceBlock( callee.getBody().front(), getFailurePropagationMode(), state, results); + + if (!result.succeeded()) + return result; + mappings.clear(); detail::prepareValueMappings( mappings, callee.getBody().front().getTerminator()->getOperands(), state); diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index 8859541..24b0487 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -1495,8 +1495,7 @@ transform::detail::checkApplyToOne(Operation *transformOp, template <typename T> static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) { - return llvm::to_vector(llvm::map_range( - range, [](transform::MappedValue value) { return cast<T>(value); })); + return llvm::map_to_vector(range, llvm::CastTo<T>); } void transform::detail::setApplyToOneResults( diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index f727118..2bd6205 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -156,7 +156,7 @@ DiagnosedSilenceableFailure transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - std::optional<size_t> selectedRegionIdx; + std::optional<int64_t> selectedRegionIdx; if (auto selectedRegionAttr = getSelectedRegionAttr()) selectedRegionIdx = selectedRegionAttr->getSExtValue(); @@ -232,7 +232,7 @@ LogicalResult transform::tune::AlternativesOp::verify() { } if (auto selectedRegionAttr = getSelectedRegionAttr()) { - size_t regionIdx = selectedRegionAttr->getSExtValue(); + int64_t regionIdx = selectedRegionAttr->getSExtValue(); if (regionIdx < 0 || regionIdx >= getNumRegions()) return emitOpError() << "'selected_region' attribute specifies region at index " diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp index a26edac..2986f4c 100644 --- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp +++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp @@ -106,14 +106,12 @@ ScalableValueBoundsConstraintSet::computeScalableBound( AffineMap bound = [&] { if (boundType == BoundType::EQ && !invalidBound(lowerBound) && - lowerBound[0] == upperBound[0]) { + lowerBound[0] == upperBound[0]) return lowerBound[0]; - } - if (boundType == BoundType::LB && !invalidBound(lowerBound)) { + if (boundType == BoundType::LB && !invalidBound(lowerBound)) return lowerBound[0]; - } else if (boundType == BoundType::UB && !invalidBound(upperBound)) { + if (boundType == BoundType::UB && !invalidBound(upperBound)) return upperBound[0]; - } return AffineMap{}; }(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index daef0ba..2789f63 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6066,19 +6066,21 @@ LogicalResult ScatterOp::verify() { VectorType indVType = getIndexVectorType(); VectorType maskVType = getMaskVectorType(); VectorType valueVType = getVectorType(); - MemRefType memType = getMemRefType(); + ShapedType baseType = getBaseType(); - if (valueVType.getElementType() != memType.getElementType()) + if (!llvm::isa<MemRefType, RankedTensorType>(baseType)) + return emitOpError("requires base to be a memref or ranked tensor type"); + + if (valueVType.getElementType() != baseType.getElementType()) return emitOpError("base and valueToStore element type should match"); - if (llvm::size(getOffsets()) != memType.getRank()) - return emitOpError("requires ") << memType.getRank() << " indices"; + if (llvm::size(getOffsets()) != baseType.getRank()) + return emitOpError("requires ") << baseType.getRank() << " indices"; if (valueVType.getShape() != indVType.getShape()) return emitOpError("expected valueToStore dim to match indices dim"); if (valueVType.getShape() != maskVType.getShape()) return emitOpError("expected valueToStore dim to match mask dim"); return success(); } - namespace { class ScatterFolder final : public OpRewritePattern<ScatterOp> { public: @@ -6241,6 +6243,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, setResultRanges(getResult(), argRanges.front()); } +std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() { + return llvm::to_vector<4>(getResultVectorType().getShape()); +} + LogicalResult ShapeCastOp::verify() { VectorType sourceType = getSourceVectorType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 546099c..352f477 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" using namespace mlir; using namespace mlir::bufferization; @@ -126,6 +127,54 @@ struct TransferWriteOpInterface } }; +/// Bufferization of vector.scatter. Replaced with a new vector.scatter that +/// operates on a memref. +struct ScatterOpInterface + : public BufferizableOpInterface::ExternalModel<ScatterOpInterface, + vector::ScatterOp> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa<RankedTensorType>(opOperand.get().getType()) && + "only tensor types expected"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa<RankedTensorType>(opOperand.get().getType()) && + "only tensor types expected"); + return true; + } + + AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa<RankedTensorType>(opOperand.get().getType()) && + "only tensor types expected"); + auto scatterOp = cast<vector::ScatterOp>(op); + if (&opOperand != &scatterOp.getBaseMutable()) + return {}; + return {{scatterOp.getResult(), BufferRelation::Equivalent}}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options, + BufferizationState &state) const { + auto scatterOp = cast<vector::ScatterOp>(op); + assert(isa<TensorType>(scatterOp.getBaseType()) && + "only tensor types expected"); + FailureOr<Value> buffer = + getBuffer(rewriter, scatterOp.getBase(), options, state); + if (failed(buffer)) + return failure(); + vector::ScatterOp::create(rewriter, scatterOp.getLoc(), + /*resultType=*/nullptr, *buffer, + scatterOp.getOffsets(), scatterOp.getIndices(), + scatterOp.getMask(), scatterOp.getValueToStore()); + replaceOpWithBufferizedValues(rewriter, op, *buffer); + return success(); + } +}; + /// Bufferization of vector.gather. Replaced with a new vector.gather that /// operates on a memref. struct GatherOpInterface @@ -335,5 +384,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels( GatherOp::attachInterface<GatherOpInterface>(*ctx); MaskOp::attachInterface<MaskOpInterface>(*ctx); YieldOp::attachInterface<YieldOpInterface>(*ctx); + ScatterOp::attachInterface<ScatterOpInterface>(*ctx); }); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index 258f2cb..1af5523 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -111,7 +111,7 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { if (!isValidKind(isInt, scanOp.getKind())) return failure(); - VectorType resType = VectorType::get(destShape, elType); + VectorType resType = destType; Value result = arith::ConstantOp::create(rewriter, loc, resType, rewriter.getZeroAttr(resType)); int64_t reductionDim = scanOp.getReductionDim(); @@ -121,8 +121,18 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { int64_t initialValueRank = initialValueType.getRank(); SmallVector<int64_t> reductionShape(destShape); + SmallVector<bool> reductionScalableDims(destType.getScalableDims()); + + if (reductionScalableDims[reductionDim]) + return rewriter.notifyMatchFailure( + scanOp, "Trying to reduce scalable dimension - not yet supported!"); + + // The reduction dimension, after reducing, becomes 1. It's a fixed-width + // dimension - no need to touch the scalability flag. reductionShape[reductionDim] = 1; - VectorType reductionType = VectorType::get(reductionShape, elType); + VectorType reductionType = + VectorType::get(reductionShape, elType, reductionScalableDims); + SmallVector<int64_t> offsets(destRank, 0); SmallVector<int64_t> strides(destRank, 1); SmallVector<int64_t> sizes(destShape); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 726da1e..ad16b80 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -453,6 +453,8 @@ struct ReorderCastOpsOnBroadcast PatternRewriter &rewriter) const override { if (op->getNumOperands() != 1) return failure(); + if (!isa<VectorType>(op->getResult(0).getType())) + return failure(); auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); if (!bcastOp) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fbae098..462bd8c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1003,6 +1003,286 @@ private: vector::UnrollVectorOptions options; }; +/// This pattern unrolls `vector.create_mask` operations into smaller mask +/// operations based on the target unroll shape. Each unrolled slice computes +/// its local mask size in each dimension (d) as: +/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]). +/// Example: +/// Given a create_mask operation: +/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10 +/// elements +/// +/// and a target unroll shape of <4x8>, the pattern produces: +/// +/// %false = arith.constant dense<false> : vector<8x16xi1> +/// +/// Slice [0,0]: +/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8 +/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1> +/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [0,8]: +/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2 +/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1> +/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [4,0]: +/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8 +/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1> +/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [4,8]: +/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2 +/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1> +/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> { + UnrollCreateMaskPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern<vector::CreateMaskOp>(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp, + PatternRewriter &rewriter) const override { + auto targetShape = getTargetShape(options, createMaskOp); + if (!targetShape) + return failure(); + + VectorType resultType = createMaskOp.getVectorType(); + SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll(); + Location loc = createMaskOp.getLoc(); + + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + VectorType targetVectorType = + VectorType::get(*targetShape, rewriter.getI1Type()); + SmallVector<int64_t> strides(targetShape->size(), 1); + + // In each dimension (d), each unrolled vector computes its mask size as: + // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]). + for (SmallVector<int64_t> offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { + SmallVector<Value> unrolledOperands; + + for (auto [i, originalMaskOperand] : + llvm::enumerate(createMaskOp.getOperands())) { + Value offsetVal = + arith::ConstantIndexOp::create(rewriter, loc, offsets[i]); + Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>( + loc, originalMaskOperand, offsetVal); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value unrolledDimSize = + arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]); + Value nonNegative = + rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero); + Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>( + loc, nonNegative, unrolledDimSize); + unrolledOperands.push_back(unrolledOperand); + } + + auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>( + loc, targetVectorType, unrolledOperands); + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( + loc, unrolledMask, result, offsets, strides); + } + rewriter.replaceOp(createMaskOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + +/// Checks whether extractShape is a contiguous slice of shape. +/// For extractShape to be contiguous in shape: +/// 1) All but the leading dimension of extractShape and shape must match +/// exactly. 2) The total number of elements in shape must be evenly divisible +/// by +/// the total number of elements in extractShape. +/// Examples: +/// isContiguous([4, 4], [8, 4]) == true +/// isContiguous([2, 4], [8, 4]) == true +/// isContiguous([2, 2], [8, 4]) == false +/// Removes leading unit dimensions to handle cases like: +/// isContiguous([1, 16], [1, 32]) == true +static bool isContiguous(ArrayRef<int64_t> extractShape, + ArrayRef<int64_t> shape) { + + if (extractShape.size() > shape.size()) + return false; + + while (!extractShape.empty() && extractShape.front() == 1) { + extractShape = extractShape.drop_front(); + } + + while (!shape.empty() && shape.front() == 1) { + shape = shape.drop_front(); + } + + size_t rankDiff = shape.size() - extractShape.size(); + if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1))) + return false; + + int64_t extractElements = ShapedType::getNumElements(extractShape); + int64_t shapeElements = ShapedType::getNumElements(shape); + return shapeElements % extractElements == 0; +} + +/// Determines what shape to use with `vector.extract_strided_slice` to extract +/// a contiguous memory region from a source vector. The extraction must be +/// contiguous and contain exactly the specified number of elements. If such an +/// extraction shape cannot be determined, returns std::nullopt. +/// EXAMPLE 1: +/// sourceShape = [16], targetElements = 8 +/// Working right-to-left: +/// - Take min(8, 16) = 8 from only dim → extractShape = [8], +/// remaining = 8/8 = 1 +/// Result: [8] +/// +/// EXAMPLE 2: +/// sourceShape = [4, 4], targetElements = 8 +/// Working right-to-left: +/// - Take min(8, 4) = 4 from last dim → extractShape = [4], +/// remaining = 8/4 = 2 +/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4], +/// remaining = 2/2 = 1 +/// Result: [2, 4] +static std::optional<SmallVector<int64_t>> +calculateSourceExtractShape(ArrayRef<int64_t> sourceShape, + int64_t targetElements) { + SmallVector<int64_t> extractShape; + int64_t remainingElements = targetElements; + + // Build extract shape from innermost dimension outward to ensure contiguity. + for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) { + int64_t takeFromDim = std::min(remainingElements, sourceShape[i]); + extractShape.insert(extractShape.begin(), takeFromDim); + + if (remainingElements % takeFromDim != 0) + return std::nullopt; // Not evenly divisible. + remainingElements /= takeFromDim; + } + + // Fill remaining dimensions with 1. + while (extractShape.size() < sourceShape.size()) + extractShape.insert(extractShape.begin(), 1); + + if (ShapedType::getNumElements(extractShape) != targetElements) + return std::nullopt; + + return extractShape; +} + +// Convert result offsets to source offsets via linear position. +static SmallVector<int64_t> +calculateSourceOffsets(ArrayRef<int64_t> resultOffsets, + ArrayRef<int64_t> sourceShape, + ArrayRef<int64_t> resultShape) { + // Convert result offsets to linear position. + int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape)); + // Convert linear position to source offsets. + return delinearize(linearIndex, computeStrides(sourceShape)); +} + +/// This pattern unrolls `vector.shape_cast` operations according to the +/// provided target unroll shape. It unrolls a large shape cast into smaller +/// shape casts by extracting contiguous slices from the source vector, casting +/// each slice to the target shape, and assembling the result by inserting each +/// computed segment into the appropriate offset of the result vector. +/// +/// This pattern only applies when contiguous slices can be extracted from the +/// source vector and inserted into the result vector such that each slice +/// remains a valid vector (and not decompose to scalars). In these cases, the +/// unrolling proceeds as: +/// vector.extract_strided_slice -> vector.shape_cast (on the slice) -> +/// vector.insert_strided_slice. +/// +/// Example: +/// Given a shape cast operation: +/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32> +/// +/// and a target unroll shape of <2x4>, the pattern produces: +/// +/// %zero = arith.constant dense<0.0> : vector<4x4xf32> +/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1] +/// : vector<8x2xf32> to vector<4x2xf32> +/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32> +/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1] +/// : vector<2x4xf32> into vector<4x4xf32> +/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1] +/// : vector<8x2xf32> to vector<4x2xf32> +/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32> +/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1] +/// : vector<2x4xf32> into vector<4x4xf32> +/// +struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> { + UnrollShapeCastPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern<vector::ShapeCastOp>(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + std::optional<SmallVector<int64_t>> targetShape = + getTargetShape(options, shapeCastOp); + if (!targetShape) + return failure(); + + VectorType sourceType = shapeCastOp.getSourceVectorType(); + VectorType resultType = shapeCastOp.getResultVectorType(); + ArrayRef<int64_t> sourceShape = sourceType.getShape(); + ArrayRef<int64_t> resultShape = resultType.getShape(); + + if (!isContiguous(*targetShape, resultShape)) + return rewriter.notifyMatchFailure( + shapeCastOp, "Only supports cases where target shape is " + "contiguous in result vector shape"); + + int64_t targetElements = ShapedType::getNumElements(*targetShape); + + // Calculate the shape to extract from source. + std::optional<SmallVector<int64_t>> extractShape = + calculateSourceExtractShape(sourceShape, targetElements); + if (!extractShape) + return rewriter.notifyMatchFailure( + shapeCastOp, + "cannot extract target number of elements contiguously from source"); + + Location loc = shapeCastOp.getLoc(); + + // Create result vector initialized to zero. + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + + VectorType targetType = + VectorType::get(*targetShape, sourceType.getElementType()); + + SmallVector<int64_t> extractStrides(extractShape->size(), 1); + SmallVector<int64_t> insertStrides(targetShape->size(), 1); + + for (SmallVector<int64_t> resultOffsets : + StaticTileOffsetRange(resultShape, *targetShape)) { + SmallVector<int64_t> sourceOffsets = + calculateSourceOffsets(resultOffsets, sourceShape, resultShape); + Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>( + loc, shapeCastOp.getSource(), sourceOffsets, *extractShape, + extractStrides); + Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>( + loc, targetType, sourceChunk); + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( + loc, targetChunk, result, resultOffsets, insertStrides); + } + + rewriter.replaceOp(shapeCastOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -1013,8 +1293,9 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, - UnrollToElements, UnrollStepPattern>(patterns.getContext(), - options, benefit); + UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern, + UnrollCreateMaskPattern>(patterns.getContext(), options, + benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index c809c502..c307fb4 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -322,46 +322,61 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, std::optional<Value> padValue, bool useInBoundsInsteadOfMasking, ArrayRef<bool> inputScalableVecDims) { - assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) && + VectorType vecToReadTy = VectorType::get( + inputVectorSizes, cast<ShapedType>(source.getType()).getElementType(), + inputScalableVecDims); + + return createReadOrMaskedRead(builder, loc, source, vecToReadTy, padValue, + useInBoundsInsteadOfMasking); +} + +Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, + Value source, + const VectorType &vecToReadTy, + std::optional<Value> padValue, + bool useInBoundsInsteadOfMasking) { + assert(!llvm::is_contained(vecToReadTy.getScalableDims(), + ShapedType::kDynamic) && "invalid input vector sizes"); auto sourceShapedType = cast<ShapedType>(source.getType()); auto sourceShape = sourceShapedType.getShape(); - assert(sourceShape.size() == inputVectorSizes.size() && + + int64_t vecToReadRank = vecToReadTy.getRank(); + auto vecToReadShape = vecToReadTy.getShape(); + + assert(sourceShape.size() == static_cast<size_t>(vecToReadRank) && "expected same ranks."); - auto vectorType = - VectorType::get(inputVectorSizes, sourceShapedType.getElementType(), - inputScalableVecDims); assert((!padValue.has_value() || padValue.value().getType() == sourceShapedType.getElementType()) && "expected same pad element type to match source element type"); - int64_t readRank = inputVectorSizes.size(); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); - SmallVector<bool> inBoundsVal(readRank, true); + SmallVector<bool> inBoundsVal(vecToReadRank, true); if (useInBoundsInsteadOfMasking) { // Update the inBounds attribute. // FIXME: This computation is too weak - it ignores the read indices. - for (unsigned i = 0; i < readRank; i++) - inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) && + for (unsigned i = 0; i < vecToReadRank; i++) + inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) && ShapedType::isStatic(sourceShape[i]); } auto transferReadOp = vector::TransferReadOp::create( builder, loc, - /*vectorType=*/vectorType, + /*vectorType=*/vecToReadTy, /*source=*/source, - /*indices=*/SmallVector<Value>(readRank, zero), + /*indices=*/SmallVector<Value>(vecToReadRank, zero), /*padding=*/padValue, /*inBounds=*/inBoundsVal); - if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking) + if (llvm::equal(vecToReadTy.getShape(), sourceShape) || + useInBoundsInsteadOfMasking) return transferReadOp; SmallVector<OpFoldResult> mixedSourceDims = isa<MemRefType>(source.getType()) ? memref::getMixedSizes(builder, loc, source) : tensor::getMixedSizes(builder, loc, source); - auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(), - inputScalableVecDims); + auto maskType = vecToReadTy.cloneWith(/*shape=*/{}, builder.getI1Type()); Value mask = vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims); return mlir::vector::maskOperation(builder, transferReadOp, mask) diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt index 9f57627..cb1e9d0 100644 --- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..f4c9f8a --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRX86VectorTransformOps + X86VectorTransformOps.cpp + + DEPENDS + MLIRX86VectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRSideEffectInterfaces + MLIRTransformDialect + MLIRTransformDialectUtils + MLIRX86VectorDialect + MLIRX86VectorTransforms + ) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp new file mode 100644 index 0000000..95db208 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp @@ -0,0 +1,64 @@ +//===- X86VectorTransformOps.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" + +using namespace mlir; +using namespace mlir::x86vector; +using namespace mlir::transform; + +void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + x86vector::populateVectorContractToFMAPatterns(patterns); +} + +void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class X86VectorTransformDialectExtension + : public transform::TransformDialectExtension< + X86VectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + X86VectorTransformDialectExtension) + + X86VectorTransformDialectExtension() { + declareGeneratedDialect<x86vector::X86VectorDialect>(); + declareGeneratedDialect<LLVM::LLVMDialect>(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + +void mlir::x86vector::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions<X86VectorTransformDialectExtension>(); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt index c51266a..2cab50f 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -1,11 +1,14 @@ add_mlir_dialect_library(MLIRX86VectorTransforms AVXTranspose.cpp LegalizeForLLVMExport.cpp + VectorContractToFMA.cpp + VectorContractToPackedTypeDotProduct.cpp LINK_LIBS PUBLIC MLIRArithDialect MLIRX86VectorDialect MLIRIR + MLIRLinalgDialect MLIRLLVMCommonConversion MLIRLLVMDialect MLIRVectorDialect diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp new file mode 100644 index 0000000..f3af5ca --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp @@ -0,0 +1,143 @@ +//===- VectorContractToFMA.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +namespace { + +// Implements outer product contraction as a sequence of broadcast and +// FMA operations. +// +// For example - for F32 type: +// ``` +// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32> +// ``` +// to +// ``` +// vector.broadcast %lhs to <16xf32> +// vector.fma vector<16xf32> +// ``` +struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> { + using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + if (contractOp.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind."); + + VectorType lhsTy = contractOp.getLhsType(); + if (!lhsTy.getElementType().isF32()) + return rewriter.notifyMatchFailure(contractOp, + "Only F32 lowering is supported."); + + ArrayRef<int64_t> lhsShape = lhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimLhs; + llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs), + [](int64_t dim) { return dim != 1; }); + + VectorType rhsTy = contractOp.getRhsType(); + ArrayRef<int64_t> rhsShape = rhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimRhs; + llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs), + [](int64_t dim) { return dim != 1; }); + + if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0) + return rewriter.notifyMatchFailure( + contractOp, "Excepts unit dimensions for either LHS or RHS shape."); + + if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, + "Excepts a one non-unit A/B dimension for either LHS or RHS shape."); + + VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType()); + if (!accTy) + return rewriter.notifyMatchFailure(contractOp, + "Accmulator is not a vector type"); + + if (!accTy.getElementType().isF32()) + return rewriter.notifyMatchFailure(contractOp, + "Accmulator should be F32 type."); + + ArrayRef<int64_t> accShape = accTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimAcc; + llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc), + [](int64_t dim) { return dim != 1; }); + if (nonUnitDimAcc.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, "A or B dimension should be non-unit."); + + // Lowers vector.contract into a broadcast+FMA sequence. + auto loc = contractOp.getLoc(); + auto castAcc = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()), + contractOp.getAcc()); + + vector::FMAOp fma; + + // Broadcast the unit-dimension LHS or RHS to match the vector length of the + // corresponding non-unit dimension on the other operand. For example, + // if LHS has type vector<1x1xf32> and RHS has type vector<1x16xf32>, we + // broadcast the LHS to vector<1x16xf32>. In the opposite case (non-unit + // dimension on the LHS), we broadcast the RHS instead. + if (nonUnitDimRhs.size() > 0) { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(1, lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()), + contractOp.getRhs()); + auto broadcastLhs = vector::BroadcastOp::create( + rewriter, loc, castRhs.getResult().getType(), castLhs); + fma = + vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc); + } else { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(1, rhsTy.getElementType()), + contractOp.getRhs()); + auto broadcastRhs = vector::BroadcastOp::create( + rewriter, loc, castLhs.getResult().getType(), castRhs); + fma = + vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc); + } + + auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma); + rewriter.replaceOp(contractOp, castFma); + + return success(); + } +}; + +} // namespace + +void x86vector::populateVectorContractToFMAPatterns( + RewritePatternSet &patterns) { + patterns.add<VectorContractToFMA>(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp new file mode 100644 index 0000000..1e64811 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp @@ -0,0 +1,301 @@ +//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +namespace { + +static FailureOr<SmallVector<mlir::utils::IteratorType>> +inferIteratorsFromOutMap(AffineMap map) { + if (!map.isProjectedPermutation()) + return failure(); + SmallVector<mlir::utils::IteratorType> iterators( + map.getNumDims(), mlir::utils::IteratorType::reduction); + for (auto expr : map.getResults()) + if (auto dim = dyn_cast<AffineDimExpr>(expr)) + iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel; + return iterators; +} + +// Returns true if the operation is in VNNI layout. +// Optionally, the check can be constrained to a specific VNNI blocking factor. +static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps, + std::optional<unsigned> blockingFactor) { + // Narrow down type operations - VNNI only applies to contractions. + FailureOr<linalg::ContractionDimensions> dims = + linalg::inferContractionDims(indexingMaps); + if (failed(dims)) + return false; + + auto matA = op->getOperand(0); + auto matB = op->getOperand(1); + auto typeA = dyn_cast<ShapedType>(matA.getType()); + auto typeB = dyn_cast<ShapedType>(matB.getType()); + unsigned rankA = typeA.getRank(); + unsigned rankB = typeB.getRank(); + // VNNI format requires at least 1 parallel and 2 reduction dimensions. + if (rankA < 3 || rankB < 3) + return false; + + // At least two reduction dimensions are expected: + // one for the VNNI factor and one for the K dimension + if (dims->k.size() < 2) + return false; + + // Validate affine maps - VNNI computation should be defined by the two + // innermost reduction iterators. + // The input matrix dimensions layout must match the following: + // - matrix A - [...][K/vnniFactor][vnniFactor] + // - matrix B - [...][K/vnniFactor][N][vnniFactor] + auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]); + if (failed(maybeIters)) + return false; + SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters; + AffineMap mapA = indexingMaps[0]; + AffineMap mapB = indexingMaps[1]; + + auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1)); + auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1)); + if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB || + iteratorTypes[vnniDimA.getPosition()] != + mlir::utils::IteratorType::reduction) + return false; + auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2)); + auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3)); + if (!redDimA || !redDimB || redDimA != redDimB || + iteratorTypes[redDimA.getPosition()] != + mlir::utils::IteratorType::reduction) + return false; + auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2)); + if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] != + mlir::utils::IteratorType::parallel) + return false; + + // VNNI factor must be: + // - the innermost inputs' dimension + // - statically known + // - multiple of 2 or equal to the specified factor + auto vnniDimSize = typeB.getShape().back(); + if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 || + vnniDimSize % 2 != 0) + return false; + if (typeA.getShape().back() != vnniDimSize) + return false; + if (blockingFactor && vnniDimSize != *blockingFactor) + return false; + + // The split reduction dimension size should also match. + if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3]) + return false; + + return true; +} + +// Implements packed type outer product contraction as a sequence +// of broadcast and packed dot-product operations. +// +// For example - for F32 type: +// ``` +// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32> +// ``` +// to +// ``` +// vector.broadcast %lhs to <32xbf16> +// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32> +// ``` +struct VectorContractToPackedTypeDotProduct + : public OpRewritePattern<vector::ContractionOp> { + using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + if (contractOp.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind."); + + VectorType lhsTy = contractOp.getLhsType(); + if (!lhsTy.getElementType().isBF16() && + !lhsTy.getElementType().isSignlessInteger(8)) + return rewriter.notifyMatchFailure( + contractOp, "Only BF16/Int8 lowering is supported."); + + unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4; + if (!isInVnniLayout(contractOp.getOperation(), + contractOp.getIndexingMapsArray(), blockingFactor)) + return rewriter.notifyMatchFailure(contractOp, + "Input matrices not in VNNI format."); + + ArrayRef<int64_t> lhsShape = lhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimLhs; + llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs), + [](int64_t dim) { return dim != 1; }); + + VectorType rhsTy = contractOp.getRhsType(); + ArrayRef<int64_t> rhsShape = rhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimRhs; + llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs), + [](int64_t dim) { return dim != 1; }); + + if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0) + return rewriter.notifyMatchFailure(contractOp, + "Excepts unit dimensions for either " + "LHS or RHS shape other than VNNI."); + + if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1) + return rewriter.notifyMatchFailure( + contractOp, + "Excepts a one non-unit A/B dimension for either LHS or RHS shape."); + + VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType()); + if (!accTy) + return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type."); + + if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) || + (lhsTy.getElementType().isSignlessInteger(8) && + !accTy.getElementType().isSignlessInteger(32))) + return rewriter.notifyMatchFailure(contractOp, + "Only F32 for BF16 or Int32 for Int8 " + "accumulation type is supported."); + + ArrayRef<int64_t> accShape = accTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimAcc; + llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc), + [](int64_t dim) { return dim != 1; }); + if (nonUnitDimAcc.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, "A or B should be a non-unit dim in acc."); + + // Non-unit dimensions should match the vector length of BF16 or Int8 + // dot-product. + unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front() + : nonUnitDimRhs.front(); + if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 && + nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim) + return rewriter.notifyMatchFailure( + contractOp, "BF16 dot-product operation expects non-unit (LHR or " + "RHS) dim and acc dim of size 4/8/16."); + + if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 && + nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim) + return rewriter.notifyMatchFailure( + contractOp, "Int8 dot-product operation expects non-unit (LHR or " + "RHS) dim and acc dim of size 4/8."); + + auto loc = contractOp.getLoc(); + auto castAcc = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()), + contractOp.getAcc()); + + Value dp; + + // Broadcast the unit-dimension LHS or RHS to match the vector length of the + // corresponding non-unit dimension on the other operand. For example, + // if LHS has type vector<1x1x2xbf16> and RHS has type vector<1x16x2xbf16>, + // we broadcast the LHS to vector<16x2xbf16>. In the opposite case (non-unit + // dimension on the LHS), we broadcast the RHS instead. + if ((nonUnitDimRhs.size() - 1) > 0) { + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(), + rhsTy.getElementType()), + contractOp.getRhs()); + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()), + contractOp.getLhs()); + auto bitcastLhs = vector::BitCastOp::create( + rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)), + castLhs); + auto broadcastLhs = vector::BroadcastOp::create( + rewriter, loc, + VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)), + bitcastLhs); + auto bitcastLhsPkType = vector::BitCastOp::create( + rewriter, loc, castRhs.getResult().getType(), broadcastLhs); + + if (lhsTy.getElementType().isBF16()) { + dp = x86vector::DotBF16Op::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()), + castAcc, bitcastLhsPkType, castRhs); + } + + if (lhsTy.getElementType().isSignlessInteger(8)) { + dp = x86vector::DotInt8Op::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)), + castAcc, bitcastLhsPkType, castRhs); + } + } else { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(), + lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()), + contractOp.getRhs()); + auto bitcastRhs = vector::BitCastOp::create( + rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)), + castRhs); + auto broadcastRhs = vector::BroadcastOp::create( + rewriter, loc, + VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)), + bitcastRhs); + auto bitcastRhsPkType = vector::BitCastOp::create( + rewriter, loc, castLhs.getResult().getType(), broadcastRhs); + + if (lhsTy.getElementType().isBF16()) { + dp = x86vector::DotBF16Op::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()), + castAcc, castLhs, bitcastRhsPkType); + } + + if (lhsTy.getElementType().isSignlessInteger(8)) { + dp = x86vector::DotInt8Op::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)), + castAcc, castLhs, bitcastRhsPkType); + } + } + + if (!dp) + return failure(); + + auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp); + rewriter.replaceOp(contractOp, castDp); + return success(); + } +}; + +} // namespace + +void x86vector::populateVectorContractToPackedTypeDotProductPatterns( + RewritePatternSet &patterns) { + patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index fb5d1e7..1a19ab5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -8,7 +8,6 @@ #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" @@ -61,7 +60,7 @@ genCoordinates(OpBuilder &builder, Location loc, // Get the offset of `subShape` within a distribution unit. SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector( llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value { - return builder.createOrFold<index::MulOp>( + return builder.createOrFold<arith::MulIOp>( loc, std::get<0>(t), builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t))); }); @@ -84,7 +83,7 @@ genCoordinates(OpBuilder &builder, Location loc, // Do not go beyond `srcShape` bounds. SmallVector<Value> mods = llvm::map_to_vector( llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value { - return builder.createOrFold<index::RemUOp>( + return builder.createOrFold<arith::RemUIOp>( loc, std::get<0>(t), arith::ConstantIndexOp::create(builder, loc, std::get<1>(t))); }); @@ -343,7 +342,7 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within /// this dimension) result[dimIdx] = - builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal); + builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal); /// Update remaining for the next dimension by removing what we've already /// processed. Division tells us "how many complete groups of this dimension @@ -352,7 +351,7 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { /// no next dimension to process if (i < order.size() - 1) { remaining = - builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal); + builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal); } } return result; @@ -391,6 +390,86 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc, return genCoordinates(builder, loc, ids, layout, subShape, shape); } +bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { + if (dyn_cast<xegpu::SliceAttr>(other)) + return false; + + return *this == dyn_cast<xegpu::LayoutAttr>(other); +} + +// set the layout for unit dims: sg_data, inst_data and lane_data to 1 +DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) { + auto sgDataOpt = getSgData(); + auto instDataOpt = getInstData(); + auto laneDataOpt = getLaneData(); + + SmallVector<int32_t> sgData; + SmallVector<int32_t> instData; + SmallVector<int32_t> laneData; + + if (sgDataOpt) { + sgData = llvm::to_vector(sgDataOpt.asArrayRef()); + } + if (instDataOpt) { + instData = llvm::to_vector(instDataOpt.asArrayRef()); + } + if (laneDataOpt) { + laneData = llvm::to_vector(laneDataOpt.asArrayRef()); + } + + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(sgData.size())) + sgData[dim] = 1; + if (dim < static_cast<int64_t>(instData.size())) + instData[dim] = 1; + if (dim < static_cast<int64_t>(laneData.size())) + laneData[dim] = 1; + } + + return LayoutAttr::get( + getContext(), getSgLayout(), + sgData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), sgData), + instData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), instData), + getLaneLayout(), + laneData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), laneData), + getOrder()); +} + +// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 +DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) { + auto sgLayoutOpt = getSgLayout(); + auto laneLayoutOpt = getLaneLayout(); + + SmallVector<int32_t> sgLayout; + SmallVector<int32_t> laneLayout; + + if (sgLayoutOpt) { + sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef()); + } + if (laneLayoutOpt) { + laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef()); + } + + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(sgLayout.size())) + sgLayout[dim] = 1; + if (dim < static_cast<int64_t>(laneLayout.size())) + laneLayout[dim] = 1; + } + + return LayoutAttr::get( + getContext(), + sgLayout.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), sgLayout), + getSgData(), getInstData(), + laneLayout.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), laneLayout), + getLaneData(), getOrder()); +} + //===----------------------------------------------------------------------===// // XeGPU_SliceAttr //===----------------------------------------------------------------------===// @@ -511,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { [&](int64_t dim) { return thisDims.contains(dim); }); } +bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { + if (dyn_cast<xegpu::LayoutAttr>(other)) + return false; + + auto flattenedThis = flatten(); + auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten(); + + return ((flattenedThis.getParent() == flattenedOther.getParent()) && + (flattenedThis.getDims() == flattenedOther.getDims())); +} + +// Helper function to adjust unit dimensions from sliced space to parent space +static SetVector<int64_t> +adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims, + ArrayRef<int64_t> sliceDims) { + // Reconstruct parent's non-sliced dimensions + + int64_t parentRank = sliceDims.size() + unitDims.size(); + llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(), + sliceDims.end()); + SmallVector<int64_t> nonSlicedDims; + for (int64_t i = 0; i < parentRank; ++i) { + if (!slicedDimsSet.contains(i)) + nonSlicedDims.push_back(i); + } + + // Map unit dims from sliced space to parent space + SetVector<int64_t> adjustUnitDims; + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(nonSlicedDims.size())) { + adjustUnitDims.insert(nonSlicedDims[dim]); + } + } + + return adjustUnitDims; +} + +// set the layout for unit dims: sg_data, inst_data and lane_data to 1 +DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) { + SliceAttr attr = flatten(); + ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + + SetVector<int64_t> adjustUnitDims = + adjustUnitDimsWithSliceDims(unitDims, sliceDims); + + return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims), + attr.getDims()); +} + +// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 +DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) { + SliceAttr attr = flatten(); + ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + + SetVector<int64_t> adjustUnitDims = + adjustUnitDimsWithSliceDims(unitDims, sliceDims); + + return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims), + attr.getDims()); +} + //===----------------------------------------------------------------------===// // XeGPU_RangeAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 4dd10be..91ba07a 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -465,14 +465,15 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, ArrayRef<OpFoldResult> offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -480,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*anchor_layout=*/layout); } LogicalResult PrefetchNdOp::verify() { @@ -519,7 +520,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, - l3_hint); + l3_hint, /*anchor_layout=*/nullptr); } void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, @@ -527,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -535,7 +537,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr, - packed, transpose, l1_hint, l2_hint, l3_hint); + packed, transpose, l1_hint, l2_hint, l3_hint, + /*anchor_layout=*/layout); } LogicalResult LoadNdOp::verify() { @@ -638,14 +641,16 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, value, tensorDesc, ValueRange(), - DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); + DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint, + /*anchor_layout=*/nullptr); } void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, ArrayRef<OpFoldResult> offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -653,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr, - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout); } LogicalResult StoreNdOp::verify() { @@ -826,7 +831,7 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint, - IntegerAttr{}); + IntegerAttr{}, /*anchor_layout=*/nullptr); } //===----------------------------------------------------------------------===// @@ -876,7 +881,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, valueType, source, Value(), mask, IntegerAttr(), - l1_hint, l2_hint, l3_hint, /*layout=*/nullptr); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, @@ -892,7 +897,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, auto offset = vector::FromElementsOp::create(builder, loc, type, values); build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, - l2_hint, l3_hint, /*layout=*/nullptr); + l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, @@ -901,7 +906,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint, - xegpu::LayoutAttr layout) { + DistributeLayoutAttr layout) { auto loc = source.getLoc(); int64_t size = static_cast<int64_t>(offsets.size()); auto type = VectorType::get(size, builder.getIndexType()); @@ -960,7 +965,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, - l2_hint, l3_hint, /*layout=*/nullptr); + l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void StoreScatterOp::build(OpBuilder &builder, OperationState &state, @@ -978,14 +983,14 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, // Call the correct builder overload that does not expect result types. build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, - l3_hint, /*layout=*/nullptr); + l3_hint, /*anchor_layout=*/nullptr); } void StoreScatterOp::build( OpBuilder &builder, OperationState &state, Value value, Value dest, ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) { + xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) { auto loc = dest.getLoc(); int64_t size = static_cast<int64_t>(offsets.size()); auto type = VectorType::get(size, builder.getIndexType()); diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index 8943ba0..e6009d5 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -7,12 +7,17 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include <optional> +#include "llvm/Support/DebugLog.h" +#define DEBUG_TYPE "xegpu-transforms" + using namespace mlir; using namespace mlir::transform; @@ -76,6 +81,45 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt( return DiagnosedSilenceableFailure::success(); } +/// Find producer operation of type T for the given value. +/// It's assumed that producer ops are chained through their first operand. +/// Producer chain is traced trough loop block arguments (init values). +template <typename T> +static std::optional<T> findProducerOfType(Value val) { + Value currentValue = val; + if (!currentValue.getDefiningOp()) { + // Value may be a block argument initialized outside a loop. + if (val.getNumUses() == 0) { + LDBG() << "Failed to find producer op, value has no uses."; + return std::nullopt; + } + auto userOp = val.getUsers().begin(); + auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>(); + if (!parentLoop) { + LDBG() << "Failed to find producer op, not in a loop."; + return std::nullopt; + } + int64_t iterArgIdx; + if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) { + auto numInductionVars = parentLoop.getLoopInductionVars()->size(); + iterArgIdx = iterArg.getArgNumber() - numInductionVars; + currentValue = parentLoop.getInits()[iterArgIdx]; + } else { + LDBG() << "Failed to find producer op, value not in init values."; + return std::nullopt; + } + } + Operation *producerOp = currentValue.getDefiningOp(); + + if (auto matchingOp = dyn_cast<T>(producerOp)) + return matchingOp; + + if (producerOp->getNumOperands() == 0) + return std::nullopt; + + return findProducerOfType<T>(producerOp->getOperand(0)); +} + /// Create a layout attribute from the given parameters. static xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout, @@ -90,10 +134,41 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout, /*order=*/nullptr); } +/// Generate `xegpu::LayoutAttr` from op mixed layout values. +DiagnosedSilenceableFailure +getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, + TransformOpInterface transformOp, + ArrayRef<::mlir::OpFoldResult> mixedSgLayout, + ArrayRef<::mlir::OpFoldResult> mixedSgData, + ArrayRef<::mlir::OpFoldResult> mixedInstData, + xegpu::LayoutAttr &layoutAttr) { + SmallVector<int32_t> sgLayout, sgData, instData; + auto status = + convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout); + if (!status.succeeded()) + return status; + + status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData); + if (!status.succeeded()) + return status; + + status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData); + if (!status.succeeded()) + return status; + auto maybeInstData = instData.empty() + ? std::nullopt + : std::optional<ArrayRef<int32_t>>(instData); + + layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData); + + return DiagnosedSilenceableFailure::success(); +} + /// Replace xegpu.create_nd_desc op with a new one with the given layout. static xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter, - xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) { + xegpu::CreateNdDescOp descOp, + xegpu::DistributeLayoutAttr layout) { assert(descOp.getMixedOffsets().size() == 0 && "create desc op with offsets is not supported"); auto oldTensorDesc = descOp.getType(); @@ -111,11 +186,35 @@ setDescLayout(transform::TransformRewriter &rewriter, return newDescOp; } +DiagnosedSilenceableFailure +transform::GetDescOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) { + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + } + + auto maybeDescOp = + findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin()); + if (!maybeDescOp) { + return emitSilenceableFailure(getLoc()) + << "Could not find a matching descriptor op when walking the " + "producer chain of the first operand."; + } + + results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp}); + return DiagnosedSilenceableFailure::success(); +} + void transform::SetDescLayoutOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData, - ArrayRef<OpFoldResult> mixedInstData) { + ArrayRef<OpFoldResult> mixedInstData, + ArrayRef<int64_t> sliceDims) { SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); @@ -128,7 +227,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder, /*inst_data=*/dynamicInstData, /*static_sg_layout=*/staticSgLayout, /*static_sg_data=*/staticSgData, - /*static_inst_data=*/staticInstData); + /*static_inst_data=*/staticInstData, + /*slice_dims=*/sliceDims); } DiagnosedSilenceableFailure @@ -142,25 +242,20 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, } Operation *target = *targetOps.begin(); - SmallVector<int32_t> sgLayout; - DiagnosedSilenceableFailure status = - convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout()); - if (!status.succeeded()) - return status; - - SmallVector<int32_t> sgData; - status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData()); + xegpu::LayoutAttr layoutAttr = nullptr; + auto status = getLayoutAttrFromOperands(getContext(), state, (*this), + getMixedSgLayout(), getMixedSgData(), + getMixedInstData(), layoutAttr); if (!status.succeeded()) return status; - SmallVector<int32_t> instData; - status = - convertMixedValuesToInt(state, (*this), instData, getMixedInstData()); - if (!status.succeeded()) - return status; - auto maybeInstData = instData.empty() - ? std::nullopt - : std::optional<ArrayRef<int32_t>>(instData); + xegpu::DistributeLayoutAttr layout = layoutAttr; + auto sliceDims = getSliceDims(); + if (sliceDims.size() > 0) { + // Wrap layoutAttr in a slice attribute. + layout = xegpu::SliceAttr::get( + getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); + } // For now only create_nd_desc op is supported. auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target); @@ -173,9 +268,7 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, } // Set layout attr in desc op's return type. Replaces old desc op. - auto layoutAttr = - createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData); - auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr); + auto newdescOp = setDescLayout(rewriter, descOp, layout); // Map result handles. results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()}); @@ -193,6 +286,383 @@ void transform::SetDescLayoutOp::getEffects( modifiesPayload(effects); } +void transform::SetOpLayoutAttrOp::build( + OpBuilder &builder, OperationState &ostate, Value target, int64_t index, + ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData, + ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims, + bool result) { + SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; + SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; + dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); + dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); + dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); + build(builder, ostate, target.getType(), + /*target=*/target, + /*index=*/index, + /*sg_layout=*/dynamicSgLayout, + /*sg_data=*/dynamicSgData, + /*inst_data=*/dynamicInstData, + /*static_sg_layout=*/staticSgLayout, + /*static_sg_data=*/staticSgData, + /*static_inst_data=*/staticInstData, + /*slice_dims=*/sliceDims, + /*result=*/result); +} + +DiagnosedSilenceableFailure +transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + bool resultTarget = getResult(); + + int64_t index = getIndex(); + if (resultTarget && index >= target->getNumResults()) { + return emitSilenceableFailure(getLoc()) + << "Index exceeds the number of op results"; + } + if (!resultTarget && index >= target->getNumOperands()) { + return emitSilenceableFailure(getLoc()) + << "Index exceeds the number of op operands"; + } + + xegpu::LayoutAttr layoutAttr = nullptr; + auto status = getLayoutAttrFromOperands(getContext(), state, (*this), + getMixedSgLayout(), getMixedSgData(), + getMixedInstData(), layoutAttr); + if (!status.succeeded()) + return status; + + xegpu::DistributeLayoutAttr layout = layoutAttr; + auto sliceDims = getSliceDims(); + if (sliceDims.size() > 0) { + // Wrap layoutAttr in a slice attribute. + layout = xegpu::SliceAttr::get( + getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); + } + + // Set layout attribute for the op result or operand + if (resultTarget) + xegpu::setDistributeLayoutAttr(target->getResult(index), layout); + else + xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout); + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetOpLayoutAttrOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getSgLayoutMutable(), effects); + onlyReadsHandle(getSgDataMutable(), effects); + onlyReadsHandle(getInstDataMutable(), effects); + modifiesPayload(effects); +} + +void transform::SetGPULaunchThreadsOp::build( + OpBuilder &builder, OperationState &ostate, Value target, + ArrayRef<OpFoldResult> mixedThreads) { + SmallVector<int64_t> staticThreads; + SmallVector<Value> dynamicThreads; + dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads); + build(builder, ostate, target.getType(), + /*target=*/target, + /*threads=*/dynamicThreads, + /*static_threads=*/staticThreads); +} + +DiagnosedSilenceableFailure +transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + auto launchOp = dyn_cast<gpu::LaunchOp>(target); + if (!launchOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a gpu.launch op, but got: " << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + SmallVector<int32_t> threads; + DiagnosedSilenceableFailure status = + convertMixedValuesToInt(state, (*this), threads, getMixedThreads()); + if (!status.succeeded()) + return status; + + if (threads.size() != 3) { + return emitSilenceableFailure(getLoc()) + << "Expected threads argument to consist of three values (got " + << threads.size() << ")"; + } + + rewriter.setInsertionPoint(launchOp); + auto createConstValue = [&](int value) { + return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value); + }; + + // Replace threads in-place. + launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0])); + launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1])); + launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2])); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetGPULaunchThreadsOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getThreadsMutable(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + auto value = *targetValues.begin(); + + int64_t nbPrefetch = getStaticNbPrefetch(); + if (getDynamicNbPrefetch()) { + // Get dynamic prefetch count from transform param or handle. + SmallVector<int32_t> dynamicNbPrefetch; + auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch, + {getDynamicNbPrefetch()}); + if (!status.succeeded()) + return status; + if (dynamicNbPrefetch.size() != 1) + return emitDefiniteFailure() + << "requires exactly one value for dynamic_nb_prefetch"; + nbPrefetch = dynamicNbPrefetch[0]; + } + if (nbPrefetch <= 0) + return emitSilenceableFailure(getLoc()) + << "nb_prefetch must be a positive integer."; + + // Find load operation of the operand. + auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value); + if (!maybeLoadOp) + return emitSilenceableFailure(getLoc()) << "Could not find load op."; + auto loadOp = *maybeLoadOp; + if (loadOp.getMixedOffsets().size() == 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op must have offsets."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find the parent scf.for loop. + auto forOp = loadOp->getParentOfType<scf::ForOp>(); + if (!forOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op is not contained in a scf.for loop."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find descriptor op. + auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value); + if (!maybeDescOp) + return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; + auto descOp = *maybeDescOp; + if (descOp.getMixedOffsets().size() > 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "desc op with offsets is not supported."; + diag.attachNote(descOp.getLoc()) << "desc op"; + } + + // Clone desc op outside the loop. + rewriter.setInsertionPoint(forOp); + auto newDescOp = + cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation())); + + // Clone reduction loop to emit initial prefetches. + // Compute upper bound of the init loop: start + nbPrefetch * step. + auto nbPrefetchCst = + arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch); + auto nbStep = rewriter.createOrFold<arith::MulIOp>( + forOp.getLoc(), nbPrefetchCst, forOp.getStep()); + auto initUpBound = rewriter.createOrFold<arith::AddIOp>( + forOp.getLoc(), forOp.getLowerBound(), nbStep); + auto initForOp = + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), + initUpBound, forOp.getStep()); + + auto ctx = rewriter.getContext(); + auto readCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED); + + // Modify loadOp mixedOffsets by replacing the for loop induction variable + // with the given value. + auto getPrefetchOffsets = + [&](Value replacementVal) -> SmallVector<OpFoldResult> { + IRMapping mapping; + mapping.map(forOp.getInductionVar(), replacementVal); + SmallVector<Value> dynamicOffsets = + llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) { + return mapping.lookupOrDefault(v); + })); + auto constOffsets = loadOp.getConstOffsets().value(); + return getMixedValues(constOffsets, dynamicOffsets, ctx); + }; + + // Insert prefetch op in init loop. + // Replace induction var with the init loop induction var. + rewriter.setInsertionPointToStart(initForOp.getBody()); + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(initForOp.getInductionVar()), + readCacheHint, readCacheHint, readCacheHint, + /*layout=*/nullptr); + + // Insert prefetch op in main loop. + // Calculate prefetch offset after the init prefetches have been issued. + rewriter.setInsertionPointToStart(forOp.getBody()); + auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(), + forOp.getInductionVar(), nbStep); + // Replace induction var with correct offset. + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(prefetchOffset), readCacheHint, + readCacheHint, readCacheHint, /*layout=*/nullptr); + + // Unroll the init loop. + if (failed(loopUnrollFull(initForOp))) + return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop"; + + results.set(llvm::cast<OpResult>(getResult()), {newDescOp}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::InsertPrefetchOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getDynamicNbPrefetchMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +void transform::ConvertLayoutOp::build( + OpBuilder &builder, OperationState &ostate, Value target, + ArrayRef<OpFoldResult> mixedInputSgLayout, + ArrayRef<OpFoldResult> mixedInputSgData, + ArrayRef<OpFoldResult> mixedInputInstData, + ArrayRef<OpFoldResult> mixedTargetSgLayout, + ArrayRef<OpFoldResult> mixedTargetSgData, + ArrayRef<OpFoldResult> mixedTargetInstData) { + SmallVector<int64_t> staticInputSgLayout, staticInputSgData, + staticInputInstData; + SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData, + dynamicInputInstData; + dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout, + staticInputSgLayout); + dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData, + staticInputSgData); + dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData, + staticInputInstData); + SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData, + staticTargetInstData; + SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData, + dynamicTargetInstData; + dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout, + staticTargetSgLayout); + dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData, + staticTargetSgData); + dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData, + staticTargetInstData); + build(builder, ostate, target.getType(), + /*target=*/target, + /*input_sg_layout=*/dynamicInputSgLayout, + /*input_sg_data=*/dynamicInputSgData, + /*input_inst_data=*/dynamicInputInstData, + /*target_sg_layout=*/dynamicTargetSgLayout, + /*target_sg_data=*/dynamicTargetSgData, + /*target_inst_data=*/dynamicTargetInstData, + /*static_input_sg_layout=*/staticInputSgLayout, + /*static_input_sg_data=*/staticInputSgData, + /*static_input_inst_data=*/staticInputInstData, + /*static_target_sg_layout=*/staticTargetSgLayout, + /*static_target_sg_data=*/staticTargetSgData, + /*static_target_inst_data=*/staticTargetInstData); +} + +DiagnosedSilenceableFailure +transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + auto value = *targetValues.begin(); + + // Construct layout attributes. + xegpu::LayoutAttr inputLayoutAttr = nullptr; + auto status = getLayoutAttrFromOperands( + getContext(), state, (*this), getMixedInputSgLayout(), + getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr); + if (!status.succeeded()) + return status; + + xegpu::LayoutAttr targetLayoutAttr = nullptr; + status = getLayoutAttrFromOperands( + getContext(), state, (*this), getMixedTargetSgLayout(), + getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr); + if (!status.succeeded()) + return status; + + // Find first user op to define insertion point for layout conversion. + if (value.use_empty()) + return emitSilenceableFailure(getLoc()) + << "Value has no users to insert layout conversion."; + Operation *userOp = *value.getUsers().begin(); + + // Emit convert_layout op. + rewriter.setInsertionPoint(userOp); + auto convLayoutOp = + xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(), + value, inputLayoutAttr, targetLayoutAttr); + // Replace load op result with the converted layout. + rewriter.replaceUsesWithIf( + value, convLayoutOp.getResult(), [&](OpOperand &use) { + return use.getOwner() != convLayoutOp.getOperation(); + }); + + results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp}); + return DiagnosedSilenceableFailure::success(); +} + +void transform::ConvertLayoutOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getInputSgLayoutMutable(), effects); + onlyReadsHandle(getInputSgDataMutable(), effects); + onlyReadsHandle(getInputInstDataMutable(), effects); + onlyReadsHandle(getTargetSgLayoutMutable(), effects); + onlyReadsHandle(getTargetSgDataMutable(), effects); + onlyReadsHandle(getTargetInstDataMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + namespace { class XeGPUTransformDialectExtension : public transform::TransformDialectExtension< diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp index 4dc5ea4..ab41fe4 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp @@ -214,7 +214,7 @@ static Value generateLoads(ConversionPatternRewriter &rewriter, newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY}, origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(), origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(), - origLoadOp.getL3HintAttr()); + origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr()); // Set the layout for the loadOp. auto layoutAttr = newTensorDesc.getType().getLayoutAttr(); xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 4e1a539..dc9eb96 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -53,6 +53,8 @@ using namespace mlir::dataflow; namespace { +enum class LayoutKind { Lane, InstData }; + //===----------------------------------------------------------------------===// // LayoutInfo //===----------------------------------------------------------------------===// @@ -166,7 +168,8 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) { llvm_unreachable("Join should not be triggered by layout propagation."); } -/// Construct a new layout with the transposed lane layout and lane data. +/// Construct a new layout with the transposed inst_data or lane_layout, +/// lane_data. LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const { if (!isAssigned()) return {}; @@ -186,12 +189,20 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const { SmallVector<int32_t> laneData; SmallVector<int32_t> instData; for (int64_t idx : permutation) { - laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx])); - laneData.push_back(static_cast<int32_t>(getLaneData()[idx])); - instData.push_back(static_cast<int32_t>(getInstData()[idx])); + if (getLaneLayout().size()) { + laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx])); + laneData.push_back(static_cast<int32_t>(getLaneData()[idx])); + } + if (getInstData().size()) + instData.push_back(static_cast<int32_t>(getInstData()[idx])); } - return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData, - laneLayout, laneData)); + xegpu::LayoutAttr layoutAttr; + if (getLaneLayout().size()) + layoutAttr = + xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData); + if (getInstData().size()) + layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData); + return LayoutInfo(layoutAttr); } //===----------------------------------------------------------------------===// @@ -213,15 +224,14 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> { /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1]. static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, unsigned rank, - const xegpu::uArch::uArch *uArch, - ArrayRef<int> instData) { + const xegpu::uArch::uArch *uArch) { assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector."); if (rank == 1) { return LayoutInfo( - xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1})); + xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1})); } - return LayoutInfo(xegpu::LayoutAttr::get( - ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1})); + return LayoutInfo( + xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1})); } static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, @@ -236,7 +246,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, const xegpu::uArch::uArch *uArch, - ArrayRef<int> instData, unsigned packingSize, bool isScattered = false) { // Expecting a 1D or 2D vector. @@ -247,16 +256,16 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, "Expected int or float element type."); // If the rank is 1, then return default layout for 1D vector. if (vectorTy.getRank() == 1) - return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData); + return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch); // Packing factor is determined by the element type bitwidth. unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1; if (isScattered) { - return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData, + return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), {uArch->getSubgroupSize(), 1}, {1, packingFactor})); } - return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData, + return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor})); } @@ -264,7 +273,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, const xegpu::uArch::uArch *uArch, - ArrayRef<int> instData, unsigned packingSize, bool isScattered = false) { // Expecting a 1D or 2D vector. @@ -275,18 +283,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, "Expected int or float element type."); // If the rank is 1, then return default layout for 1D vector. if (tdescTy.getRank() == 1) - return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData); + return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch); // Packing factor is determined by the element type bitwidth. unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth(); int subgroupSize = uArch->getSubgroupSize(); int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1; if (isScattered) { return LayoutInfo(xegpu::LayoutAttr::get( - tdescTy.getContext(), instData, {subgroupSize, 1}, {1, packingFactor})); + tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor})); } return LayoutInfo(xegpu::LayoutAttr::get( - tdescTy.getContext(), instData, {1, subgroupSize}, {1, packingFactor})); + tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor})); } /// Helper Function to get the expected layouts for DPAS operands. `lane_data` @@ -298,7 +306,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, const xegpu::uArch::uArch *uArch, - ArrayRef<int> instData, unsigned packingSize) { + unsigned packingSize) { Type elementTy = vectorTy.getElementType(); assert(elementTy.isIntOrFloat() && "Expected int or float type in DPAS operands"); @@ -310,10 +318,10 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()), 1}); return LayoutInfo( - xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data)); + xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data)); } // Otherwise, return the default layout for the vector type. - return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize); + return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize); } //===----------------------------------------------------------------------===// @@ -328,6 +336,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, class LayoutInfoPropagation : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> { private: + LayoutKind layoutKind; void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results); @@ -378,10 +387,14 @@ private: ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results); + bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout); + public: LayoutInfoPropagation(DataFlowSolver &solver, - SymbolTableCollection &symbolTable) - : SparseBackwardDataFlowAnalysis(solver, symbolTable) {} + SymbolTableCollection &symbolTable, + LayoutKind layoutKind) + : SparseBackwardDataFlowAnalysis(solver, symbolTable), + layoutKind(layoutKind) {} using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; LogicalResult @@ -464,43 +477,71 @@ LogicalResult LayoutInfoPropagation::visitOperation( return success(); } +bool LayoutInfoPropagation::hasParamsOfLayoutKind( + xegpu::DistributeLayoutAttr anchorLayout) { + if (anchorLayout == nullptr) { + return false; + } + if (layoutKind == LayoutKind::InstData) { + return !(anchorLayout.getEffectiveInstDataAsInt().empty()); + } else if (layoutKind == LayoutKind::Lane) { + return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() || + anchorLayout.getEffectiveLaneDataAsInt().empty()); + } + return false; +} + void LayoutInfoPropagation::visitPrefetchNdOp( xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // Here we assign the default layout to the tensor descriptor operand of - // prefetch. - auto tdescTy = prefetch.getTensorDescType(); - - auto uArch = getUArch(getChipStr(prefetch).value_or("")); - const auto *uArchInstruction = - dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>( - uArch->getInstruction( - xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch)); - - auto blockWHC = - uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType()); - if (!blockWHC) - prefetch.emitWarning("No known block params found for the element type."); - auto [bWidth, bHeight, bCount] = blockWHC.value(); - SmallVector<int> instData; - int instWidth = xegpu::getLargestDivisor( - static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth, - bCount); - if (instWidth == -1) - prefetch.emitWarning( - "No suitable instruction multiple found for the given shape."); - if (tdescTy.getRank() == 1) - instData = {instWidth}; - else { - int instHeight = xegpu::getLargestDivisor( - static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight); - if (instHeight == -1) + + LayoutInfo prefetchLayout; + xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + prefetchLayout = LayoutInfo(anchorLayout); + } else { + // Here we assign the default layout to the tensor descriptor operand of + // prefetch. + auto tdescTy = prefetch.getTensorDescType(); + + auto uArch = getUArch(getChipStr(prefetch).value_or("")); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>( + uArch->getInstruction( + xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch)); + + auto blockWHC = + uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType()); + if (!blockWHC) + prefetch.emitWarning("No known block params found for the element type."); + auto [bWidth, bHeight, bCount] = blockWHC.value(); + SmallVector<int> instData; + int instWidth = xegpu::getLargestDivisor( + static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth); + if (instWidth == -1) prefetch.emitWarning( "No suitable instruction multiple found for the given shape."); - instData = {instHeight, instWidth}; + if (tdescTy.getRank() == 1) + instData = {instWidth}; + else { + int instHeight = xegpu::getLargestDivisor( + static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight); + if (instHeight == -1) + prefetch.emitWarning( + "No suitable instruction multiple found for the given shape."); + instData = {instHeight, instWidth}; + } + + if (layoutKind == LayoutKind::InstData) + prefetchLayout = + LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData)); + else + prefetchLayout = getDefaultSIMTLayoutInfo( + tdescTy, uArch, uArchInstruction->getPackedFormatBitSize()); + + prefetch.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get())); } - auto prefetchLayout = getDefaultSIMTLayoutInfo( - tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize()); // Propagate the layout to the source tensor descriptor. propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout)); } @@ -539,23 +580,39 @@ void LayoutInfoPropagation::visitVectorBroadCastOp( // Only consider vector to vector broadcasts for now. VectorType resultTy = broadcast.getResultVectorType(); VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType()); - if (!sourceTy) { - broadcast.emitWarning("Expecting source type to be a vector type."); + // skip layout propagation for non-vector source operand. + if (!sourceTy) return; - } - // Only consider nD -> nD broadcast. + // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case. if (sourceTy.getRank() != resultTy.getRank()) { - broadcast.emitWarning("Expecting source and result to have same rank."); + auto sourceDims = sourceTy.getShape(); + auto resultDims = resultTy.getShape(); + SmallVector<int64_t> bcastDims; + auto dimDiff = resultTy.getRank() - sourceTy.getRank(); + // adding the missing leading dims + for (int i = 0; i < dimDiff; i++) + bcastDims.push_back(i); + + // for the rest dims in the resultTy, if sourceTy dim is 1, then it's + // broadcasted dim + for (size_t i = 0; i < sourceDims.size(); i++) + if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1)) + bcastDims.push_back(i + dimDiff); + + // create a slice layout for the source + xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get( + broadcast->getContext(), + cast<xegpu::DistributeLayoutAttr>(resultLayout.get()), + DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims)); + + propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout))); return; } + SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims(); - if (broadcastUnitDims.size() != 1) { - broadcast.emitWarning("Expecting source type to be nD vector only with " - "one broadcasted dimension."); - return; - } - // Propagate the result layout to the source operand. + resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get()) + .setUnitDimData(broadcastUnitDims); propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); } @@ -600,55 +657,97 @@ void LayoutInfoPropagation::visitUpdateNdOffsetOp( void LayoutInfoPropagation::visitDpasOp( xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - VectorType aTy = dpas.getLhsType(); - VectorType bTy = dpas.getRhsType(); - - auto uArch = getUArch(getChipStr(dpas).value_or("")); - const int subgroupSize = uArch->getSubgroupSize(); - const auto *uArchInstruction = - dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction( - xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)); - - const unsigned dataALen = aTy.getShape().front(); - auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType()); - const int maxALen = - xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen)); - if (maxALen == -1) - dpas.emitWarning( - "No suitable instruction multiple found for the given shape."); - - const unsigned dataBLen = bTy.getShape().back(); - auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType()); - const int maxBLen = - xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen)); - if (maxBLen == -1) - dpas.emitWarning( - "No suitable instruction multiple found for the given shape."); - SmallVector<int> instDataA = {maxALen, subgroupSize}; - SmallVector<int> instDataB = {subgroupSize, maxBLen}; - - propagateIfChanged(operands[0], - operands[0]->meet(getSIMTLayoutInfoForDPASOperand( - aTy, 0, uArch, instDataA, - uArchInstruction->getPackedFormatBitSizeA()))); - propagateIfChanged(operands[1], - operands[1]->meet(getSIMTLayoutInfoForDPASOperand( - bTy, 1, uArch, instDataB, - uArchInstruction->getPackedFormatBitSizeB()))); - if (operands.size() > 2) { - VectorType cTy = dpas.getAccType(); - const unsigned dataCLen = bTy.getShape().back(); - auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType()); - const int maxCLen = - xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen)); - if (maxCLen == -1) + + LayoutInfo dpasALayout; + LayoutInfo dpasBLayout; + LayoutInfo dpasCDLayout; + + xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr(); + if (hasParamsOfLayoutKind(anchorLayoutCD)) { + xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr(); + xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr(); + assert(hasParamsOfLayoutKind(anchorLayoutA) && + "Expected anchor layout for DPAS A operand."); + assert(hasParamsOfLayoutKind(anchorLayoutB) && + "Expected anchor layout for DPAS B operand."); + dpasALayout = LayoutInfo(anchorLayoutA); + dpasBLayout = LayoutInfo(anchorLayoutB); + dpasCDLayout = LayoutInfo(anchorLayoutCD); + + } else { + + VectorType aTy = dpas.getLhsType(); + VectorType bTy = dpas.getRhsType(); + + auto uArch = getUArch(getChipStr(dpas).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction( + xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)); + + const unsigned dataALen = aTy.getShape().front(); + auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType()); + const int maxALen = + xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen)); + if (maxALen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + + const unsigned dataBLen = bTy.getShape().back(); + auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType()); + + const int maxBLen = + xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen)); + + if (maxBLen == -1) dpas.emitWarning( "No suitable instruction multiple found for the given shape."); - SmallVector<int> instDataC = {maxALen, maxCLen}; - propagateIfChanged(operands[2], - operands[2]->meet(getSIMTLayoutInfoForDPASOperand( - cTy, 2, uArch, instDataC, - uArchInstruction->getPackedFormatBitSizeB()))); + SmallVector<int> instDataA = {maxALen, subgroupSize}; + SmallVector<int> instDataB = {subgroupSize, maxBLen}; + + if (layoutKind == LayoutKind::InstData) { + dpasALayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA)); + dpasBLayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB)); + } else { + dpasALayout = getSIMTLayoutInfoForDPASOperand( + aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA()); + dpasBLayout = getSIMTLayoutInfoForDPASOperand( + bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB()); + } + + if (operands.size() > 2) { + VectorType cTy = dpas.getAccType(); + if (layoutKind == LayoutKind::InstData) { + const unsigned dataCLen = bTy.getShape().back(); + auto supportedCLen = + uArchInstruction->getSupportedN(bTy.getElementType()); + const int maxCLen = xegpu::getLargestDivisor( + dataCLen, ArrayRef<unsigned>(supportedCLen)); + if (maxCLen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + SmallVector<int> instDataC = {maxALen, maxCLen}; + dpasCDLayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC)); + } else + dpasCDLayout = getSIMTLayoutInfoForDPASOperand( + cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB()); + + dpas.setLayoutCdAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get())); + } + dpas.setLayoutAAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get())); + dpas.setLayoutBAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get())); + } + + propagateIfChanged(operands[0], operands[0]->meet(dpasALayout)); + propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout)); + if (operands.size() > 2) { + propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout)); } } @@ -657,37 +756,50 @@ void LayoutInfoPropagation::visitStoreNdOp( xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - auto uArch = getUArch(getChipStr(store).value_or("")); - const auto *uArchInstruction = - dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>( - uArch->getInstruction( - xegpu::uArch::InstructionKind::Subgroup2DBlockStore)); - VectorType dataTy = store.getValueType(); - auto blockWHC = uArchInstruction->getBlockWidthHeightCount( - store.getValueType().getElementType()); - if (!blockWHC) - store.emitWarning("No known block params found for the element type."); - auto [bWidth, bHeight, bCount] = blockWHC.value(); - SmallVector<int> instData; - int instWidth = xegpu::getLargestDivisor( - static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth, - bCount); - if (instWidth == -1) - store.emitWarning( - "No suitable instruction multiple found for the given shape."); - if (dataTy.getRank() == 1) - instData = {instWidth}; - else { - int instHeight = xegpu::getLargestDivisor( - static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight); - if (instHeight == -1) + LayoutInfo storeLayout; + xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + storeLayout = LayoutInfo(anchorLayout); + } else { + auto uArch = getUArch(getChipStr(store).value_or("")); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>( + uArch->getInstruction( + xegpu::uArch::InstructionKind::Subgroup2DBlockStore)); + VectorType dataTy = store.getValueType(); + auto blockWHC = uArchInstruction->getBlockWidthHeightCount( + store.getValueType().getElementType()); + if (!blockWHC) + store.emitWarning("No known block params found for the element type."); + auto [bWidth, bHeight, bCount] = blockWHC.value(); + SmallVector<int> instData; + int instWidth = xegpu::getLargestDivisor( + static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth); + if (instWidth == -1) store.emitWarning( "No suitable instruction multiple found for the given shape."); - instData = {instHeight, instWidth}; + if (dataTy.getRank() == 1) + instData = {instWidth}; + else { + int instHeight = xegpu::getLargestDivisor( + static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight); + if (instHeight == -1) + store.emitWarning( + "No suitable instruction multiple found for the given shape."); + instData = {instHeight, instWidth}; + } + + if (layoutKind == LayoutKind::InstData) + storeLayout = + LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData)); + else + storeLayout = + getDefaultSIMTLayoutInfo(store.getValueType(), uArch, + uArchInstruction->getPackedFormatBitSize()); + store.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get())); } - LayoutInfo storeLayout = - getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData, - uArchInstruction->getPackedFormatBitSize()); + // Propagate the layout to the value operand. // Both operands should have the same layout for (LayoutInfoLattice *operand : operands) propagateIfChanged(operand, operand->meet(storeLayout)); @@ -698,21 +810,30 @@ void LayoutInfoPropagation::visitStoreNdOp( void LayoutInfoPropagation::visitLoadNdOp( xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - LayoutInfo valueLayout = results[0]->getValue(); - // Need the layout of the value to propagate to the tensor descriptor. - if (!valueLayout.isAssigned()) - return; - LayoutInfo tensorDescLayout = valueLayout; - // LoadNdOp has the transpose effect. However, at the stage of this analysis - // this effect is not expected and should be abstracted away. Emit a - // warning. - if (auto transpose = load.getTranspose()) { - load.emitWarning("Transpose effect is not expected for LoadNdOp at " - "LayoutInfoPropagation stage."); - tensorDescLayout = valueLayout.transpose(transpose.value()); + + LayoutInfo loadLayout; + xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + loadLayout = LayoutInfo(anchorLayout); + } else { + + LayoutInfo valueLayout = results[0]->getValue(); + // Need the layout of the value to propagate to the tensor descriptor. + if (!valueLayout.isAssigned()) + return; + loadLayout = valueLayout; + // LoadNdOp has the transpose effect. However, at the stage of this analysis + // this effect is not expected and should be abstracted away. Emit a + // warning. + if (auto transpose = load.getTranspose()) { + load.emitWarning("Transpose effect is not expected for LoadNdOp at " + "LayoutInfoPropagation stage."); + loadLayout = valueLayout.transpose(transpose.value()); + } + load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get())); } // Propagate the new layout to the tensor descriptor operand. - propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout)); + propagateIfChanged(operands[0], operands[0]->meet(loadLayout)); } /// For vector::TransposeOp, the layout of the result is transposed and @@ -802,33 +923,48 @@ void LayoutInfoPropagation::visitVectorBitcastOp( void LayoutInfoPropagation::visitLoadGatherOp( xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // The layout is strictly determined by the payload type. - auto payloadTy = dyn_cast<VectorType>(load.getValueType()); - if (!payloadTy) { - load.emitWarning("Not propagating, non-vector payload supplied."); - return; - } - auto uArch = getUArch(getChipStr(load).value_or("")); - const int subgroupSize = uArch->getSubgroupSize(); - SmallVector<int> instData{subgroupSize}; - if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1) - instData.push_back(chunkSize); - else if (auto srcTdescTy = - dyn_cast<xegpu::TensorDescType>(load.getSourceType())) { - if (srcTdescTy.getChunkSizeAsInt() > 1) + + LayoutInfo loadLayout; + LayoutInfo maskLayout; + xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + loadLayout = LayoutInfo(anchorLayout); + maskLayout = loadLayout; + } else { + + // The layout is strictly determined by the payload type. + VectorType payloadTy = load.getValueType(); + if (!payloadTy) { + load.emitWarning("Not propagating, non-vector payload supplied."); + return; + } + auto uArch = getUArch(getChipStr(load).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + SmallVector<int> instData{subgroupSize}; + if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1) instData.push_back(chunkSize); - } - LayoutInfo layout = getDefaultSIMTLayoutInfo( - payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(), - /*scattered*/ true); + else if (auto srcTdescTy = + dyn_cast<xegpu::TensorDescType>(load.getSourceType())) { + if (srcTdescTy.getChunkSizeAsInt() > 1) + instData.push_back(chunkSize); + } + + if (layoutKind == LayoutKind::InstData) + loadLayout = + LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData)); + else + loadLayout = getDefaultSIMTLayoutInfo( + payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(), + /*scattered*/ true); - // Mask operand should have 1D default layout. - LayoutInfo maskLayout = - getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize); + // Mask operand should have 1D default layout. + maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize); + load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get())); + } // Propagate the new layout to the tensor descriptor operand. if (isa<xegpu::TensorDescType>(load.getSourceType())) - propagateIfChanged(operands[0], operands[0]->meet(layout)); + propagateIfChanged(operands[0], operands[0]->meet(loadLayout)); // Propagate the new layout to the mask and optional offset operand. propagateIfChanged(operands[1], operands[1]->meet(maskLayout)); if (load.getOffsets()) @@ -856,45 +992,56 @@ void LayoutInfoPropagation::visitCreateDescOp( void LayoutInfoPropagation::visitStoreScatterOp( xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // Currently, for 2D StoreScatterOp we expect that the height dimension of - // the tensor descriptor is equal to the subgroup size. This is ensured by - // the op verifier. - auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType()); - if (!payloadTy) { - storeScatter.emitWarning("Not propagating, non-vector payload supplied."); - return; - } - auto uArch = getUArch(getChipStr(storeScatter).value_or("")); - const int subgroupSize = uArch->getSubgroupSize(); - - auto payloadShape = payloadTy.getShape(); - if (payloadShape.size() > 1) - assert( - payloadShape[0] == subgroupSize && - "Expected the first dimension of 2D tensor descriptor to be equal to " - "subgroup size."); - - SmallVector<int> instData{subgroupSize}; - if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1) - instData.push_back(chunkSize); - else if (auto dstTdescTy = - dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) { - if (dstTdescTy.getChunkSizeAsInt() > 1) - instData.push_back(chunkSize); - } LayoutInfo payloadLayout; - - if (auto layout = storeScatter.getLayoutAttr()) { - payloadLayout = LayoutInfo(layout); + LayoutInfo maskLayout; + xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + payloadLayout = LayoutInfo(anchorLayout); + maskLayout = payloadLayout; } else { - payloadLayout = getDefaultSIMTLayoutInfo( - payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(), - /*scattered=*/true); - } + // Currently, for 2D StoreScatterOp we expect that the height dimension of + // the tensor descriptor is equal to the subgroup size. This is ensured by + // the op verifier. + VectorType payloadTy = storeScatter.getValueType(); + if (!payloadTy) { + storeScatter.emitWarning("Not propagating, non-vector payload supplied."); + return; + } - LayoutInfo maskLayout = - getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize); + auto uArch = getUArch(getChipStr(storeScatter).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + + if (layoutKind == LayoutKind::InstData) { + SmallVector<int> instData{subgroupSize}; + if (auto chunkSize = storeScatter.getChunkSize().value_or(0); + chunkSize > 1) + instData.push_back(chunkSize); + else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>( + storeScatter.getDestType())) { + if (dstTdescTy.getChunkSizeAsInt() > 1) + instData.push_back(chunkSize); + } + payloadLayout = LayoutInfo( + xegpu::LayoutAttr::get(storeScatter.getContext(), instData)); + } else { + auto payloadShape = payloadTy.getShape(); + if (payloadShape.size() > 1) + assert(payloadShape[0] == subgroupSize && + "Expected the first dimension of 2D tensor descriptor to be " + "equal to " + "subgroup size."); + payloadLayout = getDefaultSIMTLayoutInfo( + payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(), + /*scattered=*/true); + } + + maskLayout = + getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize); + + storeScatter.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get())); + } // Propagate the payload operand layout propagateIfChanged(operands[0], operands[0]->meet(payloadLayout)); // Propagate the destination (if tdesc) operand layout @@ -916,10 +1063,10 @@ class RunLayoutInfoPropagation { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation) - RunLayoutInfoPropagation(Operation *op) : target(op) { + RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) { SymbolTableCollection symbolTable; loadBaselineAnalyses(solver); - solver.load<LayoutInfoPropagation>(symbolTable); + solver.load<LayoutInfoPropagation>(symbolTable, layoutKind); (void)solver.initializeAndRun(op); } @@ -1159,7 +1306,18 @@ struct XeGPUPropagateLayoutPass final } // namespace void XeGPUPropagateLayoutPass::runOnOperation() { - auto &analysis = getAnalysis<RunLayoutInfoPropagation>(); + LayoutKind layoutKind; + if (this->layoutKind == "lane") { + layoutKind = LayoutKind::Lane; + } else if (this->layoutKind == "inst") { + layoutKind = LayoutKind::InstData; + } else { + getOperation()->emitError("Unsupported layout kind option: " + + this->layoutKind); + signalPassFailure(); + return; + } + RunLayoutInfoPropagation analysis(getOperation(), layoutKind); // Print the analysis result and exit. (for debugging purposes) if (printOnly) { auto &os = llvm::outs(); @@ -1173,8 +1331,6 @@ void XeGPUPropagateLayoutPass::runOnOperation() { return {}; xegpu::DistributeLayoutAttr layoutAttr = cast<xegpu::DistributeLayoutAttr>(layout.get()); - if (this->layoutKind == "lane") - layoutAttr = layoutAttr.dropInstData(); if (layout.isSliceLayout()) return cast<xegpu::SliceAttr>(layoutAttr); return cast<xegpu::LayoutAttr>(layoutAttr); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index bbd7733..ca81c3c 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -99,7 +99,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout, for (auto [i, dim] : llvm::enumerate(originalType.getShape())) { if (i < distributionStart) continue; - // Check if the dimension can be distributed evenly. if (dim % effectiveLaneLayout[i - distributionStart] != 0) return failure(); @@ -174,6 +173,21 @@ static bool requireTranspose(const xegpu::LayoutAttr layout, return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1; } +/// Given a vector type and its distributed vector type, return the list of +/// dimensions that are distributed. +static SmallVector<int64_t> getDistributedDims(VectorType originalType, + VectorType distributedType) { + assert(originalType.getRank() == distributedType.getRank() && + "sequential and distributed vector types must have the same rank"); + SmallVector<int64_t> distributedDims; + for (int64_t i = 0; i < originalType.getRank(); ++i) { + if (distributedType.getDimSize(i) != originalType.getDimSize(i)) { + distributedDims.push_back(i); + } + } + return distributedDims; +} + /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is /// contained within a WarpExecuteOnLane0Op. @@ -926,8 +940,7 @@ static SmallVector<Value> computeDistributedCoordinatesForMatrixOp( SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned( rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]), getAsOpFoldResult(origOffsets)); - newCoods = llvm::to_vector(llvm::map_range( - ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); })); + newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>); return newCoods; } @@ -990,9 +1003,8 @@ struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern { SmallVector<Value> newOperands = llvm::map_to_vector( newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); - SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()}; - std::fill(newConstOffsets.begin(), newConstOffsets.end(), - ShapedType::kDynamic); + SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(), + ShapedType::kDynamic); DenseI64ArrayAttr newConstOffsetsAttr = rewriter.getDenseI64ArrayAttr(newConstOffsets); ValueRange currentOffsets = @@ -1067,9 +1079,8 @@ struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern { SmallVector<Value> newOperands = llvm::map_to_vector( newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); - SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()}; - std::fill(newConstOffsets.begin(), newConstOffsets.end(), - ShapedType::kDynamic); + SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(), + ShapedType::kDynamic); DenseI64ArrayAttr newConstOffsetsAttr = rewriter.getDenseI64ArrayAttr(newConstOffsets); ValueRange currentOffsets = @@ -1412,6 +1423,166 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { } }; +/// This pattern distributes the `vector.broadcast` operation across lanes in a +/// warp. The pattern supports three use cases: +/// +/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input +/// vector +/// must have a slice layout of the result. If the distributed source and +/// target vector types are identical, this lowers to a no-op; otherwise, it +/// remains a broadcast but operates on distributed vectors. +/// +/// 2) Broadcast a same-rank vector with identical layouts for source and +/// target: +/// The source vector must have unit dimensions, and lane_data must be unit +/// size for those unit dims. This always lowers to a no-op. +/// +/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from +/// scalar to distributed result type. +/// +/// Example 1 (lowering to a broadcast with distributed types): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [0]> } : () -> (vector<32xf32>) +/// %2 = vector.broadcast %0 {layout_result_0 = +/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>} +/// : vector<32xf32> to vector<8x32xf32> +/// gpu.yield %1 : vector<8x32xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [0]> } : () -> (vector<32xf32>) +/// gpu.yield %0 : vector<32xf32> +/// } +/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32> +/// +/// Example 2 (no-op): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [1]> } : () -> (vector<8xf32>) +/// %1 = vector.shape_cast %0 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8xf32> to vector<8x1xf32> +/// %2 = vector.broadcast %1 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8x1xf32> to vector<8x32xf32> +/// gpu.yield %1 : vector<8x32xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [1]> } : () -> (vector<8xf32>) +/// %1 = vector.shape_cast %0 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8xf32> to vector<8x1xf32> +/// gpu.yield %1 : vector<8x1xf32> +/// } +/// // The broadcast is implicit through layout transformation (no-op) +/// "some_use"(%r#0) +/// ``` +struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>); + if (!yieldOperand) + return failure(); + auto broadcastOp = + cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp()); + unsigned operandIdx = yieldOperand->getOperandNumber(); + + VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType()); + VectorType destType = + dyn_cast<VectorType>(broadcastOp.getResult().getType()); + + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(broadcastOp->getOpOperand(0)); + xegpu::DistributeLayoutAttr resultLayout = + xegpu::getDistributeLayoutAttr(broadcastOp.getResult()); + + FailureOr<VectorType> sourceDistType; + Type sourceElemOrDistType; + if (sourceType) { + + // Case 1 and 2: source is a vector type. + int64_t rankDiff = destType.getRank() - sourceType.getRank(); + if (rankDiff > 0) { + // Case 1: source is lower-rank than result. + bool isSliceOf = sourceLayout.isSliceOf(resultLayout); + if (!isSliceOf) + return rewriter.notifyMatchFailure( + warpOp, + "Broadcast input layout must be a slice of result layout."); + } + // case 2: source and result have same rank + if (rankDiff == 0) { + SetVector<int64_t> broadcastUnitDims = + broadcastOp.computeBroadcastedUnitDims(); + resultLayout = resultLayout.setUnitDimData(broadcastUnitDims); + bool isEqualTo = sourceLayout.isEqualTo(resultLayout); + if (!isEqualTo) + return rewriter.notifyMatchFailure( + warpOp, "For same-rank broadcast, source must be identical to " + "adjusted result layouts with unit dims."); + sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims); + } + + sourceDistType = + getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType); + if (failed(sourceDistType)) { + return rewriter.notifyMatchFailure( + warpOp, "Failed to distribute the source vector type."); + } + sourceElemOrDistType = sourceDistType.value(); + + } else { + // Case 3: source is a scalar type. + if (sourceLayout) { + return rewriter.notifyMatchFailure( + warpOp, "Broadcast from scalar must not have a layout attribute."); + } + sourceElemOrDistType = broadcastOp.getSourceType(); + } + FailureOr<VectorType> destDistType = + getDistVecTypeBasedOnLaneLayout(resultLayout, destType); + if (failed(destDistType)) { + return rewriter.notifyMatchFailure( + warpOp, "Failed to distribute the dest vector type."); + } + + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType, + newRetIndices); + + Value distributedSource = newWarpOp.getResult(newRetIndices[0]); + + Value newBroadcast = distributedSource; + + if (sourceElemOrDistType != destDistType.value()) { + rewriter.setInsertionPointAfter(newWarpOp); + newBroadcast = + vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(), + destDistType.value(), distributedSource); + } + + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast); + return success(); + } +}; + /// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing /// `gpu.warp_execute_on_lane_0` region. struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern { @@ -1472,6 +1643,226 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern { } }; +// Distribute a `vector.extract_strided_slice` op feeding into yield op of an +// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers +// advanced cases where the distributed dimension is partially extracted and +// currently not supported by the generic vector distribution patterns. +struct VectorExtractStridedSliceDistribution + : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>); + if (!operand) + return failure(); + auto extractOp = + cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp()); + unsigned operandIdx = operand->getOperandNumber(); + auto distributedType = + cast<VectorType>(warpOp.getResult(operandIdx).getType()); + // Find the distributed dimensions. + auto extractResultType = cast<VectorType>(operand->get().getType()); + auto distributedDims = + getDistributedDims(extractResultType, distributedType); + // Collect updated source type, sizes and offsets. They may be adjusted + // later if the data is distributed to lanes (as opposed to being owned by + // all lanes uniformly). + VectorType updatedSourceType = extractOp.getSourceVectorType(); + SmallVector<Attribute> updatedSizes = llvm::map_to_vector( + extractOp.getSizes(), [](Attribute attr) { return attr; }); + SmallVector<Attribute> updatedOffsets = llvm::map_to_vector( + extractOp.getOffsets(), [](Attribute attr) { return attr; }); + // If the result is distributed, it must be distributed in exactly one + // dimension. In this case, we adjust the sourceDistType, distributedSizes + // and distributedOffsets accordingly. + if (distributedDims.size() > 0) { + if (distributedDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, "Source can not be distributed in multiple dimensions."); + int64_t distributedDim = distributedDims[0]; + int sourceDistrDimSize = + extractOp.getSourceVectorType().getShape()[distributedDim]; + auto sourceLayout = + xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0)); + if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty()) + return rewriter.notifyMatchFailure( + warpOp, "the source of extract_strided_slice op lacks distribution " + "layout"); + auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt(); + // Because only single dimension distribution is supported, lane layout + // size at the distributed dim must be the subgroup size. + int subgroupSize = sourceLaneLayout[distributedDim]; + // Check if the source size in the distributed dimension is a multiple of + // subgroup size. + if (sourceDistrDimSize % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, + "Source size along distributed dimension is not a multiple of " + "subgroup size."); + auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt(); + // We expect lane data to be all ones in this case. + if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; })) + return rewriter.notifyMatchFailure( + warpOp, "Expecting unit lane data in source layout"); + // The offsets in the distributed dimention must be a multiple of subgroup + // size. + int64_t distrDimOffset = + cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt(); + if (distrDimOffset % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, "Offset along distributed dimension " + "is not a multiple of subgroup size."); + updatedSourceType = getDistVecTypeBasedOnLaneLayout( + sourceLayout, extractOp.getSourceVectorType()) + .value(); + // Update the distributed sizes to match the distributed type. + updatedSizes[distributedDim] = rewriter.getI64IntegerAttr( + distributedType.getDimSize(distributedDim)); + // Update the distributed offsets to match round robin distribution (i.e. + // each lane owns data at `subgroupSize` stride given unit lane data). + updatedOffsets[distributedDim] = + rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize); + } + // Do the distribution by yielding the source of the extract op from + // the warp op and creating a new extract op outside the warp op. + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType}, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value source = newWarpOp.getResult(newRetIndices[0]); + // Create a new extract op outside the warp op. + Value newExtractOp = vector::ExtractStridedSliceOp::create( + rewriter, extractOp.getLoc(), distributedType, source, + ArrayAttr::get(rewriter.getContext(), updatedOffsets), + ArrayAttr::get(rewriter.getContext(), updatedSizes), + extractOp.getStrides()); + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp); + return success(); + } +}; + +/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an +/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers +/// advanced cases where the distributed dimension is partially inserted and +/// currently not supported by the generic vector distribution patterns. +struct VectorInsertStridedSliceDistribution + : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto insertOp = + operand->get().getDefiningOp<vector::InsertStridedSliceOp>(); + auto distributedType = + cast<VectorType>(warpOp.getResult(operandNumber).getType()); + // Find the distributed dimensions of the dest vector. + auto insertResultType = cast<VectorType>(operand->get().getType()); + auto destDistributedDims = + getDistributedDims(insertResultType, distributedType); + // Collect updated offsets, source type and dest type. They may be adjusted + // later if the data is distributed to lanes (as opposed to being owned by + // all lanes uniformly). + SmallVector<Attribute> updatedOffsets = llvm::map_to_vector( + insertOp.getOffsets(), [](Attribute attr) { return attr; }); + VectorType updatedSourceType = insertOp.getSourceVectorType(); + VectorType updatedDestType = insertOp.getDestVectorType(); + if (destDistributedDims.size() > 0) { + // Only single dimension distribution is supported. + if (destDistributedDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, + "Expecting source to be distributed in a single dimension."); + int64_t destDistributedDim = destDistributedDims[0]; + + VectorType srcType = insertOp.getSourceVectorType(); + VectorType destType = insertOp.getDestVectorType(); + // Currently we require that both source (kD) and dest (nD) vectors are + // distributed. This requires that distributedDim (d) is contained in the + // last k dims of the dest vector (d >= n - k). + int64_t sourceDistributedDim = + destDistributedDim - (destType.getRank() - srcType.getRank()); + if (sourceDistributedDim < 0) + return rewriter.notifyMatchFailure( + insertOp, + "distributed dimension must be in the last k (i.e. source " + "rank) dims of dest vector"); + int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim); + // Obtain the source and dest layouts. + auto destLayout = + xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1)); + auto sourceLayout = + xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0)); + if (!destLayout || !sourceLayout || + destLayout.getEffectiveLaneLayoutAsInt().empty() || + sourceLayout.getEffectiveLaneLayoutAsInt().empty()) + return rewriter.notifyMatchFailure( + warpOp, "the source or dest of insert_strided_slice op lacks " + "distribution layout"); + // Because only single dimension distribution is supported, lane layout + // size at the distributed dim must be the subgroup size. + int subgroupSize = + destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim]; + // We require that source and dest lane data are all ones to ensure + // uniform round robin distribution. + auto destLaneData = destLayout.getEffectiveLaneDataAsInt(); + auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt(); + if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) || + !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; })) + return rewriter.notifyMatchFailure( + warpOp, "Expecting unit lane data in source and dest layouts"); + // Source distributed dim size must be multiples of subgroup size. + if (srcDistrDimSize % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, "Distributed dimension size in source is not a multiple of " + "subgroup size."); + // Offsets in the distributed dimension must be multiples of subgroup + // size. + int64_t destDistrDimOffset = + cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt(); + if (destDistrDimOffset % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, + "Offset along distributed dimension in dest is not a multiple of " + "subgroup size."); + // Update the source and dest types based on their layouts. + updatedSourceType = getDistVecTypeBasedOnLaneLayout( + sourceLayout, insertOp.getSourceVectorType()) + .value(); + updatedDestType = getDistVecTypeBasedOnLaneLayout( + destLayout, insertOp.getDestVectorType()) + .value(); + // Update the distributed offsets to match round robin distribution (i.e. + // each lane owns data at `subgroupSize` stride given unit lane data). + updatedOffsets[destDistributedDim] = + rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize); + } + // Do the distribution by yielding the source and dest of the insert op + // from the warp op and creating a new insert op outside the warp op. + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()}, + {updatedSourceType, updatedDestType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + + Value valueToStore = newWarpOp.getResult(newRetIndices[0]); + Value dest = newWarpOp.getResult(newRetIndices[1]); + // Create a new insert op outside the warp op. + Value newInsertOp = vector::InsertStridedSliceOp::create( + rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest, + ArrayAttr::get(rewriter.getContext(), updatedOffsets), + insertOp.getStrides()); + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), + newInsertOp); + return success(); + } +}; + /// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an /// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op /// outside of the warp op. @@ -1629,9 +2020,13 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( MemrefExtractAlignedPointerAsIndexDistribution>( patterns.getContext(), /*pattern benefit=*/regularPatternBenefit); - patterns.add<VectorShapeCastDistribution>( - patterns.getContext(), - /*pattern benefit=*/highPatternBenefit); + // For following patterns, we need to override the regular vector distribution + // patterns. Therefore, assign higher benefit. + patterns + .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution, + VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>( + patterns.getContext(), + /*pattern benefit=*/highPatternBenefit); } void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns( diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index c3bf960..af63f09 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -238,6 +238,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -255,7 +258,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> { auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value { xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); // return dummy Value to satisfy function's signature return nullptr; }; @@ -282,6 +285,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -306,7 +312,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> { return xegpu::LoadNdOp::create( rewriter, loc, newValueTy, convertedTdescs[0], offsets, op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + op.getL2HintAttr(), op.getL3HintAttr(), layout); }; newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape, createLoad, loc, rewriter); @@ -331,6 +337,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -354,7 +363,7 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> { xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++], convertedTdescs[0], offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); // return dummy Value to satisfy function's signature return nullptr; }; @@ -678,7 +687,7 @@ struct UnrollLoadGatherOpWithOffset pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); } - auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr()); + auto layout = op.getLayoutAttr(); if (layout) layout = layout.dropInstData(); @@ -778,7 +787,7 @@ struct UnrollStoreScatterOpWithOffsets SmallVector<Value> convertedValues = pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); - auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr()); + auto layout = op.getLayoutAttr(); if (layout) layout = layout.dropInstData(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 0a9ef0a..be82cda 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -86,8 +86,16 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, if (origOffsets.empty()) return failure(); + // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr() + xegpu::DistributeLayoutAttr layout; + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> || + std::is_same_v<OpType, xegpu::StoreMatrixOp>) { + layout = op.getLayoutAttr(); + } else { + layout = op.getDescLayoutAttr(); + } + // not applicable to ops without workgroup layout attributes - xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -190,7 +198,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { xegpu::TensorDescType tdescTy = op.getType(); ArrayRef<int64_t> wgShape = tdescTy.getShape(); Type elemTy = tdescTy.getElementType(); - xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr(); SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; auto newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), @@ -309,6 +317,9 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); SmallVector<Value> newOps; for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { @@ -318,7 +329,7 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { auto newOp = xegpu::LoadNdOp::create( rewriter, op.getLoc(), newResTy, tdesc, offsets, /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + op.getL2HintAttr(), op.getL3HintAttr(), layout); newOps.push_back(newOp); } rewriter.replaceOpWithMultiple(op, {newOps}); @@ -339,11 +350,14 @@ struct WgToSgStoreNdOpWithOffset if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); for (auto [v, tdesc, offsets] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) { xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); } rewriter.eraseOp(op); @@ -363,11 +377,14 @@ struct WgToSgPrefetchNdOpWithOffset if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); } rewriter.eraseOp(op); @@ -489,10 +506,8 @@ struct WgToSgVectorBroadcastOp for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), + layout.dropSgLayoutAndData()); newBroadcastOps.push_back(newBroadcast.getResult()); } @@ -738,12 +753,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { Location loc = op.getLoc(); auto eltType = vecType.getElementType(); - auto setLayoutIfNeeded = [&](Value val) { - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val), - layout.dropSgLayoutAndData()); - } + auto setLayout = [&](Value val) { + xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val), + layout.dropSgLayoutAndData()); }; if (vecAttr.isSplat()) { @@ -751,14 +763,14 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { Attribute singleVal = vecAttr.getSplatValue<Attribute>(); auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); - setLayoutIfNeeded(cstOp->getResult(0)); + setLayout(cstOp->getResult(0)); rewriter.replaceOp(op, cstOp); return success(); } else if (sgShape == wgShape) { // if the entire vector is shared by all // subgroups, don't distribute auto newConstOp = arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); - setLayoutIfNeeded(newConstOp->getResult(0)); + setLayout(newConstOp->getResult(0)); rewriter.replaceOp(op, newConstOp); return success(); } else { @@ -860,9 +872,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { rewriter, loc, baseConstVec.getType(), mulOffset); auto finalConst = arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); - setLayoutIfNeeded(baseConstVec); - setLayoutIfNeeded(bcastOffset); - setLayoutIfNeeded(finalConst); + setLayout(baseConstVec); + setLayout(bcastOffset); + setLayout(finalConst); newConstOps.push_back(finalConst); } rewriter.replaceOpWithMultiple(op, {newConstOps}); @@ -889,8 +901,8 @@ struct WgToSgLoadGatherOpWithOffset return failure(); ArrayRef<int64_t> wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>( - xegpu::getDistributeLayoutAttr(op.getResult())); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -913,10 +925,12 @@ struct WgToSgLoadGatherOpWithOffset VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); for (auto [offsets, mask] : llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { + auto newLayout = layout.dropSgLayoutAndData(); auto newLoadOp = xegpu::LoadGatherOp::create( rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), - layout.dropSgLayoutAndData()); + newLayout); + xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout); newLoadOps.push_back(newLoadOp); } rewriter.replaceOpWithMultiple(op, {newLoadOps}); @@ -941,8 +955,8 @@ struct WgToSgStoreScatterOpWithOffset if (!valueType) return failure(); - xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>( - xegpu::getDistributeLayoutAttr(op.getOperand(0))); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getOperand(0)); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -967,14 +981,11 @@ struct WgToSgStoreScatterOpWithOffset op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), layout.dropSgLayoutAndData()); // Update the layout attribute to drop sg_layout and sg_data. - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - for (OpOperand &operand : store->getOpOperands()) { - // Skip for operand one (memref) - if (operand.getOperandNumber() == 1) - continue; - xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); - } + for (OpOperand &operand : store->getOpOperands()) { + // Skip for operand one (memref) + if (operand.getOperandNumber() == 1) + continue; + xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); } } rewriter.eraseOp(op); @@ -1067,15 +1078,12 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> { vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); auto finalSteps = arith::AddIOp::create(rewriter, loc, steps, bcastOffset); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - xegpu::setDistributeLayoutAttr(steps->getResult(0), - layout.dropSgLayoutAndData()); - xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), - layout.dropSgLayoutAndData()); - xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), - layout.dropSgLayoutAndData()); - } + xegpu::setDistributeLayoutAttr(steps->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), + layout.dropSgLayoutAndData()); newOps.push_back(finalSteps); } @@ -1143,10 +1151,8 @@ struct WgToSgVectorShapeCastOp for (auto src : adaptor.getSource()) { auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), newResultType, src); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), + layout.dropSgLayoutAndData()); newShapeCastOps.push_back(newShapeCast.getResult()); } @@ -1207,10 +1213,8 @@ struct WgToSgMultiDimReductionOp auto newOp = vector::MultiDimReductionOp::create( rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0], op.getReductionDims()); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newOp->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newOp->getResult(0), + layout.dropSgLayoutAndData()); newReductions.push_back(newOp.getResult()); } @@ -1283,6 +1287,78 @@ struct WgToSgVectorTransposeOp } }; +// Distribute vector mask ops to work at subgroup level. +template <typename MaskOpType> +struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> { + using OpConversionPattern<MaskOpType>::OpConversionPattern; + + LogicalResult matchAndRewrite( + MaskOpType op, + typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + Location loc = op.getLoc(); + VectorType type = op.getResult().getType(); + auto wgShape = type.getShape(); + + SmallVector<Value> wgMaskDimSizes; + if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) { + for (int64_t maskSize : op.getMaskDimSizes()) { + wgMaskDimSizes.push_back( + arith::ConstantIndexOp::create(rewriter, loc, maskSize)); + } + } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) { + wgMaskDimSizes = llvm::to_vector(op.getOperands()); + } + + Value sgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType resultType = VectorType::get(sgShape, type.getElementType()); + + // In each dimension, each subgroup computes its local mask size as: + // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d]) + SmallVector<Value> newCreateMaskOps; + for (auto offsetSet : *sgOffsets) { + SmallVector<Value> maskOperands; + + for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) { + Value dimSizeVal = + arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); + Value offset = offsetSet[i]; + Value adjustedMaskSize = + arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value nonNegative = + arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); + Value sgMaskSize = + arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal); + maskOperands.push_back(sgMaskSize); + } + + auto newCreateMaskOp = + vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands); + xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0), + layout.dropSgLayoutAndData()); + newCreateMaskOps.push_back(newCreateMaskOp.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newCreateMaskOps}); + return success(); + } +}; + +using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>; +using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>; } // namespace namespace mlir { @@ -1297,7 +1373,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, - WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>( + WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp, + WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>( patterns.getContext()); } } // namespace xegpu @@ -1427,7 +1504,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp, vector::TransposeOp, vector::BroadcastOp, - vector::MultiDimReductionOp>( + vector::MultiDimReductionOp, + vector::ConstantMaskOp, vector::CreateMaskOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index de9e09d..9f126fe 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -140,7 +139,6 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { // for StoreMatrixOp, the layout is attached to the property of the op if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp)) return storeOp.getLayoutAttr(); - std::string layoutName = getLayoutName(result); if (defOp->hasAttr(layoutName)) return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); @@ -308,7 +306,7 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, int64_t rankDiff = srcShapeRank - targetShapeRank; std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff, 1); - std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff); + llvm::copy(shape, adjustedTargetShape.begin() + rankDiff); SmallVector<Value> result; for (SmallVector<int64_t> offsets : @@ -528,7 +526,7 @@ SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder, for (auto [l, r] : llvm::zip_equal(lhs, rhs)) { auto lval = getValueOrCreateConstantIndexOp(builder, loc, l); auto rval = getValueOrCreateConstantIndexOp(builder, loc, r); - results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval)); + results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval)); } return results; } diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp new file mode 100644 index 0000000..f3e38eb --- /dev/null +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -0,0 +1,174 @@ +//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the APFloat infrastructure to MLIR programs as a runtime +// library. APFloat is a software implementation of floating point arithmetics. +// +// On the MLIR side, floating-point values must be bitcasted to 64-bit integers +// before calling a runtime function. If a floating-point type has less than +// 64 bits, it must be zero-extended to 64 bits after bitcasting it to an +// integer. +// +// Runtime functions receive the floating-point operands of the arithmeic +// operation in the form of 64-bit integers, along with the APFloat semantics +// in the form of a 32-bit integer, which will be interpreted as an +// APFloatBase::Semantics enum value. +// +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APSInt.h" + +#ifdef _WIN32 +#ifndef MLIR_APFLOAT_WRAPPERS_EXPORT +#ifdef mlir_apfloat_wrappers_EXPORTS +// We are building this library +#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllexport) +#else +// We are using this library +#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllimport) +#endif // mlir_apfloat_wrappers_EXPORTS +#endif // MLIR_APFLOAT_WRAPPERS_EXPORT +#else +// Non-windows: use visibility attributes. +#define MLIR_APFLOAT_WRAPPERS_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 + +/// Binary operations without rounding mode. +#define APFLOAT_BINARY_OP(OP) \ + MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast<llvm::APFloatBase::Semantics>(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + lhs.OP(rhs); \ + return lhs.bitcastToAPInt().getZExtValue(); \ + } + +/// Binary operations with rounding mode. +#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \ + MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast<llvm::APFloatBase::Semantics>(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + lhs.OP(rhs, ROUNDING_MODE); \ + return lhs.bitcastToAPInt().getZExtValue(); \ + } + +extern "C" { + +#define BIN_OPS_WITH_ROUNDING(X) \ + X(add, llvm::RoundingMode::NearestTiesToEven) \ + X(subtract, llvm::RoundingMode::NearestTiesToEven) \ + X(multiply, llvm::RoundingMode::NearestTiesToEven) \ + X(divide, llvm::RoundingMode::NearestTiesToEven) + +BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE) +#undef BIN_OPS_WITH_ROUNDING +#undef APFLOAT_BINARY_OP_ROUNDING_MODE + +APFLOAT_BINARY_OP(remainder) + +#undef APFLOAT_BINARY_OP + +MLIR_APFLOAT_WRAPPERS_EXPORT void printApFloat(int32_t semantics, uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + double d = x.convertToDouble(); + fprintf(stdout, "%lg", d); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t +_mlir_apfloat_convert(int32_t inSemantics, int32_t outSemantics, uint64_t a) { + const llvm::fltSemantics &inSem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(inSemantics)); + const llvm::fltSemantics &outSem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(outSemantics)); + unsigned bitWidthIn = llvm::APFloatBase::semanticsSizeInBits(inSem); + llvm::APFloat val(inSem, llvm::APInt(bitWidthIn, a)); + // TODO: Custom rounding modes are not supported yet. + bool losesInfo; + val.convert(outSem, llvm::RoundingMode::NearestTiesToEven, &losesInfo); + llvm::APInt result = val.bitcastToAPInt(); + return result.getZExtValue(); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int( + int32_t semantics, int32_t resultWidth, bool isUnsigned, uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned inputWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat val(sem, llvm::APInt(inputWidth, a)); + llvm::APSInt result(resultWidth, isUnsigned); + bool isExact; + // TODO: Custom rounding modes are not supported yet. + val.convertToInteger(result, llvm::RoundingMode::NearestTiesToEven, &isExact); + // This function always returns uint64_t, regardless of the desired result + // width. It does not matter whether we zero-extend or sign-extend the APSInt + // to 64 bits because the generated IR in arith-to-apfloat will truncate the + // result to the desired result width. + return result.getZExtValue(); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int( + int32_t semantics, int32_t inputWidth, bool isUnsigned, uint64_t a) { + llvm::APInt val(inputWidth, a, /*isSigned=*/!isUnsigned); + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + llvm::APFloat result(sem); + // TODO: Custom rounding modes are not supported yet. + result.convertFromAPInt(val, /*IsSigned=*/!isUnsigned, + llvm::RoundingMode::NearestTiesToEven); + return result.bitcastToAPInt().getZExtValue(); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics, + uint64_t a, + uint64_t b) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + llvm::APFloat y(sem, llvm::APInt(bitWidth, b)); + return static_cast<int8_t>(x.compare(y)); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + x.changeSign(); + return x.bitcastToAPInt().getZExtValue(); +} + +/// Min/max operations. +#define APFLOAT_MIN_MAX_OP(OP) \ + MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast<llvm::APFloatBase::Semantics>(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + llvm::APFloat result = llvm::OP(lhs, rhs); \ + return result.bitcastToAPInt().getZExtValue(); \ + } + +APFLOAT_MIN_MAX_OP(minimum) +APFLOAT_MIN_MAX_OP(maximum) +APFLOAT_MIN_MAX_OP(minnum) +APFLOAT_MIN_MAX_OP(maxnum) + +#undef APFLOAT_MIN_MAX_OP +} diff --git a/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp b/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp index 9868ffa..9b1c39e 100644 --- a/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp @@ -49,7 +49,7 @@ extern "C" { /// The recommended strategy is to call `setArmVectorLength` only from functions /// that do not access SVE registers, either by themselves or by inlining other /// functions. -static void setArmVectorLength(std::string_view helper_name, int option, +static void setArmVectorLength(std::string_view helperName, int option, uint32_t bits) { #if defined(__linux__) && defined(__aarch64__) if (bits < 128 || bits > 2048 || !llvm::isPowerOf2_32(bits)) { @@ -63,7 +63,7 @@ static void setArmVectorLength(std::string_view helper_name, int option, abort(); } #else - std::cerr << "[error] " << helper_name << " is unsupported" << std::endl; + std::cerr << "[error] " << helperName << " is unsupported" << std::endl; abort(); #endif } diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt index fdeb4dac..a615352 100644 --- a/mlir/lib/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/CMakeLists.txt @@ -2,6 +2,7 @@ # is a big dependency which most don't need. set(LLVM_OPTIONAL_SOURCES + APFloatWrappers.cpp ArmRunnerUtils.cpp ArmSMEStubs.cpp AsyncRuntime.cpp @@ -167,6 +168,26 @@ if(LLVM_ENABLE_PIC) set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17) target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS) + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + # TODO: This support library is only used on Linux builds until we figure + # out how to hide LLVM symbols in a way that works for all platforms. + add_mlir_library(mlir_apfloat_wrappers + SHARED + APFloatWrappers.cpp + + EXCLUDE_FROM_LIBMLIR + ) + set_target_properties( + mlir_apfloat_wrappers + PROPERTIES CXX_STANDARD 17 + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN ON + ) + target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS) + # Hide LLVM symbols to avoid ODR violations. + target_link_options(mlir_apfloat_wrappers PRIVATE "-Wl,--exclude-libs,ALL") + endif() + add_subdirectory(SparseTensor) add_mlir_library(mlir_c_runner_utils @@ -184,6 +205,11 @@ if(LLVM_ENABLE_PIC) set_property(TARGET mlir_c_runner_utils PROPERTY CXX_STANDARD 17) target_compile_definitions(mlir_c_runner_utils PRIVATE mlir_c_runner_utils_EXPORTS) + # Conditionally link apfloat wrappers only on Linux. + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + target_link_libraries(mlir_c_runner_utils PUBLIC mlir_apfloat_wrappers) + endif() + add_mlir_library(mlir_runner_utils SHARED RunnerUtils.cpp @@ -195,6 +221,11 @@ if(LLVM_ENABLE_PIC) ) target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS) + # Conditionally link apfloat wrappers only on Linux. + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + target_link_libraries(mlir_runner_utils PUBLIC mlir_apfloat_wrappers) + endif() + add_mlir_library(mlir_async_runtime SHARED AsyncRuntime.cpp @@ -323,7 +354,6 @@ if(LLVM_ENABLE_PIC) endif() string(STRIP AGENTS_STRING ${AGENTS_STRING}) string(REPLACE "\n" ";" AGENTS_LIST ${AGENTS_STRING}) - list(FILTER AGENTS_LIST EXCLUDE REGEX "gfx000") if (AGENTS_LIST STREQUAL "") message(SEND_ERROR "No non-CPU ROCm agents found on the system, and ROCM_TEST_CHIPSET is not defined") else() diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp index 6cc2b7fd..f203363 100644 --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -57,7 +57,7 @@ thread_local static int32_t defaultDevice = 0; /// Helper method that checks environment value for debugging. -bool isDebugEnabled() { +static bool isDebugEnabled() { const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG"; static bool isEnabled = getenv(kDebugEnvironmentVariable) != nullptr; return isEnabled; @@ -71,7 +71,7 @@ bool isDebugEnabled() { } while (0) // Returns default CUdevice -CUdevice getDefaultCuDevice() { +static CUdevice getDefaultCuDevice() { CUdevice device; CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); return device; diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 2255633..287c52a 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -146,12 +146,10 @@ static void packFunctionArguments(Module *module) { llvm::IRBuilder<> builder(ctx); DenseSet<llvm::Function *> interfaceFunctions; for (auto &func : module->getFunctionList()) { - if (func.isDeclaration()) { + if (func.isDeclaration() || func.hasLocalLinkage()) continue; - } - if (interfaceFunctions.count(&func)) { + if (interfaceFunctions.count(&func)) continue; - } // Given a function `foo(<...>)`, define the interface function // `mlir_foo(i8**)`. diff --git a/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp index ddea230..ff0dd54 100644 --- a/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp @@ -156,7 +156,7 @@ mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ, size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/, size_t /*smem*/, void *vkRuntimeManager, void **params, void ** /*extra*/, size_t paramsCount) { - auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager); + auto *manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager); // GpuToLLVMConversionPass with the kernelBarePtrCallConv and // kernelIntersperseSizeCallConv options will set up the params array like: @@ -180,7 +180,7 @@ mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ, static_cast<uint32_t>(gridY), static_cast<uint32_t>(gridZ)}); - auto function = static_cast<VulkanFunction *>(vkKernel); + auto *function = static_cast<VulkanFunction *>(vkKernel); // Expected size should be in bytes. manager->setShaderModule( function->module->blobData(), diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 9b23dd6..fd846e4 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2032,7 +2032,7 @@ private: }; template <typename Range> -void printDimensionList(raw_ostream &stream, Range &&shape) { +static void printDimensionList(raw_ostream &stream, Range &&shape) { llvm::interleave( shape, stream, [&stream](const auto &dimSize) { diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp index 031eae2..4cce16b 100644 --- a/mlir/lib/IR/Remarks.cpp +++ b/mlir/lib/IR/Remarks.cpp @@ -31,6 +31,11 @@ Remark::Arg::Arg(llvm::StringRef k, Type t) : key(k) { os << t; } +Remark::Arg::Arg(llvm::StringRef k, Attribute a) : key(k), attr(a) { + llvm::raw_string_ostream os(val); + os << a; +} + void Remark::insert(llvm::StringRef s) { args.emplace_back(s); } void Remark::insert(Arg a) { args.push_back(std::move(a)); } diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index e438631..199744d2 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -118,8 +118,7 @@ LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) { /// have compatible dimensions. Dimensions are compatible if all non-dynamic /// dims are equal. The element type does not matter. LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { - auto shapedTypes = llvm::map_to_vector<8>( - types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); }); + auto shapedTypes = llvm::map_to_vector<8>(types, llvm::DynCastTo<ShapedType>); // Return failure if some, but not all are not shaped. Return early if none // are shaped also. if (llvm::none_of(shapedTypes, [](auto t) { return t; })) diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp index 9f4f672..c31e0ae7 100644 --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -58,6 +58,22 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op, return status; } +FailureOr<SmallVector<OpFoldResult>> +mlir::reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex) { + auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op); + if (!reifiableOp) + return failure(); + return reifiableOp.reifyShapeOfResult(b, resultIndex); +} + +FailureOr<OpFoldResult> mlir::reifyDimOfResult(OpBuilder &b, Operation *op, + int resultIndex, int dim) { + auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op); + if (!reifiableOp) + return failure(); + return reifiableOp.reifyDimOfResult(b, resultIndex, dim); +} + bool ShapeAdaptor::hasRank() const { if (val.isNull()) return false; diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index a5bfde1..cfe808b 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -129,7 +129,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map, assert(var.map.getNumDims() == 0 && "expected only symbols"); SmallVector<AffineExpr> symReplacements; for (auto valueDim : var.mapOperands) { - auto it = llvm::find(this->mapOperands, valueDim); + auto *it = llvm::find(this->mapOperands, valueDim); if (it != this->mapOperands.end()) { // There is already a symbol for this operand. symReplacements.push_back(b.getAffineSymbolExpr( diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 521c7c6..75f8826 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -559,9 +559,9 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op, return op->emitOpError() << "trying to schedule a pass on an operation not " "marked as 'IsolatedFromAbove'"; } - if (!pass->canScheduleOn(*op->getName().getRegisteredInfo())) { - return op->emitOpError() - << "trying to schedule a pass on an unsupported operation"; + if (!pass->canScheduleOn(op)) { + return op->emitOpError() << "trying to schedule pass '" << pass->getName() + << "' on an unsupported operation"; } // Initialize the pass state with a callback for the pass to dynamically diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp index e392a88..7bfe03d 100644 --- a/mlir/lib/Query/Matcher/Parser.cpp +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -27,7 +27,7 @@ struct Parser::TokenInfo { } // Known identifiers. - static const char *const ID_Extract; + static const char *const idExtract; llvm::StringRef text; TokenKind kind = TokenKind::Eof; @@ -35,7 +35,7 @@ struct Parser::TokenInfo { VariantValue value; }; -const char *const Parser::TokenInfo::ID_Extract = "extract"; +const char *const Parser::TokenInfo::idExtract = "extract"; class Parser::CodeTokenizer { public: @@ -452,13 +452,13 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken, } if (chainCallToken.kind != TokenKind::Ident || - chainCallToken.text != TokenInfo::ID_Extract) { + chainCallToken.text != TokenInfo::idExtract) { error->addError(chainCallToken.range, ErrorType::ParserMalformedChainedExpr); return false; } - if (chainCallToken.text == TokenInfo::ID_Extract && + if (chainCallToken.text == TokenInfo::idExtract && !parseChainedExpression(functionName)) return false; } diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp index 5b49204..1e00ed6 100644 --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -175,9 +175,12 @@ public: using Base::Base; // Collect the reduce patterns defined by each dialect. - void populateReductionPatterns(RewritePatternSet &pattern) const { - for (const DialectReductionPatternInterface &interface : *this) + void populateReductionPatterns(RewritePatternSet &pattern, + Tester &tester) const { + for (const DialectReductionPatternInterface &interface : *this) { interface.populateReductionPatterns(pattern); + interface.populateReductionPatternsWithTester(pattern, tester); + } } }; @@ -201,15 +204,21 @@ public: private: LogicalResult reduceOp(ModuleOp module, Region ®ion); + Tester tester; FrozenRewritePatternSet reducerPatterns; }; } // namespace LogicalResult ReductionTreePass::initialize(MLIRContext *context) { + tester.setTestScript(testerName); + tester.setTestScriptArgs(testerArgs); + RewritePatternSet patterns(context); + ReductionPatternInterfaceCollection reducePatternCollection(context); - reducePatternCollection.populateReductionPatterns(patterns); + reducePatternCollection.populateReductionPatterns(patterns, tester); + reducerPatterns = std::move(patterns); return success(); } @@ -244,11 +253,10 @@ void ReductionTreePass::runOnOperation() { } LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { - Tester test(testerName, testerArgs); switch (traversalModeId) { case TraversalMode::SinglePath: return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>( - module, region, reducerPatterns, test); + module, region, reducerPatterns, tester); default: return module.emitError() << "unsupported traversal mode detected"; } diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp index c857c38..4312100 100644 --- a/mlir/lib/RegisterAllExtensions.cpp +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -56,6 +56,7 @@ #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" @@ -113,6 +114,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) { transform::registerSMTExtension(registry); transform::registerTuneExtension(registry); vector::registerTransformDialectExtension(registry); + x86vector::registerTransformDialectExtension(registry); xegpu::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry); arm_sve::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 42843ea..159aa54 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -1099,12 +1099,12 @@ public: MutableArrayRef<PDLValue> getResults() { return results; } /// Return the type ranges allocated by this list. - MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { + MutableArrayRef<std::vector<Type>> getAllocatedTypeRanges() { return allocatedTypeRanges; } /// Return the value ranges allocated by this list. - MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { + MutableArrayRef<std::vector<Value>> getAllocatedValueRanges() { return allocatedValueRanges; } }; @@ -1112,19 +1112,20 @@ public: /// This class provides support for executing a bytecode stream. class ByteCodeExecutor { public: - ByteCodeExecutor( - const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, - MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, - MutableArrayRef<TypeRange> typeRangeMemory, - std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, - MutableArrayRef<ValueRange> valueRangeMemory, - std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, - MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, - ArrayRef<ByteCodeField> code, - ArrayRef<PatternBenefit> currentPatternBenefits, - ArrayRef<PDLByteCodePattern> patterns, - ArrayRef<PDLConstraintFunction> constraintFunctions, - ArrayRef<PDLRewriteFunction> rewriteFunctions) + ByteCodeExecutor(const ByteCodeField *curCodeIt, + MutableArrayRef<const void *> memory, + MutableArrayRef<std::vector<Operation *>> opRangeMemory, + MutableArrayRef<TypeRange> typeRangeMemory, + std::vector<std::vector<Type>> &allocatedTypeRangeMemory, + MutableArrayRef<ValueRange> valueRangeMemory, + std::vector<std::vector<Value>> &allocatedValueRangeMemory, + MutableArrayRef<unsigned> loopIndex, + ArrayRef<const void *> uniquedMemory, + ArrayRef<ByteCodeField> code, + ArrayRef<PatternBenefit> currentPatternBenefits, + ArrayRef<PDLByteCodePattern> patterns, + ArrayRef<PDLConstraintFunction> constraintFunctions, + ArrayRef<PDLRewriteFunction> rewriteFunctions) : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), typeRangeMemory(typeRangeMemory), allocatedTypeRangeMemory(allocatedTypeRangeMemory), @@ -1367,13 +1368,9 @@ private: if (range.empty()) { rangeMemory[rangeIndex] = {}; } else { - // Allocate a buffer for this type range. - llvm::OwningArrayRef<T> storage(llvm::size(range)); - llvm::copy(range, storage.begin()); - // Assign this to the range slot and use the range as the value for the // memory index. - allocatedRangeMemory.emplace_back(std::move(storage)); + allocatedRangeMemory.emplace_back(range.begin(), range.end()); rangeMemory[rangeIndex] = allocatedRangeMemory.back(); } memory[memIndex] = &rangeMemory[rangeIndex]; @@ -1397,11 +1394,11 @@ private: /// The current execution memory. MutableArrayRef<const void *> memory; - MutableArrayRef<OwningOpRange> opRangeMemory; + MutableArrayRef<std::vector<Operation *>> opRangeMemory; MutableArrayRef<TypeRange> typeRangeMemory; - std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; + std::vector<std::vector<Type>> &allocatedTypeRangeMemory; MutableArrayRef<ValueRange> valueRangeMemory; - std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; + std::vector<std::vector<Value>> &allocatedValueRangeMemory; /// The current loop indices. MutableArrayRef<unsigned> loopIndex; @@ -1907,10 +1904,10 @@ void ByteCodeExecutor::executeGetUsers() { LDBG() << "Executing GetUsers:"; unsigned memIndex = read(); unsigned rangeIndex = read(); - OwningOpRange &range = opRangeMemory[rangeIndex]; + std::vector<Operation *> &range = opRangeMemory[rangeIndex]; memory[memIndex] = ⦥ - range = OwningOpRange(); + range.clear(); if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { // Read the value. Value value = read<Value>(); @@ -1918,9 +1915,7 @@ void ByteCodeExecutor::executeGetUsers() { return; LDBG() << " * Value: " << value; - // Extract the users of a single value. - range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); - llvm::copy(value.getUsers(), range.begin()); + range.assign(value.user_begin(), value.user_end()); } else { // Read a range of values. ValueRange *values = read<ValueRange *>(); @@ -1929,12 +1924,8 @@ void ByteCodeExecutor::executeGetUsers() { LDBG() << " * Values (" << values->size() << "): " << llvm::interleaved(*values); - // Extract all the users of a range of values. - SmallVector<Operation *> users; for (Value value : *values) - users.append(value.user_begin(), value.user_end()); - range = OwningOpRange(users.size()); - llvm::copy(users, range.begin()); + range.insert(range.end(), value.user_begin(), value.user_end()); } LDBG() << " * Result: " << range.size() << " operations"; @@ -2174,7 +2165,8 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter, executeEraseOp(rewriter); break; case ExtractOp: - executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); + executeExtract<Operation *, std::vector<Operation *>, + PDLValue::Kind::Operation>(); break; case ExtractType: executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h index 4aceac7..566c1cb 100644 --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -30,7 +30,6 @@ class PDLByteCode; /// entries. ByteCodeAddr refers to size of indices into the bytecode. using ByteCodeField = uint16_t; using ByteCodeAddr = uint32_t; -using OwningOpRange = llvm::OwningArrayRef<Operation *>; //===----------------------------------------------------------------------===// // PDLByteCodePattern @@ -94,21 +93,21 @@ private: /// the bytecode to store ranges of operations. These are always stored by /// owning references, because at no point in the execution of the byte code /// we get an indexed range (view) of operations. - std::vector<OwningOpRange> opRangeMemory; + std::vector<std::vector<Operation *>> opRangeMemory; /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of types. std::vector<TypeRange> typeRangeMemory; /// A set of type ranges that have been allocated by the byte code interpreter /// to provide a guaranteed lifetime. - std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory; + std::vector<std::vector<Type>> allocatedTypeRangeMemory; /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of values. std::vector<ValueRange> valueRangeMemory; /// A set of value ranges that have been allocated by the byte code /// interpreter to provide a guaranteed lifetime. - std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory; + std::vector<std::vector<Value>> allocatedValueRangeMemory; /// The current index of ranges being iterated over for each level of nesting. /// These are always maintained at 0 for the loops that are not active, so we diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp index b0ad3ee..77a6cec 100644 --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -208,3 +208,11 @@ bool OpInterface::classof(const Interface *interface) { bool TypeInterface::classof(const Interface *interface) { return interface->getDef().isSubClassOf("TypeInterface"); } + +//===----------------------------------------------------------------------===// +// DialectInterface +//===----------------------------------------------------------------------===// + +bool DialectInterface::classof(const Interface *interface) { + return interface->getDef().isSubClassOf("DialectInterface"); +} diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 1a1a58a..ce09f5c 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Path.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -771,15 +772,27 @@ int Pattern::getBenefit() const { return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue(); } -std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { +std::vector<Pattern::IdentifierLine> +Pattern::getLocation(bool forSourceOutput) const { std::vector<std::pair<StringRef, unsigned>> result; result.reserve(def.getLoc().size()); for (auto loc : def.getLoc()) { unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); assert(buf && "invalid source location"); - result.emplace_back( - llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), - llvm::SrcMgr.getLineAndColumn(loc, buf).first); + + StringRef bufferName = + llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(); + // If we're emitting a generated file, we'd like to have some indication of + // where our patterns came from. However, LLVM's build rules use absolute + // paths as arguments to TableGen, and naively echoing such paths makes the + // contents of the generated source file depend on the build location, + // making MLIR builds substantially less reproducable. As a compromise, we + // trim absolute paths back to only the filename component. + if (forSourceOutput && llvm::sys::path::is_absolute(bufferName)) + bufferName = llvm::sys::path::filename(bufferName); + + result.emplace_back(bufferName, + llvm::SrcMgr.getLineAndColumn(loc, buf).first); } return result; } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 1243511..15c23c6 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -70,6 +70,7 @@ static inline LogicalResult interleaveCommaWithError(const Container &c, /// imply higher precedence. static FailureOr<int> getOperatorPrecedence(Operation *operation) { return llvm::TypeSwitch<Operation *, FailureOr<int>>(operation) + .Case<emitc::AddressOfOp>([&](auto op) { return 15; }) .Case<emitc::AddOp>([&](auto op) { return 12; }) .Case<emitc::ApplyOp>([&](auto op) { return 15; }) .Case<emitc::BitwiseAndOp>([&](auto op) { return 7; }) @@ -111,6 +112,8 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) { .Default([](auto op) { return op->emitError("unsupported operation"); }); } +static bool shouldBeInlined(Operation *op); + namespace { /// Emitter that uses dialect specific emitters to emit C++ code. struct CppEmitter { @@ -173,8 +176,11 @@ struct CppEmitter { /// Emits the operands of the operation. All operands are emitted in order. LogicalResult emitOperands(Operation &op); - /// Emits value as an operands of an operation - LogicalResult emitOperand(Value value); + /// Emits value as an operand of some operation. Unless \p isInBrackets is + /// true, operands emitted as sub-expressions will be parenthesized if needed + /// in order to enforce correct evaluation based on precedence and + /// associativity. + LogicalResult emitOperand(Value value, bool isInBrackets = false); /// Emit an expression as a C expression. LogicalResult emitExpression(ExpressionOp expressionOp); @@ -189,15 +195,6 @@ struct CppEmitter { /// emitc::ForOp. StringRef getOrCreateInductionVarName(Value val); - // Returns the textual representation of a subscript operation. - std::string getSubscriptName(emitc::SubscriptOp op); - - // Returns the textual representation of a member (of object) operation. - std::string createMemberAccess(emitc::MemberOp op); - - // Returns the textual representation of a member of pointer operation. - std::string createMemberAccess(emitc::MemberOfPtrOp op); - /// Return the existing or a new label of a Block. StringRef getOrCreateName(Block &block); @@ -259,25 +256,20 @@ struct CppEmitter { return !fileId.empty() && file.getId() == fileId; } - /// Get expression currently being emitted. - ExpressionOp getEmittedExpression() { return emittedExpression; } + /// Is expression currently being emitted. + bool isEmittingExpression() { return !emittedExpressionPrecedence.empty(); } /// Determine whether given value is part of the expression potentially being /// emitted. bool isPartOfCurrentExpression(Value value) { - if (!emittedExpression) - return false; Operation *def = value.getDefiningOp(); - if (!def) - return false; - return isPartOfCurrentExpression(def); + return def ? isPartOfCurrentExpression(def) : false; } /// Determine whether given operation is part of the expression potentially /// being emitted. bool isPartOfCurrentExpression(Operation *def) { - auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp()); - return operandExpression && operandExpression == emittedExpression; + return isEmittingExpression() && shouldBeInlined(def); }; // Resets the value counter to 0. @@ -324,7 +316,6 @@ private: unsigned int valueCount{0}; /// State of the current expression being emitted. - ExpressionOp emittedExpression; SmallVector<int> emittedExpressionPrecedence; void pushExpressionPrecedence(int precedence) { @@ -342,17 +333,28 @@ private: /// Determine whether expression \p op should be emitted in a deferred way. static bool hasDeferredEmission(Operation *op) { - return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp, + return isa_and_nonnull<emitc::DereferenceOp, emitc::GetGlobalOp, + emitc::LiteralOp, emitc::MemberOp, emitc::MemberOfPtrOp, emitc::SubscriptOp, emitc::GetFieldOp>(op); } -/// Determine whether expression \p expressionOp should be emitted inline, i.e. +/// Determine whether operation \p op should be emitted inline, i.e. /// as part of its user. This function recommends inlining of any expressions /// that can be inlined unless it is used by another expression, under the /// assumption that any expression fusion/re-materialization was taken care of /// by transformations run by the backend. -static bool shouldBeInlined(ExpressionOp expressionOp) { +static bool shouldBeInlined(Operation *op) { + // CExpression operations are inlined if and only if they reside within an + // ExpressionOp. + if (isa<CExpressionInterface>(op)) + return isa<ExpressionOp>(op->getParentOp()); + + // Only other inlinable operation is ExpressionOp itself. + ExpressionOp expressionOp = dyn_cast<ExpressionOp>(op); + if (!expressionOp) + return false; + // Do not inline if expression is marked as such. if (expressionOp.getDoNotInline()) return false; @@ -402,6 +404,66 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { return false; } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::DereferenceOp dereferenceOp) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << "*" << emitter.getOrCreateName(dereferenceOp.getPointer()); + emitter.cacheDeferredOpResult(dereferenceOp.getResult(), out); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GetFieldOp getFieldOp) { + emitter.cacheDeferredOpResult(getFieldOp.getResult(), + getFieldOp.getFieldName()); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GetGlobalOp getGlobalOp) { + emitter.cacheDeferredOpResult(getGlobalOp.getResult(), getGlobalOp.getName()); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::LiteralOp literalOp) { + emitter.cacheDeferredOpResult(literalOp.getResult(), literalOp.getValue()); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::MemberOp memberOp) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << emitter.getOrCreateName(memberOp.getOperand()); + ss << "." << memberOp.getMember(); + emitter.cacheDeferredOpResult(memberOp.getResult(), out); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::MemberOfPtrOp memberOfPtrOp) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << emitter.getOrCreateName(memberOfPtrOp.getOperand()); + ss << "->" << memberOfPtrOp.getMember(); + emitter.cacheDeferredOpResult(memberOfPtrOp.getResult(), out); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::SubscriptOp subscriptOp) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << emitter.getOrCreateName(subscriptOp.getValue()); + for (auto index : subscriptOp.getIndices()) { + ss << "[" << emitter.getOrCreateName(index) << "]"; + } + emitter.cacheDeferredOpResult(subscriptOp.getResult(), out); + return success(); +} + static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value) { OpResult result = operation->getResult(0); @@ -435,6 +497,17 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, } static LogicalResult printOperation(CppEmitter &emitter, + emitc::AddressOfOp addressOfOp) { + raw_ostream &os = emitter.ostream(); + Operation &op = *addressOfOp.getOperation(); + + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + os << "&"; + return emitter.emitOperand(addressOfOp.getReference()); +} + +static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp) { Operation *operation = constantOp.getOperation(); Attribute value = constantOp.getValue(); @@ -1336,32 +1409,6 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop, labelInScopeCount.push(0); } -std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getValue()); - for (auto index : op.getIndices()) { - ss << "[" << getOrCreateName(index) << "]"; - } - return out; -} - -std::string CppEmitter::createMemberAccess(emitc::MemberOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getOperand()); - ss << "." << op.getMember(); - return out; -} - -std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getOperand()); - ss << "->" << op.getMember(); - return out; -} - void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) { if (!valueMapper.count(value)) valueMapper.insert(value, str.str()); @@ -1545,7 +1592,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { "Expected precedence stack to be empty"); Operation *rootOp = expressionOp.getRootOp(); - emittedExpression = expressionOp; FailureOr<int> precedence = getOperatorPrecedence(rootOp); if (failed(precedence)) return failure(); @@ -1557,12 +1603,11 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { popExpressionPrecedence(); assert(emittedExpressionPrecedence.empty() && "Expected precedence stack to be empty"); - emittedExpression = nullptr; return success(); } -LogicalResult CppEmitter::emitOperand(Value value) { +LogicalResult CppEmitter::emitOperand(Value value, bool isInBrackets) { if (isPartOfCurrentExpression(value)) { Operation *def = value.getDefiningOp(); assert(def && "Expected operand to be defined by an operation"); @@ -1570,10 +1615,12 @@ LogicalResult CppEmitter::emitOperand(Value value) { if (failed(precedence)) return failure(); - // Sub-expressions with equal or lower precedence need to be parenthesized, - // as they might be evaluated in the wrong order depending on the shape of - // the expression tree. - bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence(); + // Unless already in brackets, sub-expressions with equal or lower + // precedence need to be parenthesized as they might be evaluated in the + // wrong order depending on the shape of the expression tree. + bool encloseInParenthesis = + !isInBrackets && precedence.value() <= getExpressionPrecedence(); + if (encloseInParenthesis) os << "("; pushExpressionPrecedence(precedence.value()); @@ -1596,14 +1643,8 @@ LogicalResult CppEmitter::emitOperand(Value value) { // If this operand is a block argument of an expression, emit instead the // matching expression parameter. Operation *argOp = arg.getParentBlock()->getParentOp(); - if (auto expressionOp = dyn_cast<ExpressionOp>(argOp)) { - // This scenario is only expected when one of the operations within the - // expression being emitted references one of the expression's block - // arguments. - assert(expressionOp == emittedExpression && - "Expected expression being emitted"); - value = expressionOp->getOperand(arg.getArgNumber()); - } + if (auto expressionOp = dyn_cast<ExpressionOp>(argOp)) + return emitOperand(expressionOp->getOperand(arg.getArgNumber())); } os << getOrCreateName(value); @@ -1612,15 +1653,9 @@ LogicalResult CppEmitter::emitOperand(Value value) { LogicalResult CppEmitter::emitOperands(Operation &op) { return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) { - // If an expression is being emitted, push lowest precedence as these - // operands are either wrapped by parenthesis. - if (getEmittedExpression()) - pushExpressionPrecedence(lowestPrecedence()); - if (failed(emitOperand(operand))) - return failure(); - if (getEmittedExpression()) - popExpressionPrecedence(); - return success(); + // Emit operand under guarantee that if it's part of an expression then it + // is being emitted within brackets. + return emitOperand(operand, /*isInBrackets=*/true); }); } @@ -1702,7 +1737,7 @@ LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) { LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { // If op is being emitted as part of an expression, bail out. - if (getEmittedExpression()) + if (isEmittingExpression()) return success(); switch (op.getNumResults()) { @@ -1753,49 +1788,27 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { .Case<cf::BranchOp, cf::CondBranchOp>( [&](auto op) { return printOperation(*this, op); }) // EmitC ops. - .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, - emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp, + .Case<emitc::AddressOfOp, emitc::AddOp, emitc::ApplyOp, + emitc::AssignOp, emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp, emitc::BitwiseNotOp, emitc::BitwiseOrOp, emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp, emitc::CallOpaqueOp, emitc::CastOp, emitc::ClassOp, emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp, - emitc::DeclareFuncOp, emitc::DivOp, emitc::DoOp, - emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp, - emitc::ForOp, emitc::FuncOp, emitc::GlobalOp, emitc::IfOp, - emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp, - emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp, - emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp, - emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp, - emitc::VerbatimOp>( + emitc::DeclareFuncOp, emitc::DereferenceOp, emitc::DivOp, + emitc::DoOp, emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp, + emitc::ForOp, emitc::FuncOp, emitc::GetFieldOp, + emitc::GetGlobalOp, emitc::GlobalOp, emitc::IfOp, + emitc::IncludeOp, emitc::LiteralOp, emitc::LoadOp, + emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, + emitc::MemberOfPtrOp, emitc::MemberOp, emitc::MulOp, + emitc::RemOp, emitc::ReturnOp, emitc::SubscriptOp, emitc::SubOp, + emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp, + emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case<func::CallOp, func::FuncOp, func::ReturnOp>( [&](auto op) { return printOperation(*this, op); }) - .Case<emitc::GetGlobalOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getName()); - return success(); - }) - .Case<emitc::GetFieldOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getFieldName()); - return success(); - }) - .Case<emitc::LiteralOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getValue()); - return success(); - }) - .Case<emitc::MemberOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); - return success(); - }) - .Case<emitc::MemberOfPtrOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); - return success(); - }) - .Case<emitc::SubscriptOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), getSubscriptName(op)); - return success(); - }) .Default([&](Operation *) { return op.emitOpError("unable to find printer for op"); }); @@ -1806,7 +1819,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (hasDeferredEmission(&op)) return success(); - if (getEmittedExpression() || + if (isEmittingExpression() || (isa<emitc::ExpressionOp>(op) && shouldBeInlined(cast<emitc::ExpressionOp>(op)))) return success(); diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 2dd0640..5be33c4 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -30,6 +30,14 @@ void registerFromLLVMIRTranslation() { llvm::cl::desc("Emit expensive warnings during LLVM IR import " "(discouraged: testing only!)"), llvm::cl::init(false)); + static llvm::cl::opt<bool> convertDebugRecToIntrinsics( + "convert-debug-rec-to-intrinsics", + llvm::cl::desc("Change the input LLVM module to use old debug intrinsics " + "instead of records " + "via convertFromNewDbgValues, this happens " + "before importing the debug information" + "(discouraged: to be removed soon!)"), + llvm::cl::init(false)); static llvm::cl::opt<bool> dropDICompositeTypeElements( "drop-di-composite-type-elements", llvm::cl::desc( @@ -69,8 +77,10 @@ void registerFromLLVMIRTranslation() { if (llvm::verifyModule(*llvmModule, &llvm::errs())) return nullptr; - // Debug records are not currently supported in the LLVM IR translator. - llvmModule->convertFromNewDbgValues(); + // Now that the translation supports importing debug records directly, + // make it the default, but allow the user to override to old behavior. + if (convertDebugRecToIntrinsics) + llvmModule->convertFromNewDbgValues(); return translateLLVMIRToModule( std::move(llvmModule), context, emitExpensiveWarnings, diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp index d3216d9..d9bfe65 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp @@ -124,10 +124,10 @@ static LogicalResult embedBinaryImpl(StringRef moduleName, } IRBuilder<> builder(module.getContext()); - auto i32Ty = builder.getInt32Ty(); - auto i64Ty = builder.getInt64Ty(); - auto ptrTy = builder.getPtrTy(0); - auto voidTy = builder.getVoidTy(); + auto *i32Ty = builder.getInt32Ty(); + auto *i64Ty = builder.getInt64Ty(); + auto *ptrTy = builder.getPtrTy(0); + auto *voidTy = builder.getVoidTy(); // Embed the module as a global object. auto *modulePtr = new GlobalVariable( @@ -147,13 +147,12 @@ static LogicalResult embedBinaryImpl(StringRef moduleName, "mgpuModuleLoadJIT", FunctionType::get(ptrTy, {ptrTy, i32Ty}, false)); Constant *optValue = ConstantInt::get(i32Ty, optLevel); return builder.CreateCall(moduleLoadFn, {serializedObj, optValue}); - } else { - FunctionCallee moduleLoadFn = module.getOrInsertFunction( - "mgpuModuleLoad", FunctionType::get(ptrTy, {ptrTy, i64Ty}, false)); - Constant *binarySize = - ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0)); - return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize}); } + FunctionCallee moduleLoadFn = module.getOrInsertFunction( + "mgpuModuleLoad", FunctionType::get(ptrTy, {ptrTy, i64Ty}, false)); + Constant *binarySize = + ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0)); + return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize}); }(); builder.CreateStore(moduleObj, modulePtr); builder.CreateRetVoid(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp index 44732d5..2d4a18c 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -80,8 +80,9 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder, /// Returns the list of LLVM IR metadata kinds that are convertible to MLIR LLVM /// dialect attributes. -static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) { - static const SmallVector<unsigned> convertibleMetadata = { +static SmallVector<unsigned> +getSupportedMetadataImpl(llvm::LLVMContext &llvmContext) { + SmallVector<unsigned> convertibleMetadata = { llvm::LLVMContext::MD_prof, llvm::LLVMContext::MD_tbaa, llvm::LLVMContext::MD_access_group, @@ -91,10 +92,10 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) { llvm::LLVMContext::MD_dereferenceable, llvm::LLVMContext::MD_dereferenceable_or_null, llvm::LLVMContext::MD_mmra, - context.getMDKindID(vecTypeHintMDName), - context.getMDKindID(workGroupSizeHintMDName), - context.getMDKindID(reqdWorkGroupSizeMDName), - context.getMDKindID(intelReqdSubGroupSizeMDName)}; + llvmContext.getMDKindID(vecTypeHintMDName), + llvmContext.getMDKindID(workGroupSizeHintMDName), + llvmContext.getMDKindID(reqdWorkGroupSizeMDName), + llvmContext.getMDKindID(intelReqdSubGroupSizeMDName)}; return convertibleMetadata; } @@ -113,7 +114,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, return failure(); // Handle function entry count metadata. - if (name->getString() == "function_entry_count") { + if (name->getString() == llvm::MDProfLabels::FunctionEntryCount) { // TODO support function entry count metadata with GUID fields. if (node->getNumOperands() != 2) @@ -131,15 +132,28 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, << "expected function_entry_count to be attached to a function"; } - if (name->getString() != "branch_weights") + if (name->getString() != llvm::MDProfLabels::BranchWeights) return failure(); + // The branch_weights metadata must have at least 2 operands. + if (node->getNumOperands() < 2) + return failure(); + + ArrayRef<llvm::MDOperand> branchWeightOperands = + node->operands().drop_front(); + if (auto *mdString = dyn_cast<llvm::MDString>(node->getOperand(1))) { + if (mdString->getString() != llvm::MDProfLabels::ExpectedBranchWeights) + return failure(); + // The MLIR WeightedBranchOpInterface does not support the + // ExpectedBranchWeights field, so it is dropped. + branchWeightOperands = branchWeightOperands.drop_front(); + } // Handle branch weights metadata. SmallVector<int32_t> branchWeights; - branchWeights.reserve(node->getNumOperands() - 1); - for (unsigned i = 1, e = node->getNumOperands(); i != e; ++i) { + branchWeights.reserve(branchWeightOperands.size()); + for (const llvm::MDOperand &operand : branchWeightOperands) { llvm::ConstantInt *branchWeight = - llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(i)); + llvm::mdconst::dyn_extract<llvm::ConstantInt>(operand); if (!branchWeight) return failure(); branchWeights.push_back(branchWeight->getZExtValue()); @@ -492,9 +506,9 @@ public: /// Returns the list of LLVM IR metadata kinds that are convertible to MLIR /// LLVM dialect attributes. - ArrayRef<unsigned> - getSupportedMetadata(llvm::LLVMContext &context) const final { - return getSupportedMetadataImpl(context); + SmallVector<unsigned> + getSupportedMetadata(llvm::LLVMContext &llvmContext) const final { + return getSupportedMetadataImpl(llvmContext); } }; } // namespace diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index eaf1d20..b6ea4ba 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -222,14 +222,14 @@ static void convertLinkerOptionsOp(ArrayAttr options, llvm::LLVMContext &context = llvmModule->getContext(); llvm::NamedMDNode *linkerMDNode = llvmModule->getOrInsertNamedMetadata("llvm.linker.options"); - SmallVector<llvm::Metadata *> MDNodes; - MDNodes.reserve(options.size()); + SmallVector<llvm::Metadata *> mdNodes; + mdNodes.reserve(options.size()); for (auto s : options.getAsRange<StringAttr>()) { - auto *MDNode = llvm::MDString::get(context, s.getValue()); - MDNodes.push_back(MDNode); + auto *mdNode = llvm::MDString::get(context, s.getValue()); + mdNodes.push_back(mdNode); } - auto *listMDNode = llvm::MDTuple::get(context, MDNodes); + auto *listMDNode = llvm::MDTuple::get(context, mdNodes); linkerMDNode->addOperand(listMDNode); } @@ -243,16 +243,16 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr, if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) { for (auto entry : arrayAttr.getAsRange<ModuleFlagCGProfileEntryAttr>()) { - llvm::Metadata *fromMetadata = - entry.getFrom() - ? llvm::ValueAsMetadata::get(moduleTranslation.lookupFunction( - entry.getFrom().getValue())) - : nullptr; - llvm::Metadata *toMetadata = - entry.getTo() - ? llvm::ValueAsMetadata::get( - moduleTranslation.lookupFunction(entry.getTo().getValue())) - : nullptr; + auto getFuncMetadata = [&](FlatSymbolRefAttr sym) -> llvm::Metadata * { + if (!sym) + return nullptr; + if (llvm::Function *fn = + moduleTranslation.lookupFunction(sym.getValue())) + return llvm::ValueAsMetadata::get(fn); + return nullptr; + }; + llvm::Metadata *fromMetadata = getFuncMetadata(entry.getFrom()); + llvm::Metadata *toMetadata = getFuncMetadata(entry.getTo()); llvm::Metadata *vals[] = { fromMetadata, toMetadata, @@ -439,7 +439,14 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, llvm::MemoryEffects::Location::InaccessibleMem, convertModRefInfoToLLVM(memAttr.getInaccessibleMem())) | llvm::MemoryEffects(llvm::MemoryEffects::Location::Other, - convertModRefInfoToLLVM(memAttr.getOther())); + convertModRefInfoToLLVM(memAttr.getOther())) | + llvm::MemoryEffects(llvm::MemoryEffects::Location::ErrnoMem, + convertModRefInfoToLLVM(memAttr.getErrnoMem())) | + llvm::MemoryEffects( + llvm::MemoryEffects::Location::TargetMem0, + convertModRefInfoToLLVM(memAttr.getTargetMem0())) | + llvm::MemoryEffects(llvm::MemoryEffects::Location::TargetMem1, + convertModRefInfoToLLVM(memAttr.getTargetMem1())); call->setMemoryEffects(memEffects); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index cecff51..b7427a5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -411,6 +411,41 @@ getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) { llvm_unreachable("unhandled tcgen05.st lowering"); } +static llvm::Intrinsic::ID getFenceSyncRestrictID(NVVM::MemOrderKind order) { + return order == NVVM::MemOrderKind::ACQUIRE + ? llvm::Intrinsic:: + nvvm_fence_acquire_sync_restrict_space_cluster_scope_cluster + : llvm::Intrinsic:: + nvvm_fence_release_sync_restrict_space_cta_scope_cluster; +} + +static llvm::Intrinsic::ID +getFenceProxyID(NVVM::ProxyKind kind, std::optional<NVVM::SharedSpace> space) { + switch (kind) { + case NVVM::ProxyKind::alias: + return llvm::Intrinsic::nvvm_fence_proxy_alias; + case NVVM::ProxyKind::async: + return llvm::Intrinsic::nvvm_fence_proxy_async; + case NVVM::ProxyKind::async_global: + return llvm::Intrinsic::nvvm_fence_proxy_async_global; + case NVVM::ProxyKind::async_shared: + return *space == NVVM::SharedSpace::shared_cta + ? llvm::Intrinsic::nvvm_fence_proxy_async_shared_cta + : llvm::Intrinsic::nvvm_fence_proxy_async_shared_cluster; + default: + llvm_unreachable("unsupported proxy kind"); + } +} + +static llvm::Intrinsic::ID +getFenceProxySyncRestrictID(NVVM::MemOrderKind order) { + return order == NVVM::MemOrderKind::ACQUIRE + ? llvm::Intrinsic:: + nvvm_fence_proxy_async_generic_acquire_sync_restrict_space_cluster_scope_cluster + : llvm::Intrinsic:: + nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster; +} + namespace { /// Implementation of the dialect interface that converts operations belonging /// to the NVVM dialect to LLVM IR. diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8edec99..03d67a5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -61,6 +61,8 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) { return llvm::omp::OMP_SCHEDULE_Auto; case omp::ClauseScheduleKind::Runtime: return llvm::omp::OMP_SCHEDULE_Runtime; + case omp::ClauseScheduleKind::Distribute: + return llvm::omp::OMP_SCHEDULE_Distribute; } llvm_unreachable("unhandled schedule clause argument"); } @@ -135,28 +137,31 @@ class LinearClauseProcessor { private: SmallVector<llvm::Value *> linearPreconditionVars; SmallVector<llvm::Value *> linearLoopBodyTemps; - SmallVector<llvm::AllocaInst *> linearOrigVars; SmallVector<llvm::Value *> linearOrigVal; SmallVector<llvm::Value *> linearSteps; + SmallVector<llvm::Type *> linearVarTypes; llvm::BasicBlock *linearFinalizationBB; llvm::BasicBlock *linearExitBB; llvm::BasicBlock *linearLastIterExitBB; public: + // Register type for the linear variables + void registerType(LLVM::ModuleTranslation &moduleTranslation, + mlir::Attribute &ty) { + linearVarTypes.push_back(moduleTranslation.convertType( + mlir::cast<mlir::TypeAttr>(ty).getValue())); + } + // Allocate space for linear variabes void createLinearVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - mlir::Value &linearVar) { - if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>( - moduleTranslation.lookupValue(linearVar))) { - linearPreconditionVars.push_back(builder.CreateAlloca( - linearVarAlloca->getAllocatedType(), nullptr, ".linear_var")); - llvm::Value *linearLoopBodyTemp = builder.CreateAlloca( - linearVarAlloca->getAllocatedType(), nullptr, ".linear_result"); - linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar)); - linearLoopBodyTemps.push_back(linearLoopBodyTemp); - linearOrigVars.push_back(linearVarAlloca); - } + mlir::Value &linearVar, int idx) { + linearPreconditionVars.push_back( + builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var")); + llvm::Value *linearLoopBodyTemp = + builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result"); + linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar)); + linearLoopBodyTemps.push_back(linearLoopBodyTemp); } // Initialize linear step @@ -166,20 +171,15 @@ public: } // Emit IR for initialization of linear variables - llvm::OpenMPIRBuilder::InsertPointOrErrorTy - initLinearVar(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - llvm::BasicBlock *loopPreHeader) { + void initLinearVar(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::BasicBlock *loopPreHeader) { builder.SetInsertPoint(loopPreHeader->getTerminator()); - for (size_t index = 0; index < linearOrigVars.size(); index++) { - llvm::LoadInst *linearVarLoad = builder.CreateLoad( - linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]); + for (size_t index = 0; index < linearOrigVal.size(); index++) { + llvm::LoadInst *linearVarLoad = + builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]); builder.CreateStore(linearVarLoad, linearPreconditionVars[index]); } - llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = - moduleTranslation.getOpenMPBuilder()->createBarrier( - builder.saveIP(), llvm::omp::OMPD_barrier); - return afterBarrierIP; } // Emit IR for updating Linear variables @@ -188,20 +188,24 @@ public: builder.SetInsertPoint(loopBody->getTerminator()); for (size_t index = 0; index < linearPreconditionVars.size(); index++) { // Emit increments for linear vars - llvm::LoadInst *linearVarStart = - builder.CreateLoad(linearOrigVars[index]->getAllocatedType(), - - linearPreconditionVars[index]); + llvm::LoadInst *linearVarStart = builder.CreateLoad( + linearVarTypes[index], linearPreconditionVars[index]); auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]); - auto addInst = builder.CreateAdd(linearVarStart, mulInst); - builder.CreateStore(addInst, linearLoopBodyTemps[index]); + if (linearVarTypes[index]->isIntegerTy()) { + auto addInst = builder.CreateAdd(linearVarStart, mulInst); + builder.CreateStore(addInst, linearLoopBodyTemps[index]); + } else if (linearVarTypes[index]->isFloatingPointTy()) { + auto cvt = builder.CreateSIToFP(mulInst, linearVarTypes[index]); + auto addInst = builder.CreateFAdd(linearVarStart, cvt); + builder.CreateStore(addInst, linearLoopBodyTemps[index]); + } } } // Linear variable finalization is conditional on the last logical iteration. // Create BB splits to manage the same. - void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder, - llvm::BasicBlock *loopExit) { + void splitLinearFiniBB(llvm::IRBuilderBase &builder, + llvm::BasicBlock *loopExit) { linearFinalizationBB = loopExit->splitBasicBlock( loopExit->getTerminator(), "omp_loop.linear_finalization"); linearExitBB = linearFinalizationBB->splitBasicBlock( @@ -225,11 +229,10 @@ public: llvm::Type::getInt32Ty(builder.getContext()), 0)); // Store the linear variable values to original variables. builder.SetInsertPoint(linearLastIterExitBB->getTerminator()); - for (size_t index = 0; index < linearOrigVars.size(); index++) { + for (size_t index = 0; index < linearOrigVal.size(); index++) { llvm::LoadInst *linearVarTemp = - builder.CreateLoad(linearOrigVars[index]->getAllocatedType(), - linearLoopBodyTemps[index]); - builder.CreateStore(linearVarTemp, linearOrigVars[index]); + builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]); + builder.CreateStore(linearVarTemp, linearOrigVal[index]); } // Create conditional branch such that the linear variable @@ -253,7 +256,8 @@ public: users.push_back(user); for (auto *user : users) { if (auto *userInst = dyn_cast<llvm::Instruction>(user)) { - if (userInst->getParent()->getName().str() == BBName) + if (userInst->getParent()->getName().str().find(BBName) != + std::string::npos) user->replaceUsesOfWith(linearOrigVal[varIndex], linearLoopBodyTemps[varIndex]); } @@ -319,10 +323,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (op.getDevice()) result = todo("device"); }; - auto checkDistSchedule = [&todo](auto op, LogicalResult &result) { - if (op.getDistScheduleChunkSize()) - result = todo("dist_schedule with chunk_size"); - }; auto checkHint = [](auto op, LogicalResult &) { if (op.getHint()) op.emitWarning("hint clause discarded"); @@ -332,14 +332,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { op.getInReductionSyms()) result = todo("in_reduction"); }; - auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) { - if (!op.getIsDevicePtrVars().empty()) - result = todo("is_device_ptr"); - }; - auto checkLinear = [&todo](auto op, LogicalResult &result) { - if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty()) - result = todo("linear"); - }; auto checkNowait = [&todo](auto op, LogicalResult &result) { if (op.getNowait()) result = todo("nowait"); @@ -387,7 +379,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { }) .Case([&](omp::DistributeOp op) { checkAllocate(op, result); - checkDistSchedule(op, result); checkOrder(op, result); }) .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); }) @@ -423,7 +414,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { }) .Case([&](omp::WsloopOp op) { checkAllocate(op, result); - checkLinear(op, result); checkOrder(op, result); checkReduction(op, result); }) @@ -431,10 +421,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkAllocate(op, result); checkReduction(op, result); }) - .Case([&](omp::SimdOp op) { - checkLinear(op, result); - checkReduction(op, result); - }) + .Case([&](omp::SimdOp op) { checkReduction(op, result); }) .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp, omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); }) .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>( @@ -444,7 +431,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkBare(op, result); checkDevice(op, result); checkInReduction(op, result); - checkIsDevicePtr(op, result); }) .Default([](Operation &) { // Assume all clauses for an operation can be translated unless they are @@ -953,6 +939,9 @@ using OwningAtomicReductionGen = std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy( llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *, llvm::Value *)>; +using OwningDataPtrPtrReductionGen = + std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy( + llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>; } // namespace /// Create an OpenMPIRBuilder-compatible reduction generator for the given @@ -1017,6 +1006,35 @@ makeAtomicReductionGen(omp::DeclareReductionOp decl, return atomicGen; } +/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for +/// the given reduction declaration. The generator uses `builder` but ignores +/// its insertion point. Returns null if there is no `data_ptr_ptr` region +/// available in the reduction declaration. +static OwningDataPtrPtrReductionGen +makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, bool isByRef) { + if (!isByRef) + return OwningDataPtrPtrReductionGen(); + + OwningDataPtrPtrReductionGen refDataPtrGen = + [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, + llvm::Value *byRefVal, llvm::Value *&result) mutable + -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy { + moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal); + builder.restoreIP(insertPoint); + SmallVector<llvm::Value *> phis; + if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(), + "omp.data_ptr_ptr.body", builder, + moduleTranslation, &phis))) + return llvm::createStringError( + "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`"); + result = llvm::getSingleElement(phis); + return builder.saveIP(); + }; + + return refDataPtrGen; +} + /// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder. static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, @@ -1170,6 +1188,7 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs, template <typename T> static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, + llvm::IRBuilderBase &builder, SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, DenseMap<Value, llvm::Value *> &reductionVariableMap, unsigned i) { @@ -1180,8 +1199,17 @@ mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, mlir::Value mlirSource = loop.getReductionVars()[i]; llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource); - assert(llvmSource && "lookup reduction var"); - moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource); + llvm::Value *origVal = llvmSource; + // If a non-pointer value is expected, load the value from the source pointer. + if (!isa<LLVM::LLVMPointerType>( + reduction.getInitializerMoldArg().getType()) && + isa<LLVM::LLVMPointerType>(mlirSource.getType())) { + origVal = + builder.CreateLoad(moduleTranslation.convertType( + reduction.getInitializerMoldArg().getType()), + llvmSource, "omp_orig"); + } + moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal); if (entry.getNumArguments() > 1) { llvm::Value *allocation = @@ -1254,7 +1282,7 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs, SmallVector<llvm::Value *, 1> phis; // map block argument to initializer region - mapInitializationArgs(op, moduleTranslation, reductionDecls, + mapInitializationArgs(op, moduleTranslation, builder, reductionDecls, reductionVariableMap, i); // TODO In some cases (specially on the GPU), the init regions may @@ -1310,8 +1338,10 @@ static void collectReductionInfo( SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, SmallVectorImpl<OwningReductionGen> &owningReductionGens, SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens, + SmallVector<OwningDataPtrPtrReductionGen> &owningDataPtrPtrReductionGens, const ArrayRef<llvm::Value *> privateReductionVariables, - SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) { + SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos, + ArrayRef<bool> isByRef) { unsigned numReductions = loop.getNumReductionVars(); for (unsigned i = 0; i < numReductions; ++i) { @@ -1319,6 +1349,8 @@ static void collectReductionInfo( makeReductionGen(reductionDecls[i], builder, moduleTranslation)); owningAtomicReductionGens.push_back( makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation)); + owningDataPtrPtrReductionGens.push_back(makeRefDataPtrGen( + reductionDecls[i], builder, moduleTranslation, isByRef[i])); } // Collect the reduction information. @@ -1329,12 +1361,28 @@ static void collectReductionInfo( atomicGen = owningAtomicReductionGens[i]; llvm::Value *variable = moduleTranslation.lookupValue(loop.getReductionVars()[i]); + mlir::Type allocatedType; + reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) { + if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) { + allocatedType = alloca.getElemType(); + return mlir::WalkResult::interrupt(); + } + + return mlir::WalkResult::advance(); + }); + reductionInfos.push_back( {moduleTranslation.convertType(reductionDecls[i].getType()), variable, privateReductionVariables[i], /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar, owningReductionGens[i], - /*ReductionGenClang=*/nullptr, atomicGen}); + /*ReductionGenClang=*/nullptr, atomicGen, + owningDataPtrPtrReductionGens[i], + allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr, + reductionDecls[i].getByrefElementType() + ? moduleTranslation.convertType( + *reductionDecls[i].getByrefElementType()) + : nullptr}); } } @@ -1392,7 +1440,8 @@ static LogicalResult createReductionsAndCleanup( SmallVector<OwningReductionGen> owningReductionGens; SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens; - SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos; + SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens; + SmallVector<llvm::OpenMPIRBuilder::ReductionInfo, 2> reductionInfos; llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); @@ -1400,7 +1449,8 @@ static LogicalResult createReductionsAndCleanup( // ReductionInfo only accepts references to the generators. collectReductionInfo(op, builder, moduleTranslation, reductionDecls, owningReductionGens, owningAtomicReductionGens, - privateReductionVariables, reductionInfos); + owningReductionGenRefDataPtrGens, + privateReductionVariables, reductionInfos, isByRef); // The call to createReductions below expects the block to have a // terminator. Create an unreachable instruction to serve as terminator @@ -1907,7 +1957,7 @@ static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) { // If we are going to use distribute reduction then remove any debug uses of // the reduction parameters in teamsOp. Otherwise they will be left without // any mapped value in moduleTranslation and will eventually error out. - for (auto use : debugUses) + for (auto *use : debugUses) use->erase(); return true; } @@ -2484,6 +2534,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, chunk = builder.CreateSExtOrTrunc(chunkVar, ivType); } + omp::DistributeOp distributeOp = nullptr; + llvm::Value *distScheduleChunk = nullptr; + bool hasDistSchedule = false; + if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) { + distributeOp = cast<omp::DistributeOp>(opInst.getParentOp()); + hasDistSchedule = distributeOp.getDistScheduleStatic(); + if (distributeOp.getDistScheduleChunkSize()) { + llvm::Value *chunkVar = moduleTranslation.lookupValue( + distributeOp.getDistScheduleChunkSize()); + distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType); + } + } + PrivateVarsInfo privateVarsInfo(wsloopOp); SmallVector<omp::DeclareReductionOp> reductionDecls; @@ -2553,10 +2616,15 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, // Initialize linear variables and linear step LinearClauseProcessor linearClauseProcessor; + if (!wsloopOp.getLinearVars().empty()) { - for (mlir::Value linearVar : wsloopOp.getLinearVars()) + auto linearVarTypes = wsloopOp.getLinearVarTypes().value(); + for (mlir::Attribute linearVarType : linearVarTypes) + linearClauseProcessor.registerType(moduleTranslation, linearVarType); + + for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars())) linearClauseProcessor.createLinearVar(builder, moduleTranslation, - linearVar); + linearVar, idx); for (mlir::Value linearStep : wsloopOp.getLinearStepVars()) linearClauseProcessor.initLinearStep(moduleTranslation, linearStep); } @@ -2571,16 +2639,17 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, // Emit Initialization and Update IR for linear variables if (!wsloopOp.getLinearVars().empty()) { + linearClauseProcessor.initLinearVar(builder, moduleTranslation, + loopInfo->getPreheader()); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = - linearClauseProcessor.initLinearVar(builder, moduleTranslation, - loopInfo->getPreheader()); + moduleTranslation.getOpenMPBuilder()->createBarrier( + builder.saveIP(), llvm::omp::OMPD_barrier); if (failed(handleError(afterBarrierIP, *loopOp))) return failure(); builder.restoreIP(*afterBarrierIP); linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(), loopInfo->getIndVar()); - linearClauseProcessor.outlineLinearFinalizationBB(builder, - loopInfo->getExit()); + linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit()); } builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); @@ -2611,7 +2680,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, convertToScheduleKind(schedule), chunk, isSimd, scheduleMod == omp::ScheduleModifier::monotonic, scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, - workshareLoopType, noLoopMode); + workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk); if (failed(handleError(wsloopIP, opInst))) return failure(); @@ -2655,6 +2724,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref()); assert(isByRef.size() == opInst.getNumReductionVars()); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + bool isCancellable = constructIsCancellable(opInst); if (failed(checkImplementationStatus(*opInst))) return failure(); @@ -2729,10 +2799,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, // Collect reduction info SmallVector<OwningReductionGen> owningReductionGens; SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens; - SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos; + SmallVector<OwningDataPtrPtrReductionGen> + owningReductionGenRefDataPtrGens; + SmallVector<llvm::OpenMPIRBuilder::ReductionInfo, 2> reductionInfos; collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls, owningReductionGens, owningAtomicReductionGens, - privateReductionVariables, reductionInfos); + owningReductionGenRefDataPtrGens, + privateReductionVariables, reductionInfos, isByRef); // Move to region cont block builder.SetInsertPoint((*regionBlock)->getTerminator()); @@ -2790,6 +2863,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, privateVarsInfo.privatizers))) return llvm::make_error<PreviouslyReportedError>(); + // If we could be performing cancellation, add the cancellation barrier on + // the way out of the outlined region. + if (isCancellable) { + auto IPOrErr = ompBuilder->createBarrier( + llvm::OpenMPIRBuilder::LocationDescription(builder), + llvm::omp::Directive::OMPD_unknown, + /* ForceSimpleCall */ false, + /* CheckCancelFlag */ false); + if (!IPOrErr) + return IPOrErr.takeError(); + } + builder.restoreIP(oldIP); return llvm::Error::success(); }; @@ -2803,7 +2888,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, auto pbKind = llvm::omp::OMP_PROC_BIND_default; if (auto bind = opInst.getProcBindKind()) pbKind = getProcBindKind(*bind); - bool isCancellable = constructIsCancellable(opInst); llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); @@ -2858,6 +2942,20 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); + // Initialize linear variables and linear step + LinearClauseProcessor linearClauseProcessor; + + if (!simdOp.getLinearVars().empty()) { + auto linearVarTypes = simdOp.getLinearVarTypes().value(); + for (mlir::Attribute linearVarType : linearVarTypes) + linearClauseProcessor.registerType(moduleTranslation, linearVarType); + for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) + linearClauseProcessor.createLinearVar(builder, moduleTranslation, + linearVar, idx); + for (mlir::Value linearStep : simdOp.getLinearStepVars()) + linearClauseProcessor.initLinearStep(moduleTranslation, linearStep); + } + llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars( builder, moduleTranslation, privateVarsInfo, allocaIP); if (handleError(afterAllocas, opInst).failed()) @@ -2927,14 +3025,27 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, if (failed(handleError(regionBlock, opInst))) return failure(); - builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation); + // Emit Initialization for linear variables + if (simdOp.getLinearVars().size()) { + linearClauseProcessor.initLinearVar(builder, moduleTranslation, + loopInfo->getPreheader()); + + linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(), + loopInfo->getIndVar()); + } + builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); + ompBuilder->applySimd(loopInfo, alignedVars, simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr()) : nullptr, order, simdlen, safelen); + for (size_t index = 0; index < simdOp.getLinearVars().size(); index++) + linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region", + index); + // We now need to reduce the per-simd-lane reduction variable into the // original variable. This works a bit differently to other reductions (e.g. // wsloop) because we don't need to call into the OpenMP runtime to handle @@ -3632,10 +3743,23 @@ convertToCaptureClauseKind( return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink; case mlir::omp::DeclareTargetCaptureClause::enter: return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter; + case mlir::omp::DeclareTargetCaptureClause::none: + return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone; } llvm_unreachable("unhandled capture clause"); } +static Operation *getGlobalOpFromValue(Value value) { + Operation *op = value.getDefiningOp(); + if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op)) + op = addrCast->getOperand(0).getDefiningOp(); + if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) { + auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>(); + return modOp.lookupSymbol(addressOfOp.getGlobalName()); + } + return nullptr; +} + static llvm::SmallString<64> getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder) { @@ -3658,62 +3782,58 @@ getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, return suffix; } -static bool isDeclareTargetLink(mlir::Value value) { - if (auto addressOfOp = value.getDefiningOp<LLVM::AddressOfOp>()) { - auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>(); - Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName()); - if (auto declareTargetGlobal = - llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp)) - if (declareTargetGlobal.getDeclareTargetCaptureClause() == - mlir::omp::DeclareTargetCaptureClause::link) - return true; - } +static bool isDeclareTargetLink(Value value) { + if (auto declareTargetGlobal = + dyn_cast_if_present<omp::DeclareTargetInterface>( + getGlobalOpFromValue(value))) + if (declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::link) + return true; + return false; +} + +static bool isDeclareTargetTo(Value value) { + if (auto declareTargetGlobal = + dyn_cast_if_present<omp::DeclareTargetInterface>( + getGlobalOpFromValue(value))) + if (declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::to || + declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::enter) + return true; return false; } -// Returns the reference pointer generated by the lowering of the declare target -// operation in cases where the link clause is used or the to clause is used in -// USM mode. +// Returns the reference pointer generated by the lowering of the declare +// target operation in cases where the link clause is used or the to clause is +// used in USM mode. static llvm::Value * -getRefPtrIfDeclareTarget(mlir::Value value, +getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - Operation *op = value.getDefiningOp(); - if (auto addrCast = llvm::dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op)) - op = addrCast->getOperand(0).getDefiningOp(); - - // An easier way to do this may just be to keep track of any pointer - // references and their mapping to their respective operation - if (auto addressOfOp = llvm::dyn_cast_if_present<LLVM::AddressOfOp>(op)) { - if (auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>( - addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol( - addressOfOp.getGlobalName()))) { - - if (auto declareTargetGlobal = - llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( - gOp.getOperation())) { - - // In this case, we must utilise the reference pointer generated by the - // declare target operation, similar to Clang - if ((declareTargetGlobal.getDeclareTargetCaptureClause() == - mlir::omp::DeclareTargetCaptureClause::link) || - (declareTargetGlobal.getDeclareTargetCaptureClause() == - mlir::omp::DeclareTargetCaptureClause::to && - ompBuilder->Config.hasRequiresUnifiedSharedMemory())) { - llvm::SmallString<64> suffix = - getDeclareTargetRefPtrSuffix(gOp, *ompBuilder); - - if (gOp.getSymName().contains(suffix)) - return moduleTranslation.getLLVMModule()->getNamedValue( - gOp.getSymName()); + if (auto gOp = + dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) { + if (auto declareTargetGlobal = + dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) { + // In this case, we must utilise the reference pointer generated by + // the declare target operation, similar to Clang + if ((declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::link) || + (declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::to && + ompBuilder->Config.hasRequiresUnifiedSharedMemory())) { + llvm::SmallString<64> suffix = + getDeclareTargetRefPtrSuffix(gOp, *ompBuilder); + if (gOp.getSymName().contains(suffix)) return moduleTranslation.getLLVMModule()->getNamedValue( - (gOp.getSymName().str() + suffix.str()).str()); - } + gOp.getSymName()); + + return moduleTranslation.getLLVMModule()->getNamedValue( + (gOp.getSymName().str() + suffix.str()).str()); } } } - return nullptr; } @@ -3756,6 +3876,32 @@ struct MapInfoData : MapInfosTy { MapInfosTy::append(CurInfo); } }; + +enum class TargetDirectiveEnumTy : uint32_t { + None = 0, + Target = 1, + TargetData = 2, + TargetEnterData = 3, + TargetExitData = 4, + TargetUpdate = 5 +}; + +static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) { + return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op) + .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; }) + .Case([](omp::TargetEnterDataOp) { + return TargetDirectiveEnumTy::TargetEnterData; + }) + .Case([&](omp::TargetExitDataOp) { + return TargetDirectiveEnumTy::TargetExitData; + }) + .Case([&](omp::TargetUpdateOp) { + return TargetDirectiveEnumTy::TargetUpdate; + }) + .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; }) + .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; }); +} + } // namespace static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, @@ -3787,7 +3933,7 @@ static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type, // This calculates the size to transfer based on bounds and the underlying // element type, provided bounds have been specified (Fortran // pointers/allocatables/target and arrays that have sections specified fall - // into this as well). + // into this as well) if (!memberClause.getBounds().empty()) { llvm::Value *elementCount = builder.getInt64(1); for (auto bounds : memberClause.getBounds()) { @@ -3835,6 +3981,9 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) { auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) { return (mlirFlags & flag) == flag; }; + const bool hasExplicitMap = + (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) != + omp::ClauseMapFlags::none; llvm::omp::OpenMPOffloadMappingFlags mapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; @@ -3875,6 +4024,12 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) { if (mapTypeToBool(omp::ClauseMapFlags::attach)) mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH; + if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) { + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; + if (!hasExplicitMap) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL; + } + return mapType; } @@ -3910,10 +4065,12 @@ static void collectMapDataFromMapOperands( mapData.Pointers.push_back(mapData.OriginalValue.back()); if (llvm::Value *refPtr = - getRefPtrIfDeclareTarget(offloadPtr, - moduleTranslation)) { // declare target + getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) { mapData.IsDeclareTarget.push_back(true); mapData.BasePointers.push_back(refPtr); + } else if (isDeclareTargetTo(offloadPtr)) { + mapData.IsDeclareTarget.push_back(true); + mapData.BasePointers.push_back(mapData.OriginalValue.back()); } else { // regular mapped variable mapData.IsDeclareTarget.push_back(false); mapData.BasePointers.push_back(mapData.OriginalValue.back()); @@ -3996,6 +4153,9 @@ static void collectMapDataFromMapOperands( llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr); auto mapType = convertClauseMapFlags(mapOp.getMapType()); auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + bool isDevicePtr = + (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) != + omp::ClauseMapFlags::none; mapData.OriginalValue.push_back(origValue); mapData.BasePointers.push_back(origValue); @@ -4022,14 +4182,18 @@ static void collectMapDataFromMapOperands( mapData.Mappers.push_back(nullptr); } } else { + // For is_device_ptr we need the map type to propagate so the runtime + // can materialize the device-side copy of the pointer container. mapData.Types.push_back( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL); + isDevicePtr ? mapType + : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL); mapData.Mappers.push_back(nullptr); } mapData.Names.push_back(LLVM::createMappingInformation( mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder())); mapData.DevicePointers.push_back( - llvm::OpenMPIRBuilder::DeviceInfoTy::Address); + isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer + : llvm::OpenMPIRBuilder::DeviceInfoTy::Address); mapData.IsAMapping.push_back(false); mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp)); } @@ -4042,41 +4206,66 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) { return std::distance(mapData.MapClause.begin(), res); } +static void sortMapIndices(llvm::SmallVectorImpl<size_t> &indices, + omp::MapInfoOp mapInfo) { + ArrayAttr indexAttr = mapInfo.getMembersIndexAttr(); + llvm::SmallVector<size_t> occludedChildren; + llvm::sort( + indices.begin(), indices.end(), [&](const size_t a, const size_t b) { + // Bail early if we are asked to look at the same index. If we do not + // bail early, we can end up mistakenly adding indices to + // occludedChildren. This can occur with some types of libc++ hardening. + if (a == b) + return false; + + auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); + auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); + + for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) { + int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt(); + int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt(); + + if (aIndex == bIndex) + continue; + + if (aIndex < bIndex) + return true; + + if (aIndex > bIndex) + return false; + } + + // Iterated up until the end of the smallest member and + // they were found to be equal up to that point, so select + // the member with the lowest index count, so the "parent" + bool memberAParent = memberIndicesA.size() < memberIndicesB.size(); + if (memberAParent) + occludedChildren.push_back(b); + else + occludedChildren.push_back(a); + return memberAParent; + }); + + // We remove children from the index list that are overshadowed by + // a parent, this prevents us retrieving these as the first or last + // element when the parent is the correct element in these cases. + for (auto v : occludedChildren) + indices.erase(std::remove(indices.begin(), indices.end(), v), + indices.end()); +} + static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first) { ArrayAttr indexAttr = mapInfo.getMembersIndexAttr(); // Only 1 member has been mapped, we can return it. if (indexAttr.size() == 1) return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()); - llvm::SmallVector<size_t> indices(indexAttr.size()); std::iota(indices.begin(), indices.end(), 0); - - llvm::sort(indices, [&](const size_t a, const size_t b) { - auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); - auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); - for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) { - int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt(); - int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt(); - - if (aIndex == bIndex) - continue; - - if (aIndex < bIndex) - return first; - - if (aIndex > bIndex) - return !first; - } - - // Iterated the up until the end of the smallest member and - // they were found to be equal up to that point, so select - // the member with the lowest index count, so the "parent" - return memberIndicesA.size() < memberIndicesB.size(); - }); - + sortMapIndices(indices, mapInfo); return llvm::cast<omp::MapInfoOp>( - mapInfo.getMembers()[indices.front()].getDefiningOp()); + mapInfo.getMembers()[first ? indices.front() : indices.back()] + .getDefiningOp()); } /// This function calculates the array/pointer offset for map data provided @@ -4155,6 +4344,86 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, return idx; } +static void getAsIntegers(ArrayAttr values, llvm::SmallVector<int64_t> &ints) { + llvm::transform(values, std::back_inserter(ints), [](Attribute value) { + return cast<IntegerAttr>(value).getInt(); + }); +} + +// Gathers members that are overlapping in the parent, excluding members that +// themselves overlap, keeping the top-most (closest to parents level) map. +static void +getOverlappedMembers(llvm::SmallVectorImpl<size_t> &overlapMapDataIdxs, + omp::MapInfoOp parentOp) { + // No members mapped, no overlaps. + if (parentOp.getMembers().empty()) + return; + + // Single member, we can insert and return early. + if (parentOp.getMembers().size() == 1) { + overlapMapDataIdxs.push_back(0); + return; + } + + // 1) collect list of top-level overlapping members from MemberOp + llvm::SmallVector<std::pair<int, ArrayAttr>> memberByIndex; + ArrayAttr indexAttr = parentOp.getMembersIndexAttr(); + for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr)) + memberByIndex.push_back( + std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr))); + + // Sort the smallest first (higher up the parent -> member chain), so that + // when we remove members, we remove as much as we can in the initial + // iterations, shortening the number of passes required. + llvm::sort(memberByIndex.begin(), memberByIndex.end(), + [&](auto a, auto b) { return a.second.size() < b.second.size(); }); + + // Remove elements from the vector if there is a parent element that + // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1], + // [0,2].. etc. + llvm::SmallVector<std::pair<int, ArrayAttr>> skipList; + for (auto v : memberByIndex) { + llvm::SmallVector<int64_t> vArr(v.second.size()); + getAsIntegers(v.second, vArr); + skipList.push_back( + *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) { + if (v == x) + return false; + llvm::SmallVector<int64_t> xArr(x.second.size()); + getAsIntegers(x.second, xArr); + return std::equal(vArr.begin(), vArr.end(), xArr.begin()) && + xArr.size() >= vArr.size(); + })); + } + + // Collect the indices, as we need the base pointer etc. from the MapData + // structure which is primarily accessible via index at the moment. + for (auto v : memberByIndex) + if (find(skipList.begin(), skipList.end(), v) == skipList.end()) + overlapMapDataIdxs.push_back(v.first); +} + +// The intent is to verify if the mapped data being passed is a +// pointer -> pointee that requires special handling in certain cases, +// e.g. applying the OMP_MAP_PTR_AND_OBJ map type. +// +// There may be a better way to verify this, but unfortunately with +// opaque pointers we lose the ability to easily check if something is +// a pointer whilst maintaining access to the underlying type. +static bool checkIfPointerMap(omp::MapInfoOp mapOp) { + // If we have a varPtrPtr field assigned then the underlying type is a pointer + if (mapOp.getVarPtrPtr()) + return true; + + // If the map data is declare target with a link clause, then it's represented + // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has + // no relation to pointers. + if (isDeclareTargetLink(mapOp.getVarPtr())) + return true; + + return false; +} + // This creates two insertions into the MapInfosTy data structure for the // "parent" of a set of members, (usually a container e.g. // class/structure/derived type) when subsequent members have also been @@ -4173,7 +4442,8 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, - MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams) { + MapInfoData &mapData, uint64_t mapDataIndex, + TargetDirectiveEnumTy targetDirective) { assert(!ompBuilder.Config.isTargetDevice() && "function only supported for host device codegen"); @@ -4182,7 +4452,8 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( // base entry so the mapper receives correct copy semantics via its 'type' // parameter. Also keep TARGET_PARAM when required for kernel arguments. llvm::omp::OpenMPOffloadMappingFlags baseFlag = - isTargetParams + (targetDirective == TargetDirectiveEnumTy::Target && + !mapData.IsDeclareTarget[mapDataIndex]) ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; @@ -4217,7 +4488,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( // runtime information on the dynamically allocated data). auto parentClause = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]); - llvm::Value *lowAddr, *highAddr; if (!parentClause.getPartialMap()) { lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex], @@ -4263,39 +4533,85 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( // further case specific flag modifications). For the moment, it handles // what we support as expected. llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex]; + bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) & + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) == + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); - combinedInfo.Types.emplace_back(mapFlag); - combinedInfo.DevicePointers.emplace_back( - llvm::OpenMPIRBuilder::DeviceInfoTy::None); - combinedInfo.Mappers.emplace_back(nullptr); - combinedInfo.Names.emplace_back(LLVM::createMappingInformation( - mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); - combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]); - combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]); - combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]); - } - return memberOfFlag; -} - -// The intent is to verify if the mapped data being passed is a -// pointer -> pointee that requires special handling in certain cases, -// e.g. applying the OMP_MAP_PTR_AND_OBJ map type. -// -// There may be a better way to verify this, but unfortunately with -// opaque pointers we lose the ability to easily check if something is -// a pointer whilst maintaining access to the underlying type. -static bool checkIfPointerMap(omp::MapInfoOp mapOp) { - // If we have a varPtrPtr field assigned then the underlying type is a pointer - if (mapOp.getVarPtrPtr()) - return true; - // If the map data is declare target with a link clause, then it's represented - // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has - // no relation to pointers. - if (isDeclareTargetLink(mapOp.getVarPtr())) - return true; + if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) { + combinedInfo.Types.emplace_back(mapFlag); + combinedInfo.DevicePointers.emplace_back( + mapData.DevicePointers[mapDataIndex]); + combinedInfo.Names.emplace_back(LLVM::createMappingInformation( + mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); + combinedInfo.BasePointers.emplace_back( + mapData.BasePointers[mapDataIndex]); + combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]); + combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); + } else { + llvm::SmallVector<size_t> overlapIdxs; + // Find all of the members that "overlap", i.e. occlude other members that + // were mapped alongside the parent, e.g. member [0], occludes [0,1] and + // [0,2], but not [1,0]. + getOverlappedMembers(overlapIdxs, parentClause); + // We need to make sure the overlapped members are sorted in order of + // lowest address to highest address. + sortMapIndices(overlapIdxs, parentClause); + + lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex], + builder.getPtrTy()); + highAddr = builder.CreatePointerCast( + builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex], + mapData.Pointers[mapDataIndex], 1), + builder.getPtrTy()); + + // TODO: We may want to skip arrays/array sections in this as Clang does. + // It appears to be an optimisation rather than a necessity though, + // but this requires further investigation. However, we would have to make + // sure to not exclude maps with bounds that ARE pointers, as these are + // processed as separate components, i.e. pointer + data. + for (auto v : overlapIdxs) { + auto mapDataOverlapIdx = getMapDataMemberIdx( + mapData, + cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp())); + combinedInfo.Types.emplace_back(mapFlag); + combinedInfo.DevicePointers.emplace_back( + mapData.DevicePointers[mapDataOverlapIdx]); + combinedInfo.Names.emplace_back(LLVM::createMappingInformation( + mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); + combinedInfo.BasePointers.emplace_back( + mapData.BasePointers[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); + combinedInfo.Pointers.emplace_back(lowAddr); + combinedInfo.Sizes.emplace_back(builder.CreateIntCast( + builder.CreatePtrDiff(builder.getInt8Ty(), + mapData.OriginalValue[mapDataOverlapIdx], + lowAddr), + builder.getInt64Ty(), /*isSigned=*/true)); + lowAddr = builder.CreateConstGEP1_32( + checkIfPointerMap(llvm::cast<omp::MapInfoOp>( + mapData.MapClause[mapDataOverlapIdx])) + ? builder.getPtrTy() + : mapData.BaseType[mapDataOverlapIdx], + mapData.BasePointers[mapDataOverlapIdx], 1); + } - return false; + combinedInfo.Types.emplace_back(mapFlag); + combinedInfo.DevicePointers.emplace_back( + mapData.DevicePointers[mapDataIndex]); + combinedInfo.Names.emplace_back(LLVM::createMappingInformation( + mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); + combinedInfo.BasePointers.emplace_back( + mapData.BasePointers[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); + combinedInfo.Pointers.emplace_back(lowAddr); + combinedInfo.Sizes.emplace_back(builder.CreateIntCast( + builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr), + builder.getInt64Ty(), true)); + } + } + return memberOfFlag; } // This function is intended to add explicit mappings of members @@ -4303,7 +4619,8 @@ static void processMapMembersWithParent( LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, - llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) { + llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, + TargetDirectiveEnumTy targetDirective) { assert(!ompBuilder.Config.isTargetDevice() && "function only supported for host device codegen"); @@ -4348,8 +4665,15 @@ static void processMapMembersWithParent( mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF; ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); - if (checkIfPointerMap(memberClause)) + bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr() + ? parentClause.getVarPtr() + : parentClause.getVarPtrPtr()); + if (checkIfPointerMap(memberClause) && + (!isDeclTargetTo || + (targetDirective != TargetDirectiveEnumTy::TargetUpdate && + targetDirective != TargetDirectiveEnumTy::TargetData))) { mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ; + } combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( @@ -4375,7 +4699,8 @@ static void processMapMembersWithParent( } static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, - MapInfosTy &combinedInfo, bool isTargetParams, + MapInfosTy &combinedInfo, + TargetDirectiveEnumTy targetDirective, int mapDataParentIdx = -1) { // Declare Target Mappings are excluded from being marked as // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're @@ -4387,7 +4712,8 @@ static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, if (isPtrTy) mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ; - if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx]) + if (targetDirective == TargetDirectiveEnumTy::Target && + !mapData.IsDeclareTarget[mapDataIdx]) mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy && @@ -4416,7 +4742,7 @@ static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, - bool isTargetParams) { + TargetDirectiveEnumTy targetDirective) { assert(!ompBuilder.Config.isTargetDevice() && "function only supported for host device codegen"); @@ -4440,17 +4766,18 @@ static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, // Clang maps array without bounds as pointers (which we do not // currently do), whereas we treat them as arrays in all cases // currently. - processIndividualMap(mapData, memberDataIdx, combinedInfo, isTargetParams, + processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective, mapDataIndex); return; } llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag = mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl, - combinedInfo, mapData, mapDataIndex, isTargetParams); + combinedInfo, mapData, mapDataIndex, + targetDirective); processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl, combinedInfo, mapData, mapDataIndex, - memberOfParentFlag); + memberOfParentFlag, targetDirective); } // This is a variation on Clang's GenerateOpenMPCapturedVars, which @@ -4528,10 +4855,10 @@ createAlteredByCaptureMap(MapInfoData &mapData, static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, - MapInfoData &mapData, bool isTargetParams = false) { + MapInfoData &mapData, + TargetDirectiveEnumTy targetDirective) { assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && "function only supported for host device codegen"); - // We wish to modify some of the methods in which arguments are // passed based on their capture type by the target region, this can // involve generating new loads and stores, which changes the @@ -4561,22 +4888,24 @@ static void genMapInfos(llvm::IRBuilderBase &builder, auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]); if (!mapInfoOp.getMembers().empty()) { processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl, - combinedInfo, mapData, i, isTargetParams); + combinedInfo, mapData, i, targetDirective); continue; } - processIndividualMap(mapData, i, combinedInfo, isTargetParams); + processIndividualMap(mapData, i, combinedInfo, targetDirective); } } static llvm::Expected<llvm::Function *> emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - llvm::StringRef mapperFuncName); + llvm::StringRef mapperFuncName, + TargetDirectiveEnumTy targetDirective); static llvm::Expected<llvm::Function *> getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { + LLVM::ModuleTranslation &moduleTranslation, + TargetDirectiveEnumTy targetDirective) { assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && "function only supported for host device codegen"); auto declMapperOp = cast<omp::DeclareMapperOp>(op); @@ -4588,13 +4917,14 @@ getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, return lookupFunc; return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation, - mapperFuncName); + mapperFuncName, targetDirective); } static llvm::Expected<llvm::Function *> emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - llvm::StringRef mapperFuncName) { + llvm::StringRef mapperFuncName, + TargetDirectiveEnumTy targetDirective) { assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && "function only supported for host device codegen"); auto declMapperOp = cast<omp::DeclareMapperOp>(op); @@ -4622,10 +4952,11 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, MapInfoData mapData; collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, builder); - genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData); + genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData, + targetDirective); - // Drop the mapping that is no longer necessary so that the same region can - // be processed multiple times. + // Drop the mapping that is no longer necessary so that the same region + // can be processed multiple times. moduleTranslation.forgetMapping(declMapperOp.getRegion()); return combinedInfo; }; @@ -4634,7 +4965,7 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, if (!combinedInfo.Mappers[i]) return nullptr; return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, - moduleTranslation); + moduleTranslation, targetDirective); }; llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper( @@ -4655,10 +4986,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, SmallVector<Value> useDeviceAddrVars; llvm::omp::RuntimeFunction RTLFn; DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>()); + TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true, - /*SeparateBeginEndCalls=*/true); + llvm::OpenMPIRBuilder::TargetDataInfo info( + /*RequiresDevicePointerInfo=*/true, + /*SeparateBeginEndCalls=*/true); bool isTargetDevice = ompBuilder->Config.isTargetDevice(); bool isOffloadEntry = isTargetDevice || !ompBuilder->Config.TargetTriples.empty(); @@ -4757,7 +5090,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, MapInfosTy combinedInfo; auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & { builder.restoreIP(codeGenIP); - genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData); + genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData, + targetDirective); return combinedInfo; }; @@ -4873,7 +5207,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, return nullptr; info.HasMapper = true; return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, - moduleTranslation); + moduleTranslation, targetDirective); }; llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); @@ -4980,15 +5314,18 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) { // TODO: Add support for clauses which are valid for DISTRIBUTE // constructs. Static schedule is the default. - auto schedule = omp::ClauseScheduleKind::Static; - bool isOrdered = false; + bool hasDistSchedule = distributeOp.getDistScheduleStatic(); + auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute + : omp::ClauseScheduleKind::Static; + // dist_schedule clauses are ordered - otherise this should be false + bool isOrdered = hasDistSchedule; std::optional<omp::ScheduleModifier> scheduleMod; bool isSimd = false; llvm::omp::WorksharingLoopType workshareLoopType = llvm::omp::WorksharingLoopType::DistributeStaticLoop; bool loopNeedsBarrier = false; - llvm::Value *chunk = nullptr; - + llvm::Value *chunk = moduleTranslation.lookupValue( + distributeOp.getDistScheduleChunkSize()); llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation); llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = @@ -4997,12 +5334,11 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, convertToScheduleKind(schedule), chunk, isSimd, scheduleMod == omp::ScheduleModifier::monotonic, scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, - workshareLoopType); + workshareLoopType, false, hasDistSchedule, chunk); if (!wsloopIP) return wsloopIP.takeError(); } - if (failed(cleanupPrivateVars(builder, moduleTranslation, distributeOp.getLoc(), privVarsInfo.llvmVars, privVarsInfo.privatizers))) @@ -5135,11 +5471,16 @@ handleDeclareTargetMapVar(MapInfoData &mapData, for (llvm::User *user : userVec) { if (auto *insn = dyn_cast<llvm::Instruction>(user)) { if (insn->getFunction() == func) { - builder.SetCurrentDebugLocation(insn->getDebugLoc()); - auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(), - mapData.BasePointers[i]); - load->moveBefore(insn->getIterator()); - user->replaceUsesOfWith(mapData.OriginalValue[i], load); + auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]); + llvm::Value *substitute = mapData.BasePointers[i]; + if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() + : mapOp.getVarPtr())) { + builder.SetCurrentDebugLocation(insn->getDebugLoc()); + substitute = builder.CreateLoad( + mapData.BasePointers[i]->getType(), mapData.BasePointers[i]); + cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator()); + } + user->replaceUsesOfWith(mapData.OriginalValue[i], substitute); } } } @@ -5431,8 +5772,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, int32_t minTeamsVal = 1, maxTeamsVal = -1; if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) { - // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match - // clang and set min and max to the same value. + // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, + // match clang and set min and max to the same value. if (numTeamsUpper) { if (auto val = extractConstInteger(numTeamsUpper)) minTeamsVal = maxTeamsVal = *val; @@ -5624,9 +5965,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>(); auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst); auto &targetRegion = targetOp.getRegion(); - // Holds the private vars that have been mapped along with the block argument - // that corresponds to the MapInfoOp corresponding to the private var in - // question. So, for instance: + // Holds the private vars that have been mapped along with the block + // argument that corresponds to the MapInfoOp corresponding to the private + // var in question. So, for instance: // // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..) // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1) @@ -5641,6 +5982,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs(); ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs(); llvm::Function *llvmOutlinedFn = nullptr; + TargetDirectiveEnumTy targetDirective = + getTargetDirectiveEnumTyFromOp(&opInst); // TODO: It can also be false if a compile-time constant `false` IF clause is // specified. @@ -5802,7 +6145,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & { builder.restoreIP(codeGenIP); - genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, true); + genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, + targetDirective); return combinedInfos; }; @@ -5882,7 +6226,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, return nullptr; info.HasMapper = true; return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder, - moduleTranslation); + moduleTranslation, targetDirective); }; llvm::Value *ifCond = nullptr; diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index d9891e3..d7d215b 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -34,12 +34,14 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Comdat.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugProgramInstruction.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/ModRef.h" #include <optional> @@ -522,6 +524,11 @@ void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) { debugIntrinsics.insert(intrinsic); } +void ModuleImport::addDebugRecord(llvm::DbgVariableRecord *dbgRecord) { + if (!dbgRecords.contains(dbgRecord)) + dbgRecords.insert(dbgRecord); +} + static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule, llvm::MDTuple *mdTuple) { auto getLLVMFunction = @@ -1214,7 +1221,7 @@ static TypedAttr getScalarConstantAsAttr(OpBuilder &builder, llvm::Constant *constScalar) { MLIRContext *context = builder.getContext(); - // Convert scalar intergers. + // Convert scalar integers. if (auto *constInt = dyn_cast<llvm::ConstantInt>(constScalar)) { return builder.getIntegerAttr( IntegerType::get(context, constInt->getBitWidth()), @@ -2003,9 +2010,15 @@ FloatAttr ModuleImport::matchFloatAttr(llvm::Value *value) { return floatAttr; } -DILocalVariableAttr ModuleImport::matchLocalVariableAttr(llvm::Value *value) { - auto *nodeAsVal = cast<llvm::MetadataAsValue>(value); - auto *node = cast<llvm::DILocalVariable>(nodeAsVal->getMetadata()); +DILocalVariableAttr ModuleImport::matchLocalVariableAttr( + llvm::PointerUnion<llvm::Value *, llvm::DILocalVariable *> valOrVariable) { + llvm::DILocalVariable *node = nullptr; + if (auto *value = dyn_cast<llvm::Value *>(valOrVariable)) { + auto *nodeAsVal = cast<llvm::MetadataAsValue>(value); + node = cast<llvm::DILocalVariable>(nodeAsVal->getMetadata()); + } else { + node = cast<llvm::DILocalVariable *>(valOrVariable); + } return debugImporter->translate(node); } @@ -2544,6 +2557,41 @@ LogicalResult ModuleImport::processInstruction(llvm::Instruction *inst) { if (auto *intrinsic = dyn_cast<llvm::IntrinsicInst>(inst)) return convertIntrinsic(intrinsic); + // Process debug records attached to this instruction. Debug variable records + // are stored for later processing after all SSA values are converted, while + // debug label records can be converted immediately. + if (inst->DebugMarker) { + for (llvm::DbgRecord &dbgRecord : inst->DebugMarker->getDbgRecordRange()) { + // Store debug variable records for later processing. + if (auto *dbgVariableRecord = + dyn_cast<llvm::DbgVariableRecord>(&dbgRecord)) { + addDebugRecord(dbgVariableRecord); + continue; + } + Location loc = translateLoc(dbgRecord.getDebugLoc()); + auto emitUnsupportedWarning = [&]() -> LogicalResult { + if (!emitExpensiveWarnings) + return success(); + std::string options; + llvm::raw_string_ostream optionsStream(options); + dbgRecord.print(optionsStream); + emitWarning(loc) << "unhandled debug record " << optionsStream.str(); + return success(); + }; + // Convert the debug label records in-place. + if (auto *dbgLabelRecord = dyn_cast<llvm::DbgLabelRecord>(&dbgRecord)) { + DILabelAttr labelAttr = + debugImporter->translate(dbgLabelRecord->getLabel()); + if (!labelAttr) + return emitUnsupportedWarning(); + LLVM::DbgLabelOp::create(builder, loc, labelAttr); + continue; + } + // Warn if an unsupported debug record is encountered. + return emitUnsupportedWarning(); + } + } + // Convert all remaining LLVM instructions to MLIR operations. return convertInstruction(inst); } @@ -2579,8 +2627,15 @@ static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) { memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem)); auto inaccessibleMem = convertModRefInfoFromLLVM( memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem)); - auto memAttr = MemoryEffectsAttr::get(funcOp.getContext(), othermem, argMem, - inaccessibleMem); + auto errnoMem = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::ErrnoMem)); + auto targetMem0 = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem0)); + auto targetMem1 = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem1)); + auto memAttr = + MemoryEffectsAttr::get(funcOp.getContext(), othermem, argMem, + inaccessibleMem, errnoMem, targetMem0, targetMem1); // Only set the attr when it does not match the default value. if (memAttr.isReadWrite()) return; @@ -2885,8 +2940,15 @@ LogicalResult ModuleImport::convertCallAttributes(llvm::CallInst *inst, memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem)); ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM( memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem)); - auto memAttr = MemoryEffectsAttr::get(op.getContext(), othermem, argMem, - inaccessibleMem); + ModRefInfo errnoMem = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::ErrnoMem)); + ModRefInfo targetMem0 = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem0)); + ModRefInfo targetMem1 = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem1)); + auto memAttr = + MemoryEffectsAttr::get(op.getContext(), othermem, argMem, inaccessibleMem, + errnoMem, targetMem0, targetMem1); // Only set the attribute when it does not match the default value. if (!memAttr.isReadWrite()) op.setMemoryEffectsAttr(memAttr); @@ -3007,6 +3069,11 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) { if (failed(processDebugIntrinsics())) return failure(); + // Process the debug records that require a delayed conversion after + // everything else was converted. + if (failed(processDebugRecords())) + return failure(); + return success(); } @@ -3022,61 +3089,32 @@ static bool isMetadataKillLocation(llvm::DbgVariableIntrinsic *dbgIntr) { return !isa<llvm::ValueAsMetadata>(nodeAsVal->getMetadata()); } -LogicalResult -ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr, - DominanceInfo &domInfo) { - Location loc = translateLoc(dbgIntr->getDebugLoc()); - auto emitUnsupportedWarning = [&]() { - if (emitExpensiveWarnings) - emitWarning(loc) << "dropped intrinsic: " << diag(*dbgIntr); - return success(); - }; - // Drop debug intrinsics with arg lists. - // TODO: Support debug intrinsics that have arg lists. - if (dbgIntr->hasArgList()) - return emitUnsupportedWarning(); - // Kill locations can have metadata nodes as location operand. This - // cannot be converted to poison as the type cannot be reconstructed. - // TODO: find a way to support this case. - if (isMetadataKillLocation(dbgIntr)) - return emitUnsupportedWarning(); - // Drop debug intrinsics if the associated variable information cannot be - // translated due to cyclic debug metadata. - // TODO: Support cyclic debug metadata. - DILocalVariableAttr localVariableAttr = - matchLocalVariableAttr(dbgIntr->getArgOperand(1)); - if (!localVariableAttr) - return emitUnsupportedWarning(); - FailureOr<Value> argOperand = convertMetadataValue(dbgIntr->getArgOperand(0)); - if (failed(argOperand)) - return emitError(loc) << "failed to convert a debug intrinsic operand: " - << diag(*dbgIntr); - - // Ensure that the debug intrinsic is inserted right after its operand is - // defined. Otherwise, the operand might not necessarily dominate the - // intrinsic. If the defining operation is a terminator, insert the intrinsic - // into a dominated block. - OpBuilder::InsertionGuard guard(builder); - if (Operation *op = argOperand->getDefiningOp(); +/// Ensure that the debug intrinsic is inserted right after the operand +/// definition. Otherwise, the operand might not necessarily dominate the +/// intrinsic. If the defining operation is a terminator, insert the intrinsic +/// into a dominated block. +static LogicalResult setDebugIntrinsicBuilderInsertionPoint( + mlir::OpBuilder &builder, DominanceInfo &domInfo, Value argOperand) { + if (Operation *op = argOperand.getDefiningOp(); op && op->hasTrait<OpTrait::IsTerminator>()) { // Find a dominated block that can hold the debug intrinsic. auto dominatedBlocks = domInfo.getNode(op->getBlock())->children(); // If no block is dominated by the terminator, this intrinisc cannot be // converted. if (dominatedBlocks.empty()) - return emitUnsupportedWarning(); + return failure(); // Set insertion point before the terminator, to avoid inserting something // before landingpads. Block *dominatedBlock = (*dominatedBlocks.begin())->getBlock(); builder.setInsertionPoint(dominatedBlock->getTerminator()); } else { - Value insertPt = *argOperand; - if (auto blockArg = dyn_cast<BlockArgument>(*argOperand)) { + Value insertPt = argOperand; + if (auto blockArg = dyn_cast<BlockArgument>(argOperand)) { // The value might be coming from a phi node and is now a block argument, // which means the insertion point is set to the start of the block. If // this block is a target destination of an invoke, the insertion point // must happen after the landing pad operation. - Block *insertionBlock = argOperand->getParentBlock(); + Block *insertionBlock = argOperand.getParentBlock(); if (!insertionBlock->empty() && isa<LandingpadOp>(insertionBlock->front())) insertPt = cast<LandingpadOp>(insertionBlock->front()).getRes(); @@ -3084,23 +3122,152 @@ ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr, builder.setInsertionPointAfterValue(insertPt); } - auto locationExprAttr = - debugImporter->translateExpression(dbgIntr->getExpression()); - Operation *op = - llvm::TypeSwitch<llvm::DbgVariableIntrinsic *, Operation *>(dbgIntr) - .Case([&](llvm::DbgDeclareInst *) { - return LLVM::DbgDeclareOp::create( - builder, loc, *argOperand, localVariableAttr, locationExprAttr); - }) - .Case([&](llvm::DbgValueInst *) { - return LLVM::DbgValueOp::create( - builder, loc, *argOperand, localVariableAttr, locationExprAttr); - }); + return success(); +} + +std::tuple<DILocalVariableAttr, DIExpressionAttr, Value> +ModuleImport::processDebugOpArgumentsAndInsertionPt( + Location loc, + llvm::function_ref<FailureOr<Value>()> convertArgOperandToValue, + llvm::Value *address, + llvm::PointerUnion<llvm::Value *, llvm::DILocalVariable *> variable, + llvm::DIExpression *expression, DominanceInfo &domInfo) { + // Drop debug intrinsics if the associated debug information cannot be + // translated due to an unsupported construct. + DILocalVariableAttr localVarAttr = matchLocalVariableAttr(variable); + if (!localVarAttr) + return {}; + FailureOr<Value> argOperand = convertArgOperandToValue(); + if (failed(argOperand)) { + emitError(loc) << "failed to convert a debug operand: " << diag(*address); + return {}; + } + + if (setDebugIntrinsicBuilderInsertionPoint(builder, domInfo, *argOperand) + .failed()) + return {}; + + return {localVarAttr, debugImporter->translateExpression(expression), + *argOperand}; +} + +LogicalResult +ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr, + DominanceInfo &domInfo) { + Location loc = translateLoc(dbgIntr->getDebugLoc()); + auto emitUnsupportedWarning = [&]() { + if (emitExpensiveWarnings) + emitWarning(loc) << "dropped intrinsic: " << diag(*dbgIntr); + return success(); + }; + + OpBuilder::InsertionGuard guard(builder); + auto convertArgOperandToValue = [&]() { + return convertMetadataValue(dbgIntr->getArgOperand(0)); + }; + + // Drop debug intrinsics with an argument list. + // TODO: Support this case. + if (dbgIntr->hasArgList()) + return emitUnsupportedWarning(); + + // Drop debug intrinsics with kill locations that have metadata nodes as + // location operand, which cannot be converted to poison as the type cannot be + // reconstructed. + // TODO: Support this case. + if (isMetadataKillLocation(dbgIntr)) + return emitUnsupportedWarning(); + + auto [localVariableAttr, locationExprAttr, locVal] = + processDebugOpArgumentsAndInsertionPt( + loc, convertArgOperandToValue, dbgIntr->getArgOperand(0), + dbgIntr->getArgOperand(1), dbgIntr->getExpression(), domInfo); + + if (!localVariableAttr) + return emitUnsupportedWarning(); + + if (!locVal) // Expected if localVariableAttr is present. + return failure(); + + Operation *op = nullptr; + if (isa<llvm::DbgDeclareInst>(dbgIntr)) + op = LLVM::DbgDeclareOp::create(builder, loc, locVal, localVariableAttr, + locationExprAttr); + else if (isa<llvm::DbgValueInst>(dbgIntr)) + op = LLVM::DbgValueOp::create(builder, loc, locVal, localVariableAttr, + locationExprAttr); + else + return emitUnsupportedWarning(); + mapNoResultOp(dbgIntr, op); setNonDebugMetadataAttrs(dbgIntr, op); return success(); } +LogicalResult +ModuleImport::processDebugRecord(llvm::DbgVariableRecord &dbgRecord, + DominanceInfo &domInfo) { + OpBuilder::InsertionGuard guard(builder); + Location loc = translateLoc(dbgRecord.getDebugLoc()); + auto emitUnsupportedWarning = [&]() -> LogicalResult { + if (!emitExpensiveWarnings) + return success(); + std::string options; + llvm::raw_string_ostream optionsStream(options); + dbgRecord.print(optionsStream); + emitWarning(loc) << "unhandled debug variable record " + << optionsStream.str(); + return success(); + }; + + // Drop debug records with an argument list. + // TODO: Support this case. + if (dbgRecord.hasArgList()) + return emitUnsupportedWarning(); + + // Drop all other debug records with a address operand that cannot be + // converted to an SSA value such as an empty metadata node. + // TODO: Support this case. + if (!dbgRecord.getAddress()) + return emitUnsupportedWarning(); + + auto convertArgOperandToValue = [&]() -> FailureOr<Value> { + llvm::Value *value = dbgRecord.getAddress(); + + // Return the mapped value if it has been converted before. + auto it = valueMapping.find(value); + if (it != valueMapping.end()) + return it->getSecond(); + + // Convert constants such as immediate values that have no mapping yet. + if (auto *constant = dyn_cast<llvm::Constant>(value)) + return convertConstantExpr(constant); + return failure(); + }; + + auto [localVariableAttr, locationExprAttr, locVal] = + processDebugOpArgumentsAndInsertionPt( + loc, convertArgOperandToValue, dbgRecord.getAddress(), + dbgRecord.getVariable(), dbgRecord.getExpression(), domInfo); + + if (!localVariableAttr) + return emitUnsupportedWarning(); + + if (!locVal) // Expected if localVariableAttr is present. + return failure(); + + if (dbgRecord.isDbgDeclare()) + LLVM::DbgDeclareOp::create(builder, loc, locVal, localVariableAttr, + locationExprAttr); + else if (dbgRecord.isDbgValue()) + LLVM::DbgValueOp::create(builder, loc, locVal, localVariableAttr, + locationExprAttr); + else // isDbgAssign + return emitUnsupportedWarning(); + + return success(); +} + LogicalResult ModuleImport::processDebugIntrinsics() { DominanceInfo domInfo; for (llvm::Instruction *inst : debugIntrinsics) { @@ -3111,6 +3278,15 @@ LogicalResult ModuleImport::processDebugIntrinsics() { return success(); } +LogicalResult ModuleImport::processDebugRecords() { + DominanceInfo domInfo; + for (llvm::DbgVariableRecord *dbgRecord : dbgRecords) + if (failed(processDebugRecord(*dbgRecord, domInfo))) + return failure(); + dbgRecords.clear(); + return success(); +} + LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb, Block *block) { builder.setInsertionPointToStart(block); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 64e3c5f..fad9bd6b7 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -588,10 +588,17 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( } // For integer types, we allow a mismatch in sizes as the index type in // MLIR might have a different size than the index type in the LLVM module. - if (auto intAttr = dyn_cast<IntegerAttr>(attr)) - return llvm::ConstantInt::get( - llvmType, - intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); + if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { + // If the attribute is an unsigned integer or a 1-bit integer, zero-extend + // the value to the bit width of the LLVM type. Otherwise, sign-extend. + auto intTy = dyn_cast<IntegerType>(intAttr.getType()); + APInt value; + if (intTy && (intTy.isUnsigned() || intTy.getWidth() == 1)) + value = intAttr.getValue().zextOrTrunc(llvmType->getIntegerBitWidth()); + else + value = intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()); + return llvm::ConstantInt::get(llvmType, value); + } if (auto floatAttr = dyn_cast<FloatAttr>(attr)) { const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); // Special case for 8-bit floats, which are represented by integers due to @@ -677,10 +684,10 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( } } } - // std::vector is used here to accomodate large number of elements that - // exceed SmallVector capacity. - std::vector<llvm::Constant *> constants(numElements, child); - return llvm::ConstantArray::get(arrayType, constants); + // std::vector is used here to accomodate large number of elements that + // exceed SmallVector capacity. + std::vector<llvm::Constant *> constants(numElements, child); + return llvm::ConstantArray::get(arrayType, constants); } } @@ -892,10 +899,13 @@ void mlir::LLVM::detail::connectPHINodes(Region ®ion, llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) { - llvm::Module *module = builder.GetInsertBlock()->getModule(); - llvm::Function *fn = - llvm::Intrinsic::getOrInsertDeclaration(module, intrinsic, tys); - return builder.CreateCall(fn, args); + return builder.CreateIntrinsic(intrinsic, tys, args); +} + +llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( + llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, + llvm::Type *retTy, ArrayRef<llvm::Value *> args) { + return builder.CreateIntrinsic(retTy, intrinsic, args); } llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( @@ -1637,6 +1647,15 @@ static void convertFunctionMemoryAttributes(LLVMFuncOp func, newMemEffects |= llvm::MemoryEffects(llvm::MemoryEffects::Location::Other, convertModRefInfoToLLVM(memEffects.getOther())); + newMemEffects |= + llvm::MemoryEffects(llvm::MemoryEffects::Location::ErrnoMem, + convertModRefInfoToLLVM(memEffects.getErrnoMem())); + newMemEffects |= + llvm::MemoryEffects(llvm::MemoryEffects::Location::TargetMem0, + convertModRefInfoToLLVM(memEffects.getTargetMem0())); + newMemEffects |= + llvm::MemoryEffects(llvm::MemoryEffects::Location::TargetMem1, + convertModRefInfoToLLVM(memEffects.getTargetMem1())); llvmFunc->setMemoryEffects(newMemEffects); } @@ -2122,8 +2141,16 @@ LogicalResult ModuleTranslation::createTBAAMetadata() { // LLVM metadata instances. AttrTypeWalker walker; walker.addWalk([&](TBAARootAttr root) { - tbaaMetadataMapping.insert( - {root, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, root.getId()))}); + llvm::MDNode *node; + if (StringAttr id = root.getId()) { + node = llvm::MDNode::get(ctx, llvm::MDString::get(ctx, id)); + } else { + // Anonymous root nodes are self-referencing. + auto selfRef = llvm::MDNode::getTemporary(ctx, {}); + node = llvm::MDNode::get(ctx, {selfRef.get()}); + node->replaceOperandWith(0, node); + } + tbaaMetadataMapping.insert({root, node}); }); walker.addWalk([&](TBAATypeDescriptorAttr descriptor) { @@ -2254,8 +2281,11 @@ llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() { /* HasRequiresUnifiedSharedMemory = */ false, /* HasRequiresDynamicAllocators = */ false); unsigned int defaultAS = - getLLVMModule()->getDataLayout().getProgramAddressSpace(); + llvmModule->getDataLayout().getProgramAddressSpace(); config.setDefaultTargetAS(defaultAS); + config.setRuntimeCC(llvmModule->getTargetTriple().isSPIRV() + ? llvm::CallingConv::SPIR_FUNC + : llvm::CallingConv::C); ompBuilder->setConfig(std::move(config)); ompBuilder->initialize(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index c27f9aa..5b04a14 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -248,6 +248,8 @@ LogicalResult spirv::Deserializer::processInstruction( return processLoopMerge(operands); case spirv::Opcode::OpPhi: return processPhi(operands); + case spirv::Opcode::OpSwitch: + return processSwitch(operands); case spirv::Opcode::OpUndef: return processUndef(operands); default: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 6492708..50883d9 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -346,6 +346,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { case spirv::Decoration::Constant: case spirv::Decoration::Invariant: case spirv::Decoration::Patch: + case spirv::Decoration::Coherent: if (words.size() != 2) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target <id>"; @@ -2292,6 +2293,38 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) { return success(); } +LogicalResult spirv::Deserializer::processSwitch(ArrayRef<uint32_t> operands) { + if (!curBlock) + return emitError(unknownLoc, "OpSwitch must appear in a block"); + + if (operands.size() < 2) + return emitError(unknownLoc, "OpSwitch must at least specify selector and " + "a default target"); + + if (operands.size() % 2) + return emitError(unknownLoc, + "OpSwitch must at have an even number of operands: " + "selector, default target and any number of literal and " + "label <id> pairs"); + + Value selector = getValue(operands[0]); + Block *defaultBlock = getOrCreateBlock(operands[1]); + Location loc = createFileLineColLoc(opBuilder); + + SmallVector<int32_t> literals; + SmallVector<Block *> blocks; + for (unsigned i = 2, e = operands.size(); i < e; i += 2) { + literals.push_back(operands[i]); + blocks.push_back(getOrCreateBlock(operands[i + 1])); + } + + SmallVector<ValueRange> targetOperands(blocks.size(), {}); + spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock, + ArrayRef<Value>(), literals, blocks, targetOperands); + + return success(); +} + namespace { /// A class for putting all blocks in a structured selection/loop in a /// spirv.mlir.selection/spirv.mlir.loop op. @@ -2799,6 +2832,23 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() { branchCondOp.getFalseBlock()); branchCondOp.erase(); + } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(op)) { + if (target == switchOp.getDefaultTarget()) { + SmallVector<ValueRange> targetOperands(switchOp.getTargetOperands()); + DenseIntElementsAttr literals = + switchOp.getLiterals().value_or(DenseIntElementsAttr()); + spirv::SwitchOp::create( + opBuilder, switchOp.getLoc(), switchOp.getSelector(), + switchOp.getDefaultTarget(), blockArgs, literals, + switchOp.getTargets(), targetOperands); + switchOp.erase(); + } else { + SuccessorRange targets = switchOp.getTargets(); + auto it = llvm::find(targets, target); + assert(it != targets.end()); + size_t index = std::distance(targets.begin(), it); + switchOp.getTargetOperandsMutable(index).assign(blockArgs); + } } else { return emitError(unknownLoc, "unimplemented terminator for Phi creation"); } @@ -2819,7 +2869,7 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() { return success(); } -LogicalResult spirv::Deserializer::splitConditionalBlocks() { +LogicalResult spirv::Deserializer::splitSelectionHeader() { // Create a copy, so we can modify keys in the original. BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo; for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end(); @@ -2836,7 +2886,7 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() { Operation *terminator = block->getTerminator(); assert(terminator); - if (!isa<spirv::BranchConditionalOp>(terminator)) + if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator)) continue; // Check if the current header block is a merge block of another construct. @@ -2846,10 +2896,10 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() { splitHeaderMergeBlock = true; } - // Do not split a block that only contains a conditional branch, unless it - // is also a merge block of another construct - in that case we want to - // split the block. We do not want two constructs to share header / merge - // block. + // Do not split a block that only contains a conditional branch / switch, + // unless it is also a merge block of another construct - in that case we + // want to split the block. We do not want two constructs to share header / + // merge block. if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) { Block *newBlock = block->splitBlock(terminator); OpBuilder builder(block, block->end()); @@ -2887,13 +2937,10 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() { logger.startLine() << "\n"; }); - if (failed(splitConditionalBlocks())) { + if (failed(splitSelectionHeader())) { return failure(); } - // TODO: This loop is non-deterministic. Iteration order may vary between runs - // for the same shader as the key to the map is a pointer. See: - // https://github.com/llvm/llvm-project/issues/128547 while (!blockMergeInfo.empty()) { Block *headerBlock = blockMergeInfo.begin()->first; BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 6027f1a..50c9350 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -58,7 +58,9 @@ struct DebugLine { }; /// Map from a selection/loop's header block to its merge (and continue) target. -using BlockMergeInfoMap = DenseMap<Block *, BlockMergeInfo>; +/// Use `MapVector<>` to ensure a deterministic iteration order with a pointer +/// key. +using BlockMergeInfoMap = llvm::MapVector<Block *, BlockMergeInfo>; /// A "deferred struct type" is a struct type with one or more member types not /// known when the Deserializer first encounters the struct. This happens, for @@ -278,11 +280,11 @@ private: return opBuilder.getStringAttr(attrName); } - /// Move a conditional branch into a separate basic block to avoid unnecessary - /// sinking of defs that may be required outside a selection region. This - /// function also ensures that a single block cannot be a header block of one - /// selection construct and the merge block of another. - LogicalResult splitConditionalBlocks(); + /// Move a conditional branch or a switch into a separate basic block to avoid + /// unnecessary sinking of defs that may be required outside a selection + /// region. This function also ensures that a single block cannot be a header + /// block of one selection construct and the merge block of another. + LogicalResult splitSelectionHeader(); //===--------------------------------------------------------------------===// // Type @@ -472,6 +474,9 @@ private: /// Processes a SPIR-V OpPhi instruction with the given `operands`. LogicalResult processPhi(ArrayRef<uint32_t> operands); + /// Processes a SPIR-V OpSwitch instruction with the given `operands`. + LogicalResult processSwitch(ArrayRef<uint32_t> operands); + /// Creates block arguments on predecessors previously recorded when handling /// OpPhi instructions. LogicalResult wireUpBlockArgument(); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index 85e92c7..6397d2c 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -775,6 +775,27 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { return success(); } +LogicalResult Serializer::processSwitchOp(spirv::SwitchOp switchOp) { + uint32_t selectorID = getValueID(switchOp.getSelector()); + uint32_t defaultLabelID = getOrCreateBlockID(switchOp.getDefaultTarget()); + SmallVector<uint32_t> arguments{selectorID, defaultLabelID}; + + std::optional<mlir::DenseIntElementsAttr> literals = switchOp.getLiterals(); + BlockRange targets = switchOp.getTargets(); + if (literals) { + for (auto [literal, target] : llvm::zip_equal(*literals, targets)) { + arguments.push_back(literal.getLimitedValue()); + uint32_t targetLabelID = getOrCreateBlockID(target); + arguments.push_back(targetLabelID); + } + } + + if (failed(emitDebugLine(functionBody, switchOp.getLoc()))) + return failure(); + encodeInstructionInto(functionBody, spirv::Opcode::OpSwitch, arguments); + return success(); +} + LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { auto varName = addressOfOp.getVariable(); auto variableID = getVariableID(varName); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 29ed5a4..c879a2b 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -373,6 +373,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, case spirv::Decoration::Block: case spirv::Decoration::Invariant: case spirv::Decoration::Patch: + case spirv::Decoration::Coherent: // For unit attributes and decoration attributes, the args list // has no values so we do nothing. if (isa<UnitAttr, DecorationAttr>(attr)) @@ -1443,7 +1444,20 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { assert(branchCondOp.getFalseTarget() == block); blockOperands = branchCondOp.getFalseTargetOperands(); } - + assert(!blockOperands->empty() && + "expected non-empty block operand range"); + predecessors.emplace_back(spirvPredecessor, *blockOperands); + } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) { + std::optional<OperandRange> blockOperands; + if (block == switchOp.getDefaultTarget()) { + blockOperands = switchOp.getDefaultOperands(); + } else { + SuccessorRange targets = switchOp.getTargets(); + auto it = llvm::find(targets, block); + assert(it != targets.end()); + size_t index = std::distance(targets.begin(), it); + blockOperands = switchOp.getTargetOperands(index); + } assert(!blockOperands->empty() && "expected non-empty block operand range"); predecessors.emplace_back(spirvPredecessor, *blockOperands); @@ -1579,6 +1593,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) { .Case([&](spirv::SpecConstantOperationOp op) { return processSpecConstantOperationOp(op); }) + .Case([&](spirv::SwitchOp op) { return processSwitchOp(op); }) .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h index add372b..6e79c13 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -304,6 +304,8 @@ private: LogicalResult processBranchOp(spirv::BranchOp branchOp); + LogicalResult processSwitchOp(spirv::SwitchOp switchOp); + //===--------------------------------------------------------------------===// // Operations //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/AST/Context.cpp b/mlir/lib/Tools/PDLL/AST/Context.cpp index 6f2e4cd..e82807f 100644 --- a/mlir/lib/Tools/PDLL/AST/Context.cpp +++ b/mlir/lib/Tools/PDLL/AST/Context.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Tools/PDLL/AST/Context.h" -#include "TypeDetail.h" +#include "mlir/Tools/PDLL/AST/Types.h" using namespace mlir; using namespace mlir::pdll::ast; diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp index 5aa0937..4358ceb 100644 --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -21,7 +21,7 @@ static StringRef copyStringWithNull(Context &ctx, StringRef str) { return str; char *data = ctx.getAllocator().Allocate<char>(str.size() + 1); - std::copy(str.begin(), str.end(), data); + llvm::copy(str, data); data[str.size()] = 0; return StringRef(data, str.size()); } diff --git a/mlir/lib/Tools/PDLL/AST/TypeDetail.h b/mlir/lib/Tools/PDLL/AST/TypeDetail.h deleted file mode 100644 index a0bd84e..0000000 --- a/mlir/lib/Tools/PDLL/AST/TypeDetail.h +++ /dev/null @@ -1,141 +0,0 @@ -//===- TypeDetail.h ---------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_ -#define LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_ - -#include "mlir/Tools/PDLL/AST/Types.h" - -namespace mlir { -namespace pdll { -namespace ast { -//===----------------------------------------------------------------------===// -// Type -//===----------------------------------------------------------------------===// - -struct Type::Storage : public StorageUniquer::BaseStorage { - Storage(TypeID typeID) : typeID(typeID) {} - - /// The type identifier for the derived type class. - TypeID typeID; -}; - -namespace detail { - -/// A utility CRTP base class that defines many of the necessary utilities for -/// defining a PDLL AST Type. -template <typename ConcreteT, typename KeyT = void> -struct TypeStorageBase : public Type::Storage { - using KeyTy = KeyT; - using Base = TypeStorageBase<ConcreteT, KeyT>; - TypeStorageBase(KeyTy key) - : Type::Storage(TypeID::get<ConcreteT>()), key(key) {} - - /// Construct an instance with the given storage allocator. - static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, - const KeyTy &key) { - return new (alloc.allocate<ConcreteT>()) ConcreteT(key); - } - - /// Utility methods required by the storage allocator. - bool operator==(const KeyTy &key) const { return this->key == key; } - - /// Return the key value of this storage class. - const KeyTy &getValue() const { return key; } - -protected: - KeyTy key; -}; -/// A specialization of the storage base for singleton types. -template <typename ConcreteT> -struct TypeStorageBase<ConcreteT, void> : public Type::Storage { - using Base = TypeStorageBase<ConcreteT, void>; - TypeStorageBase() : Type::Storage(TypeID::get<ConcreteT>()) {} -}; - -//===----------------------------------------------------------------------===// -// AttributeType -//===----------------------------------------------------------------------===// - -struct AttributeTypeStorage : public TypeStorageBase<AttributeTypeStorage> {}; - -//===----------------------------------------------------------------------===// -// ConstraintType -//===----------------------------------------------------------------------===// - -struct ConstraintTypeStorage : public TypeStorageBase<ConstraintTypeStorage> {}; - -//===----------------------------------------------------------------------===// -// OperationType -//===----------------------------------------------------------------------===// - -struct OperationTypeStorage - : public TypeStorageBase<OperationTypeStorage, - std::pair<StringRef, const ods::Operation *>> { - using Base::Base; - - static OperationTypeStorage * - construct(StorageUniquer::StorageAllocator &alloc, - const std::pair<StringRef, const ods::Operation *> &key) { - return new (alloc.allocate<OperationTypeStorage>()) OperationTypeStorage( - std::make_pair(alloc.copyInto(key.first), key.second)); - } -}; - -//===----------------------------------------------------------------------===// -// RangeType -//===----------------------------------------------------------------------===// - -struct RangeTypeStorage : public TypeStorageBase<RangeTypeStorage, Type> { - using Base::Base; -}; - -//===----------------------------------------------------------------------===// -// RewriteType -//===----------------------------------------------------------------------===// - -struct RewriteTypeStorage : public TypeStorageBase<RewriteTypeStorage> {}; - -//===----------------------------------------------------------------------===// -// TupleType -//===----------------------------------------------------------------------===// - -struct TupleTypeStorage - : public TypeStorageBase<TupleTypeStorage, - std::pair<ArrayRef<Type>, ArrayRef<StringRef>>> { - using Base::Base; - - static TupleTypeStorage * - construct(StorageUniquer::StorageAllocator &alloc, - std::pair<ArrayRef<Type>, ArrayRef<StringRef>> key) { - SmallVector<StringRef> names = llvm::to_vector(llvm::map_range( - key.second, [&](StringRef name) { return alloc.copyInto(name); })); - return new (alloc.allocate<TupleTypeStorage>()) - TupleTypeStorage(std::make_pair(alloc.copyInto(key.first), - alloc.copyInto(llvm::ArrayRef(names)))); - } -}; - -//===----------------------------------------------------------------------===// -// TypeType -//===----------------------------------------------------------------------===// - -struct TypeTypeStorage : public TypeStorageBase<TypeTypeStorage> {}; - -//===----------------------------------------------------------------------===// -// ValueType -//===----------------------------------------------------------------------===// - -struct ValueTypeStorage : public TypeStorageBase<ValueTypeStorage> {}; - -} // namespace detail -} // namespace ast -} // namespace pdll -} // namespace mlir - -#endif // LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_ diff --git a/mlir/lib/Tools/PDLL/AST/Types.cpp b/mlir/lib/Tools/PDLL/AST/Types.cpp index 1468ac2..d5497b0 100644 --- a/mlir/lib/Tools/PDLL/AST/Types.cpp +++ b/mlir/lib/Tools/PDLL/AST/Types.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Tools/PDLL/AST/Types.h" -#include "TypeDetail.h" #include "mlir/Tools/PDLL/AST/Context.h" #include <optional> diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 9ef405d..018a188 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -681,17 +681,8 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer, return success(); } -std::pair<std::string, std::string> -mlir::registerAndParseCLIOptions(int argc, char **argv, - llvm::StringRef toolName, - DialectRegistry ®istry) { - static cl::opt<std::string> inputFilename( - cl::Positional, cl::desc("<input file>"), cl::init("-")); - - static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"), - cl::value_desc("filename"), - cl::init("-")); - // Register any command line options. +std::string mlir::registerCLIOptions(llvm::StringRef toolName, + DialectRegistry ®istry) { MlirOptMainConfig::registerCLOptions(registry); registerAsmPrinterCLOptions(); registerMLIRContextCLOptions(); @@ -706,11 +697,29 @@ mlir::registerAndParseCLIOptions(int argc, char **argv, interleaveComma(registry.getDialectNames(), os, [&](auto name) { os << name; }); } - // Parse pass names in main to ensure static initialization completed. + return helpHeader; +} + +std::pair<std::string, std::string> +mlir::parseCLIOptions(int argc, char **argv, llvm::StringRef helpHeader) { + static cl::opt<std::string> inputFilename( + cl::Positional, cl::desc("<input file>"), cl::init("-")); + + static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"), + cl::value_desc("filename"), + cl::init("-")); cl::ParseCommandLineOptions(argc, argv, helpHeader); return std::make_pair(inputFilename.getValue(), outputFilename.getValue()); } +std::pair<std::string, std::string> +mlir::registerAndParseCLIOptions(int argc, char **argv, + llvm::StringRef toolName, + DialectRegistry ®istry) { + auto helpHeader = registerCLIOptions(toolName, registry); + return parseCLIOptions(argc, argv, helpHeader); +} + static LogicalResult printRegisteredDialects(DialectRegistry ®istry) { llvm::outs() << "Available Dialects: "; interleave(registry.getDialectNames(), llvm::outs(), ","); diff --git a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp index 685e794..64e86f2 100644 --- a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp +++ b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp @@ -153,5 +153,12 @@ int mlir::MlirTblgenMain(int argc, char **argv) { cl::ParseCommandLineOptions(argc, argv); - return TableGenMain(argv[0], &mlirTableGenMain); + return TableGenMain( + argv[0], [](TableGenOutputFiles &OutFiles, const RecordKeeper &RK) { + std::string S; + raw_string_ostream OS(S); + bool Res = mlirTableGenMain(OS, RK); + OutFiles = {S, {}}; + return Res; + }); } diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 54b67f5..8907724 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -27,6 +27,7 @@ add_mlir_library(MLIRTransforms DEPENDS MLIRTransformsPassIncGen + MLIRTransformsDialectInterfaceIncGen LINK_LIBS PUBLIC MLIRAnalysis @@ -39,4 +40,5 @@ add_mlir_library(MLIRTransforms MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils + MLIRUBDialect ) diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 41f3f9d..e9ced064 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -33,6 +33,7 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/LivenessAnalysis.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" @@ -260,6 +261,22 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { + // Operations that have dead operands can be erased regardless of their + // side effects. The liveness analysis would not have marked an SSA value as + // "dead" if it had a side-effecting user that is reachable. + bool hasDeadOperand = + markLives(op->getOperands(), nonLiveSet, la).flip().any(); + if (hasDeadOperand) { + LDBG() << "Simple op has dead operands, so the op must be dead: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + assert(!hasLive(op->getResults(), nonLiveSet, la) && + "expected the op to have no live results"); + cl.operations.push_back(op); + collectNonLiveValues(nonLiveSet, op->getResults(), + BitVector(op->getNumResults(), true)); + return; + } + if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { LDBG() << "Simple op is not memory effect free or has live results, " "preserving it: " @@ -361,6 +378,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, // block other than the entry block, because every block has a terminator. for (Block &block : funcOp.getBlocks()) { Operation *returnOp = block.getTerminator(); + if (!returnOp->hasTrait<OpTrait::ReturnLike>()) + continue; if (returnOp && returnOp->getNumOperands() == numReturns) cl.operands.push_back({returnOp, nonLiveRets}); } @@ -700,7 +719,11 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, } /// Steps to process a `BranchOpInterface` operation: -/// Iterate through each successor block of `branchOp`. +/// +/// When a non-forwarded operand is dead (e.g., the condition value of a +/// conditional branch op), the entire operation is dead. +/// +/// Otherwise, iterate through each successor block of `branchOp`. /// (1) For each successor block, gather all operands from all successors. /// (2) Fetch their associated liveness analysis data and collect for future /// removal. @@ -711,7 +734,22 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { LDBG() << "Processing branch op: " << *branchOp; + + // Check for dead non-forwarded operands. + BitVector deadNonForwardedOperands = + markLives(branchOp->getOperands(), nonLiveSet, la).flip(); unsigned numSuccessors = branchOp->getNumSuccessors(); + for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { + SuccessorOperands successorOperands = + branchOp.getSuccessorOperands(succIdx); + // Remove all non-forwarded operands from the bit vector. + for (OpOperand &opOperand : successorOperands.getMutableForwardedOperands()) + deadNonForwardedOperands[opOperand.getOperandNumber()] = false; + } + if (deadNonForwardedOperands.any()) { + cl.operations.push_back(branchOp.getOperation()); + return; + } for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { Block *successorBlock = branchOp->getSuccessor(succIdx); @@ -742,23 +780,70 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, static void cleanUpDeadVals(RDVFinalCleanupList &list) { LDBG() << "Starting cleanup of dead values..."; - // 1. Operations + // 1. Blocks, We must remove the block arguments and successor operands before + // deleting the operation, as they may reside in the region operation. + LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists"; + for (auto &b : list.blocks) { + // blocks that are accessed via multiple codepaths processed once + if (b.b->getNumArguments() != b.nonLiveArgs.size()) + continue; + LDBG() << "Erasing " << b.nonLiveArgs.count() + << " non-live arguments from block: " << b.b; + // it iterates backwards because erase invalidates all successor indexes + for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { + if (!b.nonLiveArgs[i]) + continue; + LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i); + b.b->getArgument(i).dropAllUses(); + b.b->eraseArgument(i); + } + } + + // 2. Successor Operands + LDBG() << "Cleaning up " << list.successorOperands.size() + << " successor operand lists"; + for (auto &op : list.successorOperands) { + SuccessorOperands successorOperands = + op.branch.getSuccessorOperands(op.successorIndex); + // blocks that are accessed via multiple codepaths processed once + if (successorOperands.size() != op.nonLiveOperands.size()) + continue; + LDBG() << "Erasing " << op.nonLiveOperands.count() + << " non-live successor operands from successor " + << op.successorIndex << " of branch: " + << OpWithFlags(op.branch, OpPrintingFlags().skipRegions()); + // it iterates backwards because erase invalidates all successor indexes + for (int i = successorOperands.size() - 1; i >= 0; --i) { + if (!op.nonLiveOperands[i]) + continue; + LDBG() << " Erasing successor operand " << i << ": " + << successorOperands[i]; + successorOperands.erase(i); + } + } + + // 3. Operations LDBG() << "Cleaning up " << list.operations.size() << " operations"; - for (auto &op : list.operations) { + for (Operation *op : list.operations) { LDBG() << "Erasing operation: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); + if (op->hasTrait<OpTrait::IsTerminator>()) { + // When erasing a terminator, insert an unreachable op in its place. + OpBuilder b(op); + ub::UnreachableOp::create(b, op->getLoc()); + } op->dropAllUses(); op->erase(); } - // 2. Values + // 4. Values LDBG() << "Cleaning up " << list.values.size() << " values"; for (auto &v : list.values) { LDBG() << "Dropping all uses of value: " << v; v.dropAllUses(); } - // 3. Functions + // 5. Functions LDBG() << "Cleaning up " << list.functions.size() << " functions"; // Record which function arguments were erased so we can shrink call-site // argument segments for CallOpInterface operations (e.g. ops using @@ -780,7 +865,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { (void)f.funcOp.eraseResults(f.nonLiveRets); } - // 4. Operands + // 6. Operands LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { // Handle call-specific cleanup only when we have a cached callee reference. @@ -822,7 +907,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { } } - // 5. Results + // 7. Results LDBG() << "Cleaning up " << list.results.size() << " result lists"; for (auto &r : list.results) { LDBG() << "Erasing " << r.nonLive.count() @@ -830,48 +915,6 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { << OpWithFlags(r.op, OpPrintingFlags().skipRegions()); dropUsesAndEraseResults(r.op, r.nonLive); } - - // 6. Blocks - LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists"; - for (auto &b : list.blocks) { - // blocks that are accessed via multiple codepaths processed once - if (b.b->getNumArguments() != b.nonLiveArgs.size()) - continue; - LDBG() << "Erasing " << b.nonLiveArgs.count() - << " non-live arguments from block: " << b.b; - // it iterates backwards because erase invalidates all successor indexes - for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { - if (!b.nonLiveArgs[i]) - continue; - LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i); - b.b->getArgument(i).dropAllUses(); - b.b->eraseArgument(i); - } - } - - // 7. Successor Operands - LDBG() << "Cleaning up " << list.successorOperands.size() - << " successor operand lists"; - for (auto &op : list.successorOperands) { - SuccessorOperands successorOperands = - op.branch.getSuccessorOperands(op.successorIndex); - // blocks that are accessed via multiple codepaths processed once - if (successorOperands.size() != op.nonLiveOperands.size()) - continue; - LDBG() << "Erasing " << op.nonLiveOperands.count() - << " non-live successor operands from successor " - << op.successorIndex << " of branch: " - << OpWithFlags(op.branch, OpPrintingFlags().skipRegions()); - // it iterates backwards because erase invalidates all successor indexes - for (int i = successorOperands.size() - 1; i >= 0; --i) { - if (!op.nonLiveOperands[i]) - continue; - LDBG() << " Erasing successor operand " << i << ": " - << successorOperands[i]; - successorOperands.erase(i); - } - } - LDBG() << "Finished cleanup of dead values"; } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index f8c38fa..09ad423 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -25,6 +25,7 @@ #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" #include <optional> +#include <utility> using namespace mlir; using namespace mlir::detail; @@ -975,9 +976,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues); /// Replace the uses of the given value with the given values. The specified - /// converter is used to build materializations (if necessary). - void replaceAllUsesWith(Value from, ValueRange to, - const TypeConverter *converter); + /// converter is used to build materializations (if necessary). If `functor` + /// is specified, only the uses that the functor returns "true" for are + /// replaced. + void replaceValueUses(Value from, ValueRange to, + const TypeConverter *converter, + function_ref<bool(OpOperand &)> functor = nullptr); /// Erase the given block and its contents. void eraseBlock(Block *block); @@ -1051,7 +1055,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { MLIRContext *context, std::function<void(Operation *)> opErasedCallback = nullptr) : RewriterBase(context, /*listener=*/this), - opErasedCallback(opErasedCallback) {} + opErasedCallback(std::move(opErasedCallback)) {} /// Erase the given op (unless it was already erased). void eraseOp(Operation *op) override { @@ -1202,11 +1206,16 @@ void BlockTypeConversionRewrite::rollback() { } /// Replace all uses of `from` with `repl`. -static void performReplaceValue(RewriterBase &rewriter, Value from, - Value repl) { +static void +performReplaceValue(RewriterBase &rewriter, Value from, Value repl, + function_ref<bool(OpOperand &)> functor = nullptr) { if (isa<BlockArgument>(repl)) { // `repl` is a block argument. Directly replace all uses. - rewriter.replaceAllUsesWith(from, repl); + if (functor) { + rewriter.replaceUsesWithIf(from, repl, functor); + } else { + rewriter.replaceAllUsesWith(from, repl); + } return; } @@ -1237,7 +1246,11 @@ static void performReplaceValue(RewriterBase &rewriter, Value from, Block *replBlock = replOp->getBlock(); rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); - return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + bool result = + user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + if (result && functor) + result &= functor(operand); + return result; }); } @@ -1645,7 +1658,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /*outputTypes=*/origArgType, /*originalType=*/Type(), converter, /*isPureTypeConversion=*/false) .front(); - replaceAllUsesWith(origArg, mat, converter); + replaceValueUses(origArg, mat, converter); continue; } @@ -1654,14 +1667,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - replaceAllUsesWith(origArg, inputMap->replacementValues, converter); + replaceValueUses(origArg, inputMap->replacementValues, converter); continue; } // This is a 1->1+ mapping. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - replaceAllUsesWith(origArg, replArgs, converter); + replaceValueUses(origArg, replArgs, converter); } if (config.allowPatternRollback) @@ -1961,8 +1974,24 @@ void ConversionPatternRewriterImpl::replaceOp( op->walk([&](Operation *op) { replacedOps.insert(op); }); } -void ConversionPatternRewriterImpl::replaceAllUsesWith( - Value from, ValueRange to, const TypeConverter *converter) { +void ConversionPatternRewriterImpl::replaceValueUses( + Value from, ValueRange to, const TypeConverter *converter, + function_ref<bool(OpOperand &)> functor) { + LLVM_DEBUG({ + logger.startLine() << "** Replace Value : '" << from << "'"; + if (auto blockArg = dyn_cast<BlockArgument>(from)) { + if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { + logger.getOStream() << " (in region of '" << parentOp->getName() + << "' (" << parentOp << ")"; + } else { + logger.getOStream() << " (unlinked block)"; + } + } + if (functor) { + logger.getOStream() << ", conditional replacement"; + } + }); + if (!config.allowPatternRollback) { SmallVector<Value> toConv = llvm::to_vector(to); SmallVector<Value> repls = @@ -1972,7 +2001,7 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith( if (!repl) return; - performReplaceValue(r, from, repl); + performReplaceValue(r, from, repl, functor); return; } @@ -1991,6 +2020,9 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith( replacedValues.insert(from); #endif // NDEBUG + if (functor) + llvm::report_fatal_error( + "conditional value replacement is not supported in rollback mode"); mapping.map(from, to); appendRewrite<ReplaceValueRewrite>(from, converter); } @@ -2189,18 +2221,15 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( } void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) { - LLVM_DEBUG({ - impl->logger.startLine() << "** Replace Value : '" << from << "'"; - if (auto blockArg = dyn_cast<BlockArgument>(from)) { - if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { - impl->logger.getOStream() << " (in region of '" << parentOp->getName() - << "' (" << parentOp << ")\n"; - } else { - impl->logger.getOStream() << " (unlinked block)\n"; - } - } - }); - impl->replaceAllUsesWith(from, to, impl->currentTypeConverter); + impl->replaceValueUses(from, to, impl->currentTypeConverter); +} + +void ConversionPatternRewriter::replaceUsesWithIf( + Value from, ValueRange to, function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced) { + assert(!allUsesReplaced && + "allUsesReplaced is not supported in a dialect conversion"); + impl->replaceValueUses(from, to, impl->currentTypeConverter, functor); } Value ConversionPatternRewriter::getRemappedValue(Value key) { @@ -2765,7 +2794,7 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) { rewriterImpl.patternMaterializations.clear(); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Expensive pattern check that can detect API violations. - if (checkOp) { + if (checkOp && topLevelFingerPrint) { OperationFingerPrint fingerPrintAfterPattern(checkOp); if (fingerPrintAfterPattern != *topLevelFingerPrint) llvm::report_fatal_error("pattern '" + pattern.getDebugName() + @@ -2856,17 +2885,19 @@ LogicalResult OperationLegalizer::legalizePatternResult( assert(impl.pendingRootUpdates.empty() && "dangling root updates"); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - // Check that the root was either replaced or updated in place. - auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites); - auto replacedRoot = [&] { - return hasRewrite<ReplaceOperationRewrite>(newRewrites, op); - }; - auto updatedRootInPlace = [&] { - return hasRewrite<ModifyOperationRewrite>(newRewrites, op); - }; - if (!replacedRoot() && !updatedRootInPlace()) - llvm::report_fatal_error( - "expected pattern to replace the root operation or modify it in place"); + if (impl.config.allowPatternRollback) { + // Check that the root was either replaced or updated in place. + auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites); + auto replacedRoot = [&] { + return hasRewrite<ReplaceOperationRewrite>(newRewrites, op); + }; + auto updatedRootInPlace = [&] { + return hasRewrite<ModifyOperationRewrite>(newRewrites, op); + }; + if (!replacedRoot() && !updatedRootInPlace()) + llvm::report_fatal_error("expected pattern to replace the root operation " + "or modify it in place"); + } #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Legalize each of the actions registered during application. diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index 26c965c..4095031 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -613,8 +613,8 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, LLVM_DEBUG({ LDBG() << "* Inliner: Initial calls in SCC are: {"; - for (unsigned i = 0, e = calls.size(); i < e; ++i) - LDBG() << " " << i << ". " << calls[i].call << ","; + for (unsigned I = 0, E = calls.size(); I < E; ++I) + LDBG() << " " << I << ". " << calls[I].call << ","; LDBG() << "}"; }); diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 31ae1d1..330a2d3 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -1149,9 +1149,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, // Remove the values that already dominate the insertion point. SmallVector<Value> prunedValues; for (auto value : values) { - if (dominance.properlyDominates(value, insertionPoint)) { + if (dominance.properlyDominates(value, insertionPoint)) continue; - } // Block arguments are not supported. if (isa<BlockArgument>(value)) { return rewriter.notifyMatchFailure( @@ -1178,8 +1177,13 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, // Since current support is to only move within a same basic block, // the slices dont need to look past block arguments. options.omitBlockArguments = true; + bool dependsOnSideEffectingOp = false; options.filter = [&](Operation *sliceBoundaryOp) { - return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); + bool mustMove = + !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); + if (mustMove && !isPure(sliceBoundaryOp)) + dependsOnSideEffectingOp = true; + return mustMove; }; llvm::SetVector<Operation *> slice; for (auto value : prunedValues) { @@ -1188,6 +1192,10 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, (void)result; } + // Check if any operation in the slice is side-effecting. + if (dependsOnSideEffectingOp) + return failure(); + // If the slice contains `insertionPoint` cannot move the dependencies. if (slice.contains(insertionPoint)) { return rewriter.notifyMatchFailure( @@ -1198,9 +1206,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, // Sort operations topologically before moving. mlir::topologicalSort(slice); - for (Operation *op : slice) { + for (Operation *op : slice) rewriter.moveOpBefore(op, insertionPoint); - } return success(); } |
