diff options
author | Matthias Springer <mspringer@nvidia.com> | 2025-04-23 14:56:37 +0200 |
---|---|---|
committer | Matthias Springer <mspringer@nvidia.com> | 2025-04-23 14:56:37 +0200 |
commit | 03e54bf5a82f47c0845eadf30cbb82d7bc677313 (patch) | |
tree | ad94064176a451a7a90222bcd40ab3957bb1e653 | |
parent | 2d3bbb6aafbc74ef6fc51286f09def0f0e35fe14 (diff) | |
download | llvm-origin/users/matthias-springer/extractvalue_folder.zip llvm-origin/users/matthias-springer/extractvalue_folder.tar.gz llvm-origin/users/matthias-springer/extractvalue_folder.tar.bz2 |
[mlir][LLVM] Improve `llvm.extractvalue` folderorigin/users/matthias-springer/extractvalue_folder
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 43 | ||||
-rw-r--r-- | mlir/test/Dialect/LLVMIR/canonicalize.mlir | 16 |
2 files changed, 53 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 0022be8..5586f57 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1885,11 +1885,44 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) { auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); OpFoldResult result = {}; + ArrayRef<int64_t> extractPos = getPosition(); + bool switchedToInsertedValue = false; while (insertValueOp) { - if (getPosition() == insertValueOp.getPosition()) + ArrayRef<int64_t> insertPos = insertValueOp.getPosition(); + auto extractPosSize = extractPos.size(); + auto insertPosSize = insertPos.size(); + + // Case 1: Exact match of positions. + if (extractPos == insertPos) return insertValueOp.getValue(); - unsigned min = - std::min(getPosition().size(), insertValueOp.getPosition().size()); + + // Case 2: Insert position is a prefix of extract position. Continue + // traversal with the inserted value. Example: + // ``` + // %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)> + // %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)> + // %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)> + // %3 = llvm.insertvalue %2, %foo[0] + // : !llvm.struct<(struct<(i32, i32, i32)>, i64)> + // %4 = llvm.extractvalue %3[0, 0] + // : !llvm.struct<(struct<(i32, i32, i32)>, i64)> + // ``` + // In the above example, %4 is folded to %arg1. + if (extractPosSize > insertPosSize && + extractPos.take_front(insertPosSize) == insertPos) { + insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>(); + extractPos = extractPos.drop_front(insertPosSize); + switchedToInsertedValue = true; + continue; + } + + // Case 3: Try to continue the traversal with the container value, in order + // to swap out the container operand. This does not work if we decided + // earlier to continue the traversal with the inserted value (Case 2). + if (switchedToInsertedValue) + return {}; + unsigned min = std::min(extractPosSize, insertPosSize); + // If one is fully prefix of the other, stop propagating back as it will // miss dependencies. For instance, %3 should not fold to %f0 in the // following example: @@ -1900,10 +1933,8 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) { // !llvm.array<4 x !llvm.array<4 x f32>> // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>> // ``` - if (getPosition().take_front(min) == - insertValueOp.getPosition().take_front(min)) + if (extractPos.take_front(min) == insertPos.take_front(min)) return result; - // If neither a prefix, nor the exact position, we can extract out of the // value being inserted into. Moreover, we can try again if that operand // is itself an insertvalue expression. diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir index a793cac..8accf6e 100644 --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -57,6 +57,22 @@ llvm.func @fold_extractvalue() -> i32 { // ----- +// CHECK-LABEL: fold_extractvalue( +// CHECK-SAME: %[[arg1:.*]]: i32, %[[arg2:.*]]: i32, %[[arg3:.*]]: i32) +// CHECK-NEXT: llvm.return %[[arg1]] : i32 +llvm.func @fold_extractvalue(%arg1: i32, %arg2: i32, %arg3: i32) -> i32{ + %3 = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)> + %5 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32)> + %6 = llvm.insertvalue %arg1, %5[0] : !llvm.struct<(i32, i32, i32)> + %7 = llvm.insertvalue %arg1, %6[1] : !llvm.struct<(i32, i32, i32)> + %8 = llvm.insertvalue %arg1, %7[2] : !llvm.struct<(i32, i32, i32)> + %11 = llvm.insertvalue %8, %3[0] : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)> + %13 = llvm.extractvalue %11[0, 0] : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)> + llvm.return %13 : i32 +} + +// ----- + // CHECK-LABEL: no_fold_extractvalue llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 { %f0 = arith.constant 0.0 : f32 |