aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp')
-rw-r--r--mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp90
1 files changed, 87 insertions, 3 deletions
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 2573e11..e5d7dc5 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -186,6 +187,51 @@ private:
TransformationFilter filter;
};
+/// Pattern for testing `tileUsingSCFForallOp` (that tiles operations using
+/// the `TilingInterface` with `scf.forall` ops for iterating over the tiles)
+/// while using a `filter` to avoid recursive application.
+struct TestTileUsingSCFForallOp
+ : public OpInterfaceRewritePattern<TilingInterface> {
+ TestTileUsingSCFForallOp(MLIRContext *context, scf::SCFTilingOptions options,
+ TransformationFilter filter = TransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)), filter(std::move(filter)) {}
+
+ /// Construct a generic pattern applied to `opName`.
+ TestTileUsingSCFForallOp(StringRef opName, MLIRContext *context,
+ scf::SCFTilingOptions options,
+ TransformationFilter filter = TransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)), filter(std::move(filter)) {}
+
+ LogicalResult matchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const override {
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+
+ FailureOr<scf::SCFTilingResult> tilingResult =
+ scf::tileUsingSCFForallOp(rewriter, op, options);
+ if (failed(tilingResult))
+ return rewriter.notifyMatchFailure(op, "failed to tile operation");
+
+ if (op->getNumResults()) {
+ rewriter.replaceOp(op, tilingResult->replacements);
+ } else {
+ rewriter.eraseOp(op);
+ }
+
+ for (auto *tiledOp : tilingResult->tiledOps)
+ filter.replaceTransformationFilter(rewriter, tiledOp);
+ return success();
+ }
+
+private:
+ scf::SCFTilingOptions options;
+ TransformationFilter filter;
+};
+
/// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
/// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
/// ops for iterating over the tiles) while using a `filter` to avoid recursive
@@ -398,9 +444,9 @@ struct TestTilingInterfacePass
TestTilingInterfacePass(const TestTilingInterfacePass &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
- registry.insert<affine::AffineDialect, linalg::LinalgDialect,
- memref::MemRefDialect, scf::SCFDialect,
- tensor::TensorDialect>();
+ registry.insert<affine::AffineDialect, gpu::GPUDialect,
+ linalg::LinalgDialect, memref::MemRefDialect,
+ scf::SCFDialect, tensor::TensorDialect>();
linalg::registerTilingInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
}
@@ -415,6 +461,12 @@ struct TestTilingInterfacePass
"Test tiling using TilingInterface with scf.for operations"),
llvm::cl::init(false)};
+ Option<bool> testTilingForAll{
+ *this, "tile-using-scf-forall",
+ llvm::cl::desc(
+ "Test tiling using TilingInterface with scf.forall operations"),
+ llvm::cl::init(false)};
+
Option<bool> testTileConsumerFuseAndYieldProducer{
*this, "tile-consumer-fuse-and-yield-producer-using-scf-for",
llvm::cl::desc(
@@ -455,6 +507,21 @@ static void addPatternForTiling(MLIRContext *context,
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
}
+static void addPatternForTilingUsingForall(
+ MLIRContext *context, RewritePatternSet &patterns, StringRef filterName,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<DeviceMappingAttrInterface> mapping = {},
+ ArrayRef<int64_t> interchange = {}) {
+ scf::SCFTilingOptions tilingOptions;
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(context, tileSizes);
+ tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
+ tilingOptions.setMapping(mapping);
+ TransformationFilter filter(StringAttr::get(context, filterName),
+ StringAttr::get(context, "tiled"));
+ patterns.add<TestTileUsingSCFForallOp>(context, tilingOptions, filter);
+}
+
static void addPatternForTileFuseAndYield(MLIRContext *context,
RewritePatternSet &patterns,
StringRef filterName,
@@ -514,6 +581,23 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20});
return;
}
+ if (testTilingForAll) {
+ // 1. Tiling M and N dims of `linalg.matmul` on tensors.
+ addPatternForTilingUsingForall(
+ context, patterns, "simple_gemm", {10, 20},
+ {gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimY),
+ gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimX)});
+ // 2. Tiling 3D parallel generic op which implements a transpose.
+ addPatternForTilingUsingForall(context, patterns,
+ "parallel_generic_transpose", {10, 0, 20});
+ // 3. Tiling 2D conv op.
+ addPatternForTilingUsingForall(context, patterns, "simple_conv",
+ {0, 0, 0, 0, 10, 20, 30});
+ // 4. Tiling a simple op with `linalg.index` inside.
+ addPatternForTilingUsingForall(context, patterns, "indexed_semantics",
+ {10, 20});
+ return;
+ }
if (testTileConsumerAndFuseProducer) {
// 1. Tile and fuse of gemm with fill producer and bias-add consumer.
addPatternForTileAndFuse(context, patterns, "fusion", {10, 20});