diff options
Diffstat (limited to 'mlir/lib')
34 files changed, 809 insertions, 272 deletions
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 83a8757..32b2b0c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3219,13 +3219,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("end_line"), nb::arg("end_col"), nb::arg("context") = nb::none(), kContextGetFileRangeDocstring) .def("is_a_file", mlirLocationIsAFileLineColRange) - .def_prop_ro( - "filename", - [](MlirLocation loc) { - return mlirIdentifierStr( - mlirLocationFileLineColRangeGetFilename(loc)); - }, - nb::sig("def filename(self) -> str")) + .def_prop_ro("filename", + [](MlirLocation loc) { + return mlirIdentifierStr( + mlirLocationFileLineColRangeGetFilename(loc)); + }) .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine) .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn) .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine) @@ -3274,12 +3272,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("name"), nb::arg("childLoc") = nb::none(), nb::arg("context") = nb::none(), kContextGetNameLocationDocString) .def("is_a_name", mlirLocationIsAName) - .def_prop_ro( - "name_str", - [](MlirLocation loc) { - return mlirIdentifierStr(mlirLocationNameGetName(loc)); - }, - nb::sig("def name_str(self) -> str")) + .def_prop_ro("name_str", + [](MlirLocation loc) { + return mlirIdentifierStr(mlirLocationNameGetName(loc)); + }) .def_prop_ro("child_loc", [](PyLocation &self) { return PyLocation(self.getContext(), @@ -3453,15 +3449,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { return concreteOperation.getContext().getObject(); }, "Context that owns the Operation") - .def_prop_ro( - "name", - [](PyOperationBase &self) { - auto &concreteOperation = self.getOperation(); - concreteOperation.checkValid(); - MlirOperation operation = concreteOperation.get(); - return mlirIdentifierStr(mlirOperationGetName(operation)); - }, - nb::sig("def name(self) -> str")) + .def_prop_ro("name", + [](PyOperationBase &self) { + auto &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + MlirOperation operation = concreteOperation.get(); + return mlirIdentifierStr(mlirOperationGetName(operation)); + }) .def_prop_ro("operands", [](PyOperationBase &self) { return PyOpOperandList(self.getOperation().getRef()); @@ -3485,15 +3479,21 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") - .def_prop_ro( + .def_prop_rw( "location", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); return PyLocation(operation.getContext(), mlirOperationGetLocation(operation.get())); }, - "Returns the source location the operation was defined or derived " - "from.") + [](PyOperationBase &self, const PyLocation &location) { + PyOperation &operation = self.getOperation(); + mlirOperationSetLocation(operation.get(), location.get()); + }, + nb::for_getter("Returns the source location the operation was " + "defined or derived from."), + nb::for_setter("Sets the source location the operation was defined " + "or derived from.")) .def_prop_ro("parent", [](PyOperationBase &self) -> std::optional<nb::typed<nb::object, PyOperation>> { @@ -3597,12 +3597,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "Reports if the operation is attached to its parent block.") .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) - .def( - "walk", &PyOperationBase::walk, nb::arg("callback"), - nb::arg("walk_order") = MlirWalkPostOrder, - // clang-format off - nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = " MAKE_MLIR_PYTHON_QUALNAME("ir.WalkOrder.POST_ORDER") ") -> None") - // clang-format on + .def("walk", &PyOperationBase::walk, nb::arg("callback"), + nb::arg("walk_order") = MlirWalkPostOrder, + // clang-format off + nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None") + // clang-format on ); nb::class_<PyOperation, PyOperationBase>(m, "Operation") @@ -4118,7 +4117,6 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyNamedAttribute &self) { return mlirIdentifierStr(self.namedAttr.name); }, - nb::sig("def name(self) -> str"), "The name of the NamedAttribute binding") .def_prop_ro( "attr", @@ -4336,17 +4334,15 @@ void mlir::python::populateIRCore(nb::module_ &m) { kValueReplaceAllUsesWithDocstring) .def( "replace_all_uses_except", - [](MlirValue self, MlirValue with, PyOperation &exception) { + [](PyValue &self, PyValue &with, PyOperation &exception) { MlirOperation exceptedUser = exception.get(); mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); }, nb::arg("with_"), nb::arg("exceptions"), - nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: " - "Operation) -> None"), kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", - [](MlirValue self, MlirValue with, nb::list exceptions) { + [](PyValue &self, PyValue &with, const nb::list &exceptions) { // Convert Python list to a SmallVector of MlirOperations llvm::SmallVector<MlirOperation> exceptionOps; for (nb::handle exception : exceptions) { @@ -4358,8 +4354,6 @@ void mlir::python::populateIRCore(nb::module_ &m) { exceptionOps.data()); }, nb::arg("with_"), nb::arg("exceptions"), - nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: " - "Sequence[Operation]) -> None"), kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 598ae01..edbd73e 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -273,8 +273,7 @@ class DefaultingPyMlirContext : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { public: using Defaulting::Defaulting; - static constexpr const char kTypeDescription[] = - MAKE_MLIR_PYTHON_QUALNAME("ir.Context"); + static constexpr const char kTypeDescription[] = "Context"; static PyMlirContext &resolve(); }; @@ -500,8 +499,7 @@ class DefaultingPyLocation : public Defaulting<DefaultingPyLocation, PyLocation> { public: using Defaulting::Defaulting; - static constexpr const char kTypeDescription[] = - MAKE_MLIR_PYTHON_QUALNAME("ir.Location"); + static constexpr const char kTypeDescription[] = "Location"; static PyLocation &resolve(); operator MlirLocation() const { return *get(); } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 3488d92..34c5b8d 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -1010,7 +1010,7 @@ public: }, nb::arg("elements"), nb::arg("context") = nb::none(), // clang-format off - nb::sig("def get_tuple(elements: Sequence[Type], context: mlir.ir.Context | None = None) -> TupleType"), + nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"), // clang-format on "Create a tuple type"); c.def( @@ -1070,7 +1070,7 @@ public: }, nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(), // clang-format off - nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"), + nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"), // clang-format on "Gets a FunctionType from a list of input and result types"); c.def_prop_ro( diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 52656138..a14f09f 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -115,9 +115,6 @@ NB_MODULE(_mlir, m) { }); }, "typeid"_a, nb::kw_only(), "replace"_a = false, - // clang-format off - nb::sig("def register_type_caster(typeid: " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID") ", *, replace: bool = False) -> object"), - // clang-format on "Register a type caster for casting MLIR types to custom user types."); m.def( MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, @@ -130,9 +127,6 @@ NB_MODULE(_mlir, m) { }); }, "typeid"_a, nb::kw_only(), "replace"_a = false, - // clang-format off - nb::sig("def register_value_caster(typeid: " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID") ", *, replace: bool = False) -> object"), - // clang-format on "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index e9844a7..1881865 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -656,6 +656,10 @@ MlirLocation mlirOperationGetLocation(MlirOperation op) { return wrap(unwrap(op)->getLoc()); } +void mlirOperationSetLocation(MlirOperation op, MlirLocation loc) { + unwrap(op)->setLoc(unwrap(loc)); +} + MlirTypeID mlirOperationGetTypeID(MlirOperation op) { if (auto info = unwrap(op)->getRegisteredInfo()) return wrap(info->getTypeID()); diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 7cfd6d3..898d76c 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1282,6 +1282,13 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_OneFloat())) return getLhs(); + if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan | + arith::FastMathFlags::nsz)) { + // mulf(x, 0) -> 0 + if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat())) + return getRhs(); + } + return constFoldBinaryOp<FloatAttr>( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a * b; }); diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 7626d35..c64e10f5 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -123,7 +123,8 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality( vector::OuterProductOp, vector::ScanOp>( [&](Operation *op) { return converter.isLegal(op); }); target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp, - arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>(); + arith::ConstantOp, arith::SelectOp, vector::SplatOp, + vector::BroadcastOp>(); } void EmulateUnsupportedFloatsPass::runOnOperation() { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index a173cf1..5672942 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Support/InterleavedRange.h" #include <cstddef> #include <iterator> #include <optional> @@ -77,6 +78,232 @@ struct LLVMPointerPointerLikeModel }; } // namespace +/// Generate a name of a canonical loop nest of the format +/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region +/// argument index of an operation that has multiple regions, if the operation +/// has multiple regions. +/// `_s<idx>` identifies the position of an operation within a region, where +/// only operations that may potentially contain loops ("container operations" +/// i.e. have region arguments) are counted. Again, it is omitted if there is +/// only one such operation in a region. If there are canonical loops nested +/// inside each other, also may also use the format `_d<num>` where <num> is the +/// nesting depth of the loop. +/// +/// The generated name is a best-effort to make canonical loop unique within an +/// SSA namespace. This also means that regions with IsolatedFromAbove property +/// do not consider any parents or siblings. +static std::string generateLoopNestingName(StringRef prefix, + CanonicalLoopOp op) { + struct Component { + /// If true, this component describes a region operand of an operation (the + /// operand's owner) If false, this component describes an operation located + /// in a parent region + bool isRegionArgOfOp; + bool skip = false; + bool isUnique = false; + + size_t idx; + Operation *op; + Region *parentRegion; + size_t loopDepth; + + Operation *&getOwnerOp() { + assert(isRegionArgOfOp && "Must describe a region operand"); + return op; + } + size_t &getArgIdx() { + assert(isRegionArgOfOp && "Must describe a region operand"); + return idx; + } + + Operation *&getContainerOp() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return op; + } + size_t &getOpPos() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return idx; + } + bool isLoopOp() const { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return isa<CanonicalLoopOp>(op); + } + Region *&getParentRegion() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return parentRegion; + } + size_t &getLoopDepth() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return loopDepth; + } + + void skipIf(bool v = true) { skip = skip || v; } + }; + + // List of ancestors, from inner to outer. + // Alternates between + // * region argument of an operation + // * operation within a region + SmallVector<Component> components; + + // Gather a list of parent regions and operations, and the position within + // their parent + Operation *o = op.getOperation(); + while (o) { + // Operation within a region + Region *r = o->getParentRegion(); + if (!r) + break; + + llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front()); + size_t idx = 0; + bool found = false; + size_t sequentialIdx = -1; + bool isOnlyContainerOp = true; + for (Block *b : traversal) { + for (Operation &op : *b) { + if (&op == o && !found) { + sequentialIdx = idx; + found = true; + } + if (op.getNumRegions()) { + idx += 1; + if (idx > 1) + isOnlyContainerOp = false; + } + if (found && !isOnlyContainerOp) + break; + } + } + + Component &containerOpInRegion = components.emplace_back(); + containerOpInRegion.isRegionArgOfOp = false; + containerOpInRegion.isUnique = isOnlyContainerOp; + containerOpInRegion.getContainerOp() = o; + containerOpInRegion.getOpPos() = sequentialIdx; + containerOpInRegion.getParentRegion() = r; + + Operation *parent = r->getParentOp(); + + // Region argument of an operation + Component ®ionArgOfOperation = components.emplace_back(); + regionArgOfOperation.isRegionArgOfOp = true; + regionArgOfOperation.isUnique = true; + regionArgOfOperation.getArgIdx() = 0; + regionArgOfOperation.getOwnerOp() = parent; + + // The IsolatedFromAbove trait of the parent operation implies that each + // individual region argument has its own separate namespace, so no + // ambiguity. + if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>()) + break; + + // Component only needed if operation has multiple region operands. Region + // arguments may be optional, but we currently do not consider this. + if (parent->getRegions().size() > 1) { + auto getRegionIndex = [](Operation *o, Region *r) { + for (auto [idx, region] : llvm::enumerate(o->getRegions())) { + if (®ion == r) + return idx; + } + llvm_unreachable("Region not child of its parent operation"); + }; + regionArgOfOperation.isUnique = false; + regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r); + } + + // next parent + o = parent; + } + + // Determine whether a region-argument component is not needed + for (Component &c : components) + c.skipIf(c.isRegionArgOfOp && c.isUnique); + + // Find runs of nested loops and determine each loop's depth in the loop nest + size_t numSurroundingLoops = 0; + for (Component &c : llvm::reverse(components)) { + if (c.skip) + continue; + + // non-skipped multi-argument operands interrupt the loop nest + if (c.isRegionArgOfOp) { + numSurroundingLoops = 0; + continue; + } + + // Multiple loops in a region means each of them is the outermost loop of a + // new loop nest + if (!c.isUnique) + numSurroundingLoops = 0; + + c.getLoopDepth() = numSurroundingLoops; + + // Next loop is surrounded by one more loop + if (isa<CanonicalLoopOp>(c.getContainerOp())) + numSurroundingLoops += 1; + } + + // In loop nests, skip all but the innermost loop that contains the depth + // number + bool isLoopNest = false; + for (Component &c : components) { + if (c.skip || c.isRegionArgOfOp) + continue; + + if (!isLoopNest && c.getLoopDepth() >= 1) { + // Innermost loop of a loop nest of at least two loops + isLoopNest = true; + } else if (isLoopNest) { + // Non-innermost loop of a loop nest + c.skipIf(c.isUnique); + + // If there is no surrounding loop left, this must have been the outermost + // loop; leave loop-nest mode for the next iteration + if (c.getLoopDepth() == 0) + isLoopNest = false; + } + } + + // Skip non-loop unambiguous regions (but they should interrupt loop nests, so + // we mark them as skipped only after computing loop nests) + for (Component &c : components) + c.skipIf(!c.isRegionArgOfOp && c.isUnique && + !isa<CanonicalLoopOp>(c.getContainerOp())); + + // Components can be skipped if they are already disambiguated by their parent + // (or does not have a parent) + bool newRegion = true; + for (Component &c : llvm::reverse(components)) { + c.skipIf(newRegion && c.isUnique); + + // non-skipped components disambiguate unique children + if (!c.skip) + newRegion = true; + + // ...except canonical loops that need a suffix for each nest + if (!c.isRegionArgOfOp && c.getContainerOp()) + newRegion = false; + } + + // Compile the nesting name string + SmallString<64> Name{prefix}; + llvm::raw_svector_ostream NameOS(Name); + for (auto &c : llvm::reverse(components)) { + if (c.skip) + continue; + + if (c.isRegionArgOfOp) + NameOS << "_r" << c.getArgIdx(); + else if (c.getLoopDepth() >= 1) + NameOS << "_d" << c.getLoopDepth(); + else + NameOS << "_s" << c.getOpPos(); + } + + return NameOS.str().str(); +} + void OpenMPDialect::initialize() { addOperations< #define GET_OP_LIST @@ -3159,6 +3386,9 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { Value result = getResult(); auto [newCli, gen, cons] = decodeCli(result); + // Structured binding `gen` cannot be captured in lambdas before C++20 + OpOperand *generator = gen; + // Derive the CLI variable name from its generator: // * "canonloop" for omp.canonical_loop // * custom name for loop transformation generatees @@ -3172,71 +3402,29 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { cliName = TypeSwitch<Operation *, std::string>(gen->getOwner()) .Case([&](CanonicalLoopOp op) { - // Find the canonical loop nesting: For each ancestor add a - // "+_r<idx>" suffix (in reverse order) - SmallVector<std::string> components; - Operation *o = op.getOperation(); - while (o) { - if (o->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>()) - break; - - Region *r = o->getParentRegion(); - if (!r) - break; - - auto getSequentialIndex = [](Region *r, Operation *o) { - llvm::ReversePostOrderTraversal<Block *> traversal( - &r->getBlocks().front()); - size_t idx = 0; - for (Block *b : traversal) { - for (Operation &op : *b) { - if (&op == o) - return idx; - // Only consider operations that are containers as - // possible children - if (!op.getRegions().empty()) - idx += 1; - } - } - llvm_unreachable("Operation not part of the region"); - }; - size_t sequentialIdx = getSequentialIndex(r, o); - components.push_back(("s" + Twine(sequentialIdx)).str()); - - Operation *parent = r->getParentOp(); - if (!parent) - break; - - // If the operation has more than one region, also count in - // which of the regions - if (parent->getRegions().size() > 1) { - auto getRegionIndex = [](Operation *o, Region *r) { - for (auto [idx, region] : - llvm::enumerate(o->getRegions())) { - if (®ion == r) - return idx; - } - llvm_unreachable("Region not child its parent operation"); - }; - size_t regionIdx = getRegionIndex(parent, r); - components.push_back(("r" + Twine(regionIdx)).str()); - } - - // next parent - o = parent; - } - - SmallString<64> Name("canonloop"); - for (const std::string &s : reverse(components)) { - Name += '_'; - Name += s; - } - - return Name; + return generateLoopNestingName("canonloop", op); }) .Case([&](UnrollHeuristicOp op) -> std::string { llvm_unreachable("heuristic unrolling does not generate a loop"); }) + .Case([&](TileOp op) -> std::string { + auto [generateesFirst, generateesCount] = + op.getGenerateesODSOperandIndexAndLength(); + unsigned firstGrid = generateesFirst; + unsigned firstIntratile = generateesFirst + generateesCount / 2; + unsigned end = generateesFirst + generateesCount; + unsigned opnum = generator->getOperandNumber(); + // In the OpenMP apply and looprange clauses, indices are 1-based + if (firstGrid <= opnum && opnum < firstIntratile) { + unsigned gridnum = opnum - firstGrid + 1; + return ("grid" + Twine(gridnum)).str(); + } + if (firstIntratile <= opnum && opnum < end) { + unsigned intratilenum = opnum - firstIntratile + 1; + return ("intratile" + Twine(intratilenum)).str(); + } + llvm_unreachable("Unexpected generatee argument"); + }) .Default([&](Operation *op) { assert(false && "TODO: Custom name for this operation"); return "transformed"; @@ -3323,7 +3511,8 @@ void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) { void CanonicalLoopOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { - setNameFn(region.getArgument(0), "iv"); + std::string ivName = generateLoopNestingName("iv", *this); + setNameFn(region.getArgument(0), ivName); } void CanonicalLoopOp::print(OpAsmPrinter &p) { @@ -3465,6 +3654,138 @@ UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() { } //===----------------------------------------------------------------------===// +// TileOp +//===----------------------------------------------------------------------===// + +static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, + OperandRange generatees, + OperandRange applyees) { + if (!generatees.empty()) + p << '(' << llvm::interleaved(generatees) << ')'; + + if (!applyees.empty()) + p << " <- (" << llvm::interleaved(applyees) << ')'; +} + +static ParseResult parseLoopTransformClis( + OpAsmParser &parser, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &generateesOperands, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &applyeesOperands) { + if (parser.parseOptionalLess()) { + // Syntax 1: generatees present + + if (parser.parseOperandList(generateesOperands, + mlir::OpAsmParser::Delimiter::Paren)) + return failure(); + + if (parser.parseLess()) + return failure(); + } else { + // Syntax 2: generatees omitted + } + + // Parse `<-` (`<` has already been parsed) + if (parser.parseMinus()) + return failure(); + + if (parser.parseOperandList(applyeesOperands, + mlir::OpAsmParser::Delimiter::Paren)) + return failure(); + + return success(); +} + +LogicalResult TileOp::verify() { + if (getApplyees().empty()) + return emitOpError() << "must apply to at least one loop"; + + if (getSizes().size() != getApplyees().size()) + return emitOpError() << "there must be one tile size for each applyee"; + + if (!getGeneratees().empty() && + 2 * getSizes().size() != getGeneratees().size()) + return emitOpError() + << "expecting two times the number of generatees than applyees"; + + DenseSet<Value> parentIVs; + + Value parent = getApplyees().front(); + for (auto &&applyee : llvm::drop_begin(getApplyees())) { + auto [parentCreate, parentGen, parentCons] = decodeCli(parent); + auto [create, gen, cons] = decodeCli(applyee); + + if (!parentGen) + return emitOpError() << "applyee CLI has no generator"; + + auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner()); + if (!parentGen) + return emitOpError() + << "currently only supports omp.canonical_loop as applyee"; + + parentIVs.insert(parentLoop.getInductionVar()); + + if (!gen) + return emitOpError() << "applyee CLI has no generator"; + auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner()); + if (!loop) + return emitOpError() + << "currently only supports omp.canonical_loop as applyee"; + + // Canonical loop must be perfectly nested, i.e. the body of the parent must + // only contain the omp.canonical_loop of the nested loops, and + // omp.terminator + bool isPerfectlyNested = [&]() { + auto &parentBody = parentLoop.getRegion(); + if (!parentBody.hasOneBlock()) + return false; + auto &parentBlock = parentBody.getBlocks().front(); + + auto nestedLoopIt = parentBlock.begin(); + if (nestedLoopIt == parentBlock.end() || + (&*nestedLoopIt != loop.getOperation())) + return false; + + auto termIt = std::next(nestedLoopIt); + if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt)) + return false; + + if (std::next(termIt) != parentBlock.end()) + return false; + + return true; + }(); + if (!isPerfectlyNested) + return emitOpError() << "tiled loop nest must be perfectly nested"; + + if (parentIVs.contains(loop.getTripCount())) + return emitOpError() << "tiled loop nest must be rectangular"; + + parent = applyee; + } + + // TODO: The tile sizes must be computed before the loop, but checking this + // requires dominance analysis. For instance: + // + // %canonloop = omp.new_cli + // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { + // // write to %x + // omp.terminator + // } + // %ts = llvm.load %x + // omp.tile <- (%canonloop) sizes(%ts : i32) + + return success(); +} + +std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() { + return getODSOperandIndexAndLength(odsIndex_applyees); +} + +std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() { + return getODSOperandIndexAndLength(odsIndex_generatees); +} + +//===----------------------------------------------------------------------===// // Critical construct (2.17.1) //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 132ed81..3385b2a 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -616,11 +616,10 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( if (diag.succeeded()) { // Tracking failure is the only failure. return trackingFailure; - } else { - diag.attachNote() << "tracking listener also failed: " - << trackingFailure.getMessage(); - (void)trackingFailure.silence(); } + diag.attachNote() << "tracking listener also failed: " + << trackingFailure.getMessage(); + (void)trackingFailure.silence(); } if (!diag.succeeded()) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eb46869..b0132e8 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -580,7 +580,7 @@ namespace { // ElideSingleElementReduction for ReduceOp. struct ElideUnitDimsInMultiDimReduction : public OpRewritePattern<MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { @@ -730,7 +730,7 @@ std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() { namespace { struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ReductionOp reductionOp, PatternRewriter &rewriter) const override { @@ -2197,7 +2197,7 @@ namespace { // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -2220,7 +2220,7 @@ public: // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask. class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -2546,7 +2546,7 @@ rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(FromElementsOp fromElements, PatternRewriter &rewriter) const override { @@ -2938,7 +2938,7 @@ namespace { // Fold broadcast1(broadcast2(x)) into broadcast1(x). struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(BroadcastOp broadcastOp, PatternRewriter &rewriter) const override { @@ -3109,7 +3109,7 @@ namespace { // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector // to a broadcast. struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { @@ -3165,7 +3165,7 @@ static Value getScalarSplatSource(Value value) { /// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v). class ShuffleSplat final : public OpRewritePattern<ShuffleOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { @@ -3182,7 +3182,7 @@ public: /// vector.interleave. class ShuffleInterleave : public OpRewritePattern<ShuffleOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { @@ -3326,7 +3326,7 @@ namespace { // broadcast. class InsertToBroadcast final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp insertOp, PatternRewriter &rewriter) const override { @@ -3344,7 +3344,7 @@ public: /// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v). class InsertSplatToSplat final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { @@ -3380,7 +3380,7 @@ public: /// %result = vector.from_elements %c1, %c2 : vector<2xi32> class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { @@ -3748,7 +3748,7 @@ namespace { class FoldInsertStridedSliceSplat final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -3768,7 +3768,7 @@ public: class FoldInsertStridedSliceOfExtract final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -3798,7 +3798,7 @@ public: class InsertStridedSliceConstantFolder final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; // Do not create constants with more than `vectorSizeFoldThreashold` elements, // unless the source vector constant has a single use. @@ -4250,7 +4250,7 @@ namespace { // %mask = vector.create_mask %new_ub : vector<8xi1> class StridedSliceCreateMaskFolder final : public OpRewritePattern<ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; public: LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, @@ -4310,7 +4310,7 @@ public: class StridedSliceConstantMaskFolder final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, PatternRewriter &rewriter) const override { @@ -4365,7 +4365,7 @@ public: class StridedSliceBroadcast final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -4416,7 +4416,7 @@ public: /// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v). class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -4448,7 +4448,7 @@ public: class ContiguousExtractStridedSliceToExtract final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -5023,7 +5023,7 @@ namespace { /// ``` struct TransferReadAfterWriteToBroadcast : public OpRewritePattern<TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -5458,7 +5458,7 @@ namespace { /// any other uses. class FoldWaw final : public OpRewritePattern<TransferWriteOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransferWriteOp writeOp, PatternRewriter &rewriter) const override { if (!llvm::isa<RankedTensorType>(writeOp.getShapedType())) @@ -5514,7 +5514,7 @@ public: struct SwapExtractSliceOfTransferWrite : public OpRewritePattern<tensor::InsertSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -5737,7 +5737,7 @@ LogicalResult MaskedLoadOp::verify() { namespace { class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskedLoadOp load, PatternRewriter &rewriter) const override { switch (getMaskFormat(load.getMask())) { @@ -5794,7 +5794,7 @@ LogicalResult MaskedStoreOp::verify() { namespace { class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskedStoreOp store, PatternRewriter &rewriter) const override { switch (getMaskFormat(store.getMask())) { @@ -5890,7 +5890,7 @@ static LogicalResult isZeroBasedContiguousSeq(Value indexVec) { namespace { class GatherFolder final : public OpRewritePattern<GatherOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(GatherOp gather, PatternRewriter &rewriter) const override { switch (getMaskFormat(gather.getMask())) { @@ -5910,7 +5910,7 @@ public: /// maskedload. Only 1D fixed vectors are supported for now. class FoldContiguousGather final : public OpRewritePattern<GatherOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(GatherOp op, PatternRewriter &rewriter) const override { if (!isa<MemRefType>(op.getBase().getType())) @@ -5962,7 +5962,7 @@ LogicalResult ScatterOp::verify() { namespace { class ScatterFolder final : public OpRewritePattern<ScatterOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ScatterOp scatter, PatternRewriter &rewriter) const override { switch (getMaskFormat(scatter.getMask())) { @@ -5982,7 +5982,7 @@ public: /// maskedstore. Only 1D fixed vectors are supported for now. class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ScatterOp op, PatternRewriter &rewriter) const override { if (failed(isZeroBasedContiguousSeq(op.getIndices()))) @@ -6030,7 +6030,7 @@ LogicalResult ExpandLoadOp::verify() { namespace { class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExpandLoadOp expand, PatternRewriter &rewriter) const override { switch (getMaskFormat(expand.getMask())) { @@ -6081,7 +6081,7 @@ LogicalResult CompressStoreOp::verify() { namespace { class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(CompressStoreOp compress, PatternRewriter &rewriter) const override { switch (getMaskFormat(compress.getMask())) { @@ -6260,7 +6260,7 @@ static VectorType trimTrailingOneDims(VectorType oldType) { class ShapeCastCreateMaskFolderTrailingOneDim final : public OpRewritePattern<ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShapeCastOp shapeOp, PatternRewriter &rewriter) const override { @@ -6330,7 +6330,7 @@ public: /// If both (i) and (ii) are possible, (i) is chosen. class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { @@ -6614,7 +6614,7 @@ namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6646,7 +6646,7 @@ public: /// Replace transpose(splat-like(v)) with broadcast(v) class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6663,7 +6663,7 @@ public: /// Folds transpose(create_mask) into a new transposed create_mask. class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transpOp, PatternRewriter &rewriter) const override { @@ -6700,7 +6700,7 @@ public: /// Folds transpose(shape_cast) into a new shape_cast. class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6750,7 +6750,7 @@ public: /// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6). class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern<vector::TransposeOp>(context, benefit) {} @@ -6971,7 +6971,7 @@ namespace { /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { @@ -7300,7 +7300,7 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor, /// %0 = arith.select %mask, %a, %passthru : vector<8xf32> /// class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskOp maskOp, PatternRewriter &rewriter) const override { @@ -7410,7 +7410,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { // vector.broadcast. class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> { public: - using OpRewritePattern<SplatOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(SplatOp splatOp, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index dedc3b3..61d9357 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -34,7 +34,7 @@ namespace { /// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly. class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BroadcastOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 65702ff..efe8d14 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -1151,7 +1151,7 @@ FailureOr<Value> ContractionOpLowering::lowerReduction( /// class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 1f96a3a..6bc8347 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -50,7 +50,7 @@ namespace { /// /// Supports vector types with a fixed leading dimension. struct UnrollGather : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { @@ -98,7 +98,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> { /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef, /// but should be fairly straightforward to extend beyond that. struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { @@ -164,7 +164,7 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these /// loads/extracts are made conditional using `scf.if` ops. struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp index 9d6a865..479fc0c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp @@ -163,7 +163,7 @@ private: /// : vector<7xi16>, vector<7xi16> /// ``` struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InterleaveOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 5617b06..7730c4e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -48,7 +48,7 @@ namespace { /// until a one-dimensional vector is reached. class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { @@ -100,7 +100,7 @@ public: /// will be folded at LLVM IR level. class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ConstantMaskOp op, PatternRewriter &rewriter) const override { @@ -184,7 +184,7 @@ namespace { /// and actually match the traits of its the nested `MaskableOpInterface`. template <class SourceOp> struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { - using OpRewritePattern<MaskOp>::OpRewritePattern; + using Base::Base; private: LogicalResult matchAndRewrite(MaskOp maskOp, diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 4773732d..e86e2a9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -39,7 +39,7 @@ namespace { class InnerOuterDimReductionConversion : public OpRewritePattern<vector::MultiDimReductionOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; explicit InnerOuterDimReductionConversion( MLIRContext *context, vector::VectorMultiReductionLowering options, @@ -136,7 +136,7 @@ private: class ReduceMultiDimReductionRank : public OpRewritePattern<vector::MultiDimReductionOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; explicit ReduceMultiDimReductionRank( MLIRContext *context, vector::VectorMultiReductionLowering options, @@ -304,7 +304,7 @@ private: /// and combines results struct TwoDimMultiReductionToElementWise : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -359,7 +359,7 @@ struct TwoDimMultiReductionToElementWise /// a sequence of vector.reduction ops. struct TwoDimMultiReductionToReduction : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -420,7 +420,7 @@ struct TwoDimMultiReductionToReduction /// separately. struct OneDimMultiReductionToTwoDim : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index af4851e..258f2cb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -99,7 +99,7 @@ namespace { /// return %7, %8 : vector<2x3xi32>, vector<2xi32> /// ``` struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ScanOp scanOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index 603ea41..c5f22b2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -189,7 +189,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { } public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { @@ -356,7 +356,7 @@ public: class ScalableShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp index 78102f7..8f46ad6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp @@ -44,7 +44,7 @@ namespace { /// struct MixedSizeInputShuffleOpRewrite final : OpRewritePattern<vector::ShuffleOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp index ee5568a..08e7c89 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp @@ -24,7 +24,7 @@ using namespace mlir::vector; namespace { struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::StepOp stepOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index 6407a86..7521e24 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -667,7 +667,7 @@ getToElementsDefiningOps(FromElementsOp fromElemsOp, struct ToFromElementsToShuffleTreeRewrite final : OpRewritePattern<vector::FromElementsOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 9e7d0ce..c3f7de0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -300,7 +300,7 @@ namespace { /// %x = vector.insert .., .. [.., ..] class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering, MLIRContext *context, PatternBenefit benefit = 1) @@ -395,7 +395,7 @@ private: class Transpose2DWithUnitDimToShapeCast : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; Transpose2DWithUnitDimToShapeCast(MLIRContext *context, PatternBenefit benefit = 1) @@ -433,7 +433,7 @@ public: class TransposeOp2DToShuffleLowering : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; TransposeOp2DToShuffleLowering( vector::VectorTransposeLowering vectorTransposeLowering, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index cab1289..963b2c8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -54,7 +54,7 @@ namespace { // input by inserting vector.broadcast. struct CastAwayExtractStridedSliceLeadingOneDim : public OpRewritePattern<vector::ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { @@ -104,7 +104,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim // inputs by inserting vector.broadcast. struct CastAwayInsertStridedSliceLeadingOneDim : public OpRewritePattern<vector::InsertStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -145,7 +145,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim // Casts away leading one dimensions in vector.insert's vector inputs by // inserting vector.broadcast. struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InsertOp insertOp, PatternRewriter &rewriter) const override { @@ -221,7 +221,7 @@ static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, // 1 dimensions. struct CastAwayTransferReadLeadingOneDim : public OpRewritePattern<vector::TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { @@ -275,7 +275,7 @@ struct CastAwayTransferReadLeadingOneDim // 1 dimensions. struct CastAwayTransferWriteLeadingOneDim : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { @@ -541,7 +541,7 @@ public: // vector.broadcast back to the original shape. struct CastAwayConstantMaskLeadingOneDim : public OpRewritePattern<vector::ConstantMaskOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp index bdbb792..7acc120 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp @@ -48,7 +48,7 @@ namespace { /// struct VectorMaskedLoadOpConverter final : OpRewritePattern<vector::MaskedLoadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp, PatternRewriter &rewriter) const override { @@ -117,7 +117,7 @@ struct VectorMaskedLoadOpConverter final /// struct VectorMaskedStoreOpConverter final : OpRewritePattern<vector::MaskedStoreOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 264cbc1..3a6684f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -548,7 +548,7 @@ namespace { // NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to // `false` to generate non-atomic RMW sequences. struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW) : OpConversionPattern<vector::StoreOp>(context), @@ -827,7 +827,7 @@ private: /// adjusted mask . struct ConvertVectorMaskedStore final : OpConversionPattern<vector::MaskedStoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor, @@ -950,7 +950,7 @@ struct ConvertVectorMaskedStore final /// those cases, loads are converted to byte-aligned, byte-sized loads and the /// target vector is extracted from the loaded vector. struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor, @@ -1059,7 +1059,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> { /// bitcasting, since each `i8` container element holds two `i4` values. struct ConvertVectorMaskedLoad final : OpConversionPattern<vector::MaskedLoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor, @@ -1257,7 +1257,7 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy, // TODO: Document-me struct ConvertVectorTransferRead final : OpConversionPattern<vector::TransferReadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor, @@ -1942,7 +1942,7 @@ namespace { /// advantage of high-level information to avoid leaving LLVM to scramble with /// peephole optimizations. struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp, PatternRewriter &rewriter) const override { @@ -2147,7 +2147,7 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> { /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4> /// struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> { - using OpRewritePattern<arith::TruncIOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(arith::TruncIOp truncOp, PatternRewriter &rewriter) const override { @@ -2200,7 +2200,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> { /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4> /// struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; + using Base::Base; RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit) : OpRewritePattern<vector::TransposeOp>(context, benefit) {} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index f6d6555..9e49873 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -34,7 +34,7 @@ using namespace mlir::vector; class DecomposeDifferentRankInsertStridedSlice : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -84,7 +84,7 @@ public: class ConvertSameRankInsertStridedSliceIntoShuffle : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; + using Base::Base; void initialize() { // This pattern creates recursive InsertStridedSliceOp, but the recursion is @@ -183,7 +183,7 @@ public: class Convert1DExtractStridedSliceIntoShuffle : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -271,7 +271,7 @@ private: class DecomposeNDExtractStridedSlice : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; + using Base::Base; void initialize() { // This pattern creates recursive ExtractStridedSliceOp, but the recursion diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 82bac8c..71fba71c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -214,7 +214,7 @@ SmallVector<int64_t> static getStridedSliceInsertionIndices( /// vector.extract_strided_slice operation. struct LinearizeVectorExtractStridedSlice final : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -285,7 +285,7 @@ struct LinearizeVectorExtractStridedSlice final /// struct LinearizeVectorInsertStridedSlice final : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -348,7 +348,7 @@ struct LinearizeVectorInsertStridedSlice final /// of the original shuffle operation. struct LinearizeVectorShuffle final : public OpConversionPattern<vector::ShuffleOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorShuffle(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -423,7 +423,7 @@ struct LinearizeVectorShuffle final /// struct LinearizeVectorExtract final : public OpConversionPattern<vector::ExtractOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorExtract(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -501,7 +501,7 @@ struct LinearizeVectorExtract final /// struct LinearizeVectorInsert final : public OpConversionPattern<vector::InsertOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorInsert(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -575,7 +575,7 @@ struct LinearizeVectorInsert final /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> struct LinearizeVectorBitCast final : public OpConversionPattern<vector::BitCastOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorBitCast(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -598,7 +598,7 @@ struct LinearizeVectorBitCast final /// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> struct LinearizeVectorSplat final : public OpConversionPattern<vector::SplatOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -629,7 +629,7 @@ struct LinearizeVectorSplat final /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> struct LinearizeVectorCreateMask final : OpConversionPattern<vector::CreateMaskOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorCreateMask(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -684,7 +684,7 @@ struct LinearizeVectorCreateMask final /// For generic cases, the vector unroll pass should be used to unroll the load /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -731,7 +731,7 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> { /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorStore final : public OpConversionPattern<vector::StoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -778,7 +778,7 @@ struct LinearizeVectorStore final /// struct LinearizeVectorFromElements final : public OpConversionPattern<vector::FromElementsOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorFromElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -814,7 +814,7 @@ struct LinearizeVectorFromElements final /// struct LinearizeVectorToElements final : public OpConversionPattern<vector::ToElementsOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorToElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index c364a8b..1121d95 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -1081,7 +1081,7 @@ private: /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) /// to memref.store. class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 866f789..d6a6d7cd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -78,7 +78,7 @@ namespace { /// ``` struct MultiReduceToContract : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, PatternRewriter &rewriter) const override { @@ -138,7 +138,7 @@ struct MultiReduceToContract /// ``` struct CombineContractABTranspose final : public OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -202,7 +202,7 @@ struct CombineContractABTranspose final /// ``` struct CombineContractResultTranspose final : public OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp resTOp, PatternRewriter &rewriter) const override { @@ -568,7 +568,7 @@ static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) { // %2 = vector.extract %1[1] : f16 from vector<2xf16> struct BubbleDownVectorBitCastForExtract : public OpRewritePattern<vector::ExtractOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -643,7 +643,7 @@ struct BubbleDownVectorBitCastForExtract // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> struct BubbleDownBitCastForStridedSliceExtract : public OpRewritePattern<vector::ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { @@ -721,7 +721,7 @@ struct BubbleDownBitCastForStridedSliceExtract // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8> // struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override { @@ -794,7 +794,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> struct BubbleUpBitCastForStridedSliceInsert : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override { @@ -892,7 +892,7 @@ struct BubbleUpBitCastForStridedSliceInsert // %7 = vector.insert_strided_slice %6, %cst { // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; public: BreakDownVectorBitCast(MLIRContext *context, @@ -1131,7 +1131,7 @@ struct ReorderElementwiseOpsOnBroadcast final class ExtractOpFromElementwise final : public OpRewritePattern<vector::ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { @@ -1206,7 +1206,7 @@ static bool isSupportedMemSinkElementType(Type type) { /// ``` class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { @@ -1285,7 +1285,7 @@ public: class StoreOpFromSplatOrBroadcast final : public OpRewritePattern<vector::StoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::StoreOp op, PatternRewriter &rewriter) const override { @@ -1476,7 +1476,7 @@ static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) { /// InstCombine seems to handle vectors with multiple elements but not the /// single element ones. struct FoldI1Select : public OpRewritePattern<arith::SelectOp> { - using OpRewritePattern<arith::SelectOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(arith::SelectOp selectOp, PatternRewriter &rewriter) const override { @@ -1560,7 +1560,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { /// Drop inner most contiguous unit dimensions from transfer_read operand. class DropInnerMostUnitDimsTransferRead : public OpRewritePattern<vector::TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -1651,7 +1651,7 @@ class DropInnerMostUnitDimsTransferRead /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`). class DropInnerMostUnitDimsTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { @@ -1728,7 +1728,7 @@ class DropInnerMostUnitDimsTransferWrite /// with the RHS transposed) lowering. struct CanonicalizeContractMatmulToMMT final : OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; using FilterConstraintType = std::function<LogicalResult(vector::ContractionOp op)>; @@ -1845,7 +1845,7 @@ private: template <typename ExtOp> struct FoldArithExtIntoContractionOp : public OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -1878,7 +1878,7 @@ struct FoldArithExtIntoContractionOp /// %b = vector.reduction <add> %a, %acc /// ``` struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { @@ -2033,7 +2033,7 @@ struct DropUnitDimFromElementwiseOps final /// ``` struct DropUnitDimsFromTransposeOp final : OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { @@ -2110,7 +2110,7 @@ struct DropUnitDimsFromTransposeOp final /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> /// ``` struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override { @@ -2155,7 +2155,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> { /// %c = vector.reduction <add> %b, %acc /// ``` struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9413a92..784e5d6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -824,7 +824,7 @@ struct WgToSgStoreScatterOpWithOffset return failure(); xegpu::DistributeLayoutAttr layout = - xegpu::getDistributeLayoutAttr(op.getValue()); + xegpu::getDistributeLayoutAttr(op.getOperand(0)); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -844,12 +844,19 @@ struct WgToSgStoreScatterOpWithOffset auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); for (auto [val, offs, mask] : llvm::zip( adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { - xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs, - mask, chunkSizeAttr, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + auto store = xegpu::StoreScatterOp::create( + rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr, + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); // Update the layout attribute to drop sg_layout and sg_data. - if (auto newLayout = layout.dropSgLayoutAndData()) - op->setAttr("layout", newLayout); + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) { + for (OpOperand &operand : store->getOpOperands()) { + // Skip for operand one (memref) + if (operand.getOperandNumber() == 1) + continue; + xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); + } + } } rewriter.eraseOp(op); return success(); @@ -1247,10 +1254,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp<xegpu::StoreScatterOp>( [=](xegpu::StoreScatterOp op) -> bool { - // Check if the layout attribute is present on the result. - auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout"); - if (!layout) - return true; + auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0)); return isLegal(layout); }); diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp index d6b8a8a..e3f075f 100644 --- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp +++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp @@ -54,6 +54,7 @@ struct OpStrings { std::string opCppName; SmallVector<std::string> opResultNames; SmallVector<std::string> opOperandNames; + SmallVector<std::string> opRegionNames; }; static std::string joinNameList(llvm::ArrayRef<std::string> names) { @@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) { /// Generates OpStrings from an OperatioOp static OpStrings getStrings(irdl::OperationOp op) { auto operandOp = op.getOp<irdl::OperandsOp>(); - auto resultOp = op.getOp<irdl::ResultsOp>(); + auto regionsOp = op.getOp<irdl::RegionsOp>(); OpStrings strings; strings.opName = op.getSymName(); @@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) { })); } + if (regionsOp) { + strings.opRegionNames = SmallVector<std::string>( + llvm::map_range(regionsOp->getNames(), [](Attribute attr) { + return llvm::formatv("{0}", cast<StringAttr>(attr)); + })); + } + return strings; } @@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict, static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) { const auto operandCount = strings.opOperandNames.size(); const auto resultCount = strings.opResultNames.size(); + const auto regionCount = strings.opRegionNames.size(); dict["OP_NAME"] = strings.opName; dict["OP_CPP_NAME"] = strings.opCppName; @@ -131,6 +140,7 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) { operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}"; dict["OP_RESULT_INITIALIZER_LIST"] = resultCount ? joinNameList(strings.opResultNames) : "{\"\"}"; + dict["OP_REGION_COUNT"] = std::to_string(regionCount); } /// Fills a dictionary with values from DialectStrings @@ -179,6 +189,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings) { auto opGetters = std::string{}; auto resGetters = std::string{}; + auto regionGetters = std::string{}; + auto regionAdaptorGetters = std::string{}; for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) { const auto op = @@ -196,8 +208,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, op, i); } + for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) { + const auto op = + llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true); + regionAdaptorGetters += llvm::formatv( + R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; } + )", + op, i); + regionGetters += llvm::formatv( + R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); } + )", + op, i); + } + dict["OP_OPERAND_GETTER_DECLS"] = opGetters; dict["OP_RESULT_GETTER_DECLS"] = resGetters; + dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters; + dict["OP_REGION_GETTER_DECLS"] = regionGetters; } static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, @@ -238,6 +265,22 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, dict["OP_BUILD_DECLS"] = buildDecls; } +// add traits to the dictionary, return true if any were added +static SmallVector<std::string> generateTraits(irdl::OperationOp op, + const OpStrings &strings) { + SmallVector<std::string> cppTraitNames; + if (!strings.opRegionNames.empty()) { + cppTraitNames.push_back( + llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl", + strings.opRegionNames.size()) + .str()); + + // Requires verifyInvariantsImpl is implemented on the op + cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants"); + } + return cppTraitNames; +} + static LogicalResult generateOperationInclude(irdl::OperationOp op, raw_ostream &output, irdl::detail::dictionary &dict) { @@ -247,6 +290,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op, const auto opStrings = getStrings(op); fillDict(dict, opStrings); + SmallVector<std::string> traitNames = generateTraits(op, opStrings); + if (traitNames.empty()) + dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName; + else + dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName, + llvm::join(traitNames, ", ")); + generateOpGetterDeclarations(dict, opStrings); generateOpBuilderDeclarations(dict, opStrings); @@ -301,6 +351,110 @@ static LogicalResult generateInclude(irdl::DialectOp dialect, return success(); } +static void generateRegionConstraintVerifiers( + irdl::detail::dictionary &dict, irdl::OperationOp op, + const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers, + SmallVectorImpl<std::string> &verifierCalls) { + auto regionsOp = op.getOp<irdl::RegionsOp>(); + if (strings.opRegionNames.empty() || !regionsOp) + return; + + for (size_t i = 0; i < strings.opRegionNames.size(); ++i) { + std::string regionName = strings.opRegionNames[i]; + std::string helperFnName = + llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}", + strings.opCppName, regionName) + .str(); + + // Extract the actual region constraint from the IRDL RegionOp + std::string condition = "true"; + std::string textualConditionName = "any region"; + + if (auto regionDefOp = + dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) { + // Generate constraint condition based on RegionOp attributes + SmallVector<std::string> conditionParts; + SmallVector<std::string> descriptionParts; + + // Check number of blocks constraint + if (auto blockCount = regionDefOp.getNumberOfBlocks()) { + conditionParts.push_back( + llvm::formatv("region.getBlocks().size() == {0}", + blockCount.value()) + .str()); + descriptionParts.push_back( + llvm::formatv("exactly {0} block(s)", blockCount.value()).str()); + } + + // Check entry block arguments constraint + if (regionDefOp.getConstrainedArguments()) { + size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size(); + conditionParts.push_back( + llvm::formatv("region.getNumArguments() == {0}", expectedArgCount) + .str()); + descriptionParts.push_back( + llvm::formatv("{0} entry block argument(s)", expectedArgCount) + .str()); + } + + // Combine conditions + if (!conditionParts.empty()) { + condition = llvm::join(conditionParts, " && "); + } + + // Generate descriptive error message + if (!descriptionParts.empty()) { + textualConditionName = + llvm::formatv("region with {0}", + llvm::join(descriptionParts, " and ")) + .str(); + } + } + + verifierHelpers.push_back(llvm::formatv( + R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, unsigned regionIndex) {{ + if (!({1})) {{ + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: {2}"; + } + return ::mlir::success(); +})", + helperFnName, condition, textualConditionName)); + + verifierCalls.push_back(llvm::formatv(R"( + if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1}))) + return ::mlir::failure();)", + helperFnName, i, regionName) + .str()); + } +} + +static void generateVerifiers(irdl::detail::dictionary &dict, + irdl::OperationOp op, const OpStrings &strings) { + SmallVector<std::string> verifierHelpers; + SmallVector<std::string> verifierCalls; + + generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers, + verifierCalls); + + // Add an overall verifier that sequences the helper calls + std::string verifierDef = + llvm::formatv(R"( +::llvm::LogicalResult {0}::verifyInvariantsImpl() {{ + if(::mlir::failed(verify())) + return ::mlir::failure(); + + {1} + + return ::mlir::success(); +})", + strings.opCppName, llvm::join(verifierCalls, "\n")); + + dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n"); + dict["OP_VERIFIER"] = verifierDef; +} + static std::string generateOpDefinition(irdl::detail::dictionary &dict, irdl::OperationOp op) { static const auto perOpDefTemplate = mlir::irdl::detail::Template{ @@ -370,6 +524,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, { dict["OP_BUILD_DEFS"] = buildDefinition; + generateVerifiers(dict, op, opStrings); + std::string str; llvm::raw_string_ostream stream{str}; perOpDefTemplate.render(stream, dict); @@ -427,7 +583,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output, dict["TYPE_PARSER"] = llvm::formatv( R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) { return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser) - {0} + {0} .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{ *mnemonic = keyword; return std::nullopt; @@ -520,6 +676,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) { "IRDL C++ translation does not yet support variadic results"); })) .Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); })) + .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); })) + .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); })) .Default([](mlir::Operation *op) -> LogicalResult { return op->emitError("IRDL C++ translation does not yet support " "translation of ") diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt index e9068e9..93ce0be 100644 --- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt +++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt @@ -12,15 +12,15 @@ public: struct Properties { }; public: - __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op) - : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()), - odsRegions(op->getRegions()) + __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op) + : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()), + odsRegions(op->getRegions()) {} /// Return the unstructured operand index of a structured operand along with // the amount of unstructured operands it contains. std::pair<unsigned, unsigned> - getStructuredOperandIndexAndLength (unsigned index, + getStructuredOperandIndexAndLength (unsigned index, unsigned odsOperandsSize) { return {index, 1}; } @@ -32,6 +32,12 @@ public: ::mlir::DictionaryAttr getAttributes() { return odsAttrs; } + + __OP_REGION_ADAPTER_GETTER_DECLS__ + + ::mlir::RegionRange getRegions() { + return odsRegions; + } protected: ::mlir::DictionaryAttr odsAttrs; ::std::optional<::mlir::OperationName> odsOpName; @@ -42,28 +48,28 @@ protected: } // namespace detail template <typename RangeT> -class __OP_CPP_NAME__GenericAdaptor +class __OP_CPP_NAME__GenericAdaptor : public detail::__OP_CPP_NAME__GenericAdaptorBase { using ValueT = ::llvm::detail::ValueOfRange<RangeT>; using Base = detail::__OP_CPP_NAME__GenericAdaptorBase; public: __OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, - ::mlir::OpaqueProperties properties, - ::mlir::RegionRange regions = {}) - : __OP_CPP_NAME__GenericAdaptor(values, attrs, - (properties ? *properties.as<::mlir::EmptyProperties *>() + ::mlir::OpaqueProperties properties, + ::mlir::RegionRange regions = {}) + : __OP_CPP_NAME__GenericAdaptor(values, attrs, + (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} - __OP_CPP_NAME__GenericAdaptor(RangeT values, + __OP_CPP_NAME__GenericAdaptor(RangeT values, const __OP_CPP_NAME__GenericAdaptorBase &base) : Base(base), odsOperands(values) {} - // This template parameter allows using __OP_CPP_NAME__ which is declared + // This template parameter allows using __OP_CPP_NAME__ which is declared // later. template <typename LateInst = __OP_CPP_NAME__, typename = std::enable_if_t< std::is_same_v<LateInst, __OP_CPP_NAME__>>> - __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op) + __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} /// Return the unstructured operand index of a structured operand along with @@ -77,7 +83,7 @@ public: RangeT getStructuredOperands(unsigned index) { auto valueRange = getStructuredOperandIndexAndLength(index); return {std::next(odsOperands.begin(), valueRange.first), - std::next(odsOperands.begin(), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; } @@ -91,7 +97,7 @@ private: RangeT odsOperands; }; -class __OP_CPP_NAME__Adaptor +class __OP_CPP_NAME__Adaptor : public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> { public: using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor; @@ -100,7 +106,7 @@ public: ::llvm::LogicalResult verify(::mlir::Location loc); }; -class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> { +class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> { public: using Op::Op; using Op::print; @@ -112,6 +118,8 @@ public: return {}; } + ::llvm::LogicalResult verifyInvariantsImpl(); + static constexpr ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral("__DIALECT_NAME__.__OP_NAME__"); } @@ -147,7 +155,7 @@ public: ::mlir::Operation::operand_range getStructuredOperands(unsigned index) { auto valueRange = getStructuredOperandIndexAndLength(index); return {std::next(getOperation()->operand_begin(), valueRange.first), - std::next(getOperation()->operand_begin(), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; } @@ -162,18 +170,19 @@ public: ::mlir::Operation::result_range getStructuredResults(unsigned index) { auto valueRange = getStructuredResultIndexAndLength(index); return {std::next(getOperation()->result_begin(), valueRange.first), - std::next(getOperation()->result_begin(), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; } __OP_OPERAND_GETTER_DECLS__ __OP_RESULT_GETTER_DECLS__ - + __OP_REGION_GETTER_DECLS__ + __OP_BUILD_DECLS__ - static void build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, - ::mlir::TypeRange resultTypes, - ::mlir::ValueRange operands, + static void build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, + ::mlir::TypeRange resultTypes, + ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder, diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt index 30ca420..f4a1b7a 100644 --- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt +++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt @@ -6,12 +6,14 @@ R"( __NAMESPACE_OPEN__ +__OP_VERIFIER_HELPERS__ + __OP_BUILD_DEFS__ -void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, - ::mlir::TypeRange resultTypes, - ::mlir::ValueRange operands, +void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, + ::mlir::TypeRange resultTypes, + ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { assert(operands.size() == __OP_OPERAND_COUNT__); @@ -19,6 +21,9 @@ void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, odsState.addOperands(operands); odsState.addAttributes(attributes); odsState.addTypes(resultTypes); + for (unsigned i = 0; i != __OP_REGION_COUNT__; ++i) { + (void)odsState.addRegion(); + } } __OP_CPP_NAME__ @@ -44,6 +49,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder, return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes); } +__OP_VERIFIER__ __NAMESPACE_CLOSE__ diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 53209a4..9fcb02e 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3175,6 +3175,45 @@ applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, return success(); } +/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the +/// OpenMPIRBuilder. +static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::OpenMPIRBuilder::LocationDescription loc(builder); + + SmallVector<llvm::CanonicalLoopInfo *> translatedLoops; + SmallVector<llvm::Value *> translatedSizes; + + for (Value size : op.getSizes()) { + llvm::Value *translatedSize = moduleTranslation.lookupValue(size); + assert(translatedSize && + "sizes clause arguments must already be translated"); + translatedSizes.push_back(translatedSize); + } + + for (Value applyee : op.getApplyees()) { + llvm::CanonicalLoopInfo *consBuilderCLI = + moduleTranslation.lookupOMPLoop(applyee); + assert(applyee && "Canonical loop must already been translated"); + translatedLoops.push_back(consBuilderCLI); + } + + auto generatedLoops = + ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes); + if (!op.getGeneratees().empty()) { + for (auto [mlirLoop, genLoop] : + zip_equal(op.getGeneratees(), generatedLoops)) + moduleTranslation.mapOmpLoop(mlirLoop, genLoop); + } + + // CLIs can only be consumed once + for (Value applyee : op.getApplyees()) + moduleTranslation.invalidateOmpLoop(applyee); + + return success(); +} + /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering. static llvm::AtomicOrdering convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) { @@ -6227,6 +6266,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, // the omp.canonical_loop. return applyUnrollHeuristic(op, builder, moduleTranslation); }) + .Case([&](omp::TileOp op) { + return applyTile(op, builder, moduleTranslation); + }) .Case([&](omp::TargetAllocMemOp) { return convertTargetAllocMemOp(*op, builder, moduleTranslation); }) |