aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2020-04-08 14:53:37 -0400
committerNicolas Vasilache <ntv@google.com>2020-04-08 16:54:40 -0400
commit6fb6a4d7f972d8faacd6b2646fe15f2eea1e4915 (patch)
tree85c82b5d3d0c85f359479281acb9010210e3455f
parentc6e917d2d3ea07960721923230c34abe3b6214cc (diff)
downloadllvm-6fb6a4d7f972d8faacd6b2646fe15f2eea1e4915.zip
llvm-6fb6a4d7f972d8faacd6b2646fe15f2eea1e4915.tar.gz
llvm-6fb6a4d7f972d8faacd6b2646fe15f2eea1e4915.tar.bz2
[mlir][Linalg] Add a test for a fused Linalg pass based on DRR to go from matmul to vectors
This revision builds a simple "fused pass" consisting of 2 levels of tiling, memory promotion and vectorization using linalg transformations written as composable pattern rewrites.
-rw-r--r--mlir/test/Dialect/Linalg/matmul-to-vector.mlir16
-rw-r--r--mlir/test/lib/DeclarativeTransforms/CMakeLists.txt4
-rw-r--r--mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td43
-rw-r--r--mlir/test/lib/Transforms/CMakeLists.txt2
-rw-r--r--mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp51
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp2
6 files changed, 118 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Linalg/matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir
new file mode 100644
index 0000000..351b204
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -linalg-matmul-to-vector | FileCheck %s
+
+func @matmul_perm(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
+ linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} :
+ memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
+ return
+}
+
+// CHECK-LABEL:func @matmul_perm
+// CHECK: vector.contract
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
index 9672edb..f068542 100644
--- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
+++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
@@ -5,3 +5,7 @@ add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td)
mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td)
+mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters)
+add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen)
diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td
new file mode 100644
index 0000000..7fa4a3d
--- /dev/null
+++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td
@@ -0,0 +1,43 @@
+//===- TestLinalgMatmulToVectorPatterns.td - Test patterns -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the pattern definition file for declarative Linalg transformations
+// tests.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
+#define TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
+
+include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
+include "mlir/Dialect/Vector/VectorTransformPatterns.td"
+
+//===----------------------------------------------------------------------===//
+// Linalg tiling and permutation patterns.
+//===----------------------------------------------------------------------===//
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (TileLinalgOp<[768, 264, 768], "L2__with_perm__", [1, 2, 0]>),
+ [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (TileLinalgOp<[8, 12, 16], "L1__with_perm__", [1, 0, 2]>),
+ [(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (PromoteSubviewsLinalgOp),
+ [(Constraint<HasOperandsOfType<"SubViewOp">>),
+ (Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
+
+//===----------------------------------------------------------------------===//
+// Linalg to vector contraction patterns.
+//===----------------------------------------------------------------------===//
+def : Pattern<(MatmulOp:$op $_, $_, $_),
+ [(VectorizeLinalgOp)],
+ [(Constraint<And<[
+ HasLinalgTransformMarker<"L1__with_perm__">,
+ PreconditionVectorizeLinalgOp]>>)]>;
+
+#endif // TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 904a472..23107f2 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_llvm_library(MLIRTestTransforms
TestGpuMemoryPromotion.cpp
TestGpuParallelLoopMapping.cpp
TestInlining.cpp
+ TestLinalgMatmulToVector.cpp
TestLinalgTransforms.cpp
TestLiveness.cpp
TestLoopMapping.cpp
@@ -24,6 +25,7 @@ add_llvm_library(MLIRTestTransforms
DEPENDS
MLIRStandardOpsIncGen
+ MLIRTestLinalgMatmulToVectorPatternsIncGen
MLIRTestLinalgTransformPatternsIncGen
MLIRTestVectorTransformPatternsIncGen
)
diff --git a/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp
new file mode 100644
index 0000000..6f49fab
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp
@@ -0,0 +1,51 @@
+//===- TestLinalgMatmulToVector.cpp - Test VectorTransfers lowering -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <type_traits>
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::vector;
+
+namespace {
+#include "TestLinalgMatmulToVectorPatterns.h.inc"
+
+struct DeclarativeTransforms
+ : public PassWrapper<DeclarativeTransforms, FunctionPass> {
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ AffineApplyOp::getCanonicalizationPatterns(patterns, context);
+ AffineMinOp::getCanonicalizationPatterns(patterns, context);
+ AffineMaxOp::getCanonicalizationPatterns(patterns, context);
+ AllocOp::getCanonicalizationPatterns(patterns, context);
+ SubViewOp::getCanonicalizationPatterns(patterns, context);
+ ViewOp::getCanonicalizationPatterns(patterns, context);
+ populateWithGenerated(context, &patterns);
+ applyPatternsGreedily(getFunction(), patterns);
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+void registerTestLinalgMatmulToVectorPass() {
+ PassRegistration<DeclarativeTransforms> pass(
+ "linalg-matmul-to-vector",
+ "Test declarative transform patterns for matmul 3-D tiling + promotion"
+ " + vectorization");
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e8b2f3d..50a9296 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -39,6 +39,7 @@ void registerSimpleParametricTilingPass();
void registerSymbolTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAllReduceLoweringPass();
+void registerTestLinalgMatmulToVectorPass();
void registerTestLoopPermutationPass();
void registerTestCallGraphPass();
void registerTestConstantFold();
@@ -101,6 +102,7 @@ void registerTestPasses() {
registerSymbolTestPasses();
registerTestAffineDataCopyPass();
registerTestAllReduceLoweringPass();
+ registerTestLinalgMatmulToVectorPass();
registerTestLoopPermutationPass();
registerTestCallGraphPass();
registerTestConstantFold();