diff options
author | jerryyin <zhuoryin@amd.com> | 2025-05-06 20:34:39 +0000 |
---|---|---|
committer | jerryyin <zhuoryin@amd.com> | 2025-05-06 21:14:55 +0000 |
commit | a3b34537c385bcd8b9574fe16471b094cb2a3291 (patch) | |
tree | 92fc169aff39a85f45853d1d80d8a0ec8744d38e | |
parent | 92d1875f4c43c9d9fc270e592e9812f675678c52 (diff) | |
download | llvm-users/zyin/cherry-pick-upstream-PR138332.zip llvm-users/zyin/cherry-pick-upstream-PR138332.tar.gz llvm-users/zyin/cherry-pick-upstream-PR138332.tar.bz2 |
Unconditionally fold pack(unpack) for push down unpack passusers/zyin/cherry-pick-upstream-PR138332
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 43 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/data-layout-propagation.mlir | 42 |
2 files changed, 29 insertions, 56 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 19b590b..7b0abff 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -298,55 +298,37 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, return std::make_tuple(packedOperand, indexingMap); } -static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) { - int numDpsOuts = genericOp.getNumDpsInits(); - Block *block = genericOp.getBody(); - int numBlockArgs = block->getNumArguments(); - int initArgStartIndex = numBlockArgs - numDpsOuts; - for (int i = 0; i < numDpsOuts; ++i) { - int matchingInitArgIndex = initArgStartIndex + i; - return block->getArgument(matchingInitArgIndex).use_empty(); - } - return true; -} - -/// Pack a genericOp and return it. +/// This function is a helper subroutine to pack a genericOp and return it. It +/// will create a new generic op with the packed operand and the packed output +/// according to packInfo when we attempt to push down unpack or bubble up pack +/// around it. Implicitly this will only work when a packInfo can be obtained. +/// This make sure that we are only using this function on parallel permuted +/// dimensions. static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest, AffineMap packedOutIndexingMap, const PackInfo &packInfo, - bool canUnpackPackFold) { + bool isFoldableUnpackPack) { Location loc = genericOp.getLoc(); SmallVector<Value> inputOperands; SmallVector<Value> inputOperandsFromUnpackedSource; SmallVector<AffineMap> indexingMaps; - for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( rewriter, loc, packInfo, genericOp, inputOperand); - if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) { inputOperandsFromUnpackedSource.push_back(unpackOp.getSource()); } else { inputOperandsFromUnpackedSource.push_back(packedOperand); } - inputOperands.push_back(packedOperand); indexingMaps.push_back(packedIndexingMap); } - // Note: Whether or not the unpack pack sequence can fold also depends on - // the caller of this routine. - // 1) In push down unpack op pattern, this is true because the pack op is - // generated and we can guarantee they are compatible. - // 2) In bubble up pack op pattern, this is not true because the unpack op - // can be from an arbitrary domain so we need to keep both. - canUnpackPackFold = canUnpackPackFold && isGenericOutsNotUsed(genericOp) && - !hasGatherSemantics(genericOp); // If The pack and unpack op can be folded: // 1) use unpack op source op for operand to fold unpack -> pack sequence. // 2) init tensor of the generic op can be replaced by the destination of the // pack op. - if (canUnpackPackFold) { + if (isFoldableUnpackPack) { inputOperands = inputOperandsFromUnpackedSource; if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) dest = destPack.getDest(); @@ -487,8 +469,10 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, .getDefiningOp<tensor::EmptyOp>()) { dest = packOpDest; } + // Here pack(unpack) isn't naively foldable because the unpack op can be from + // an arbitrary domain so we need to keep both. return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, - *packInfo, /*canUnpackPackFold=*/false); + *packInfo, /*isFoldableUnpackPack=*/false); } /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. @@ -1125,9 +1109,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, } // Pack the genericOp. + // pack(unpack) is foldable in this case. This is because in pushing down the + // unpack, by default we will populate an additional pack op after the unpack. + // This guarantees them to be foldable. GenericOp newGenericOp = packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo, - /*canUnpackPackFold=*/true); + /*isFoldableUnpackPack=*/true); Value newResult = newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index a7749e7..63f068d 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -455,13 +455,10 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56 // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]] -// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_EMPTY_PACK]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]]] -// CHECK-SAME: outs(%[[PACKED_ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[UNPACKED_ARG0]] @@ -485,22 +482,11 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56 // CHECK-LABEL: func.func @unpack_on_input // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> -// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] -// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG1_PACK_EMPTY]] -// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] -// CHECK-SAME: ins(%[[ARG0_PACK]] -// CHECK-SAME: outs(%[[ARG1_PACK]] +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG1]] @@ -1407,19 +1393,21 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de } -> tensor<?x64xbf16> return %0 : tensor<?x64xbf16> } - // CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] // CHECK: %[[EMPTY:.+]] = tensor.empty // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>) // CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>) // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]] +// CHECK-SAME: into %[[ARG2]] // CHECK: return %[[UNPACK]] : tensor<?x64xbf16> // ----- -func.func @push_unpack_in_padded_domain_not_foldable(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor<?x64xf32>) -> tensor<?x64xf32> { +func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor<?x64xf32>) -> tensor<?x64xf32> { %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %arg1 : tensor<8x8x4x8xf32> -> tensor<?x64xf32> %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xf32>) { ^bb0(%in: f32, %out: f32): @@ -1428,15 +1416,13 @@ func.func @push_unpack_in_padded_domain_not_foldable(%arg0: tensor<8x8x4x8xf32>, } -> tensor<?x64xf32> return %0 : tensor<?x64xf32> } - -// CHECK-LABEL: func.func @push_unpack_in_padded_domain_not_foldable +// CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] -// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG1]] -// CHECK: %[[UNPACK1:.+]] = linalg.pack %[[UNPACK]] +// CHECK: %[[EMPTY:.+]] = tensor.empty // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[UNPACK1]] : tensor<?x8x4x8xf32>) -// CHECK-SAME: outs(%[[PACK]] : tensor<?x8x4x8xf32>) +// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xf32>) // CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]] +// CHECK-SAME: into %[[ARG1]] // CHECK: return %[[UNPACK2]] : tensor<?x64xf32> |