aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzhicong zhong <zhiczhong@outlook.com>2024-06-28 08:50:18 +0800
committerGitHub <noreply@github.com>2024-06-28 08:50:18 +0800
commiteec9d0b6816e815fbe009941c1fda3b39c38adeb (patch)
tree459282af4011336c32c017eafd6707c189d23f7f
parentca06b610841c849eb1db43ad057310c8f7eea81e (diff)
downloadllvm-eec9d0b6816e815fbe009941c1fda3b39c38adeb.zip
llvm-eec9d0b6816e815fbe009941c1fda3b39c38adeb.tar.gz
llvm-eec9d0b6816e815fbe009941c1fda3b39c38adeb.tar.bz2
[mlir][Linalg] use linalg.reduce to simplify the mergeReductions in partialReductionInterface (#94579)
The current implementation of `mergeReduction` in `LinalgOpPartialReductionInterface` builds a `linalg.generic` from scratch. While we already have `linalg.reduce` op which has the same semantic as this generic op, this PR replaces the generic op with `linalg.reduce` to simplify the implementation.
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp37
-rw-r--r--mlir/test/Dialect/Linalg/transform-tile-reduction.mlir37
2 files changed, 21 insertions, 53 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index b2a1e7c..3978049 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -447,39 +447,10 @@ struct LinalgOpPartialReductionInterface
Location loc, ValueRange partialReduce,
ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
-
- // Step 1. Recover the dims that actually need to be merged from the
- // original operation. We can classify the original iterators as follows:
- //
- // parallel --> parallel
- // reduction + not in reductionDims --> parallel (already reduced)
- // reduction + in reductionDims --> reduction (will reduce now)
- SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
- utils::IteratorType::parallel);
- for (int redIdx : reductionDims)
- iterators[redIdx] = utils::IteratorType::reduction;
-
- // Step 2. For each partial result, create a map to index it. This map
- // is simply the indexing map for the original result with reductionDims
- // appended (as produced in tileToPartialReduction).
- int64_t numInits = linalgOp.getNumDpsInits();
- SmallVector<AffineMap> indexingMaps(numInits * 2);
- for (int idx : llvm::seq<int>(0, numInits)) {
- AffineMap &inputMap = indexingMaps[idx];
- AffineMap &outputMap = indexingMaps[numInits + idx];
-
- outputMap =
- linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
- inputMap = outputMap;
- for (int redPos : reductionDims) {
- inputMap = inputMap.insertResult(b.getAffineDimExpr(redPos),
- inputMap.getNumResults());
- }
- }
-
- auto reduction = b.create<GenericOp>(
- loc, op->getResultTypes(), partialReduce, linalgOp.getDpsInits(),
- indexingMaps, iterators,
+ SmallVector<int64_t> reductionDimsInt64(reductionDims.begin(),
+ reductionDims.end());
+ auto reduction = b.create<linalg::ReduceOp>(
+ loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
[&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
int64_t numInits = linalgOp.getNumDpsInits();
SmallVector<Value> yieldedValues;
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 8feb3c2..cce4b4e 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -23,9 +23,8 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @reduction_tile(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
@@ -37,10 +36,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
// CHECK: %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
-// CHECK: %[[PS:.*]] = affine.min #[[MAP2]](%[[K]])[%[[D1]]]
+// CHECK: %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]]
// CHECK: %[[EXT2:.*]] = tensor.extract_slice %[[ARG0]][0, %[[K:.*]]] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
-// CHECK: %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
+// CHECK: %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
@@ -48,10 +47,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
// CHECK: scf.yield %[[INS]] : tensor<?x5xf32>
// CHECK: }
-// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
// CHECK: arith.addf
// CHECK: linalg.yield
-// CHECK: } -> tensor<?xf32>
+// CHECK: }
// CHECK: return %[[R]] : tensor<?xf32>
// -----
@@ -81,7 +80,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: func @reduction_tile_transpose
// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
@@ -91,7 +89,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
// CHECK: scf.yield {{.*}} : tensor<5x?xf32>
// CHECK: }
-// CHECK: linalg.generic
+// CHECK: linalg.reduce
// CHECK: return
// -----
@@ -150,10 +148,11 @@ module attributes {transform.with_named_sequence} {
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
// CHECK: }
// CHECK: }
-// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
+// CHECK: {
// CHECK: arith.addf
// CHECK: linalg.yield
-// CHECK: } -> tensor<?xf32>
+// CHECK: }
// CHECK: return %[[R]] : tensor<?xf32>
// -----
@@ -177,8 +176,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK: func @matmul_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
@@ -203,10 +200,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32>
// CHECK: }
// CHECK: }
-// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) {
+// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) dimensions = [2]
// CHECK: arith.addf
// CHECK: linalg.yield
-// CHECK: } -> tensor<?x?xf32>
+// CHECK: }
// CHECK: return %[[R]] : tensor<?x?xf32>
// -----
@@ -270,10 +267,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
// CHECK: }
// CHECK: }
-// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
// CHECK: arith.addf
// CHECK: linalg.yield
-// CHECK: } -> tensor<?xf32>
+// CHECK: }
// CHECK: return %[[R]] : tensor<?xf32>
// -----
@@ -307,7 +304,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: iterator_types = ["parallel", "reduction"]
transform.print %2 {name = "expecting parallel reduction"} : !transform.any_op
// CHECK: expecting parallel reduction
- // CHECK-NEXT: linalg.generic
+ // CHECK-NEXT: linalg.reduce
// CHECK: iterator_types = ["parallel", "reduction"]
transform.print %3 {name = "expecting parallel reduction"} : !transform.any_op
transform.yield
@@ -401,7 +398,7 @@ module {
// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}, {{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%{{.*}}, %{{.*}}: tensor<2x64xf32>, tensor<4096x2x64xf32>) outs(%{{.*}}: tensor<4096x2x64xf32>)
// CHECK: scf.yield %[[OUT]] : tensor<4096x2x64xf32>
// CHECK: scf.yield %[[L1]] : tensor<4096x2x64xf32>
-// CHECK: %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
+// CHECK: %[[OUT2:.*]] = linalg.reduce ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
// CHECK: return %[[OUT2]] : tensor<4096xf32>
// -----
@@ -445,6 +442,6 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[UPDATED]]#0 into %[[SUM]]
// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[UPDATED]]#1 into %[[MAX]]
// CHECK: scf.yield %[[INSERT1]], %[[INSERT1]]
-// CHECK: linalg.generic
+// CHECK: linalg.reduce
// CHECK: arith.addf
// CHECK: arith.maximumf