diff options
author | Srishti Srivastava <srishtisrivastava.ai@gmail.com> | 2023-08-11 17:25:53 +0000 |
---|---|---|
committer | Srishti Srivastava <srishtisrivastava.ai@gmail.com> | 2023-08-11 17:26:58 +0000 |
commit | 232f8eadae18889627bdca75c45e98b0c1460086 (patch) | |
tree | 0868bdc269bca75540e4a309bfc6d807c4bde5af /mlir/test/Analysis | |
parent | 4e1a23e3fcea102dbf6c5a4742d96b8dc9de59e6 (diff) | |
download | llvm-232f8eadae18889627bdca75c45e98b0c1460086.zip llvm-232f8eadae18889627bdca75c45e98b0c1460086.tar.gz llvm-232f8eadae18889627bdca75c45e98b0c1460086.tar.bz2 |
[MLIR][analysis] Fix call op handling in sparse backward dataflow
Currently, data in `AbstractSparseBackwardDataFlowAnalysis` is
considered to flow one-to-one, in order, from the operands of an op
implementing `CallOpInterface` to the arguments of the function it is
calling.
This understanding of the data flow is inaccurate. The operands of such
an op that forward to the function arguments are obtained using a
method provided by `CallOpInterface` called `getArgOperands()`.
This commit fixes this bug by using `getArgOperands()` instead of
`getOperands()` to get the mapping from operands to function arguments
because not all operands necessarily forward to the function arguments
and even if they do, they don't necessarily have to be in the order in
which they appear in the op. The operands that don't get forwarded are
handled by the newly introduced `visitCallOperand()` function, which
works analogous to the `visitBranchOperand()` function.
This fix is also propagated to liveness analysis that earlier relied on
this incorrect implementation of the sparse backward dataflow analysis
framework and corrects some incorrect assumptions made in it.
Extra cleanup: Improved a comment and removed an unnecessary code line.
Signed-off-by: Srishti Srivastava <srishtisrivastava.ai@gmail.com>
Reviewed By: matthiaskramm, jcai19
Differential Revision: https://reviews.llvm.org/D157261
Diffstat (limited to 'mlir/test/Analysis')
-rw-r--r-- | mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir | 28 | ||||
-rw-r--r-- | mlir/test/Analysis/DataFlow/test-written-to.mlir | 19 |
2 files changed, 40 insertions, 7 deletions
diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir index a040fb3..b6aed1c 100644 --- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir +++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir @@ -59,11 +59,27 @@ func.func @test_3_BranchOpInterface_type_1.b(%arg0: i32, %arg1: memref<i32>, %ar // ----- +func.func private @private(%arg0 : i32, %arg1 : i32) { + func.return +} + +// Positive test: Type (1.c) "is a non-forwarded call operand" +// CHECK-LABEL: test_tag: call +// CHECK-LABEL: operand #0: not live +// CHECK-LABEL: operand #1: not live +// CHECK-LABEL: operand #2: live +func.func @test_4_type_1.c(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) { + test.call_on_device @private(%arg0, %arg1), %device {tag = "call"} : (i32, i32, i32) -> () + return +} + +// ----- + // Positive test: Type (2) "is returned by a public function" // zero is live because it is returned by a public function. // CHECK-LABEL: test_tag: zero: // CHECK-NEXT: result #0: live -func.func @test_4_type_2() -> (f32){ +func.func @test_5_type_2() -> (f32){ %0 = arith.constant {tag = "zero"} 0.0 : f32 return %0 : f32 } @@ -90,7 +106,7 @@ func.func @test_4_type_2() -> (f32){ // CHECK-NEXT: operand #3: live // CHECK-LABEL: test_tag: add: // CHECK-NEXT: operand #0: live -func.func @test_5_RegionBranchTerminatorOpInterface_type_3(%arg0: memref<i32>, %arg1: i1) -> (i32) { +func.func @test_6_RegionBranchTerminatorOpInterface_type_3(%arg0: memref<i32>, %arg1: i1) -> (i32) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c2_i32 = arith.constant 2 : i32 @@ -135,7 +151,7 @@ func.func private @private0(%0 : i32) -> i32 { // CHECK-NEXT: result #0: live // CHECK-LABEL: test_tag: y: // CHECK-NEXT: result #0: not live -func.func @test_6_type_3(%arg0: memref<i32>) { +func.func @test_7_type_3(%arg0: memref<i32>) { %c0 = arith.constant {tag = "zero"} 0 : index %c10 = arith.constant {tag = "ten"} 10 : index %c1 = arith.constant {tag = "one"} 1 : index @@ -190,7 +206,7 @@ func.func private @private2(%0 : i32) -> i32 { // CHECK-NEXT: operand #0: live // CHECK-NEXT: operand #1: live // CHECK-NEXT: result #0: live -func.func @test_7_type_3(%arg: i32) -> (i32) { +func.func @test_8_type_3(%arg: i32) -> (i32) { %0 = func.call @private1(%arg) : (i32) -> i32 %final = arith.muli %0, %arg {tag = "final"} : i32 return %final : i32 @@ -205,7 +221,7 @@ func.func @test_7_type_3(%arg: i32) -> (i32) { // CHECK-NEXT: result #0: not live // CHECK-LABEL: test_tag: one: // CHECK-NEXT: result #0: live -func.func @test_8_negative() -> (f32){ +func.func @test_9_negative() -> (f32){ %0 = arith.constant {tag = "zero"} 0.0 : f32 %1 = arith.constant {tag = "one"} 1.0 : f32 return %1 : f32 @@ -230,7 +246,7 @@ func.func private @private_1() -> (i32, i32) { %1 = arith.addi %0, %0 {tag = "one"} : i32 return %0, %1 : i32, i32 } -func.func @test_9_negative() -> (i32) { +func.func @test_10_negative() -> (i32) { %0:2 = func.call @private_1() : () -> (i32, i32) return %0#0 : i32 } diff --git a/mlir/test/Analysis/DataFlow/test-written-to.mlir b/mlir/test/Analysis/DataFlow/test-written-to.mlir index 1ff92f5..82fe755 100644 --- a/mlir/test/Analysis/DataFlow/test-written-to.mlir +++ b/mlir/test/Analysis/DataFlow/test-written-to.mlir @@ -286,4 +286,21 @@ llvm.func @decl(i64) llvm.func @func(%lb : i64) -> () { llvm.call @decl(%lb) : (i64) -> () llvm.return -} +} + +// ----- + +func.func private @callee(%arg0 : i32, %arg1 : i32) -> i32 { + func.return %arg0 : i32 +} + +// CHECK-LABEL: test_tag: a +// CHECK-LABEL: operand #0: [b] +// CHECK-LABEL: operand #1: [] +// CHECK-LABEL: operand #2: [callarg2] +// CHECK-LABEL: result #0: [b] +func.func @test_call_on_device(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) { + %0 = test.call_on_device @callee(%arg0, %arg1), %device {tag = "a"} : (i32, i32, i32) -> (i32) + memref.store %0, %m0[] {tag_name = "b"} : memref<i32> + return +} |