aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Target
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Target')
-rw-r--r--mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp162
-rw-r--r--mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt53
-rw-r--r--mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt14
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp42
4 files changed, 243 insertions, 28 deletions
diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
index d6b8a8a..e3f075f 100644
--- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
+++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
@@ -54,6 +54,7 @@ struct OpStrings {
std::string opCppName;
SmallVector<std::string> opResultNames;
SmallVector<std::string> opOperandNames;
+ SmallVector<std::string> opRegionNames;
};
static std::string joinNameList(llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
/// Generates OpStrings from an OperatioOp
static OpStrings getStrings(irdl::OperationOp op) {
auto operandOp = op.getOp<irdl::OperandsOp>();
-
auto resultOp = op.getOp<irdl::ResultsOp>();
+ auto regionsOp = op.getOp<irdl::RegionsOp>();
OpStrings strings;
strings.opName = op.getSymName();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
}));
}
+ if (regionsOp) {
+ strings.opRegionNames = SmallVector<std::string>(
+ llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
+ return llvm::formatv("{0}", cast<StringAttr>(attr));
+ }));
+ }
+
return strings;
}
@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
const auto operandCount = strings.opOperandNames.size();
const auto resultCount = strings.opResultNames.size();
+ const auto regionCount = strings.opRegionNames.size();
dict["OP_NAME"] = strings.opName;
dict["OP_CPP_NAME"] = strings.opCppName;
@@ -131,6 +140,7 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
dict["OP_RESULT_INITIALIZER_LIST"] =
resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
+ dict["OP_REGION_COUNT"] = std::to_string(regionCount);
}
/// Fills a dictionary with values from DialectStrings
@@ -179,6 +189,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
const OpStrings &opStrings) {
auto opGetters = std::string{};
auto resGetters = std::string{};
+ auto regionGetters = std::string{};
+ auto regionAdaptorGetters = std::string{};
for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
const auto op =
@@ -196,8 +208,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
op, i);
}
+ for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
+ const auto op =
+ llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
+ regionAdaptorGetters += llvm::formatv(
+ R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; }
+ )",
+ op, i);
+ regionGetters += llvm::formatv(
+ R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
+ )",
+ op, i);
+ }
+
dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
dict["OP_RESULT_GETTER_DECLS"] = resGetters;
+ dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
+ dict["OP_REGION_GETTER_DECLS"] = regionGetters;
}
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
@@ -238,6 +265,22 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
dict["OP_BUILD_DECLS"] = buildDecls;
}
+// add traits to the dictionary, return true if any were added
+static SmallVector<std::string> generateTraits(irdl::OperationOp op,
+ const OpStrings &strings) {
+ SmallVector<std::string> cppTraitNames;
+ if (!strings.opRegionNames.empty()) {
+ cppTraitNames.push_back(
+ llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
+ strings.opRegionNames.size())
+ .str());
+
+ // Requires verifyInvariantsImpl is implemented on the op
+ cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
+ }
+ return cppTraitNames;
+}
+
static LogicalResult generateOperationInclude(irdl::OperationOp op,
raw_ostream &output,
irdl::detail::dictionary &dict) {
@@ -247,6 +290,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
const auto opStrings = getStrings(op);
fillDict(dict, opStrings);
+ SmallVector<std::string> traitNames = generateTraits(op, opStrings);
+ if (traitNames.empty())
+ dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
+ else
+ dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
+ llvm::join(traitNames, ", "));
+
generateOpGetterDeclarations(dict, opStrings);
generateOpBuilderDeclarations(dict, opStrings);
@@ -301,6 +351,110 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
return success();
}
+static void generateRegionConstraintVerifiers(
+ irdl::detail::dictionary &dict, irdl::OperationOp op,
+ const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers,
+ SmallVectorImpl<std::string> &verifierCalls) {
+ auto regionsOp = op.getOp<irdl::RegionsOp>();
+ if (strings.opRegionNames.empty() || !regionsOp)
+ return;
+
+ for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
+ std::string regionName = strings.opRegionNames[i];
+ std::string helperFnName =
+ llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
+ strings.opCppName, regionName)
+ .str();
+
+ // Extract the actual region constraint from the IRDL RegionOp
+ std::string condition = "true";
+ std::string textualConditionName = "any region";
+
+ if (auto regionDefOp =
+ dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
+ // Generate constraint condition based on RegionOp attributes
+ SmallVector<std::string> conditionParts;
+ SmallVector<std::string> descriptionParts;
+
+ // Check number of blocks constraint
+ if (auto blockCount = regionDefOp.getNumberOfBlocks()) {
+ conditionParts.push_back(
+ llvm::formatv("region.getBlocks().size() == {0}",
+ blockCount.value())
+ .str());
+ descriptionParts.push_back(
+ llvm::formatv("exactly {0} block(s)", blockCount.value()).str());
+ }
+
+ // Check entry block arguments constraint
+ if (regionDefOp.getConstrainedArguments()) {
+ size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
+ conditionParts.push_back(
+ llvm::formatv("region.getNumArguments() == {0}", expectedArgCount)
+ .str());
+ descriptionParts.push_back(
+ llvm::formatv("{0} entry block argument(s)", expectedArgCount)
+ .str());
+ }
+
+ // Combine conditions
+ if (!conditionParts.empty()) {
+ condition = llvm::join(conditionParts, " && ");
+ }
+
+ // Generate descriptive error message
+ if (!descriptionParts.empty()) {
+ textualConditionName =
+ llvm::formatv("region with {0}",
+ llvm::join(descriptionParts, " and "))
+ .str();
+ }
+ }
+
+ verifierHelpers.push_back(llvm::formatv(
+ R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName, unsigned regionIndex) {{
+ if (!({1})) {{
+ return op->emitOpError("region #") << regionIndex
+ << (regionName.empty() ? " " : " ('" + regionName + "') ")
+ << "failed to verify constraint: {2}";
+ }
+ return ::mlir::success();
+})",
+ helperFnName, condition, textualConditionName));
+
+ verifierCalls.push_back(llvm::formatv(R"(
+ if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1})))
+ return ::mlir::failure();)",
+ helperFnName, i, regionName)
+ .str());
+ }
+}
+
+static void generateVerifiers(irdl::detail::dictionary &dict,
+ irdl::OperationOp op, const OpStrings &strings) {
+ SmallVector<std::string> verifierHelpers;
+ SmallVector<std::string> verifierCalls;
+
+ generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers,
+ verifierCalls);
+
+ // Add an overall verifier that sequences the helper calls
+ std::string verifierDef =
+ llvm::formatv(R"(
+::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
+ if(::mlir::failed(verify()))
+ return ::mlir::failure();
+
+ {1}
+
+ return ::mlir::success();
+})",
+ strings.opCppName, llvm::join(verifierCalls, "\n"));
+
+ dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
+ dict["OP_VERIFIER"] = verifierDef;
+}
+
static std::string generateOpDefinition(irdl::detail::dictionary &dict,
irdl::OperationOp op) {
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +524,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
dict["OP_BUILD_DEFS"] = buildDefinition;
+ generateVerifiers(dict, op, opStrings);
+
std::string str;
llvm::raw_string_ostream stream{str};
perOpDefTemplate.render(stream, dict);
@@ -427,7 +583,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
dict["TYPE_PARSER"] = llvm::formatv(
R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
- {0}
+ {0}
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
*mnemonic = keyword;
return std::nullopt;
@@ -520,6 +676,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
"IRDL C++ translation does not yet support variadic results");
}))
.Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
+ .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
+ .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
.Default([](mlir::Operation *op) -> LogicalResult {
return op->emitError("IRDL C++ translation does not yet support "
"translation of ")
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
index e9068e9..93ce0be 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
@@ -12,15 +12,15 @@ public:
struct Properties {
};
public:
- __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
- : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
- odsRegions(op->getRegions())
+ __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
+ : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
+ odsRegions(op->getRegions())
{}
/// Return the unstructured operand index of a structured operand along with
// the amount of unstructured operands it contains.
std::pair<unsigned, unsigned>
- getStructuredOperandIndexAndLength (unsigned index,
+ getStructuredOperandIndexAndLength (unsigned index,
unsigned odsOperandsSize) {
return {index, 1};
}
@@ -32,6 +32,12 @@ public:
::mlir::DictionaryAttr getAttributes() {
return odsAttrs;
}
+
+ __OP_REGION_ADAPTER_GETTER_DECLS__
+
+ ::mlir::RegionRange getRegions() {
+ return odsRegions;
+ }
protected:
::mlir::DictionaryAttr odsAttrs;
::std::optional<::mlir::OperationName> odsOpName;
@@ -42,28 +48,28 @@ protected:
} // namespace detail
template <typename RangeT>
-class __OP_CPP_NAME__GenericAdaptor
+class __OP_CPP_NAME__GenericAdaptor
: public detail::__OP_CPP_NAME__GenericAdaptorBase {
using ValueT = ::llvm::detail::ValueOfRange<RangeT>;
using Base = detail::__OP_CPP_NAME__GenericAdaptorBase;
public:
__OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs,
- ::mlir::OpaqueProperties properties,
- ::mlir::RegionRange regions = {})
- : __OP_CPP_NAME__GenericAdaptor(values, attrs,
- (properties ? *properties.as<::mlir::EmptyProperties *>()
+ ::mlir::OpaqueProperties properties,
+ ::mlir::RegionRange regions = {})
+ : __OP_CPP_NAME__GenericAdaptor(values, attrs,
+ (properties ? *properties.as<::mlir::EmptyProperties *>()
: ::mlir::EmptyProperties{}), regions) {}
- __OP_CPP_NAME__GenericAdaptor(RangeT values,
+ __OP_CPP_NAME__GenericAdaptor(RangeT values,
const __OP_CPP_NAME__GenericAdaptorBase &base)
: Base(base), odsOperands(values) {}
- // This template parameter allows using __OP_CPP_NAME__ which is declared
+ // This template parameter allows using __OP_CPP_NAME__ which is declared
// later.
template <typename LateInst = __OP_CPP_NAME__,
typename = std::enable_if_t<
std::is_same_v<LateInst, __OP_CPP_NAME__>>>
- __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
+ __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
: Base(op), odsOperands(values) {}
/// Return the unstructured operand index of a structured operand along with
@@ -77,7 +83,7 @@ public:
RangeT getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(odsOperands.begin(), valueRange.first),
- std::next(odsOperands.begin(),
+ std::next(odsOperands.begin(),
valueRange.first + valueRange.second)};
}
@@ -91,7 +97,7 @@ private:
RangeT odsOperands;
};
-class __OP_CPP_NAME__Adaptor
+class __OP_CPP_NAME__Adaptor
: public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> {
public:
using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor;
@@ -100,7 +106,7 @@ public:
::llvm::LogicalResult verify(::mlir::Location loc);
};
-class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> {
+class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> {
public:
using Op::Op;
using Op::print;
@@ -112,6 +118,8 @@ public:
return {};
}
+ ::llvm::LogicalResult verifyInvariantsImpl();
+
static constexpr ::llvm::StringLiteral getOperationName() {
return ::llvm::StringLiteral("__DIALECT_NAME__.__OP_NAME__");
}
@@ -147,7 +155,7 @@ public:
::mlir::Operation::operand_range getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(getOperation()->operand_begin(), valueRange.first),
- std::next(getOperation()->operand_begin(),
+ std::next(getOperation()->operand_begin(),
valueRange.first + valueRange.second)};
}
@@ -162,18 +170,19 @@ public:
::mlir::Operation::result_range getStructuredResults(unsigned index) {
auto valueRange = getStructuredResultIndexAndLength(index);
return {std::next(getOperation()->result_begin(), valueRange.first),
- std::next(getOperation()->result_begin(),
+ std::next(getOperation()->result_begin(),
valueRange.first + valueRange.second)};
}
__OP_OPERAND_GETTER_DECLS__
__OP_RESULT_GETTER_DECLS__
-
+ __OP_REGION_GETTER_DECLS__
+
__OP_BUILD_DECLS__
- static void build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+ static void build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder,
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
index 30ca420..f4a1b7a 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
@@ -6,12 +6,14 @@ R"(
__NAMESPACE_OPEN__
+__OP_VERIFIER_HELPERS__
+
__OP_BUILD_DEFS__
-void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
{
assert(operands.size() == __OP_OPERAND_COUNT__);
@@ -19,6 +21,9 @@ void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
odsState.addOperands(operands);
odsState.addAttributes(attributes);
odsState.addTypes(resultTypes);
+ for (unsigned i = 0; i != __OP_REGION_COUNT__; ++i) {
+ (void)odsState.addRegion();
+ }
}
__OP_CPP_NAME__
@@ -44,6 +49,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder,
return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes);
}
+__OP_VERIFIER__
__NAMESPACE_CLOSE__
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 53209a4..9fcb02e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3175,6 +3175,45 @@ applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
return success();
}
+/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
+/// OpenMPIRBuilder.
+static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ llvm::OpenMPIRBuilder::LocationDescription loc(builder);
+
+ SmallVector<llvm::CanonicalLoopInfo *> translatedLoops;
+ SmallVector<llvm::Value *> translatedSizes;
+
+ for (Value size : op.getSizes()) {
+ llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
+ assert(translatedSize &&
+ "sizes clause arguments must already be translated");
+ translatedSizes.push_back(translatedSize);
+ }
+
+ for (Value applyee : op.getApplyees()) {
+ llvm::CanonicalLoopInfo *consBuilderCLI =
+ moduleTranslation.lookupOMPLoop(applyee);
+ assert(applyee && "Canonical loop must already been translated");
+ translatedLoops.push_back(consBuilderCLI);
+ }
+
+ auto generatedLoops =
+ ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
+ if (!op.getGeneratees().empty()) {
+ for (auto [mlirLoop, genLoop] :
+ zip_equal(op.getGeneratees(), generatedLoops))
+ moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
+ }
+
+ // CLIs can only be consumed once
+ for (Value applyee : op.getApplyees())
+ moduleTranslation.invalidateOmpLoop(applyee);
+
+ return success();
+}
+
/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static llvm::AtomicOrdering
convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
@@ -6227,6 +6266,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
// the omp.canonical_loop.
return applyUnrollHeuristic(op, builder, moduleTranslation);
})
+ .Case([&](omp::TileOp op) {
+ return applyTile(op, builder, moduleTranslation);
+ })
.Case([&](omp::TargetAllocMemOp) {
return convertTargetAllocMemOp(*op, builder, moduleTranslation);
})