diff options
Diffstat (limited to 'mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp')
-rw-r--r-- | mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp | 90 |
1 files changed, 87 insertions, 3 deletions
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp index 2573e11..e5d7dc5 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -186,6 +187,51 @@ private: TransformationFilter filter; }; +/// Pattern for testing `tileUsingSCFForallOp` (that tiles operations using +/// the `TilingInterface` with `scf.forall` ops for iterating over the tiles) +/// while using a `filter` to avoid recursive application. +struct TestTileUsingSCFForallOp + : public OpInterfaceRewritePattern<TilingInterface> { + TestTileUsingSCFForallOp(MLIRContext *context, scf::SCFTilingOptions options, + TransformationFilter filter = TransformationFilter(), + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern<TilingInterface>(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} + + /// Construct a generic pattern applied to `opName`. + TestTileUsingSCFForallOp(StringRef opName, MLIRContext *context, + scf::SCFTilingOptions options, + TransformationFilter filter = TransformationFilter(), + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern<TilingInterface>(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} + + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + + FailureOr<scf::SCFTilingResult> tilingResult = + scf::tileUsingSCFForallOp(rewriter, op, options); + if (failed(tilingResult)) + return rewriter.notifyMatchFailure(op, "failed to tile operation"); + + if (op->getNumResults()) { + rewriter.replaceOp(op, tilingResult->replacements); + } else { + rewriter.eraseOp(op); + } + + for (auto *tiledOp : tilingResult->tiledOps) + filter.replaceTransformationFilter(rewriter, tiledOp); + return success(); + } + +private: + scf::SCFTilingOptions options; + TransformationFilter filter; +}; + /// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern /// (that tiles and fuses operations using the `TilingInterface` with `scf.for` /// ops for iterating over the tiles) while using a `filter` to avoid recursive @@ -398,9 +444,9 @@ struct TestTilingInterfacePass TestTilingInterfacePass(const TestTilingInterfacePass &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert<affine::AffineDialect, linalg::LinalgDialect, - memref::MemRefDialect, scf::SCFDialect, - tensor::TensorDialect>(); + registry.insert<affine::AffineDialect, gpu::GPUDialect, + linalg::LinalgDialect, memref::MemRefDialect, + scf::SCFDialect, tensor::TensorDialect>(); linalg::registerTilingInterfaceExternalModels(registry); tensor::registerTilingInterfaceExternalModels(registry); } @@ -415,6 +461,12 @@ struct TestTilingInterfacePass "Test tiling using TilingInterface with scf.for operations"), llvm::cl::init(false)}; + Option<bool> testTilingForAll{ + *this, "tile-using-scf-forall", + llvm::cl::desc( + "Test tiling using TilingInterface with scf.forall operations"), + llvm::cl::init(false)}; + Option<bool> testTileConsumerFuseAndYieldProducer{ *this, "tile-consumer-fuse-and-yield-producer-using-scf-for", llvm::cl::desc( @@ -455,6 +507,21 @@ static void addPatternForTiling(MLIRContext *context, patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter); } +static void addPatternForTilingUsingForall( + MLIRContext *context, RewritePatternSet &patterns, StringRef filterName, + ArrayRef<int64_t> tileSizes, + ArrayRef<DeviceMappingAttrInterface> mapping = {}, + ArrayRef<int64_t> interchange = {}) { + scf::SCFTilingOptions tilingOptions; + SmallVector<OpFoldResult> tileSizesOfr = + getAsIndexOpFoldResult(context, tileSizes); + tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange); + tilingOptions.setMapping(mapping); + TransformationFilter filter(StringAttr::get(context, filterName), + StringAttr::get(context, "tiled")); + patterns.add<TestTileUsingSCFForallOp>(context, tilingOptions, filter); +} + static void addPatternForTileFuseAndYield(MLIRContext *context, RewritePatternSet &patterns, StringRef filterName, @@ -514,6 +581,23 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20}); return; } + if (testTilingForAll) { + // 1. Tiling M and N dims of `linalg.matmul` on tensors. + addPatternForTilingUsingForall( + context, patterns, "simple_gemm", {10, 20}, + {gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimY), + gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimX)}); + // 2. Tiling 3D parallel generic op which implements a transpose. + addPatternForTilingUsingForall(context, patterns, + "parallel_generic_transpose", {10, 0, 20}); + // 3. Tiling 2D conv op. + addPatternForTilingUsingForall(context, patterns, "simple_conv", + {0, 0, 0, 0, 10, 20, 30}); + // 4. Tiling a simple op with `linalg.index` inside. + addPatternForTilingUsingForall(context, patterns, "indexed_semantics", + {10, 20}); + return; + } if (testTileConsumerAndFuseProducer) { // 1. Tile and fuse of gemm with fill producer and bias-add consumer. addPatternForTileAndFuse(context, patterns, "fusion", {10, 20}); |