diff options
author | Matthias Springer <mspringer@nvidia.com> | 2025-01-04 13:53:38 +0100 |
---|---|---|
committer | Matthias Springer <mspringer@nvidia.com> | 2025-01-04 15:06:47 +0100 |
commit | bf57b8d0a3da1c9d383374399a36f766df3f255e (patch) | |
tree | e20a66606d7187c20f9d655c97527d8a95fecd83 | |
parent | 11026039f7ce600ff04fcf2e54a84035f4484678 (diff) | |
download | llvm-users/matthias-springer/fix_mapping_2.zip llvm-users/matthias-springer/fix_mapping_2.tar.gz llvm-users/matthias-springer/fix_mapping_2.tar.bz2 |
[mlir][Transforms] Detect mapping overwrites during block signature conversionusers/matthias-springer/fix_mapping_2
Add extra assertions to make sure that a value in the conversion value mapping is not overwritten during `applySignatureConversion`.
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 4904d3c..94e61a2 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -176,6 +176,8 @@ struct ConversionValueMapping { template <typename OldVal, typename NewVal> std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value> map(OldVal &&oldVal, NewVal &&newVal) { + assert(!mapping.contains(oldVal) && + "attempting to overwrite existing mapping"); LLVM_DEBUG({ ValueVector next(newVal); while (true) { @@ -1412,6 +1414,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( for (unsigned i = 0; i != origArgCount; ++i) { BlockArgument origArg = block->getArgument(i); Type origArgType = origArg.getType(); + ValueVector currentMapping = mapping.lookupOrDefault(origArg); std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap = signatureConversion.getInputMapping(i); @@ -1421,7 +1424,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( buildUnresolvedMaterialization( MaterializationKind::Source, OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(), + /*valuesToMap=*/currentMapping, /*inputs=*/ValueRange(), /*outputType=*/origArgType, /*originalType=*/Type(), converter); appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); continue; @@ -1432,7 +1435,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - mapping.map(origArg, repl); + mapping.map(currentMapping, repl); appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); continue; } @@ -1441,7 +1444,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs); - mapping.map(origArg, std::move(replArgVals)); + mapping.map(currentMapping, std::move(replArgVals)); appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); } @@ -1757,6 +1760,8 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, << "'(in region of '" << parentOp->getName() << "'(" << from.getOwner()->getParentOp() << ")\n"; }); + llvm::errs() << "replaceUsesOfBlockArgument: " << from.getOwner() << "\n"; + impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, impl->currentTypeConverter); impl->mapping.map(impl->mapping.lookupOrDefault(from), to); |