aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib
diff options
context:
space:
mode:
authorJakub Lichman <limo@google.com>2020-09-30 07:13:59 +0000
committerJakub Lichman <limo@google.com>2020-09-30 11:31:28 +0000
commit0b17d4754a94b7129c2483762acd586783802b12 (patch)
treeaf9dc10de06ba85e9b35392844693eabea805e99 /mlir/test/lib
parent3cbd01ddb9372b725dcea3dd5fed21ef5b3d9578 (diff)
downloadllvm-0b17d4754a94b7129c2483762acd586783802b12.zip
llvm-0b17d4754a94b7129c2483762acd586783802b12.tar.gz
llvm-0b17d4754a94b7129c2483762acd586783802b12.tar.bz2
[mlir][Linalg] Tile sizes for Conv ops vectorization added as pass arguments
Current setup for conv op vectorization does not enable user to specify tile sizes as well as dimensions for vectorization. In this commit we change that by adding tile sizes as pass arguments. Every dimension with corresponding tile size > 1 is automatically vectorized. Differential Revision: https://reviews.llvm.org/D88533
Diffstat (limited to 'mlir/test/lib')
-rw-r--r--mlir/test/lib/Transforms/TestConvVectorization.cpp13
1 files changed, 12 insertions, 1 deletions
diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp
index c90d805..79b6464 100644
--- a/mlir/test/lib/Transforms/TestConvVectorization.cpp
+++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp
@@ -24,6 +24,13 @@ namespace {
/// A pass converting MLIR Linalg ops into Vector ops.
class TestConvVectorization
: public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> {
+public:
+ TestConvVectorization() = default;
+ TestConvVectorization(const TestConvVectorization &) {}
+ explicit TestConvVectorization(ArrayRef<int64_t> tileSizesParam) {
+ tileSizes = tileSizesParam;
+ }
+
void runOnOperation() override;
void getDependentDialects(DialectRegistry &registry) const override {
@@ -33,6 +40,10 @@ class TestConvVectorization
registry.insert<AffineDialect>();
registry.insert<StandardOpsDialect>();
}
+
+ ListOption<int64_t> tileSizes{
+ *this, "tile-sizes", llvm::cl::desc("Vectorization sizes."),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
};
} // namespace
@@ -47,7 +58,7 @@ void TestConvVectorization::runOnOperation() {
target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
- linalg::populateConvVectorizationPatterns(context, stage1Patterns);
+ linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes);
OwningRewritePatternList stage2Patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);