aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Kramer <benny.kra@googlemail.com>2024-06-27 09:14:54 +0200
committerBenjamin Kramer <benny.kra@googlemail.com>2024-06-27 09:16:17 +0200
commit605098dcd4e79b27c86784b1a3d7fc6e3010ce00 (patch)
tree7dcff66e3c6187db822cf8217098e8bced02af0a
parentb5cc19e572855136eb4080208a9bd5ecef785aa3 (diff)
downloadllvm-605098dcd4e79b27c86784b1a3d7fc6e3010ce00.zip
llvm-605098dcd4e79b27c86784b1a3d7fc6e3010ce00.tar.gz
llvm-605098dcd4e79b27c86784b1a3d7fc6e3010ce00.tar.bz2
Revert "[mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic (#96329)"
This reverts commit c01ce797619359ee282773dfc4b1e91ff0a30435. It depends on f1e0657d144f5a3cfef4b625d0f875f4dacd21d1 which breaks SCF lowering.
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp133
1 files changed, 81 insertions, 52 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 47e0338..07ebd68 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -53,16 +53,6 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
});
}
-/// Helper function that computes an insertion point where the given value is
-/// defined and can be used without a dominance violation.
-static OpBuilder::InsertPoint computeInsertPoint(Value value) {
- Block *insertBlock = value.getParentBlock();
- Block::iterator insertPt = insertBlock->begin();
- if (OpResult inputRes = dyn_cast<OpResult>(value))
- insertPt = ++inputRes.getOwner()->getIterator();
- return OpBuilder::InsertPoint(insertBlock, insertPt);
-}
-
//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
@@ -455,9 +445,11 @@ public:
return rewrite->getKind() == Kind::BlockTypeConversion;
}
- Block *getOrigBlock() const { return origBlock; }
-
- const TypeConverter *getConverter() const { return converter; }
+ /// Materialize any necessary conversions for converted arguments that have
+ /// live users, using the provided `findLiveUser` to search for a user that
+ /// survives the conversion process.
+ LogicalResult
+ materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
void commit(RewriterBase &rewriter) override;
@@ -849,10 +841,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value buildUnresolvedMaterialization(MaterializationKind kind,
- OpBuilder::InsertPoint ip, Location loc,
+ Block *insertBlock,
+ Block::iterator insertPt, Location loc,
ValueRange inputs, Type outputType,
Type origOutputType,
const TypeConverter *converter);
+ Value buildUnresolvedTargetMaterialization(Location loc, Value input,
+ Type outputType,
+ const TypeConverter *converter);
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -985,6 +981,49 @@ void BlockTypeConversionRewrite::rollback() {
block->replaceAllUsesWith(origBlock);
}
+LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
+ function_ref<Operation *(Value)> findLiveUser) {
+ // Process the remapping for each of the original arguments.
+ for (auto it : llvm::enumerate(origBlock->getArguments())) {
+ BlockArgument origArg = it.value();
+ // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
+ OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
+ builder.setInsertionPointToStart(block);
+
+ // If the type of this argument changed and the argument is still live, we
+ // need to materialize a conversion.
+ if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+ continue;
+ Operation *liveUser = findLiveUser(origArg);
+ if (!liveUser)
+ continue;
+
+ Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
+ assert(replacementValue && "replacement value not found");
+ Value newArg;
+ if (converter) {
+ builder.setInsertionPointAfterValue(replacementValue);
+ newArg = converter->materializeSourceConversion(
+ builder, origArg.getLoc(), origArg.getType(), replacementValue);
+ assert((!newArg || newArg.getType() == origArg.getType()) &&
+ "materialization hook did not provide a value of the expected "
+ "type");
+ }
+ if (!newArg) {
+ InFlightDiagnostic diag =
+ emitError(origArg.getLoc())
+ << "failed to materialize conversion for block argument #"
+ << it.index() << " that remained live after conversion, type was "
+ << origArg.getType();
+ diag.attachNote(liveUser->getLoc())
+ << "see existing live user here: " << *liveUser;
+ return failure();
+ }
+ rewriterImpl.mapping.map(origArg, newArg);
+ }
+ return success();
+}
+
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
if (!repl)
@@ -1157,10 +1196,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Type newOperandType = newOperand.getType();
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
- Value castValue = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(newOperand),
- operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
- /*origArgType=*/{}, currentTypeConverter);
+ Value castValue = buildUnresolvedTargetMaterialization(
+ operandLoc, newOperand, desiredType, currentTypeConverter);
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
newOperand = castValue;
}
@@ -1288,9 +1325,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
Value repl = buildUnresolvedMaterialization(
- MaterializationKind::Source,
- OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*inputs=*/ValueRange(),
+ MaterializationKind::Source, newBlock, newBlock->begin(),
+ origArg.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*origArgType=*/{}, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1315,9 +1351,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value repl = buildUnresolvedMaterialization(
- MaterializationKind::Argument,
- OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*inputs=*/replArgs,
+ MaterializationKind::Argument, newBlock, newBlock->begin(),
+ origArg.getLoc(), /*inputs=*/replArgs,
/*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1339,8 +1374,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
- MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- ValueRange inputs, Type outputType, Type origArgType,
+ MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
+ Location loc, ValueRange inputs, Type outputType, Type origArgType,
const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1348,13 +1383,25 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
- OpBuilder builder(ip.getBlock(), ip.getPoint());
+ OpBuilder builder(insertBlock, insertPt);
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
origArgType);
return convertOp.getResult(0);
}
+Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
+ Location loc, Value input, Type outputType,
+ const TypeConverter *converter) {
+ Block *insertBlock = input.getParentBlock();
+ Block::iterator insertPt = insertBlock->begin();
+ if (OpResult inputRes = dyn_cast<OpResult>(input))
+ insertPt = ++inputRes.getOwner()->getIterator();
+
+ return buildUnresolvedMaterialization(
+ MaterializationKind::Target, insertBlock, insertPt, loc, input,
+ outputType, /*origArgType=*/{}, converter);
+}
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -2468,9 +2515,9 @@ LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
- if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
- failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
- inverseMapping)))
+ if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
+ inverseMapping)) ||
+ failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();
// Process requested operation replacements.
@@ -2526,28 +2573,10 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
++i) {
auto &rewrite = rewriterImpl.rewrites[i];
if (auto *blockTypeConversionRewrite =
- dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
- // Process the remapping for each of the original arguments.
- for (Value origArg :
- blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
- // If the type of this argument changed and the argument is still live,
- // we need to materialize a conversion.
- if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
- continue;
- Operation *liveUser = findLiveUser(origArg);
- if (!liveUser)
- continue;
-
- Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
- assert(replacementValue && "replacement value not found");
- Value repl = rewriterImpl.buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(replacementValue),
- origArg.getLoc(), /*inputs=*/replacementValue,
- /*outputType=*/origArg.getType(), /*origArgType=*/{},
- blockTypeConversionRewrite->getConverter());
- rewriterImpl.mapping.map(origArg, repl);
- }
- }
+ dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
+ if (failed(blockTypeConversionRewrite->materializeLiveConversions(
+ findLiveUser)))
+ return failure();
}
return success();
}