aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2025-08-29 14:13:52 +0000
committerMatthias Springer <me@m-sp.org>2025-08-29 14:18:31 +0000
commit6477f75393aedd50463c89f883407a2060c7f490 (patch)
treeb7151241760a5b07c4b8655a7cd136f7bc688deb
parent0d9c0ced14827ea15f3cc9a981195258c049b9a6 (diff)
downloadllvm-users/matthias-springer/materialize_immediately.zip
llvm-users/matthias-springer/materialize_immediately.tar.gz
llvm-users/matthias-springer/materialize_immediately.tar.bz2
[mlir][Transforms] Add option to build materializations immediatelyusers/matthias-springer/materialize_immediately
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h28
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp122
-rw-r--r--mlir/test/Transforms/test-legalizer.mlir2
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp17
4 files changed, 113 insertions, 56 deletions
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 14dfbf1..e139372 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1346,15 +1346,25 @@ struct ConversionConfig {
// represented at the moment.
RewriterBase::Listener *listener = nullptr;
- /// If set to "true", the dialect conversion attempts to build source/target
- /// materializations through the type converter API in lieu of
- /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
- /// at least one materialization could not be built.
- ///
- /// If set to "false", the dialect conversion does not build any custom
- /// materializations and instead inserts "builtin.unrealized_conversion_cast"
- /// ops to ensure that the resulting IR is valid.
- bool buildMaterializations = true;
+ enum class MaterializationMode {
+ /// Never build materializations with the type converter. Instead, insert
+ /// "builtin.unrealized_conversion_cast" ops to ensure that the types of
+ /// the resulting IR are valid.
+ Never,
+ /// Build materializations with the type converter immediately. (If that
+ /// fails, insert "builtin.unrealized_conversion_cast" ops.)
+ Immediate,
+ /// Insert "builtin.unrealized_conversion_cast" ops first. At the end of
+ /// the conversion, replace them with materializations built with the type
+ /// converter. This can result in fewer materializations because
+ /// "builtin.unrealized_conversion_cast" ops that cancel each other out are
+ /// folded away.
+ Delayed
+ };
+
+ /// This option controls whether and when materializations should be built
+ /// through the type converter API.
+ MaterializationMode buildMaterializations = MaterializationMode::Delayed;
/// If set to "true", pattern rollback is allowed. The conversion driver
/// rolls back IR modifications in the following situations.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c0685f5..18adee7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -810,7 +810,7 @@ public:
void rollback() override;
UnrealizedConversionCastOp getOperation() const {
- return cast<UnrealizedConversionCastOp>(op);
+ return dyn_cast_or_null<UnrealizedConversionCastOp>(op);
}
private:
@@ -973,10 +973,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// "out of thin air" appear like unresolved materializations because they are
/// unrealized_conversion_cast ops. However, they must be treated like
/// regular value replacements.)
- ValueRange buildUnresolvedMaterialization(
- MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
- Type originalType, const TypeConverter *converter,
+ ValueVector buildUnresolvedMaterialization(
+ OpBuilder &builder, MaterializationKind kind, OpBuilder::InsertPoint ip,
+ Location loc, ValueVector valuesToMap, ValueRange inputs,
+ TypeRange outputTypes, Type originalType, const TypeConverter *converter,
bool isPureTypeConversion = true);
/// Find a replacement value for the given SSA value in the conversion value
@@ -984,7 +984,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// value. If there is no replacement value with the correct type, find the
/// latest replacement value (regardless of the type) and build a source
/// materialization.
- Value findOrBuildReplacementValue(Value value,
+ Value findOrBuildReplacementValue(OpBuilder &builder, Value value,
const TypeConverter *converter);
//===--------------------------------------------------------------------===//
@@ -1188,7 +1188,8 @@ static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
}
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+ Value repl =
+ rewriterImpl.findOrBuildReplacementValue(rewriter, arg, converter);
if (!repl)
return;
performReplaceBlockArg(rewriter, arg, repl);
@@ -1203,7 +1204,8 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
// Compute replacement values.
SmallVector<Value> replacements =
llvm::map_to_vector(op->getResults(), [&](OpResult result) {
- return rewriterImpl.findOrBuildReplacementValue(result, converter);
+ return rewriterImpl.findOrBuildReplacementValue(rewriter, result,
+ converter);
});
// Notify the listener that the operation is about to be replaced.
@@ -1251,8 +1253,10 @@ void CreateOperationRewrite::rollback() {
void UnresolvedMaterializationRewrite::rollback() {
if (!mappedValues.empty())
rewriterImpl.mapping.erase(mappedValues);
- rewriterImpl.unresolvedMaterializations.erase(getOperation());
- op->erase();
+ if (getOperation()) {
+ rewriterImpl.unresolvedMaterializations.erase(getOperation());
+ op->erase();
+ }
}
void ConversionPatternRewriterImpl::applyRewrites() {
@@ -1458,11 +1462,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// Create a materialization for the most recently mapped values.
repl = lookupOrDefault(operand, /*desiredTypes=*/{},
/*skipPureTypeConversions=*/true);
- ValueRange castValues = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
+ ValueVector castValues = buildUnresolvedMaterialization(
+ rewriter, MaterializationKind::Target, computeInsertPoint(repl),
+ operandLoc,
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
/*originalType=*/origType, currentTypeConverter);
- remapped.push_back(castValues);
+ remapped.push_back(std::move(castValues));
}
return success();
}
@@ -1577,7 +1582,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// Materialize a replacement value "out of thin air".
Value mat =
buildUnresolvedMaterialization(
- MaterializationKind::Source,
+ rewriter, MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
origArg.getLoc(),
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
@@ -1620,46 +1625,76 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
-ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
- MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
- Type originalType, const TypeConverter *converter,
+ValueVector ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+ OpBuilder &builder, MaterializationKind kind, OpBuilder::InsertPoint ip,
+ Location loc, ValueVector valuesToMap, ValueRange inputs,
+ TypeRange outputTypes, Type originalType, const TypeConverter *converter,
bool isPureTypeConversion) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
assert(TypeRange(inputs) != outputTypes &&
"materialization is not necessary");
+ ValueVector results;
+
+ // Build materializations with the type converter if requested.
+ if (converter && config.buildMaterializations ==
+ ConversionConfig::MaterializationMode::Immediate) {
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
+ if (kind == MaterializationKind::Source) {
+ assert(outputTypes.size() == 1 && "expected single output type");
+ Value mat = converter->materializeSourceConversion(
+ builder, loc, outputTypes.front(), inputs);
+ if (mat)
+ results.push_back(mat);
+ } else {
+ assert(kind == MaterializationKind::Target &&
+ "expected source or target materialization");
+ SmallVector<Value> mat = converter->materializeTargetConversion(
+ builder, loc, outputTypes, inputs);
+ if (!mat.empty())
+ llvm::append_range(results, mat);
+ }
+ }
- // Create an unresolved materialization. We use a new OpBuilder to avoid
- // tracking the materialization like we do for other operations.
- OpBuilder builder(outputTypes.front().getContext());
- builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
- UnrealizedConversionCastOp convertOp =
- UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
- if (config.attachDebugMaterializationKind) {
- StringRef kindStr =
- kind == MaterializationKind::Source ? "source" : "target";
- convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
+ // Otherwise, or if the type converter failed to build a materialization,
+ // insert an unrealized_conversion_cast op and try to resolve it later.
+ UnrealizedConversionCastOp castOp;
+ if (results.empty()) {
+ assert(results.empty() && "expected no results");
+ // Create an unresolved materialization. We use a new OpBuilder to avoid
+ // tracking the materialization like we do for other operations.
+ OpBuilder builder(outputTypes.front().getContext());
+ builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
+ castOp =
+ UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
+ unresolvedMaterializations[castOp] =
+ UnresolvedMaterializationInfo(converter, kind, originalType);
+ if (config.attachDebugMaterializationKind) {
+ StringRef kindStr =
+ kind == MaterializationKind::Source ? "source" : "target";
+ castOp->setAttr("__kind__", builder.getStringAttr(kindStr));
+ }
+ if (isPureTypeConversion)
+ castOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
+ llvm::append_range(results, castOp.getResults());
}
- if (isPureTypeConversion)
- convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
// Register the materialization.
- unresolvedMaterializations[convertOp] =
- UnresolvedMaterializationInfo(converter, kind, originalType);
if (config.allowPatternRollback) {
if (!valuesToMap.empty())
- mapping.map(valuesToMap, convertOp.getResults());
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+ mapping.map(valuesToMap, results);
+ appendRewrite<UnresolvedMaterializationRewrite>(castOp,
std::move(valuesToMap));
} else {
- patternMaterializations.insert(convertOp);
+ if (castOp)
+ patternMaterializations.insert(castOp);
}
- return convertOp.getResults();
+ return results;
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
- Value value, const TypeConverter *converter) {
+ OpBuilder &builder, Value value, const TypeConverter *converter) {
assert(config.allowPatternRollback &&
"this code path is valid only in rollback mode");
@@ -1700,7 +1735,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// materialization must be valid for all future users that may be created
// later in the conversion process.
Value castValue =
- buildUnresolvedMaterialization(MaterializationKind::Source,
+ buildUnresolvedMaterialization(rewriter, MaterializationKind::Source,
computeInsertPoint(repl), value.getLoc(),
/*valuesToMap=*/repl, /*inputs=*/repl,
/*outputTypes=*/value.getType(),
@@ -1779,7 +1814,8 @@ getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
// The replaced value is dropped. Materialize a replacement value "out of
// thin air".
Value srcMat = impl.buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
+ impl.rewriter, MaterializationKind::Source, computeInsertPoint(from),
+ from.getLoc(),
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
/*outputTypes=*/from.getType(), /*originalType=*/Type(),
converter)[0];
@@ -1799,7 +1835,8 @@ getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
// materializations if possible. This would require an extension of the
// `lookupOrDefault` API.
Value srcMat = impl.buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
+ impl.rewriter, MaterializationKind::Source, computeInsertPoint(to),
+ from.getLoc(),
/*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
/*originalType=*/Type(), converter)[0];
repls.push_back(srcMat);
@@ -1855,7 +1892,7 @@ void ConversionPatternRewriterImpl::replaceOp(
// This result was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(result),
+ rewriter, MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
currentTypeConverter, /*isPureTypeConversion=*/false);
@@ -3234,7 +3271,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
castOp->removeAttr(kPureTypeConversionMarker);
// Try to legalize all unresolved materializations.
- if (rewriter.getConfig().buildMaterializations) {
+ if (rewriter.getConfig().buildMaterializations !=
+ ConversionConfig::MaterializationMode::Never) {
// Use a new rewriter, so the modifications are not tracked for rollback
// purposes etc.
IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 3fa42ff..5ec12be 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=never attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B"
// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B"
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 95f381e..d6126de 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1660,11 +1660,20 @@ struct TestLegalizePatternDriver
llvm::cl::desc(
"Attach materialization kind to unrealized_conversion_cast ops"),
llvm::cl::init(false)};
- Option<bool> buildMaterializations{
+ Option<ConversionConfig::MaterializationMode> buildMaterializations{
*this, "build-materializations",
- llvm::cl::desc(
- "If set to 'false', leave unrealized_conversion_cast ops in place"),
- llvm::cl::init(true)};
+ llvm::cl::desc("When to build unresovled materializations."),
+ llvm::cl::init(ConversionConfig::MaterializationMode::Delayed),
+ llvm::cl::values(
+ clEnumValN(ConversionConfig::MaterializationMode::Never, "never",
+ "Never build materialization with the type converter."),
+ clEnumValN(ConversionConfig::MaterializationMode::Delayed, "delayed",
+ "Build materializations with the type converter at the "
+ "end of the conversion."),
+ clEnumValN(ConversionConfig::MaterializationMode::Immediate,
+ "immediate",
+ "Build materializations with the type converter "
+ "immediately."))};
};
} // namespace