aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2025-08-10 11:41:51 +0000
committerMatthias Springer <me@m-sp.org>2025-08-26 12:31:20 +0000
commit4ce5edf8e73d8e8c850ea9fd319d3692eaf1477e (patch)
tree6651930f08af515cd56ec0965c6991b8e3b3949f
parent769d5c2dfb9d1bde24f915d926f8ac17ffbe29a1 (diff)
downloadllvm-users/matthias-springer/migrate_detensorize.zip
llvm-users/matthias-springer/migrate_detensorize.tar.gz
llvm-users/matthias-springer/migrate_detensorize.tar.bz2
[mlir][linalg] Migrate Detensorize pass to new dialect conversion driverusers/matthias-springer/migrate_detensorize
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp34
-rw-r--r--mlir/test/Dialect/Linalg/detensorize_0d.mlir9
2 files changed, 37 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 8309054..221f95a8d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -458,6 +458,22 @@ struct LinalgDetensorize
}
};
+ /// A listener that forwards notifyBlockErased and notifyOperationErased to
+ /// the given callbacks.
+ struct CallbackListener : public RewriterBase::Listener {
+ CallbackListener(std::function<void(Operation *op)> onOperationErased,
+ std::function<void(Block *block)> onBlockErased)
+ : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {}
+
+ void notifyBlockErased(Block *block) override { onBlockErased(block); }
+ void notifyOperationErased(Operation *op) override {
+ onOperationErased(op);
+ }
+
+ std::function<void(Operation *op)> onOperationErased;
+ std::function<void(Block *block)> onBlockErased;
+ };
+
void runOnOperation() override {
MLIRContext *context = &getContext();
DetensorizeTypeConverter typeConverter;
@@ -551,8 +567,22 @@ struct LinalgDetensorize
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
shouldConvertBranchOperand);
- if (failed(
- applyFullConversion(getOperation(), target, std::move(patterns))))
+ ConversionConfig config;
+ auto onOperationErased = [&](Operation *op) {
+ opsToDetensor.erase(op);
+ detensorableBranchOps.erase(op);
+ };
+ auto onBlockErased = [&](Block *block) {
+ for (BlockArgument arg : block->getArguments()) {
+ blockArgsToDetensor.erase(arg);
+ }
+ };
+ CallbackListener listener(onOperationErased, onBlockErased);
+
+ config.listener = &listener;
+ config.allowPatternRollback = false;
+ if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
+ config)))
signalPassFailure();
RewritePatternSet canonPatterns(context);
diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
index 74931cb..76e8c7e 100644
--- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
@@ -53,10 +53,11 @@ func.func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tenso
}
// CHECK-LABEL: func @detensor_op_sequence
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
-// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
-// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
-// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]]
-// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]]
+// CHECK: %[[arg1_val_1:.*]] = tensor.extract %[[arg1]]
+// CHECK: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
+// CHECK: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]]
+// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]]
+// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val_1]], %[[detensored_res]]
// CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]]
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
// CHECK: return %[[new_tensor_res]]