aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2024-11-26 05:38:50 +0100
committerMatthias Springer <mspringer@nvidia.com>2024-11-26 06:22:24 +0100
commitd046db6686dc31a39511422ad8c35ae439f2a7ba (patch)
tree401f77e5cae0c108f41d7540eb8d686457472e04
parentd2d3eb9c9043a05325f2f9b056aaf27db21b3301 (diff)
downloadllvm-users/matthias-springer/replace_uses_of_bbarg.zip
llvm-users/matthias-springer/replace_uses_of_bbarg.tar.gz
llvm-users/matthias-springer/replace_uses_of_bbarg.tar.bz2
[mlir][Transforms] Dialect conversion: Fix `replaceUsesOfBlockArgument`users/matthias-springer/replace_uses_of_bbarg
-rw-r--r--mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp2
-rw-r--r--mlir/test/Transforms/test-legalizer.mlir20
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp46
4 files changed, 67 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index c046ea1..9c776b9 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -295,7 +295,7 @@ static void restoreByValRefArgumentType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
- rewriter.replaceUsesOfBlockArgument(oldArg, valueArg);
+ rewriter.replaceUsesOfBlockArgument(arg, valueArg);
}
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 60b3656..8b7ffd7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1641,7 +1641,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
});
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
impl->currentTypeConverter);
- impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+ impl->mapping.map(from, to);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 624add0..63d27fa 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -472,3 +472,23 @@ func.func @circular_mapping() {
%0 = "test.erase_op"() : () -> (i64)
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
}
+
+// -----
+
+// CHECK-LABEL: func @test_replace_uses_of_block_arg() {
+// CHECK: "test.convert_block_and_replace_arg"() ({
+// CHECK: ^bb0(%[[arg0:.*]]: f64, %[[arg1:.*]]: f64):
+// CHECK: %[[producer:.*]] = "test.type_producer"() : () -> f64
+// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]], %[[arg1]]) : (f64, f64) -> f32
+// CHECK: "test.some_user"(%[[cast]]) : (f32) -> ()
+// CHECK: }) {legal} : () -> ()
+// CHECK: "test.return"() : () -> ()
+// CHECK: }
+func.func @test_replace_uses_of_block_arg() {
+ "test.convert_block_and_replace_arg"() ({
+ ^bb0(%arg0: f32):
+ // expected-remark @below{{'test.some_user' is not legalizable}}
+ "test.some_user"(%arg0) : (f32) -> ()
+ }) : () -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e931b39..f8df89f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -902,6 +902,45 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
}
};
+struct TestConvertBlockAndReplaceArg : public ConversionPattern {
+ TestConvertBlockAndReplaceArg(MLIRContext *ctx,
+ const TypeConverter &converter)
+ : ConversionPattern(converter, "test.convert_block_and_replace_arg",
+ /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Expect single region with single block with single block argument.
+ if (op->getNumRegions() != 1)
+ return failure();
+ if (op->getRegion(0).getBlocks().size() != 1)
+ return failure();
+ Block *block = &op->getRegion(0).front();
+ if (block->getArguments().size() != 1)
+ return failure();
+
+ // Convert the block argument into to F64 block arguments.
+ TypeConverter::SignatureConversion result(1);
+ result.addInputs(0, {rewriter.getF64Type(), rewriter.getF64Type()});
+ Block *newBlock =
+ rewriter.applySignatureConversion(block, result, getTypeConverter());
+
+ // Create a replacement value.
+ rewriter.setInsertionPointToStart(newBlock);
+ Value repl = rewriter.create<TestTypeProducerOp>(op->getLoc(),
+ rewriter.getF64Type());
+ BlockArgument arg0 = newBlock->getArgument(0);
+ // Replace the block argument.
+ rewriter.replaceUsesOfBlockArgument(arg0, repl);
+
+ // Mark the op as legal.
+ rewriter.modifyOpInPlace(
+ op, [&]() { op->setAttr("legal", rewriter.getUnitAttr()); });
+ return success();
+ }
+};
+
/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
/// This is to test the rollback logic.
struct TestUndoMoveOpBefore : public ConversionPattern {
@@ -1265,7 +1304,8 @@ struct TestLegalizePatternDriver
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
TestUndoPropertiesModification, TestEraseOp>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
- TestPassthroughInvalidOp>(&getContext(), converter);
+ TestPassthroughInvalidOp, TestConvertBlockAndReplaceArg>(
+ &getContext(), converter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1317,6 +1357,10 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });
+ target.addDynamicallyLegalOp(
+ OperationName("test.convert_block_and_replace_arg", &getContext()),
+ [](Operation *op) { return op->hasAttr("legal"); });
+
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;