// RUN: mlir-opt %s --bubble-down-memory-space-casts | FileCheck %s #map = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)> // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)> // CHECK-LABEL: func.func @load_store( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: index) { // CHECK: %[[VAL_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref // CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref // CHECK: return // CHECK: } func.func @load_store(%arg0: memref, %arg1: index) { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %0 = memref.load %memspacecast[%arg1] : memref memref.store %0, %memspacecast[%arg1] : memref return } // CHECK-LABEL: func.func @load_store_unfoldable( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: index) { // CHECK: %[[VAL_0:.*]] = memref.memory_space_cast %[[ARG0]] : memref to memref // CHECK: %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[ARG1]]] : memref // CHECK: memref.store %[[VAL_1]], %[[VAL_0]]{{\[}}%[[ARG1]]] : memref // CHECK: return // CHECK: } func.func @load_store_unfoldable(%arg0: memref, %arg1: index) { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %0 = memref.load %memspacecast[%arg1] : memref memref.store %0, %memspacecast[%arg1] : memref return } // CHECK-LABEL: func.func @cast( // CHECK-SAME: %[[ARG0:.*]]: memref<2xf32, 1>, // CHECK-SAME: %[[ARG1:.*]]: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) { // CHECK: %[[VAL_0:.*]] = memref.cast %[[ARG0]] : memref<2xf32, 1> to memref<*xf32, 1> // CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<*xf32, 1> to memref<*xf32> // CHECK: %[[VAL_2:.*]] = memref.cast %[[ARG1]] : memref<*xf32, 1> to memref<3x2xf32, 1> // CHECK: %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<3x2xf32, 1> to memref<3x2xf32> // CHECK: return %[[VAL_1]], %[[VAL_3]] : memref<*xf32>, memref<3x2xf32> // CHECK: } func.func @cast(%arg0: memref<2xf32, 1>, %arg1: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) { %memspacecast = memref.memory_space_cast %arg0 : memref<2xf32, 1> to memref<2xf32> %1 = memref.cast %memspacecast : memref<2xf32> to memref<*xf32> %memspacecast_1 = memref.memory_space_cast %arg1 : memref<*xf32, 1> to memref<*xf32> %2 = memref.cast %memspacecast_1 : memref<*xf32> to memref<3x2xf32> return %1, %2 : memref<*xf32>, memref<3x2xf32> } // CHECK-LABEL: func.func @view( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref { // CHECK: %[[VAL_0:.*]] = arith.constant 100 : index // CHECK: %[[VAL_1:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]]{{\[}}%[[ARG2]], %[[VAL_0]]] : memref to memref // CHECK: %[[VAL_2:.*]] = memref.memory_space_cast %[[VAL_1]] : memref to memref // CHECK: return %[[VAL_2]] : memref // CHECK: } func.func @view(%arg0: memref, %arg1: index, %arg2: index) -> memref { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %c100 = arith.constant 100 : index %view = memref.view %memspacecast[%arg1][%arg2, %c100] : memref to memref return %view : memref } // CHECK-LABEL: func.func @subview( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> { // CHECK: %[[VAL_0:.*]] = memref.subview %[[ARG0]][4, 2] [8, 2] [3, 2] : memref to memref<8x2xf32, strided<[?, 2], offset: ?>, 1> // CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<8x2xf32, strided<[?, 2], offset: ?>, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>> // CHECK: return %[[VAL_1]] : memref<8x2xf32, strided<[?, 2], offset: ?>> // CHECK: } func.func @subview(%arg0: memref, %arg1: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %subview = memref.subview %memspacecast[4, 2] [8, 2] [3, 2] : memref to memref<8x2xf32, strided<[?, 2], offset: ?>> return %subview : memref<8x2xf32, strided<[?, 2], offset: ?>> } // CHECK-LABEL: func.func @reinterpret_cast( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> { // CHECK-DAG: %[[VAL_0:.*]] = arith.constant 10 : index // CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index // CHECK: %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[VAL_1]]], sizes: [10, %[[VAL_0]]], strides: {{\[}}%[[VAL_0]], 1] : memref to memref<10x?xf32, strided<[?, 1], offset: ?>, 1> // CHECK: %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<10x?xf32, strided<[?, 1], offset: ?>, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>> // CHECK: return %[[VAL_3]] : memref<10x?xf32, strided<[?, 1], offset: ?>> // CHECK: } func.func @reinterpret_cast(%arg0: memref, %arg1: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %c0 = arith.constant 0 : index %c10 = arith.constant 10 : index %reinterpret_cast = memref.reinterpret_cast %memspacecast to offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] : memref to memref<10x?xf32, strided<[?, 1], offset: ?>> return %reinterpret_cast : memref<10x?xf32, strided<[?, 1], offset: ?>> } // CHECK-LABEL: func.func @reshape( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: memref<1xindex>) -> memref { // CHECK: %[[VAL_0:.*]] = memref.reshape %[[ARG0]](%[[ARG1]]) : (memref, memref<1xindex>) -> memref // CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref to memref // CHECK: return %[[VAL_1]] : memref // CHECK: } func.func @reshape(%arg0: memref, %arg1: memref<1xindex>) -> memref { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %reshape = memref.reshape %memspacecast(%arg1) : (memref, memref<1xindex>) -> memref return %reshape : memref } // CHECK-LABEL: func.func @expand_shape( // CHECK-SAME: %[[ARG0:.*]]: memref<12xf32, 1>) -> memref<3x4xf32> { // CHECK: %[[VAL_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1]] output_shape [3, 4] : memref<12xf32, 1> into memref<3x4xf32, 1> // CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<3x4xf32, 1> to memref<3x4xf32> // CHECK: return %[[VAL_1]] : memref<3x4xf32> // CHECK: } func.func @expand_shape(%arg0: memref<12xf32, 1>) -> memref<3x4xf32> { %memspacecast = memref.memory_space_cast %arg0 : memref<12xf32, 1> to memref<12xf32> %expand_shape = memref.expand_shape %memspacecast [[0, 1]] output_shape [3, 4] : memref<12xf32> into memref<3x4xf32> return %expand_shape : memref<3x4xf32> } // CHECK-LABEL: func.func @collapse_shape( // CHECK-SAME: %[[ARG0:.*]]: memref<3x4xf32, 1>) -> memref<12xf32> { // CHECK: %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : memref<3x4xf32, 1> into memref<12xf32, 1> // CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<12xf32, 1> to memref<12xf32> // CHECK: return %[[VAL_1]] : memref<12xf32> // CHECK: } func.func @collapse_shape(%arg0: memref<3x4xf32, 1>) -> memref<12xf32> { %memspacecast = memref.memory_space_cast %arg0 : memref<3x4xf32, 1> to memref<3x4xf32> %collapse_shape = memref.collapse_shape %memspacecast [[0, 1]] : memref<3x4xf32> into memref<12xf32> return %collapse_shape : memref<12xf32> } // CHECK-LABEL: func.func @transpose( // CHECK-SAME: %[[ARG0:.*]]: memref) -> memref { // CHECK: %[[VAL_0:.*]] = memref.transpose %[[ARG0]] (d0, d1) -> (d1, d0) : memref to memref // CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref to memref // CHECK: return %[[VAL_1]] : memref // CHECK: } func.func @transpose(%arg0: memref) -> memref { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %transpose = memref.transpose %memspacecast (d0, d1) -> (d1, d0) : memref to memref return %transpose : memref } // CHECK-LABEL: func.func @atomic_rmw( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: index, // CHECK-SAME: %[[ARG2:.*]]: f32) -> f32 { // CHECK: %[[VAL_0:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[ARG0]]{{\[}}%[[ARG1]]] : (f32, memref) -> f32 // CHECK: return %[[VAL_0]] : f32 // CHECK: } func.func @atomic_rmw(%arg0: memref, %arg1: index, %arg2: f32) -> f32 { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %0 = memref.atomic_rmw addf %arg2, %memspacecast[%arg1] : (f32, memref) -> f32 return %0 : f32 } // CHECK-LABEL: func.func @assume_alignment( // CHECK-SAME: %[[ARG0:.*]]: memref) -> memref { // CHECK: %[[VAL_0:.*]] = memref.assume_alignment %[[ARG0]], 16 : memref // CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref to memref // CHECK: return %[[VAL_1]] : memref // CHECK: } func.func @assume_alignment(%arg0: memref) -> memref { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %1 = memref.assume_alignment %memspacecast, 16 : memref return %1 : memref } // CHECK-LABEL: func.func @op_with_cast_sequence( // CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf32, 1>, // CHECK-SAME: %[[ARG1:.*]]: index, // CHECK-SAME: %[[ARG2:.*]]: f32) -> memref<16xf32> { // CHECK-DAG: %[[VAL_0:.*]] = arith.constant 4 : index // CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index // CHECK: %[[VAL_2:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1> // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1> // CHECK: %[[VAL_4:.*]] = memref.memory_space_cast %[[VAL_3]] : memref<16xf32, 1> to memref<16xf32> // CHECK: %[[VAL_5:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1> // CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[ARG2]] : f32 // CHECK: memref.store %[[VAL_6]], %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1> // CHECK: %[[VAL_7:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[VAL_3]]{{\[}}%[[VAL_0]]] : (f32, memref<16xf32, 1>) -> f32 // CHECK: return %[[VAL_4]] : memref<16xf32> // CHECK: } func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> { %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32> %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32> %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32> %loaded = memref.load %collapsed[%c0] : memref<16xf32> %added = arith.addf %loaded, %arg2 : f32 memref.store %added, %collapsed[%c0] : memref<16xf32> %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32 return %collapsed : memref<16xf32> } // CHECK-LABEL: func.func @transfer_read_write( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: index) { // CHECK: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[VAL_1:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_0]] : memref, vector<4xf32> // CHECK: vector.transfer_write %[[VAL_1]], %[[ARG0]]{{\[}}%[[ARG1]]] : vector<4xf32>, memref // CHECK: return // CHECK: } func.func @transfer_read_write(%arg0: memref, %arg1: index) { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %c0 = arith.constant 0.0 : f32 %0 = vector.transfer_read %memspacecast[%arg1], %c0 : memref, vector<4xf32> vector.transfer_write %0, %memspacecast[%arg1] : vector<4xf32>, memref return } // NOTE: The operations disappear because they can get folded. // CHECK-LABEL: func.func @transfer_read_write_tensor( // CHECK-SAME: %[[ARG0:.*]]: tensor, // CHECK-SAME: %[[ARG1:.*]]: index) -> tensor { // CHECK: return %[[ARG0]] : tensor // CHECK: } func.func @transfer_read_write_tensor(%arg0: tensor, %arg1: index) -> tensor { %c0 = arith.constant 0.0 : f32 %0 = vector.transfer_read %arg0[%arg1], %c0 : tensor, vector<4xf32> %1 = vector.transfer_write %0, %arg0[%arg1] : vector<4xf32>, tensor return %1 : tensor } // CHECK-LABEL: func.func @vector_load_store( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: index) { // CHECK: %[[VAL_0:.*]] = vector.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref, vector<4xf32> // CHECK: vector.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref, vector<4xf32> // CHECK: return // CHECK: } func.func @vector_load_store(%arg0: memref, %arg1: index) { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %0 = vector.load %memspacecast[%arg1] : memref, vector<4xf32> vector.store %0, %memspacecast[%arg1] : memref, vector<4xf32> return } // CHECK-LABEL: func.func @masked_load_store( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: index) { // CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> // CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> // CHECK: %[[VAL_2:.*]] = vector.maskedload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref, vector<4xi1>, vector<4xf32> into vector<4xf32> // CHECK: vector.maskedstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref, vector<4xi1>, vector<4xf32> // CHECK: return // CHECK: } func.func @masked_load_store(%arg0: memref, %arg1: index) { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1> %passthrough = arith.constant dense<0.0> : vector<4xf32> %0 = vector.maskedload %memspacecast[%arg1], %mask, %passthrough : memref, vector<4xi1>, vector<4xf32> into vector<4xf32> vector.maskedstore %memspacecast[%arg1], %mask, %0 : memref, vector<4xi1>, vector<4xf32> return } // CHECK-LABEL: func.func @gather_scatter( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: index) { // CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> // CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense : vector<4xi1> // CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index // CHECK: %[[VAL_4:.*]] = vector.gather %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_0]] : memref, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> // CHECK: vector.scatter %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_4]] : memref, vector<4xindex>, vector<4xi1>, vector<4xf32> // CHECK: return // CHECK: } func.func @gather_scatter(%arg0: memref, %arg1: index) { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %c0 = arith.constant 0 : index %indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> %mask = arith.constant dense : vector<4xi1> %passthrough = arith.constant dense<0.0> : vector<4xf32> %0 = vector.gather %memspacecast[%c0] [%indices], %mask, %passthrough : memref, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> vector.scatter %memspacecast[%c0] [%indices], %mask, %0 : memref, vector<4xindex>, vector<4xi1>, vector<4xf32> return } // CHECK-LABEL: func.func @expandload_compressstore( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: index) { // CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> // CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> // CHECK: %[[VAL_2:.*]] = vector.expandload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref, vector<4xi1>, vector<4xf32> into vector<4xf32> // CHECK: vector.compressstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref, vector<4xi1>, vector<4xf32> // CHECK: return // CHECK: } func.func @expandload_compressstore(%arg0: memref, %arg1: index) { %memspacecast = memref.memory_space_cast %arg0 : memref to memref %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1> %passthrough = arith.constant dense<0.0> : vector<4xf32> %0 = vector.expandload %memspacecast[%arg1], %mask, %passthrough : memref, vector<4xi1>, vector<4xf32> into vector<4xf32> vector.compressstore %memspacecast[%arg1], %mask, %0 : memref, vector<4xi1>, vector<4xf32> return }