aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Transforms
diff options
context:
space:
mode:
authoravarmapml <abhishek.varma@polymagelabs.com>2020-08-13 19:09:22 +0530
committerUday Bondhugula <uday@polymagelabs.com>2020-08-13 19:10:47 +0530
commit6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78 (patch)
tree3b12c1023615d9b2e33bbf80b1996173d83efb67 /mlir/test/Transforms
parentfc7f004b887594ed9b5e25497aa4475b3e74cc94 (diff)
downloadllvm-6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78.zip
llvm-6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78.tar.gz
llvm-6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78.tar.bz2
[MLIR] Support for ReturnOps in memref map layout normalization
-- This commit handles the returnOp in memref map layout normalization. -- An initial filter is applied on FuncOps which helps us know which functions can be a suitable candidate for memref normalization which doesn't lead to invalid IR. -- Handles memref map normalization for external function assuming the external function is normalizable. Differential Revision: https://reviews.llvm.org/D85226
Diffstat (limited to 'mlir/test/Transforms')
-rw-r--r--mlir/test/Transforms/normalize-memrefs.mlir134
1 files changed, 121 insertions, 13 deletions
diff --git a/mlir/test/Transforms/normalize-memrefs.mlir b/mlir/test/Transforms/normalize-memrefs.mlir
index 7d56c88..9c1e610 100644
--- a/mlir/test/Transforms/normalize-memrefs.mlir
+++ b/mlir/test/Transforms/normalize-memrefs.mlir
@@ -126,14 +126,6 @@ func @symbolic_operands(%s : index) {
return
}
-// Memref escapes; no normalization.
-// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}>
-func @escaping() -> memref<64xf32, affine_map<(d0) -> (d0 + 2)>> {
- // CHECK: %{{.*}} = alloc() : memref<64xf32, #map{{[0-9]+}}>
- %A = alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 2)>>
- return %A : memref<64xf32, affine_map<(d0) -> (d0 + 2)>>
-}
-
// Semi-affine maps, normalization not implemented yet.
// CHECK-LABEL: func @semi_affine_layout_map
func @semi_affine_layout_map(%s0: index, %s1: index) {
@@ -205,9 +197,125 @@ func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 {
return %d : i1
}
-// Test case 4: No normalization should take place because the function is returning the memref.
-// CHECK-LABEL: func @memref_used_in_return
-// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) -> memref<8xf64, #map{{[0-9]+}}>
-func @memref_used_in_return(%A: memref<8xf64, #tile>) -> (memref<8xf64, #tile>) {
- return %A : memref<8xf64, #tile>
+// Test cases here onwards deal with normalization of memref in function signature, caller site.
+
+// Test case 4: Check successful memref normalization in case of inter/intra-recursive calls.
+// CHECK-LABEL: func @ret_multiple_argument_type
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<2x4xf64>, f64)
+func @ret_multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) {
+ %a = affine.load %A[0] : memref<16xf64, #tile>
+ %p = mulf %a, %a : f64
+ %cond = constant 1 : i1
+ cond_br %cond, ^bb1, ^bb2
+ ^bb1:
+ %res1, %res2 = call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+ return %res2, %p: memref<8xf64, #tile>, f64
+ ^bb2:
+ return %C, %p: memref<8xf64, #tile>, f64
+}
+
+// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64>
+// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64
+// CHECK: %true = constant true
+// CHECK: cond_br %true, ^bb1, ^bb2
+// CHECK: ^bb1: // pred: ^bb0
+// CHECK: %[[res:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK: return %[[res]]#1, %[[p]] : memref<2x4xf64>, f64
+// CHECK: ^bb2: // pred: ^bb0
+// CHECK: return %{{.*}}, %{{.*}} : memref<2x4xf64>, f64
+
+// CHECK-LABEL: func @ret_single_argument_type
+// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+func @ret_single_argument_type(%C: memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>){
+ %a = alloc() : memref<8xf64, #tile>
+ %b = alloc() : memref<16xf64, #tile>
+ %d = constant 23.0 : f64
+ call @ret_single_argument_type(%a) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+ call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+ %res1, %res2 = call @ret_multiple_argument_type(%b, %d, %a) : (memref<16xf64, #tile>, f64, memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64)
+ %res3, %res4 = call @ret_single_argument_type(%res1) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+ return %b, %a: memref<16xf64, #tile>, memref<8xf64, #tile>
+}
+
+// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64>
+// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64>
+// CHECK: %cst = constant 2.300000e+01 : f64
+// CHECK: %[[resA:[0-9]+]]:2 = call @ret_single_argument_type(%[[a]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK: %[[resB:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK: %[[resC:[0-9]+]]:2 = call @ret_multiple_argument_type(%[[b]], %cst, %[[a]]) : (memref<4x4xf64>, f64, memref<2x4xf64>) -> (memref<2x4xf64>, f64)
+// CHECK: %[[resD:[0-9]+]]:2 = call @ret_single_argument_type(%[[resC]]#0) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK: return %{{.*}}, %{{.*}} : memref<4x4xf64>, memref<2x4xf64>
+
+// Test case set #5: To check normalization in a chain of interconnected functions.
+// CHECK-LABEL: func @func_A
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
+func @func_A(%A: memref<8xf64, #tile>) {
+ call @func_B(%A) : (memref<8xf64, #tile>) -> ()
+ return
+}
+// CHECK: call @func_B(%[[A]]) : (memref<2x4xf64>) -> ()
+
+// CHECK-LABEL: func @func_B
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
+func @func_B(%A: memref<8xf64, #tile>) {
+ call @func_C(%A) : (memref<8xf64, #tile>) -> ()
+ return
+}
+// CHECK: call @func_C(%[[A]]) : (memref<2x4xf64>) -> ()
+
+// CHECK-LABEL: func @func_C
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
+func @func_C(%A: memref<8xf64, #tile>) {
+ return
+}
+
+// Test case set #6: Checking if no normalization takes place in a scenario: A -> B -> C and B has an unsupported type.
+// CHECK-LABEL: func @some_func_A
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
+func @some_func_A(%A: memref<8xf64, #tile>) {
+ call @some_func_B(%A) : (memref<8xf64, #tile>) -> ()
+ return
+}
+// CHECK: call @some_func_B(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> ()
+
+// CHECK-LABEL: func @some_func_B
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
+func @some_func_B(%A: memref<8xf64, #tile>) {
+ "test.test"(%A) : (memref<8xf64, #tile>) -> ()
+ call @some_func_C(%A) : (memref<8xf64, #tile>) -> ()
+ return
+}
+// CHECK: call @some_func_C(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> ()
+
+// CHECK-LABEL: func @some_func_C
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
+func @some_func_C(%A: memref<8xf64, #tile>) {
+ return
+}
+
+// Test case set #7: Check normalization in case of external functions.
+// CHECK-LABEL: func @external_func_A
+// CHECK-SAME: (memref<4x4xf64>)
+func @external_func_A(memref<16xf64, #tile>) -> ()
+
+// CHECK-LABEL: func @external_func_B
+// CHECK-SAME: (memref<4x4xf64>, f64) -> memref<2x4xf64>
+func @external_func_B(memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
+
+// CHECK-LABEL: func @simply_call_external()
+func @simply_call_external() {
+ %a = alloc() : memref<16xf64, #tile>
+ call @external_func_A(%a) : (memref<16xf64, #tile>) -> ()
+ return
+}
+// CHECK: %[[a:[0-9]+]] = alloc() : memref<4x4xf64>
+// CHECK: call @external_func_A(%[[a]]) : (memref<4x4xf64>) -> ()
+
+// CHECK-LABEL: func @use_value_of_external
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64) -> memref<2x4xf64>
+func @use_value_of_external(%A: memref<16xf64, #tile>, %B: f64) -> (memref<8xf64, #tile>) {
+ %res = call @external_func_B(%A, %B) : (memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
+ return %res : memref<8xf64, #tile>
}
+// CHECK: %[[res:[0-9]+]] = call @external_func_B(%[[A]], %[[B]]) : (memref<4x4xf64>, f64) -> memref<2x4xf64>
+// CHECK: return %{{.*}} : memref<2x4xf64>