aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp153
-rw-r--r--mlir/test/Dialect/Linalg/data-layout-propagation.mlir104
2 files changed, 200 insertions, 57 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 47145e3..dc132b2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
namespace mlir {
#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
@@ -29,10 +30,66 @@ using namespace mlir::linalg;
namespace {
+// The struct contains the infomation about mapping packing information to
+// the iteration domain of Linalg ops.
+struct PackInfo {
+ int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
+ // InnerDimsPos on iteration domain, which follows the order in pack ops.
+ SmallVector<int64_t> tiledDimsPos;
+ // The sizes of tiling data dimensions on iteration domain.
+ llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
+ // The mapping from a dimension of iteration domain to the corresponding inner
+ // tiling dimension on iteration domain.
+ llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
+ // The permutation of outer dims (on domain).
+ SmallVector<int64_t> outerDimsOnDomainPerm;
+ Optional<Value> paddingValue;
+};
+
+static PackInfo getPackingInfoFromConsumer(
+ AffineMap indexingMap, ArrayRef<OpFoldResult> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+ Optional<Value> paddingValue = llvm::None) {
+ LLVM_DEBUG(
+ { llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; });
+ PackInfo packInfo;
+ packInfo.paddingValue = paddingValue;
+ int64_t origNumDims = indexingMap.getNumDims();
+ SmallVector<AffineExpr> exprs(indexingMap.getResults());
+ for (auto [index, innerDimPos, tileSize] :
+ llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
+ innerDimsPos, innerTileSizes)) {
+ int64_t domainDimPos =
+ exprs[innerDimPos].cast<AffineDimExpr>().getPosition();
+ packInfo.tiledDimsPos.push_back(domainDimPos);
+ packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
+ packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
+ LLVM_DEBUG({
+ llvm::dbgs() << "map innerDimPos=" << innerDimPos
+ << " to iteration dimension (d" << domainDimPos << ", d"
+ << packInfo.tileToPointMapping[domainDimPos]
+ << "), which has size=("
+ << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
+ });
+ }
+
+ for (auto dim : outerDimsPerm)
+ packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim));
+ if (!packInfo.outerDimsOnDomainPerm.empty()) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "map outer dimsDimsPerm to ";
+ for (auto dim : packInfo.outerDimsOnDomainPerm)
+ llvm::dbgs() << dim << " ";
+ llvm::dbgs() << "\n";
+ });
+ }
+
+ return packInfo;
+}
+
/// Returns a tuple for packed operand and indexing_map with the assumptions:
/// 1) The generic op is the producer of the pack op.
/// 2) The generic op has only one result.
-/// 3) The indexing map of the output operand is identity.
/// If the operand is a scalar or packing dimensions are all irrelevant to the
/// operand, the opreand and the updated indexing map will be returned.
/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
@@ -62,62 +119,57 @@ namespace {
/// inner_tiles = [8]
/// into %init : tensor<?xf32> -> tensor<?x8xf32>
static std::tuple<Value, AffineMap>
-getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc,
- tensor::PackOp packOp, GenericOp genericOp,
- OpOperand *opOperand) {
- int numOrigLoops = genericOp.getNumLoops();
- int64_t numInnerLoops = packOp.getInnerDimsPos().size();
+getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
+ GenericOp genericOp, OpOperand *opOperand) {
+ int64_t numOrigLoops = genericOp.getNumLoops();
+ int64_t numInnerLoops = packInfo.getNumTiledLoops();
int64_t numLoops = numOrigLoops + numInnerLoops;
AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
+ llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
-
if (genericOp.isScalar(opOperand))
- return std::make_tuple(
- opOperand->get(),
- AffineMap::get(numLoops, 0, exprs, packOp.getContext()));
-
- llvm::SetVector<int64_t> innerDimsPosSet(packOp.getInnerDimsPos().begin(),
- packOp.getInnerDimsPos().end());
- // Mapping from AffinDimExpr of indexing maps to the operand shape dimension.
- DenseMap<int64_t, int64_t> iterMapToDim;
- for (auto [index, expr] : llvm::enumerate(origIndexingMap.getResults())) {
+ return std::make_tuple(opOperand->get(),
+ AffineMap::get(numLoops, 0, exprs, b.getContext()));
+
+ // Step 1. Construct the information of packing data dimensions; append inner
+ // dimensions to the indexing maps for the operand.
+ for (auto [index, expr] : llvm::enumerate(exprs)) {
int64_t dimPos = expr.cast<AffineDimExpr>().getPosition();
- if (!innerDimsPosSet.contains(dimPos))
- continue;
- iterMapToDim[dimPos] = index;
+ domainDimToOperandDim[dimPos] = index;
}
-
- // Construct the information of packing data dimensions and new indexing maps
- // for the operand.
SmallVector<int64_t> innerDimsPos;
SmallVector<OpFoldResult> innerTileSizes;
- for (auto [index, value] : llvm::enumerate(
- llvm::zip(packOp.getInnerDimsPos(), packOp.getMixedTiles()))) {
- int64_t dimPos = std::get<0>(value);
- if (!iterMapToDim.count(dimPos))
+ for (auto dimPos : packInfo.tiledDimsPos) {
+ if (!domainDimToOperandDim.count(dimPos))
continue;
- innerDimsPos.push_back(iterMapToDim[dimPos]);
- innerTileSizes.push_back(std::get<1>(value));
- exprs.push_back(b.getAffineDimExpr(numOrigLoops + index));
+ int64_t index = domainDimToOperandDim[dimPos];
+ innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
+ innerDimsPos.push_back(index);
+ exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
}
- auto indexingMap = AffineMap::get(numLoops, 0, exprs, packOp.getContext());
+ // Step 2. Fold transpose variants (i.e., outerDimsPerm) into generic op.
+ // TODO: should we propagate the permutation of outer dims to the pack op?
SmallVector<int64_t> outerDimsPerm;
- for (auto outDim : packOp.getOuterDimsPerm()) {
- if (!iterMapToDim.count(outDim))
- continue;
- outerDimsPerm.push_back(iterMapToDim[outDim]);
+ if (!packInfo.outerDimsOnDomainPerm.empty()) {
+ SmallVector<int64_t> inversedOuterPerm =
+ invertPermutationVector(packInfo.outerDimsOnDomainPerm);
+ for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
+ int64_t dimPos = exprs[i].cast<AffineDimExpr>().getPosition();
+ exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
+ }
}
+ auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
// The operand does not have dimensions that relates to pack op.
- if (innerDimsPos.empty() && outerDimsPerm.empty())
+ if (innerDimsPos.empty())
return std::make_tuple(opOperand->get(), indexingMap);
auto empty = tensor::PackOp::createDestinationTensor(
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
auto packedOperand = b.create<tensor::PackOp>(
loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
- packOp.getPaddingValue(), outerDimsPerm);
+ packInfo.paddingValue, outerDimsPerm);
return std::make_tuple(packedOperand, indexingMap);
}
@@ -187,34 +239,45 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
return failure();
OpOperand *opOperand = genericOp.getDpsInitOperand(0);
- // TODO: Add support for all permutation indexing maps.
- if (!genericOp.getMatchingIndexingMap(opOperand).isIdentity())
- return rewriter.notifyMatchFailure(
- packOp, "the result of generic op does not have identity indexing_map");
+ auto packInfo = getPackingInfoFromConsumer(
+ genericOp.getMatchingIndexingMap(opOperand), packOp.getMixedTiles(),
+ packOp.getInnerDimsPos(), packOp.getOuterDimsPerm(),
+ packOp.getPaddingValue());
Location loc = packOp.getLoc();
SmallVector<Value> inputOperands;
SmallVector<AffineMap> indexingMaps;
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
- rewriter, loc, packOp, genericOp, inputOperand);
+ rewriter, loc, packInfo, genericOp, inputOperand);
inputOperands.push_back(packedOperand);
indexingMaps.push_back(packedIndexingMap);
}
int64_t numLoops = genericOp.getNumLoops();
- int64_t numInnerLoops = packOp.getInnerDimsPos().size();
+ int64_t numInnerLoops = packInfo.getNumTiledLoops();
int64_t newNumLoops = numLoops + numInnerLoops;
SmallVector<utils::IteratorType> iterTypes =
genericOp.getIteratorTypesArray();
iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
+ // Rebuild the indexing map for the corresponding init operand.
+ auto [packedOutOperand, packedOutIndexingMap] =
+ getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp,
+ opOperand);
SmallVector<AffineExpr> outExprs(
- genericOp.getMatchingIndexingMap(opOperand).getResults());
+ packedOutIndexingMap.getResults().drop_back(numInnerLoops));
+ // Apply transpose to the indexing map, because we'll replace the init operand
+ // with the destination of pack op.
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+ if (!outerDimsPerm.empty()) {
+ applyPermutationToVector<AffineExpr>(outExprs, outerDimsPerm);
+ }
for (int i = 0; i < numInnerLoops; ++i)
outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i));
- indexingMaps.push_back(
- AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext()));
+ AffineMap outMap =
+ AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext());
+ indexingMaps.push_back(outMap);
auto newGenericOp = rewriter.create<linalg::GenericOp>(
loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps,
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index a5488d2..bb84272 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -96,16 +96,17 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: ten
into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
return %pack : tensor<16x4x32x16xi32>
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func.func @elem_pack_transpose_outer_dims
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16xi32>
// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<4x16x32x16xi32>
// CHECK: %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[PACK_ARG0]]
// CHECK-SAME: outs(%[[DEST]]
@@ -130,16 +131,17 @@ func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>,
into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32>
return %pack : tensor<16x4x16x32xi32>
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func.func @elem_pack_transpose_inner_and_outer_dims
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
// CHECK-SAME: into %[[ARG0_EMPTY]]
// CHECK: %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[PACK_ARG0]]
// CHECK-SAME: outs(%[[DEST]]
@@ -200,6 +202,37 @@ func.func @dynamic_broadcast_pack(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %d
// -----
+#map = affine_map<(d0, d1, d2, d3) -> (d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @elem_pack_transpose_inner_and_outer_dims2(%arg0: tensor<64xf32>, %dest: tensor<1x2x56x57x32xf32>) -> tensor<1x2x56x57x32xf32> {
+ %0 = tensor.empty() : tensor<1x56x57x64xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map, #map1],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<64xf32>)
+ outs(%0 : tensor<1x56x57x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<1x56x57x64xf32>
+ %2 = tensor.pack %1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %dest : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
+ return %2 : tensor<1x2x56x57x32xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: func.func @elem_pack_transpose_inner_and_outer_dims2
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<2x32xf32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: ins(%[[PACKED_ARG0]]
+// CHECK-SAME: outs(%[[DEST]]
+
+// -----
+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#map2 = affine_map<(d0, d1) -> (d1)>
@@ -225,6 +258,53 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x
into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32>
return %4 : tensor<100x200x4x16x16x32xi32>
}
-// CHECK: func.func @transpose_pack
-// CHECK: linalg.generic
-// CHECK: tensor.pack
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)>
+// CHECK: func.func @transpose_pack
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [3, 1] inner_tiles = [16, 32]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
+// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG2_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
+// CHECK-SAME: outs(%[[DEST]]
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d1)>
+func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32>
+{
+ %init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
+ %transpose = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>)
+ outs(%init_transpose : tensor<100x200x128x256xi32>) {
+ ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+ %0 = arith.addi %b0, %b1 : i32
+ %1 = arith.addi %0, %b2 : i32
+ linalg.yield %1 : i32
+ } -> tensor<100x200x128x256xi32>
+ %4 = tensor.pack %transpose
+ outer_dims_perm = [1, 2, 3, 0]
+ inner_dims_pos = [3, 2]
+ inner_tiles = [16, 32]
+ into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32>
+ return %4 : tensor<200x4x16x100x16x32xi32>
+}