aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorHan-Chung Wang <hanhan0912@gmail.com>2024-06-03 16:39:52 -0700
committerGitHub <noreply@github.com>2024-06-03 16:39:52 -0700
commit0ea1271ee13c8c3d765904dba16dd27b91584d66 (patch)
tree316ba86757929781be150423520cc1acd30312cc /mlir
parent43847c1de60ddba26d93c138ad81aa0d3b3c8c31 (diff)
downloadllvm-0ea1271ee13c8c3d765904dba16dd27b91584d66.zip
llvm-0ea1271ee13c8c3d765904dba16dd27b91584d66.tar.gz
llvm-0ea1271ee13c8c3d765904dba16dd27b91584d66.tar.bz2
[mlir][vector] Add support for unrolling vector.bitcast ops. (#94064)
The revision unrolls vector.bitcast like: ```mlir %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64> ``` to ```mlir %cst = arith.constant dense<0> : vector<2x2xi64> %0 = vector.extract %arg0[0] : vector<4xi32> from vector<2x4xi32> %1 = vector.bitcast %0 : vector<4xi32> to vector<2xi64> %2 = vector.insert %1, %cst [0] : vector<2xi64> into vector<2x2xi64> %3 = vector.extract %arg0[1] : vector<4xi32> from vector<2x4xi32> %4 = vector.bitcast %3 : vector<4xi32> to vector<2xi64> %5 = vector.insert %4, %2 [1] : vector<2xi64> into vector<2x2xi64> ``` The scalable vector is not supported because of the limitation of `vector::createUnrollIterator`. The targetRank could mismatch the final rank during unrolling; there is no direct way to query what the final rank is from the object.
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td14
-rw-r--r--mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h9
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp1
-rw-r--r--mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp5
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp96
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir10
-rw-r--r--mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir53
8 files changed, 189 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index bc3c16d4..c91e8fb 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -89,6 +89,20 @@ def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyLowerBitCastPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.lower_bitcast",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector bitcast operations should be lowered to
+ finer-grained vector primitives.
+
+ This is usally a late step that is run after bufferization as part of the
+ process of lowering to e.g. LLVM or NVVM.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyLowerBroadcastPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_broadcast",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 8fd9904..1976b83 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -276,6 +276,15 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populates the pattern set with the following patterns:
+///
+/// [UnrollBitCastOp]
+/// A one-shot unrolling of BitCastOp to (one or more) ExtractOp +
+/// BitCastOp (of `targetRank`) + InsertOp.
+void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
+ int64_t targetRank = 1,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index e3a436c..55143d5 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -64,6 +64,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorBitCastLoweringPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
populateVectorMaskOpLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 61fd6bd..2396026 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -79,6 +79,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
}
+void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorBitCastLoweringPatterns(patterns);
+}
+
void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorBroadcastLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 4dbefdd..723b2f6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVectorTransforms
BufferizableOpInterfaceImpl.cpp
+ LowerVectorBitCast.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
LowerVectorGather.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
new file mode 100644
index 0000000..092ec92
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -0,0 +1,96 @@
+//===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.bitcast' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-bitcast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// A one-shot unrolling of vector.bitcast to the `targetRank`.
+///
+/// Example:
+///
+/// vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32>
+///
+/// Would be unrolled to:
+///
+/// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
+/// %0 = vector.extract %a[0, 0, 0] ─┐
+/// : vector<4xi64> from vector<1x2x3x4xi64> |
+/// %1 = vector.bitcast %0 | - Repeated 6x for
+/// : vector<4xi64> to vector<8xi32> | all leading positions
+/// %2 = vector.insert %1, %result [0, 0, 0] |
+/// : vector<8xi64> into vector<1x2x3x8xi32> ─┘
+///
+/// Note: If any leading dimension before the `targetRank` is scalable the
+/// unrolling will stop before the scalable dimension.
+class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
+public:
+ UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), targetRank(targetRank) {};
+
+ LogicalResult matchAndRewrite(vector::BitCastOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+ auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+ if (!unrollIterator)
+ return failure();
+
+ // TODO: Support the scalable vector cases. It is not supported because
+ // the final rank could be values other than `targetRank`. It makes creating
+ // the result type of new vector.bitcast ops much harder.
+ if (resultType.isScalable()) {
+ return rewriter.notifyMatchFailure(op,
+ "unrolling vector.bitcast on scalable "
+ "vectors is not yet implemented");
+ }
+
+ ArrayRef<int64_t> shape = resultType.getShape().take_back(targetRank);
+ auto bitcastResType = VectorType::get(shape, resultType.getElementType());
+
+ Location loc = op.getLoc();
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+ for (auto position : *unrollIterator) {
+ Value extract =
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+ Value bitcast =
+ rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
+ result =
+ rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+private:
+ int64_t targetRank = 1;
+};
+
+} // namespace
+
+void mlir::vector::populateVectorBitCastLoweringPatterns(
+ RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
+ patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 245edb6..12121ea 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2564,3 +2564,13 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
%0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_bitcast_2d
+// CHECK: llvm.bitcast
+// CHECK-NOT: vector.bitcast
+func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
+ %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
+ return %0 : vector<2x2xi64>
+}
diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
new file mode 100644
index 0000000..23fece2
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+func.func @vector_bitcast_0d(%arg0: vector<i32>) -> vector<f32> {
+ %0 = vector.bitcast %arg0 : vector<i32> to vector<f32>
+ return %0 : vector<f32>
+}
+// CHECK-LABEL: func.func @vector_bitcast_0d
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
+// CHECK: %[[RES:.+]] = vector.bitcast %[[IN]] : vector<i32> to vector<f32>
+// CHECK: return %[[RES]]
+
+func.func @vector_bitcast_1d(%arg0: vector<10xi64>) -> vector<20xi32> {
+ %0 = vector.bitcast %arg0 : vector<10xi64> to vector<20xi32>
+ return %0 : vector<20xi32>
+}
+// CHECK-LABEL: func.func @vector_bitcast_1d
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
+// CHECK: %[[RES:.+]] = vector.bitcast %[[IN]] : vector<10xi64> to vector<20xi32>
+// CHECK: return %[[RES]]
+
+func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
+ %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
+ return %0 : vector<2x2xi64>
+}
+// CHECK-LABEL: func.func @vector_bitcast_2d
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
+// CHECK: %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64>
+// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32>
+// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64>
+// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0]
+// CHECK: %[[V2:.+]] = vector.extract %[[IN]][1] : vector<4xi32> from vector<2x4xi32>
+// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<4xi32> to vector<2xi64>
+// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1]
+// CHECK: return %[[R2]]
+
+func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) -> vector<1x2x[3]x8xi32> {
+ %0 = vector.bitcast %arg0 : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
+ return %0 : vector<1x2x[3]x8xi32>
+}
+// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
+// CHECK: vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_bitcast
+ } : !transform.any_op
+ transform.yield
+ }
+}