diff options
Diffstat (limited to 'mlir/lib/Target')
4 files changed, 243 insertions, 28 deletions
diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp index d6b8a8a..e3f075f 100644 --- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp +++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp @@ -54,6 +54,7 @@ struct OpStrings { std::string opCppName; SmallVector<std::string> opResultNames; SmallVector<std::string> opOperandNames; + SmallVector<std::string> opRegionNames; }; static std::string joinNameList(llvm::ArrayRef<std::string> names) { @@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) { /// Generates OpStrings from an OperatioOp static OpStrings getStrings(irdl::OperationOp op) { auto operandOp = op.getOp<irdl::OperandsOp>(); - auto resultOp = op.getOp<irdl::ResultsOp>(); + auto regionsOp = op.getOp<irdl::RegionsOp>(); OpStrings strings; strings.opName = op.getSymName(); @@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) { })); } + if (regionsOp) { + strings.opRegionNames = SmallVector<std::string>( + llvm::map_range(regionsOp->getNames(), [](Attribute attr) { + return llvm::formatv("{0}", cast<StringAttr>(attr)); + })); + } + return strings; } @@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict, static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) { const auto operandCount = strings.opOperandNames.size(); const auto resultCount = strings.opResultNames.size(); + const auto regionCount = strings.opRegionNames.size(); dict["OP_NAME"] = strings.opName; dict["OP_CPP_NAME"] = strings.opCppName; @@ -131,6 +140,7 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) { operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}"; dict["OP_RESULT_INITIALIZER_LIST"] = resultCount ? joinNameList(strings.opResultNames) : "{\"\"}"; + dict["OP_REGION_COUNT"] = std::to_string(regionCount); } /// Fills a dictionary with values from DialectStrings @@ -179,6 +189,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings) { auto opGetters = std::string{}; auto resGetters = std::string{}; + auto regionGetters = std::string{}; + auto regionAdaptorGetters = std::string{}; for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) { const auto op = @@ -196,8 +208,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, op, i); } + for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) { + const auto op = + llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true); + regionAdaptorGetters += llvm::formatv( + R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; } + )", + op, i); + regionGetters += llvm::formatv( + R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); } + )", + op, i); + } + dict["OP_OPERAND_GETTER_DECLS"] = opGetters; dict["OP_RESULT_GETTER_DECLS"] = resGetters; + dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters; + dict["OP_REGION_GETTER_DECLS"] = regionGetters; } static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, @@ -238,6 +265,22 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, dict["OP_BUILD_DECLS"] = buildDecls; } +// add traits to the dictionary, return true if any were added +static SmallVector<std::string> generateTraits(irdl::OperationOp op, + const OpStrings &strings) { + SmallVector<std::string> cppTraitNames; + if (!strings.opRegionNames.empty()) { + cppTraitNames.push_back( + llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl", + strings.opRegionNames.size()) + .str()); + + // Requires verifyInvariantsImpl is implemented on the op + cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants"); + } + return cppTraitNames; +} + static LogicalResult generateOperationInclude(irdl::OperationOp op, raw_ostream &output, irdl::detail::dictionary &dict) { @@ -247,6 +290,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op, const auto opStrings = getStrings(op); fillDict(dict, opStrings); + SmallVector<std::string> traitNames = generateTraits(op, opStrings); + if (traitNames.empty()) + dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName; + else + dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName, + llvm::join(traitNames, ", ")); + generateOpGetterDeclarations(dict, opStrings); generateOpBuilderDeclarations(dict, opStrings); @@ -301,6 +351,110 @@ static LogicalResult generateInclude(irdl::DialectOp dialect, return success(); } +static void generateRegionConstraintVerifiers( + irdl::detail::dictionary &dict, irdl::OperationOp op, + const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers, + SmallVectorImpl<std::string> &verifierCalls) { + auto regionsOp = op.getOp<irdl::RegionsOp>(); + if (strings.opRegionNames.empty() || !regionsOp) + return; + + for (size_t i = 0; i < strings.opRegionNames.size(); ++i) { + std::string regionName = strings.opRegionNames[i]; + std::string helperFnName = + llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}", + strings.opCppName, regionName) + .str(); + + // Extract the actual region constraint from the IRDL RegionOp + std::string condition = "true"; + std::string textualConditionName = "any region"; + + if (auto regionDefOp = + dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) { + // Generate constraint condition based on RegionOp attributes + SmallVector<std::string> conditionParts; + SmallVector<std::string> descriptionParts; + + // Check number of blocks constraint + if (auto blockCount = regionDefOp.getNumberOfBlocks()) { + conditionParts.push_back( + llvm::formatv("region.getBlocks().size() == {0}", + blockCount.value()) + .str()); + descriptionParts.push_back( + llvm::formatv("exactly {0} block(s)", blockCount.value()).str()); + } + + // Check entry block arguments constraint + if (regionDefOp.getConstrainedArguments()) { + size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size(); + conditionParts.push_back( + llvm::formatv("region.getNumArguments() == {0}", expectedArgCount) + .str()); + descriptionParts.push_back( + llvm::formatv("{0} entry block argument(s)", expectedArgCount) + .str()); + } + + // Combine conditions + if (!conditionParts.empty()) { + condition = llvm::join(conditionParts, " && "); + } + + // Generate descriptive error message + if (!descriptionParts.empty()) { + textualConditionName = + llvm::formatv("region with {0}", + llvm::join(descriptionParts, " and ")) + .str(); + } + } + + verifierHelpers.push_back(llvm::formatv( + R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, unsigned regionIndex) {{ + if (!({1})) {{ + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: {2}"; + } + return ::mlir::success(); +})", + helperFnName, condition, textualConditionName)); + + verifierCalls.push_back(llvm::formatv(R"( + if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1}))) + return ::mlir::failure();)", + helperFnName, i, regionName) + .str()); + } +} + +static void generateVerifiers(irdl::detail::dictionary &dict, + irdl::OperationOp op, const OpStrings &strings) { + SmallVector<std::string> verifierHelpers; + SmallVector<std::string> verifierCalls; + + generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers, + verifierCalls); + + // Add an overall verifier that sequences the helper calls + std::string verifierDef = + llvm::formatv(R"( +::llvm::LogicalResult {0}::verifyInvariantsImpl() {{ + if(::mlir::failed(verify())) + return ::mlir::failure(); + + {1} + + return ::mlir::success(); +})", + strings.opCppName, llvm::join(verifierCalls, "\n")); + + dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n"); + dict["OP_VERIFIER"] = verifierDef; +} + static std::string generateOpDefinition(irdl::detail::dictionary &dict, irdl::OperationOp op) { static const auto perOpDefTemplate = mlir::irdl::detail::Template{ @@ -370,6 +524,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, { dict["OP_BUILD_DEFS"] = buildDefinition; + generateVerifiers(dict, op, opStrings); + std::string str; llvm::raw_string_ostream stream{str}; perOpDefTemplate.render(stream, dict); @@ -427,7 +583,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output, dict["TYPE_PARSER"] = llvm::formatv( R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) { return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser) - {0} + {0} .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{ *mnemonic = keyword; return std::nullopt; @@ -520,6 +676,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) { "IRDL C++ translation does not yet support variadic results"); })) .Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); })) + .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); })) + .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); })) .Default([](mlir::Operation *op) -> LogicalResult { return op->emitError("IRDL C++ translation does not yet support " "translation of ") diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt index e9068e9..93ce0be 100644 --- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt +++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt @@ -12,15 +12,15 @@ public: struct Properties { }; public: - __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op) - : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()), - odsRegions(op->getRegions()) + __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op) + : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()), + odsRegions(op->getRegions()) {} /// Return the unstructured operand index of a structured operand along with // the amount of unstructured operands it contains. std::pair<unsigned, unsigned> - getStructuredOperandIndexAndLength (unsigned index, + getStructuredOperandIndexAndLength (unsigned index, unsigned odsOperandsSize) { return {index, 1}; } @@ -32,6 +32,12 @@ public: ::mlir::DictionaryAttr getAttributes() { return odsAttrs; } + + __OP_REGION_ADAPTER_GETTER_DECLS__ + + ::mlir::RegionRange getRegions() { + return odsRegions; + } protected: ::mlir::DictionaryAttr odsAttrs; ::std::optional<::mlir::OperationName> odsOpName; @@ -42,28 +48,28 @@ protected: } // namespace detail template <typename RangeT> -class __OP_CPP_NAME__GenericAdaptor +class __OP_CPP_NAME__GenericAdaptor : public detail::__OP_CPP_NAME__GenericAdaptorBase { using ValueT = ::llvm::detail::ValueOfRange<RangeT>; using Base = detail::__OP_CPP_NAME__GenericAdaptorBase; public: __OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, - ::mlir::OpaqueProperties properties, - ::mlir::RegionRange regions = {}) - : __OP_CPP_NAME__GenericAdaptor(values, attrs, - (properties ? *properties.as<::mlir::EmptyProperties *>() + ::mlir::OpaqueProperties properties, + ::mlir::RegionRange regions = {}) + : __OP_CPP_NAME__GenericAdaptor(values, attrs, + (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} - __OP_CPP_NAME__GenericAdaptor(RangeT values, + __OP_CPP_NAME__GenericAdaptor(RangeT values, const __OP_CPP_NAME__GenericAdaptorBase &base) : Base(base), odsOperands(values) {} - // This template parameter allows using __OP_CPP_NAME__ which is declared + // This template parameter allows using __OP_CPP_NAME__ which is declared // later. template <typename LateInst = __OP_CPP_NAME__, typename = std::enable_if_t< std::is_same_v<LateInst, __OP_CPP_NAME__>>> - __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op) + __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} /// Return the unstructured operand index of a structured operand along with @@ -77,7 +83,7 @@ public: RangeT getStructuredOperands(unsigned index) { auto valueRange = getStructuredOperandIndexAndLength(index); return {std::next(odsOperands.begin(), valueRange.first), - std::next(odsOperands.begin(), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; } @@ -91,7 +97,7 @@ private: RangeT odsOperands; }; -class __OP_CPP_NAME__Adaptor +class __OP_CPP_NAME__Adaptor : public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> { public: using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor; @@ -100,7 +106,7 @@ public: ::llvm::LogicalResult verify(::mlir::Location loc); }; -class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> { +class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> { public: using Op::Op; using Op::print; @@ -112,6 +118,8 @@ public: return {}; } + ::llvm::LogicalResult verifyInvariantsImpl(); + static constexpr ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral("__DIALECT_NAME__.__OP_NAME__"); } @@ -147,7 +155,7 @@ public: ::mlir::Operation::operand_range getStructuredOperands(unsigned index) { auto valueRange = getStructuredOperandIndexAndLength(index); return {std::next(getOperation()->operand_begin(), valueRange.first), - std::next(getOperation()->operand_begin(), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; } @@ -162,18 +170,19 @@ public: ::mlir::Operation::result_range getStructuredResults(unsigned index) { auto valueRange = getStructuredResultIndexAndLength(index); return {std::next(getOperation()->result_begin(), valueRange.first), - std::next(getOperation()->result_begin(), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; } __OP_OPERAND_GETTER_DECLS__ __OP_RESULT_GETTER_DECLS__ - + __OP_REGION_GETTER_DECLS__ + __OP_BUILD_DECLS__ - static void build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, - ::mlir::TypeRange resultTypes, - ::mlir::ValueRange operands, + static void build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, + ::mlir::TypeRange resultTypes, + ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder, diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt index 30ca420..f4a1b7a 100644 --- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt +++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt @@ -6,12 +6,14 @@ R"( __NAMESPACE_OPEN__ +__OP_VERIFIER_HELPERS__ + __OP_BUILD_DEFS__ -void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, - ::mlir::TypeRange resultTypes, - ::mlir::ValueRange operands, +void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, + ::mlir::TypeRange resultTypes, + ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { assert(operands.size() == __OP_OPERAND_COUNT__); @@ -19,6 +21,9 @@ void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, odsState.addOperands(operands); odsState.addAttributes(attributes); odsState.addTypes(resultTypes); + for (unsigned i = 0; i != __OP_REGION_COUNT__; ++i) { + (void)odsState.addRegion(); + } } __OP_CPP_NAME__ @@ -44,6 +49,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder, return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes); } +__OP_VERIFIER__ __NAMESPACE_CLOSE__ diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 53209a4..9fcb02e 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3175,6 +3175,45 @@ applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, return success(); } +/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the +/// OpenMPIRBuilder. +static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::OpenMPIRBuilder::LocationDescription loc(builder); + + SmallVector<llvm::CanonicalLoopInfo *> translatedLoops; + SmallVector<llvm::Value *> translatedSizes; + + for (Value size : op.getSizes()) { + llvm::Value *translatedSize = moduleTranslation.lookupValue(size); + assert(translatedSize && + "sizes clause arguments must already be translated"); + translatedSizes.push_back(translatedSize); + } + + for (Value applyee : op.getApplyees()) { + llvm::CanonicalLoopInfo *consBuilderCLI = + moduleTranslation.lookupOMPLoop(applyee); + assert(applyee && "Canonical loop must already been translated"); + translatedLoops.push_back(consBuilderCLI); + } + + auto generatedLoops = + ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes); + if (!op.getGeneratees().empty()) { + for (auto [mlirLoop, genLoop] : + zip_equal(op.getGeneratees(), generatedLoops)) + moduleTranslation.mapOmpLoop(mlirLoop, genLoop); + } + + // CLIs can only be consumed once + for (Value applyee : op.getApplyees()) + moduleTranslation.invalidateOmpLoop(applyee); + + return success(); +} + /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering. static llvm::AtomicOrdering convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) { @@ -6227,6 +6266,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, // the omp.canonical_loop. return applyUnrollHeuristic(op, builder, moduleTranslation); }) + .Case([&](omp::TileOp op) { + return applyTile(op, builder, moduleTranslation); + }) .Case([&](omp::TargetAllocMemOp) { return convertTargetAllocMemOp(*op, builder, moduleTranslation); }) |