aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexander Belyaev <pifon@google.com>2021-05-26 20:22:49 +0200
committerAlexander Belyaev <pifon@google.com>2021-05-27 08:45:20 +0200
commit281ee4291110af5d1337d1da819a284eecf368ec (patch)
treea53d3af1d4c49f1b724f1949ca3ba24b2fb61a95
parent51d334a845a082338735b0fdfc620a4b15fa26fe (diff)
downloadllvm-281ee4291110af5d1337d1da819a284eecf368ec.zip
llvm-281ee4291110af5d1337d1da819a284eecf368ec.tar.gz
llvm-281ee4291110af5d1337d1da819a284eecf368ec.tar.bz2
[mlir] Add a pass to distribute linalg::TiledLoopOp.
Differential Revision: https://reviews.llvm.org/D103194
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h7
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Utils/Utils.h7
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp85
-rw-r--r--mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir39
-rw-r--r--mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp79
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp2
7 files changed, 220 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e346903..d6cb5cb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -860,6 +860,13 @@ void populateLinalgConvGeneralizationPatterns(
RewritePatternSet &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter());
+/// Linalg distribution patterns
+//
+/// Populates `patterns` with patterns to distribute linalg.tiled_loop.
+void populateLinalgDistributeTiledLoopPattern(
+ RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts,
+ const LinalgTransformationFilter &marker);
+
//===----------------------------------------------------------------------===//
// Op-specific patterns.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 03728e3..55da21d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -184,6 +184,8 @@ struct ProcInfo {
};
using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>(
OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges)>;
+using OneDimProcInfoCallBackFn =
+ std::function<ProcInfo(OpBuilder &b, Location loc)>;
/// Options that allow distribution of loops generated in Linalg transforms to
/// processors while generating the loops.
@@ -201,6 +203,11 @@ struct LinalgLoopDistributionOptions {
/// applied. If the vector is less than the number of `scf.parallel` loops
/// generated, then no distribution is applied.
SmallVector<DistributionMethod, 0> distributionMethod = {};
+
+ /// The map keyed by the distribution type that contains callback functions
+ /// that return the Values for processor ID (`procId`), and number of
+ /// processors (`nprocs`) used to execute the parallel loops.
+ DenseMap<StringRef, OneDimProcInfoCallBackFn> procInfoMap;
};
/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 1458c94..d954db9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
CodegenStrategy.cpp
ComprehensiveBufferize.cpp
Detensorize.cpp
+ Distribution.cpp
DropUnitDims.cpp
ElementwiseToLinalg.cpp
Fusion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp b/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp
new file mode 100644
index 0000000..994f7c7
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp
@@ -0,0 +1,85 @@
+//===- Distibution.cpp - linalg named ops to generic ops --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Linalg distibution pass. It updates `tiled_loop`
+// control variables depending on the distribution type.
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#define DEBUG_TYPE "linalg-distribution"
+
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+struct DistributeTiledLoopPattern
+ : public OpRewritePattern<linalg::TiledLoopOp> {
+ DistributeTiledLoopPattern(MLIRContext *context,
+ LinalgLoopDistributionOptions options,
+ LinalgTransformationFilter marker)
+ : OpRewritePattern<linalg::TiledLoopOp>(context), options(options),
+ marker(marker) {}
+ LogicalResult matchAndRewrite(linalg::TiledLoopOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(marker.checkAndNotify(rewriter, op)))
+ return failure();
+ if (!op.distribution_types().hasValue())
+ return failure();
+
+ Location loc = op.getLoc();
+ SmallVector<Value, 2> newLowerBounds = op.lowerBound();
+ SmallVector<Value, 2> newUpperBounds = op.upperBound();
+ SmallVector<Value, 2> newSteps = op.step();
+
+ // Update bounds and steps.
+ auto distributionTypes = op.distribution_types().getValue();
+ for (int i = 0, e = op.getNumLoops(); i < e; ++i) {
+ StringRef type = distributionTypes[i].cast<StringAttr>().getValue();
+ auto procInfoCallback = options.procInfoMap.find(type);
+ if (procInfoCallback == options.procInfoMap.end())
+ continue;
+
+ if (!isParallelIteratorType(op.iterator_types()[i])) {
+ op.emitOpError("only support for parallel loops is implemented");
+ return failure();
+ }
+ ProcInfo info = procInfoCallback->second(rewriter, loc);
+ updateBoundsForCyclicDistribution(rewriter, loc, info.procId, info.nprocs,
+ newLowerBounds[i], newUpperBounds[i],
+ newSteps[i]);
+ }
+ rewriter.updateRootInPlace(op, [&] {
+ op.setLowerBounds(newLowerBounds);
+ op.setUpperBounds(newUpperBounds);
+ op.setSteps(newSteps);
+ });
+ marker.replaceLinalgTransformationFilter(rewriter, op);
+ return success();
+ }
+
+private:
+ LinalgLoopDistributionOptions options;
+ LinalgTransformationFilter marker;
+};
+
+} // namespace
+
+void mlir::linalg::populateLinalgDistributeTiledLoopPattern(
+ RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts,
+ const LinalgTransformationFilter &marker) {
+ patterns.add<DistributeTiledLoopPattern>(patterns.getContext(), opts, marker);
+}
diff --git a/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir b/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
new file mode 100644
index 0000000..564db5ab
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -test-linalg-distribution %s | FileCheck %s
+
+func private @foo(%A: tensor<64x64xf32>,
+ %B: tensor<64x64xf32>) -> tensor<64x64xf32>
+
+func @distribute_for_gpu(%A: tensor<64x64xf32>,
+ %B: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ %c0 = constant 0 : index
+ %c16 = constant 16 : index
+ %c64 = constant 64 : index
+ %c24 = constant 24 : index
+ %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c64, %c64) step (%c24, %c16)
+ ins (%A_ = %A: tensor<64x64xf32>) outs (%B_ = %B:tensor<64x64xf32>)
+ distribution ["block_x", "block_y"] {
+ %0 = call @foo(%A_, %B_)
+ : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32>
+ linalg.yield %0 : tensor<64x64xf32>
+ }
+ return %0 : tensor<64x64xf32>
+}
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 * 24)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 * 16)>
+
+// CHECK-LABEL: func @distribute_for_gpu
+// CHECK: %[[C64:.*]] = constant 64 : index
+
+// CHECK-DAG: %[[GPU_BLOCK_X:.*]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[GPU_GRID_DIM_X:.*]] = "gpu.grid_dim"() {dimension = "x"}
+// CHECK-DAG: %[[LB_I:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[GPU_BLOCK_X]]]
+// CHECK-DAG: %[[STEP_I:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[GPU_GRID_DIM_X]]]
+
+// CHECK-DAG: %[[GPU_BLOCK_Y:.*]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-DAG: %[[GPU_GRID_DIM_Y:.*]] = "gpu.grid_dim"() {dimension = "y"}
+// CHECK-DAG: %[[LB_J:.*]] = affine.apply #[[$MAP1]](){{\[}}%[[GPU_BLOCK_Y]]]
+// CHECK-DAG: %[[STEP_J:.*]] = affine.apply #[[$MAP1]](){{\[}}%[[GPU_GRID_DIM_Y]]]
+
+// CHECK: linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = (%[[LB_I]], %[[LB_J]])
+// CHECK-SAME: to (%[[C64]], %[[C64]]) step (%[[STEP_I]], %[[STEP_J]])
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp
new file mode 100644
index 0000000..224d8ca
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp
@@ -0,0 +1,79 @@
+//===- TestLinalgDistribution.cpp - Test Linalg hoisting functions --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for testing Linalg hoisting functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+template <char dim>
+static linalg::ProcInfo getGpuBlockInfo(OpBuilder &b, Location loc) {
+ std::string d(1, dim);
+ StringAttr attr = b.getStringAttr(d);
+
+ Type indexType = b.getIndexType();
+ ProcInfo procInfo = {b.create<gpu::BlockIdOp>(loc, indexType, attr),
+ b.create<gpu::GridDimOp>(loc, indexType, attr)};
+ return procInfo;
+}
+
+static LinalgLoopDistributionOptions getDistributionOptions() {
+ LinalgLoopDistributionOptions opts;
+ opts.procInfoMap.insert(std::make_pair("block_x", getGpuBlockInfo<'x'>));
+ opts.procInfoMap.insert(std::make_pair("block_y", getGpuBlockInfo<'y'>));
+ return opts;
+}
+
+namespace {
+struct TestLinalgDistribution
+ : public PassWrapper<TestLinalgDistribution, FunctionPass> {
+ TestLinalgDistribution() = default;
+ TestLinalgDistribution(const TestLinalgDistribution &pass) {}
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<AffineDialect, gpu::GPUDialect>();
+ }
+
+ void runOnFunction() override;
+};
+} // namespace
+
+void TestLinalgDistribution::runOnFunction() {
+ auto funcOp = getFunction();
+ OwningRewritePatternList distributeTiledLoopsPatterns(&getContext());
+ populateLinalgDistributeTiledLoopPattern(
+ distributeTiledLoopsPatterns, getDistributionOptions(),
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>{},
+ {Identifier::get("distributed", funcOp.getContext())})
+ .addFilter([](Operation *op) {
+ return success(!op->getParentOfType<linalg::TiledLoopOp>());
+ }));
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(distributeTiledLoopsPatterns));
+ // Ensure we drop the marker in the end.
+ funcOp.walk([](LinalgOp op) {
+ op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
+ });
+}
+
+namespace mlir {
+namespace test {
+void registerTestLinalgDistribution() {
+ PassRegistration<TestLinalgDistribution> testTestLinalgDistributionPass(
+ "test-linalg-distribution", "Test Linalg distribution.");
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 23bfe77..c2966e6 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -77,6 +77,7 @@ void registerTestGpuParallelLoopMappingPass();
void registerTestIRVisitorsPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
+void registerTestLinalgDistribution();
void registerTestLinalgElementwiseFusion();
void registerTestPushExpandingReshape();
void registerTestLinalgFusionTransforms();
@@ -156,6 +157,7 @@ void registerTestPasses() {
test::registerTestIRVisitorsPass();
test::registerTestInterfaces();
test::registerTestLinalgCodegenStrategy();
+ test::registerTestLinalgDistribution();
test::registerTestLinalgElementwiseFusion();
test::registerTestPushExpandingReshape();
test::registerTestLinalgFusionTransforms();