diff options
author | Matthias Springer <me@m-sp.org> | 2025-08-29 14:13:52 +0000 |
---|---|---|
committer | Matthias Springer <me@m-sp.org> | 2025-08-29 14:18:31 +0000 |
commit | 6477f75393aedd50463c89f883407a2060c7f490 (patch) | |
tree | b7151241760a5b07c4b8655a7cd136f7bc688deb | |
parent | 0d9c0ced14827ea15f3cc9a981195258c049b9a6 (diff) | |
download | llvm-users/matthias-springer/materialize_immediately.zip llvm-users/matthias-springer/materialize_immediately.tar.gz llvm-users/matthias-springer/materialize_immediately.tar.bz2 |
[mlir][Transforms] Add option to build materializations immediatelyusers/matthias-springer/materialize_immediately
-rw-r--r-- | mlir/include/mlir/Transforms/DialectConversion.h | 28 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 122 | ||||
-rw-r--r-- | mlir/test/Transforms/test-legalizer.mlir | 2 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestPatterns.cpp | 17 |
4 files changed, 113 insertions, 56 deletions
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 14dfbf1..e139372 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -1346,15 +1346,25 @@ struct ConversionConfig { // represented at the moment. RewriterBase::Listener *listener = nullptr; - /// If set to "true", the dialect conversion attempts to build source/target - /// materializations through the type converter API in lieu of - /// "builtin.unrealized_conversion_cast ops". The conversion process fails if - /// at least one materialization could not be built. - /// - /// If set to "false", the dialect conversion does not build any custom - /// materializations and instead inserts "builtin.unrealized_conversion_cast" - /// ops to ensure that the resulting IR is valid. - bool buildMaterializations = true; + enum class MaterializationMode { + /// Never build materializations with the type converter. Instead, insert + /// "builtin.unrealized_conversion_cast" ops to ensure that the types of + /// the resulting IR are valid. + Never, + /// Build materializations with the type converter immediately. (If that + /// fails, insert "builtin.unrealized_conversion_cast" ops.) + Immediate, + /// Insert "builtin.unrealized_conversion_cast" ops first. At the end of + /// the conversion, replace them with materializations built with the type + /// converter. This can result in fewer materializations because + /// "builtin.unrealized_conversion_cast" ops that cancel each other out are + /// folded away. + Delayed + }; + + /// This option controls whether and when materializations should be built + /// through the type converter API. + MaterializationMode buildMaterializations = MaterializationMode::Delayed; /// If set to "true", pattern rollback is allowed. The conversion driver /// rolls back IR modifications in the following situations. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c0685f5..18adee7 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -810,7 +810,7 @@ public: void rollback() override; UnrealizedConversionCastOp getOperation() const { - return cast<UnrealizedConversionCastOp>(op); + return dyn_cast_or_null<UnrealizedConversionCastOp>(op); } private: @@ -973,10 +973,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// "out of thin air" appear like unresolved materializations because they are /// unrealized_conversion_cast ops. However, they must be treated like /// regular value replacements.) - ValueRange buildUnresolvedMaterialization( - MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, - Type originalType, const TypeConverter *converter, + ValueVector buildUnresolvedMaterialization( + OpBuilder &builder, MaterializationKind kind, OpBuilder::InsertPoint ip, + Location loc, ValueVector valuesToMap, ValueRange inputs, + TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion = true); /// Find a replacement value for the given SSA value in the conversion value @@ -984,7 +984,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// value. If there is no replacement value with the correct type, find the /// latest replacement value (regardless of the type) and build a source /// materialization. - Value findOrBuildReplacementValue(Value value, + Value findOrBuildReplacementValue(OpBuilder &builder, Value value, const TypeConverter *converter); //===--------------------------------------------------------------------===// @@ -1188,7 +1188,8 @@ static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg, } void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); + Value repl = + rewriterImpl.findOrBuildReplacementValue(rewriter, arg, converter); if (!repl) return; performReplaceBlockArg(rewriter, arg, repl); @@ -1203,7 +1204,8 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { // Compute replacement values. SmallVector<Value> replacements = llvm::map_to_vector(op->getResults(), [&](OpResult result) { - return rewriterImpl.findOrBuildReplacementValue(result, converter); + return rewriterImpl.findOrBuildReplacementValue(rewriter, result, + converter); }); // Notify the listener that the operation is about to be replaced. @@ -1251,8 +1253,10 @@ void CreateOperationRewrite::rollback() { void UnresolvedMaterializationRewrite::rollback() { if (!mappedValues.empty()) rewriterImpl.mapping.erase(mappedValues); - rewriterImpl.unresolvedMaterializations.erase(getOperation()); - op->erase(); + if (getOperation()) { + rewriterImpl.unresolvedMaterializations.erase(getOperation()); + op->erase(); + } } void ConversionPatternRewriterImpl::applyRewrites() { @@ -1458,11 +1462,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // Create a materialization for the most recently mapped values. repl = lookupOrDefault(operand, /*desiredTypes=*/{}, /*skipPureTypeConversions=*/true); - ValueRange castValues = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(repl), operandLoc, + ValueVector castValues = buildUnresolvedMaterialization( + rewriter, MaterializationKind::Target, computeInsertPoint(repl), + operandLoc, /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, /*originalType=*/origType, currentTypeConverter); - remapped.push_back(castValues); + remapped.push_back(std::move(castValues)); } return success(); } @@ -1577,7 +1582,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // Materialize a replacement value "out of thin air". Value mat = buildUnresolvedMaterialization( - MaterializationKind::Source, + rewriter, MaterializationKind::Source, OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), /*valuesToMap=*/{}, /*inputs=*/ValueRange(), @@ -1620,46 +1625,76 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /// Build an unresolved materialization operation given an output type and set /// of input operands. -ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( - MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, - Type originalType, const TypeConverter *converter, +ValueVector ConversionPatternRewriterImpl::buildUnresolvedMaterialization( + OpBuilder &builder, MaterializationKind kind, OpBuilder::InsertPoint ip, + Location loc, ValueVector valuesToMap, ValueRange inputs, + TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); assert(TypeRange(inputs) != outputTypes && "materialization is not necessary"); + ValueVector results; + + // Build materializations with the type converter if requested. + if (converter && config.buildMaterializations == + ConversionConfig::MaterializationMode::Immediate) { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); + if (kind == MaterializationKind::Source) { + assert(outputTypes.size() == 1 && "expected single output type"); + Value mat = converter->materializeSourceConversion( + builder, loc, outputTypes.front(), inputs); + if (mat) + results.push_back(mat); + } else { + assert(kind == MaterializationKind::Target && + "expected source or target materialization"); + SmallVector<Value> mat = converter->materializeTargetConversion( + builder, loc, outputTypes, inputs); + if (!mat.empty()) + llvm::append_range(results, mat); + } + } - // Create an unresolved materialization. We use a new OpBuilder to avoid - // tracking the materialization like we do for other operations. - OpBuilder builder(outputTypes.front().getContext()); - builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); - UnrealizedConversionCastOp convertOp = - UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs); - if (config.attachDebugMaterializationKind) { - StringRef kindStr = - kind == MaterializationKind::Source ? "source" : "target"; - convertOp->setAttr("__kind__", builder.getStringAttr(kindStr)); + // Otherwise, or if the type converter failed to build a materialization, + // insert an unrealized_conversion_cast op and try to resolve it later. + UnrealizedConversionCastOp castOp; + if (results.empty()) { + assert(results.empty() && "expected no results"); + // Create an unresolved materialization. We use a new OpBuilder to avoid + // tracking the materialization like we do for other operations. + OpBuilder builder(outputTypes.front().getContext()); + builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); + castOp = + UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs); + unresolvedMaterializations[castOp] = + UnresolvedMaterializationInfo(converter, kind, originalType); + if (config.attachDebugMaterializationKind) { + StringRef kindStr = + kind == MaterializationKind::Source ? "source" : "target"; + castOp->setAttr("__kind__", builder.getStringAttr(kindStr)); + } + if (isPureTypeConversion) + castOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr()); + llvm::append_range(results, castOp.getResults()); } - if (isPureTypeConversion) - convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr()); // Register the materialization. - unresolvedMaterializations[convertOp] = - UnresolvedMaterializationInfo(converter, kind, originalType); if (config.allowPatternRollback) { if (!valuesToMap.empty()) - mapping.map(valuesToMap, convertOp.getResults()); - appendRewrite<UnresolvedMaterializationRewrite>(convertOp, + mapping.map(valuesToMap, results); + appendRewrite<UnresolvedMaterializationRewrite>(castOp, std::move(valuesToMap)); } else { - patternMaterializations.insert(convertOp); + if (castOp) + patternMaterializations.insert(castOp); } - return convertOp.getResults(); + return results; } Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( - Value value, const TypeConverter *converter) { + OpBuilder &builder, Value value, const TypeConverter *converter) { assert(config.allowPatternRollback && "this code path is valid only in rollback mode"); @@ -1700,7 +1735,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // materialization must be valid for all future users that may be created // later in the conversion process. Value castValue = - buildUnresolvedMaterialization(MaterializationKind::Source, + buildUnresolvedMaterialization(rewriter, MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(), /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/value.getType(), @@ -1779,7 +1814,8 @@ getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, // The replaced value is dropped. Materialize a replacement value "out of // thin air". Value srcMat = impl.buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(from), from.getLoc(), + impl.rewriter, MaterializationKind::Source, computeInsertPoint(from), + from.getLoc(), /*valuesToMap=*/{}, /*inputs=*/ValueRange(), /*outputTypes=*/from.getType(), /*originalType=*/Type(), converter)[0]; @@ -1799,7 +1835,8 @@ getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, // materializations if possible. This would require an extension of the // `lookupOrDefault` API. Value srcMat = impl.buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(to), from.getLoc(), + impl.rewriter, MaterializationKind::Source, computeInsertPoint(to), + from.getLoc(), /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(), /*originalType=*/Type(), converter)[0]; repls.push_back(srcMat); @@ -1855,7 +1892,7 @@ void ConversionPatternRewriterImpl::replaceOp( // This result was dropped and no replacement value was provided. // Materialize a replacement value "out of thin air". buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(result), + rewriter, MaterializationKind::Source, computeInsertPoint(result), result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(), /*outputTypes=*/result.getType(), /*originalType=*/Type(), currentTypeConverter, /*isPureTypeConversion=*/false); @@ -3234,7 +3271,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { castOp->removeAttr(kPureTypeConversionMarker); // Try to legalize all unresolved materializations. - if (rewriter.getConfig().buildMaterializations) { + if (rewriter.getConfig().buildMaterializations != + ConversionConfig::MaterializationMode::Never) { // Use a new rewriter, so the modifications are not tracked for rollback // purposes etc. IRRewriter irRewriter(rewriterImpl.rewriter.getContext(), diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 3fa42ff..5ec12be 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=never attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND // CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B" // CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B" diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 95f381e..d6126de 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1660,11 +1660,20 @@ struct TestLegalizePatternDriver llvm::cl::desc( "Attach materialization kind to unrealized_conversion_cast ops"), llvm::cl::init(false)}; - Option<bool> buildMaterializations{ + Option<ConversionConfig::MaterializationMode> buildMaterializations{ *this, "build-materializations", - llvm::cl::desc( - "If set to 'false', leave unrealized_conversion_cast ops in place"), - llvm::cl::init(true)}; + llvm::cl::desc("When to build unresovled materializations."), + llvm::cl::init(ConversionConfig::MaterializationMode::Delayed), + llvm::cl::values( + clEnumValN(ConversionConfig::MaterializationMode::Never, "never", + "Never build materialization with the type converter."), + clEnumValN(ConversionConfig::MaterializationMode::Delayed, "delayed", + "Build materializations with the type converter at the " + "end of the conversion."), + clEnumValN(ConversionConfig::MaterializationMode::Immediate, + "immediate", + "Build materializations with the type converter " + "immediately."))}; }; } // namespace |