diff options
Diffstat (limited to 'mlir/test/Dialect/Tensor/one-shot-bufferize.mlir')
-rw-r--r-- | mlir/test/Dialect/Tensor/one-shot-bufferize.mlir | 41 |
1 files changed, 35 insertions, 6 deletions
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index 5f95da2..f66cf7a 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -8,12 +8,12 @@ // Test bufferization using memref types that have no layout map. // RUN: mlir-opt %s -one-shot-bufferize="unknown-type-conversion=identity-layout-map bufferize-function-boundaries" -split-input-file -o /dev/null -// CHECK-LABEL: func @insert_slice_fun +// CHECK-LABEL: func private @insert_slice_fun // CHECK-SAME: %[[A0:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>, // CHECK-SAME: %[[A1:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>, // CHECK-SAME: %[[t0:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>>, // CHECK-SAME: %[[t1:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>> -func.func @insert_slice_fun( +func.func private @insert_slice_fun( %A0 : tensor<?xf32> {bufferization.writable = false}, %A1 : tensor<?xf32> {bufferization.writable = true}, %t0 : tensor<4xf32> {bufferization.writable = false}, @@ -331,12 +331,12 @@ func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index) // ----- // CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0 + 5)> -// CHECK-LABEL: func.func @cast_retains_buffer_layout( +// CHECK-LABEL: func.func private @cast_retains_buffer_layout( // CHECK-SAME: %[[t:.*]]: memref<?xf32, #[[$map]]>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> { // CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, #[[$map]]> to memref<10xf32, #[[$map]]> // CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, #[[$map]]> to memref<?xf32, strided<[1], offset: 7>> // CHECK: return %[[slice]] -func.func @cast_retains_buffer_layout( +func.func private @cast_retains_buffer_layout( %t: tensor<?xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0 + 5)>}, %sz: index) @@ -353,12 +353,12 @@ func.func @cast_retains_buffer_layout( // ----- -// CHECK-LABEL: func.func @cast_retains_buffer_layout_strided( +// CHECK-LABEL: func private @cast_retains_buffer_layout_strided( // CHECK-SAME: %[[t:.*]]: memref<?xf32, strided<[1], offset: 5>>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> { // CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, strided<[1], offset: 5>> to memref<10xf32, strided<[1], offset: 5>> // CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, strided<[1], offset: 5>> to memref<?xf32, strided<[1], offset: 7>> // CHECK: return %[[slice]] -func.func @cast_retains_buffer_layout_strided( +func.func private @cast_retains_buffer_layout_strided( %t: tensor<?xf32> {bufferization.buffer_layout = strided<[1], offset: 5>}, %sz: index) @@ -490,3 +490,32 @@ func.func @collapse_shape_regression( tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32> return } + +// ----- + +// CHECK-LABEL: func private @mult_return_callee( +// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1, +// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index { +// CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: return %[[A]] : index +// CHECK: ^bb2: +// CHECK: return %[[B]] : index +func.func private @mult_return_callee(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) { + %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32> + cf.cond_br %cond,^a, ^b +^a: + return %casted, %a : tensor<10xf32>, index +^b: + return %casted, %b : tensor<10xf32>, index +} + +// CHECK-LABEL: func @mult_return( +// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1, +// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) { +func.func @mult_return(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) { + // CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index + // CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index + %t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index) + return %t_res, %v : tensor<10xf32>, index +} |