diff options
author | Jakub Lichman <limo@google.com> | 2020-09-30 07:13:59 +0000 |
---|---|---|
committer | Jakub Lichman <limo@google.com> | 2020-09-30 11:31:28 +0000 |
commit | 0b17d4754a94b7129c2483762acd586783802b12 (patch) | |
tree | af9dc10de06ba85e9b35392844693eabea805e99 /mlir/test/lib | |
parent | 3cbd01ddb9372b725dcea3dd5fed21ef5b3d9578 (diff) | |
download | llvm-0b17d4754a94b7129c2483762acd586783802b12.zip llvm-0b17d4754a94b7129c2483762acd586783802b12.tar.gz llvm-0b17d4754a94b7129c2483762acd586783802b12.tar.bz2 |
[mlir][Linalg] Tile sizes for Conv ops vectorization added as pass arguments
Current setup for conv op vectorization does not enable user to specify tile
sizes as well as dimensions for vectorization. In this commit we change that by
adding tile sizes as pass arguments. Every dimension with corresponding tile
size > 1 is automatically vectorized.
Differential Revision: https://reviews.llvm.org/D88533
Diffstat (limited to 'mlir/test/lib')
-rw-r--r-- | mlir/test/lib/Transforms/TestConvVectorization.cpp | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp index c90d805..79b6464 100644 --- a/mlir/test/lib/Transforms/TestConvVectorization.cpp +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -24,6 +24,13 @@ namespace { /// A pass converting MLIR Linalg ops into Vector ops. class TestConvVectorization : public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> { +public: + TestConvVectorization() = default; + TestConvVectorization(const TestConvVectorization &) {} + explicit TestConvVectorization(ArrayRef<int64_t> tileSizesParam) { + tileSizes = tileSizesParam; + } + void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { @@ -33,6 +40,10 @@ class TestConvVectorization registry.insert<AffineDialect>(); registry.insert<StandardOpsDialect>(); } + + ListOption<int64_t> tileSizes{ + *this, "tile-sizes", llvm::cl::desc("Vectorization sizes."), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; }; } // namespace @@ -47,7 +58,7 @@ void TestConvVectorization::runOnOperation() { target.addLegalOp<linalg::FillOp, linalg::YieldOp>(); SmallVector<OwningRewritePatternList, 4> stage1Patterns; - linalg::populateConvVectorizationPatterns(context, stage1Patterns); + linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes); OwningRewritePatternList stage2Patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); |