# RUN: %PYTHON %s | FileCheck %s import mlir.dialects.arith as arith import mlir.dialects.memref as memref import mlir.extras.types as T from mlir.dialects.memref import _infer_memref_subview_result_type from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() return f # CHECK-LABEL: TEST: testSubViewAccessors @run def testSubViewAccessors(): ctx = Context() module = Module.parse( r""" func.func @f1(%arg0: memref) { %0 = arith.constant 0 : index %1 = arith.constant 1 : index %2 = arith.constant 2 : index %3 = arith.constant 3 : index %4 = arith.constant 4 : index %5 = arith.constant 5 : index memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref to memref> return } """, ctx, ) func_body = module.body.operations[0].regions[0].blocks[0] subview = func_body.operations[6] assert subview.source == subview.operands[0] assert len(subview.offsets) == 2 assert len(subview.sizes) == 2 assert len(subview.strides) == 2 assert subview.result == subview.results[0] # CHECK: SubViewOp print(type(subview).__name__) # CHECK: constant 0 print(subview.offsets[0]) # CHECK: constant 1 print(subview.offsets[1]) # CHECK: constant 2 print(subview.sizes[0]) # CHECK: constant 3 print(subview.sizes[1]) # CHECK: constant 4 print(subview.strides[0]) # CHECK: constant 5 print(subview.strides[1]) # CHECK-LABEL: TEST: testCustomBuidlers @run def testCustomBuidlers(): with Context() as ctx, Location.unknown(ctx): module = Module.parse( r""" func.func @f1(%arg0: memref, %arg1: index, %arg2: index) { return } """ ) f = module.body.operations[0] func_body = f.regions[0].blocks[0] with InsertionPoint.at_block_terminator(func_body): memref.LoadOp(f.arguments[0], f.arguments[1:]) # CHECK: func @f1(%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] print(module) assert module.operation.verify() # CHECK-LABEL: TEST: testMemRefAttr @run def testMemRefAttr(): with Context() as ctx, Location.unknown(ctx): module = Module.create() with InsertionPoint(module.body): memref.global_("objFifo_in0", T.memref(16, T.i32())) # CHECK: memref.global @objFifo_in0 : memref<16xi32> print(module) # CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics @run def testSubViewOpInferReturnTypeSemantics(): with Context() as ctx, Location.unknown(ctx): module = Module.create() with InsertionPoint(module.body): x = memref.alloc(T.memref(10, 10, T.i32()), [], []) # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32> print(x.owner) y = memref.subview(x, [1, 1], [3, 3], [1, 1]) assert y.owner.verify() # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>> print(y.owner) z = memref.subview( x, [arith.constant(T.index(), 1), 1], [3, 3], [1, 1], ) # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>> print(z.owner) z = memref.subview( x, [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], [3, 3], [1, 1], ) # CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>> print(z.owner) s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4)) z = memref.subview( x, [s, 0], [3, 3], [1, 1], ) # CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>> print(z) try: _infer_memref_subview_result_type( x.type, [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], [ShapedType.get_dynamic_size(), 3], [1, 1], ) except ValueError as e: # CHECK: Only inferring from python or mlir integer constant is supported print(e) try: memref.subview( x, [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], [ShapedType.get_dynamic_size(), 3], [1, 1], ) except ValueError as e: # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type print(e) layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1]) x = memref.alloc( T.memref( 10, 10, T.i32(), layout=layout, ), [], [arith.constant(T.index(), 42)], ) # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>> print(x.owner) y = memref.subview( x, [1, 1], [3, 3], [1, 1], result_type=T.memref(3, 3, T.i32(), layout=layout), ) # CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>> print(y.owner) # CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing @run def testSubViewOpInferReturnTypeExtensiveSlicing(): def check_strides_offset(memref, np_view): layout = memref.type.layout dtype_size_in_bytes = np_view.dtype.itemsize golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist() golden_offset = ( np_view.ctypes.data - np_view.base.ctypes.data ) // dtype_size_in_bytes assert (layout.strides, layout.offset) == (golden_strides, golden_offset) with Context() as ctx, Location.unknown(ctx): module = Module.create() with InsertionPoint(module.body): shape = (10, 22, 3, 44) golden_mem = np.zeros(shape, dtype=np.int32) mem1 = memref.alloc(T.memref(*shape, T.i32()), [], []) # fmt: off check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, ...]) check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 3, 44), (1, 1, 1, 1)), golden_mem[:, 1:2]) check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 44), (1, 1, 1, 1)), golden_mem[:, :, 1:2]) check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 3, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2]) check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 3, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2]) check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2]) check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :]) check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2]) check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 44), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :]) check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :]) check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2]) check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2]) check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2]) check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :]) # fmt: on # default strides and offset means no stridedlayout attribute means affinemap layout assert memref.subview( mem1, (0, 0, 0, 0), (10, 22, 3, 44), (1, 1, 1, 1) ).type.layout == AffineMapAttr.get( AffineMap.get( 4, 0, [ AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2), AffineDimExpr.get(3), ], ) ) shape = (7, 22, 30, 44) golden_mem = np.zeros(shape, dtype=np.int32) mem2 = memref.alloc(T.memref(*shape, T.i32()), [], []) # fmt: off check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 3, 44), (1, 2, 1, 1)), golden_mem[:, 0:22:2]) check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 44), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30]) check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400]) # fmt: on shape = (8, 8) golden_mem = np.zeros(shape, dtype=np.int32) # fmt: off mem3 = memref.alloc(T.memref(*shape, T.i32()), [], []) check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4]) check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8]) # fmt: on