From 42bba97fc24f045f593fc26f998bac9b08633255 Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Thu, 7 Dec 2023 15:01:55 -0800 Subject: [mlir] Extend CombineTransferReadOpTranspose pattern to handle extf ops. (#74754) This patch modifies the CombineTransferReadOpTranspose pattern to handle extf ops. Also adds a test which shows the transpose getting folded into the transfer_read. --- mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp | 8 ++++-- .../Conversion/VectorToGPU/vector-to-mma-ops.mlir | 30 ++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) (limited to 'mlir') diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 429d113..f151011 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -455,7 +455,8 @@ struct CombineTransferReadOpTranspose final Type resultType = op.getType(); Operation *extOp; if ((extOp = source.getDefiningOp()) || - (extOp = source.getDefiningOp())) { + (extOp = source.getDefiningOp()) || + (extOp = source.getDefiningOp())) { source = extOp->getOperand(0); resultType = VectorType::get(cast(resultType).getShape(), @@ -493,9 +494,12 @@ struct CombineTransferReadOpTranspose final if (isa(extOp)) result = rewriter.create(loc, op.getType(), result) .getResult(); - else + else if (isa(extOp)) result = rewriter.create(loc, op.getType(), result) .getResult(); + else + result = rewriter.create(loc, op.getType(), result) + .getResult(); } rewriter.replaceOp(op, result); diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir index fa9fff2..962ed7d 100644 --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -460,3 +460,33 @@ func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf vector.transfer_write %cast, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32> return } + +// ----- + +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-LABEL: func @fold_transpose_into_transfer_read( +// CHECK-SAME: %[[ALLOC:.+]]: memref<64x128xf16> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f16 +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true], permutation_map = #[[$MAP]]} +// CHECK: %[[EXTF1:.+]] = arith.extf %[[READ]] +// CHECK-NOT: vector.transpose +// CHECK: %[[RESULT:.+]] = vector.contract +func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector: vector<32x128xf16>, %alloc2: memref<32x64xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %init = arith.constant dense<0.000000e+00> : vector<32x64xf32> + %0 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<64x128xf16>, vector<64x128xf16> + %1 = arith.extf %0 : vector<64x128xf16> to vector<64x128xf32> + %2 = arith.extf %vector : vector<32x128xf16> to vector<32x128xf32> + %3 = vector.transpose %1, [1, 0] : vector<64x128xf32> to vector<128x64xf32> + %4 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %2, %3, %init : vector<32x128xf32>, vector<128x64xf32> into vector<32x64xf32> + vector.transfer_write %4, %alloc2[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32> + return +} + +// ----- -- cgit v1.1