diff options
author | avarmapml <abhishek.varma@polymagelabs.com> | 2020-08-13 19:09:22 +0530 |
---|---|---|
committer | Uday Bondhugula <uday@polymagelabs.com> | 2020-08-13 19:10:47 +0530 |
commit | 6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78 (patch) | |
tree | 3b12c1023615d9b2e33bbf80b1996173d83efb67 /mlir/test/Transforms | |
parent | fc7f004b887594ed9b5e25497aa4475b3e74cc94 (diff) | |
download | llvm-6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78.zip llvm-6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78.tar.gz llvm-6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78.tar.bz2 |
[MLIR] Support for ReturnOps in memref map layout normalization
-- This commit handles the returnOp in memref map layout normalization.
-- An initial filter is applied on FuncOps which helps us know which functions can be
a suitable candidate for memref normalization which doesn't lead to invalid IR.
-- Handles memref map normalization for external function assuming the external function
is normalizable.
Differential Revision: https://reviews.llvm.org/D85226
Diffstat (limited to 'mlir/test/Transforms')
-rw-r--r-- | mlir/test/Transforms/normalize-memrefs.mlir | 134 |
1 files changed, 121 insertions, 13 deletions
diff --git a/mlir/test/Transforms/normalize-memrefs.mlir b/mlir/test/Transforms/normalize-memrefs.mlir index 7d56c88..9c1e610 100644 --- a/mlir/test/Transforms/normalize-memrefs.mlir +++ b/mlir/test/Transforms/normalize-memrefs.mlir @@ -126,14 +126,6 @@ func @symbolic_operands(%s : index) { return } -// Memref escapes; no normalization. -// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}> -func @escaping() -> memref<64xf32, affine_map<(d0) -> (d0 + 2)>> { - // CHECK: %{{.*}} = alloc() : memref<64xf32, #map{{[0-9]+}}> - %A = alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 2)>> - return %A : memref<64xf32, affine_map<(d0) -> (d0 + 2)>> -} - // Semi-affine maps, normalization not implemented yet. // CHECK-LABEL: func @semi_affine_layout_map func @semi_affine_layout_map(%s0: index, %s1: index) { @@ -205,9 +197,125 @@ func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 { return %d : i1 } -// Test case 4: No normalization should take place because the function is returning the memref. -// CHECK-LABEL: func @memref_used_in_return -// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) -> memref<8xf64, #map{{[0-9]+}}> -func @memref_used_in_return(%A: memref<8xf64, #tile>) -> (memref<8xf64, #tile>) { - return %A : memref<8xf64, #tile> +// Test cases here onwards deal with normalization of memref in function signature, caller site. + +// Test case 4: Check successful memref normalization in case of inter/intra-recursive calls. +// CHECK-LABEL: func @ret_multiple_argument_type +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<2x4xf64>, f64) +func @ret_multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) { + %a = affine.load %A[0] : memref<16xf64, #tile> + %p = mulf %a, %a : f64 + %cond = constant 1 : i1 + cond_br %cond, ^bb1, ^bb2 + ^bb1: + %res1, %res2 = call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) + return %res2, %p: memref<8xf64, #tile>, f64 + ^bb2: + return %C, %p: memref<8xf64, #tile>, f64 +} + +// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64> +// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64 +// CHECK: %true = constant true +// CHECK: cond_br %true, ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: %[[res:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) +// CHECK: return %[[res]]#1, %[[p]] : memref<2x4xf64>, f64 +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: return %{{.*}}, %{{.*}} : memref<2x4xf64>, f64 + +// CHECK-LABEL: func @ret_single_argument_type +// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) +func @ret_single_argument_type(%C: memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>){ + %a = alloc() : memref<8xf64, #tile> + %b = alloc() : memref<16xf64, #tile> + %d = constant 23.0 : f64 + call @ret_single_argument_type(%a) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) + call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) + %res1, %res2 = call @ret_multiple_argument_type(%b, %d, %a) : (memref<16xf64, #tile>, f64, memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) + %res3, %res4 = call @ret_single_argument_type(%res1) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) + return %b, %a: memref<16xf64, #tile>, memref<8xf64, #tile> +} + +// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64> +// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64> +// CHECK: %cst = constant 2.300000e+01 : f64 +// CHECK: %[[resA:[0-9]+]]:2 = call @ret_single_argument_type(%[[a]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) +// CHECK: %[[resB:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) +// CHECK: %[[resC:[0-9]+]]:2 = call @ret_multiple_argument_type(%[[b]], %cst, %[[a]]) : (memref<4x4xf64>, f64, memref<2x4xf64>) -> (memref<2x4xf64>, f64) +// CHECK: %[[resD:[0-9]+]]:2 = call @ret_single_argument_type(%[[resC]]#0) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) +// CHECK: return %{{.*}}, %{{.*}} : memref<4x4xf64>, memref<2x4xf64> + +// Test case set #5: To check normalization in a chain of interconnected functions. +// CHECK-LABEL: func @func_A +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>) +func @func_A(%A: memref<8xf64, #tile>) { + call @func_B(%A) : (memref<8xf64, #tile>) -> () + return +} +// CHECK: call @func_B(%[[A]]) : (memref<2x4xf64>) -> () + +// CHECK-LABEL: func @func_B +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>) +func @func_B(%A: memref<8xf64, #tile>) { + call @func_C(%A) : (memref<8xf64, #tile>) -> () + return +} +// CHECK: call @func_C(%[[A]]) : (memref<2x4xf64>) -> () + +// CHECK-LABEL: func @func_C +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>) +func @func_C(%A: memref<8xf64, #tile>) { + return +} + +// Test case set #6: Checking if no normalization takes place in a scenario: A -> B -> C and B has an unsupported type. +// CHECK-LABEL: func @some_func_A +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) +func @some_func_A(%A: memref<8xf64, #tile>) { + call @some_func_B(%A) : (memref<8xf64, #tile>) -> () + return +} +// CHECK: call @some_func_B(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> () + +// CHECK-LABEL: func @some_func_B +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) +func @some_func_B(%A: memref<8xf64, #tile>) { + "test.test"(%A) : (memref<8xf64, #tile>) -> () + call @some_func_C(%A) : (memref<8xf64, #tile>) -> () + return +} +// CHECK: call @some_func_C(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> () + +// CHECK-LABEL: func @some_func_C +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) +func @some_func_C(%A: memref<8xf64, #tile>) { + return +} + +// Test case set #7: Check normalization in case of external functions. +// CHECK-LABEL: func @external_func_A +// CHECK-SAME: (memref<4x4xf64>) +func @external_func_A(memref<16xf64, #tile>) -> () + +// CHECK-LABEL: func @external_func_B +// CHECK-SAME: (memref<4x4xf64>, f64) -> memref<2x4xf64> +func @external_func_B(memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>) + +// CHECK-LABEL: func @simply_call_external() +func @simply_call_external() { + %a = alloc() : memref<16xf64, #tile> + call @external_func_A(%a) : (memref<16xf64, #tile>) -> () + return +} +// CHECK: %[[a:[0-9]+]] = alloc() : memref<4x4xf64> +// CHECK: call @external_func_A(%[[a]]) : (memref<4x4xf64>) -> () + +// CHECK-LABEL: func @use_value_of_external +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64) -> memref<2x4xf64> +func @use_value_of_external(%A: memref<16xf64, #tile>, %B: f64) -> (memref<8xf64, #tile>) { + %res = call @external_func_B(%A, %B) : (memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>) + return %res : memref<8xf64, #tile> } +// CHECK: %[[res:[0-9]+]] = call @external_func_B(%[[A]], %[[B]]) : (memref<4x4xf64>, f64) -> memref<2x4xf64> +// CHECK: return %{{.*}} : memref<2x4xf64> |