aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjerryyin <zhuoryin@amd.com>2025-05-06 20:34:39 +0000
committerjerryyin <zhuoryin@amd.com>2025-05-06 21:14:55 +0000
commita3b34537c385bcd8b9574fe16471b094cb2a3291 (patch)
tree92fc169aff39a85f45853d1d80d8a0ec8744d38e
parent92d1875f4c43c9d9fc270e592e9812f675678c52 (diff)
downloadllvm-users/zyin/cherry-pick-upstream-PR138332.zip
llvm-users/zyin/cherry-pick-upstream-PR138332.tar.gz
llvm-users/zyin/cherry-pick-upstream-PR138332.tar.bz2
Unconditionally fold pack(unpack) for push down unpack passusers/zyin/cherry-pick-upstream-PR138332
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp43
-rw-r--r--mlir/test/Dialect/Linalg/data-layout-propagation.mlir42
2 files changed, 29 insertions, 56 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 19b590b..7b0abff 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -298,55 +298,37 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
return std::make_tuple(packedOperand, indexingMap);
}
-static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
- int numDpsOuts = genericOp.getNumDpsInits();
- Block *block = genericOp.getBody();
- int numBlockArgs = block->getNumArguments();
- int initArgStartIndex = numBlockArgs - numDpsOuts;
- for (int i = 0; i < numDpsOuts; ++i) {
- int matchingInitArgIndex = initArgStartIndex + i;
- return block->getArgument(matchingInitArgIndex).use_empty();
- }
- return true;
-}
-
-/// Pack a genericOp and return it.
+/// This function is a helper subroutine to pack a genericOp and return it. It
+/// will create a new generic op with the packed operand and the packed output
+/// according to packInfo when we attempt to push down unpack or bubble up pack
+/// around it. Implicitly this will only work when a packInfo can be obtained.
+/// This make sure that we are only using this function on parallel permuted
+/// dimensions.
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
Value dest, AffineMap packedOutIndexingMap,
const PackInfo &packInfo,
- bool canUnpackPackFold) {
+ bool isFoldableUnpackPack) {
Location loc = genericOp.getLoc();
SmallVector<Value> inputOperands;
SmallVector<Value> inputOperandsFromUnpackedSource;
SmallVector<AffineMap> indexingMaps;
-
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, packInfo, genericOp, inputOperand);
-
if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
} else {
inputOperandsFromUnpackedSource.push_back(packedOperand);
}
-
inputOperands.push_back(packedOperand);
indexingMaps.push_back(packedIndexingMap);
}
- // Note: Whether or not the unpack pack sequence can fold also depends on
- // the caller of this routine.
- // 1) In push down unpack op pattern, this is true because the pack op is
- // generated and we can guarantee they are compatible.
- // 2) In bubble up pack op pattern, this is not true because the unpack op
- // can be from an arbitrary domain so we need to keep both.
- canUnpackPackFold = canUnpackPackFold && isGenericOutsNotUsed(genericOp) &&
- !hasGatherSemantics(genericOp);
// If The pack and unpack op can be folded:
// 1) use unpack op source op for operand to fold unpack -> pack sequence.
// 2) init tensor of the generic op can be replaced by the destination of the
// pack op.
- if (canUnpackPackFold) {
+ if (isFoldableUnpackPack) {
inputOperands = inputOperandsFromUnpackedSource;
if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
dest = destPack.getDest();
@@ -487,8 +469,10 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
.getDefiningOp<tensor::EmptyOp>()) {
dest = packOpDest;
}
+ // Here pack(unpack) isn't naively foldable because the unpack op can be from
+ // an arbitrary domain so we need to keep both.
return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
- *packInfo, /*canUnpackPackFold=*/false);
+ *packInfo, /*isFoldableUnpackPack=*/false);
}
/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
@@ -1125,9 +1109,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
}
// Pack the genericOp.
+ // pack(unpack) is foldable in this case. This is because in pushing down the
+ // unpack, by default we will populate an additional pack op after the unpack.
+ // This guarantees them to be foldable.
GenericOp newGenericOp =
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
- /*canUnpackPackFold=*/true);
+ /*isFoldableUnpackPack=*/true);
Value newResult =
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index a7749e7..63f068d 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -455,13 +455,10 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
-// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_EMPTY_PACK]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]]]
-// CHECK-SAME: outs(%[[PACKED_ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[UNPACKED_ARG0]]
@@ -485,22 +482,11 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
// CHECK-LABEL: func.func @unpack_on_input
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME: ins(%[[ARG0_PACK]]
-// CHECK-SAME: outs(%[[ARG1_PACK]]
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
@@ -1407,19 +1393,21 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
} -> tensor<?x64xbf16>
return %0 : tensor<?x64xbf16>
}
-
// CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
+// CHECK-SAME: into %[[ARG2]]
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
// -----
-func.func @push_unpack_in_padded_domain_not_foldable(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor<?x64xf32>) -> tensor<?x64xf32> {
+func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor<?x64xf32>) -> tensor<?x64xf32> {
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %arg1 : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xf32>) {
^bb0(%in: f32, %out: f32):
@@ -1428,15 +1416,13 @@ func.func @push_unpack_in_padded_domain_not_foldable(%arg0: tensor<8x8x4x8xf32>,
} -> tensor<?x64xf32>
return %0 : tensor<?x64xf32>
}
-
-// CHECK-LABEL: func.func @push_unpack_in_padded_domain_not_foldable
+// CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
-// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK: %[[UNPACK1:.+]] = linalg.pack %[[UNPACK]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[UNPACK1]] : tensor<?x8x4x8xf32>)
-// CHECK-SAME: outs(%[[PACK]] : tensor<?x8x4x8xf32>)
+// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xf32>)
// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
+// CHECK-SAME: into %[[ARG1]]
// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>