diff options
author | Han-Chung Wang <hanhan0912@gmail.com> | 2024-06-03 16:39:52 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-03 16:39:52 -0700 |
commit | 0ea1271ee13c8c3d765904dba16dd27b91584d66 (patch) | |
tree | 316ba86757929781be150423520cc1acd30312cc | |
parent | 43847c1de60ddba26d93c138ad81aa0d3b3c8c31 (diff) | |
download | llvm-0ea1271ee13c8c3d765904dba16dd27b91584d66.zip llvm-0ea1271ee13c8c3d765904dba16dd27b91584d66.tar.gz llvm-0ea1271ee13c8c3d765904dba16dd27b91584d66.tar.bz2 |
[mlir][vector] Add support for unrolling vector.bitcast ops. (#94064)
The revision unrolls vector.bitcast like:
```mlir
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
```
to
```mlir
%cst = arith.constant dense<0> : vector<2x2xi64>
%0 = vector.extract %arg0[0] : vector<4xi32> from vector<2x4xi32>
%1 = vector.bitcast %0 : vector<4xi32> to vector<2xi64>
%2 = vector.insert %1, %cst [0] : vector<2xi64> into vector<2x2xi64>
%3 = vector.extract %arg0[1] : vector<4xi32> from vector<2x4xi32>
%4 = vector.bitcast %3 : vector<4xi32> to vector<2xi64>
%5 = vector.insert %4, %2 [1] : vector<2xi64> into vector<2x2xi64>
```
The scalable vector is not supported because of the limitation of
`vector::createUnrollIterator`. The targetRank could mismatch the final
rank during unrolling; there is no direct way to query what the final
rank is from the object.
8 files changed, 189 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index bc3c16d4..c91e8fb 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -89,6 +89,20 @@ def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect, let assemblyFormat = "attr-dict"; } +def ApplyLowerBitCastPatternsOp : Op<Transform_Dialect, + "apply_patterns.vector.lower_bitcast", + [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { + let description = [{ + Indicates that vector bitcast operations should be lowered to + finer-grained vector primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyLowerBroadcastPatternsOp : Op<Transform_Dialect, "apply_patterns.vector.lower_broadcast", [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 8fd9904..1976b83 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -276,6 +276,15 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populates the pattern set with the following patterns: +/// +/// [UnrollBitCastOp] +/// A one-shot unrolling of BitCastOp to (one or more) ExtractOp + +/// BitCastOp (of `targetRank`) + InsertOp. +void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, + int64_t targetRank = 1, + PatternBenefit benefit = 1); + } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index e3a436c..55143d5 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -64,6 +64,7 @@ void LowerVectorToLLVMPass::runOnOperation() { { RewritePatternSet patterns(&getContext()); populateVectorToVectorCanonicalizationPatterns(patterns); + populateVectorBitCastLoweringPatterns(patterns); populateVectorBroadcastLoweringPatterns(patterns); populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions()); populateVectorMaskOpLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 61fd6bd..2396026 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -79,6 +79,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns( vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); } +void transform::ApplyLowerBitCastPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorBitCastLoweringPatterns(patterns); +} + void transform::ApplyLowerBroadcastPatternsOp::populatePatterns( RewritePatternSet &patterns) { populateVectorBroadcastLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 4dbefdd..723b2f6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRVectorTransforms BufferizableOpInterfaceImpl.cpp + LowerVectorBitCast.cpp LowerVectorBroadcast.cpp LowerVectorContract.cpp LowerVectorGather.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp new file mode 100644 index 0000000..092ec92 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp @@ -0,0 +1,96 @@ +//===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.bitcast' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" + +#define DEBUG_TYPE "vector-bitcast-lowering" + +using namespace mlir; +using namespace mlir::vector; + +namespace { + +/// A one-shot unrolling of vector.bitcast to the `targetRank`. +/// +/// Example: +/// +/// vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32> +/// +/// Would be unrolled to: +/// +/// %result = arith.constant dense<0> : vector<1x2x3x8xi32> +/// %0 = vector.extract %a[0, 0, 0] ─┐ +/// : vector<4xi64> from vector<1x2x3x4xi64> | +/// %1 = vector.bitcast %0 | - Repeated 6x for +/// : vector<4xi64> to vector<8xi32> | all leading positions +/// %2 = vector.insert %1, %result [0, 0, 0] | +/// : vector<8xi64> into vector<1x2x3x8xi32> ─┘ +/// +/// Note: If any leading dimension before the `targetRank` is scalable the +/// unrolling will stop before the scalable dimension. +class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> { +public: + UnrollBitCastOp(int64_t targetRank, MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), targetRank(targetRank) {}; + + LogicalResult matchAndRewrite(vector::BitCastOp op, + PatternRewriter &rewriter) const override { + VectorType resultType = op.getResultVectorType(); + auto unrollIterator = vector::createUnrollIterator(resultType, targetRank); + if (!unrollIterator) + return failure(); + + // TODO: Support the scalable vector cases. It is not supported because + // the final rank could be values other than `targetRank`. It makes creating + // the result type of new vector.bitcast ops much harder. + if (resultType.isScalable()) { + return rewriter.notifyMatchFailure(op, + "unrolling vector.bitcast on scalable " + "vectors is not yet implemented"); + } + + ArrayRef<int64_t> shape = resultType.getShape().take_back(targetRank); + auto bitcastResType = VectorType::get(shape, resultType.getElementType()); + + Location loc = op.getLoc(); + Value result = rewriter.create<arith::ConstantOp>( + loc, resultType, rewriter.getZeroAttr(resultType)); + for (auto position : *unrollIterator) { + Value extract = + rewriter.create<vector::ExtractOp>(loc, op.getSource(), position); + Value bitcast = + rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract); + result = + rewriter.create<vector::InsertOp>(loc, bitcast, result, position); + } + + rewriter.replaceOp(op, result); + return success(); + } + +private: + int64_t targetRank = 1; +}; + +} // namespace + +void mlir::vector::populateVectorBitCastLoweringPatterns( + RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) { + patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit); +} diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 245edb6..12121ea 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2564,3 +2564,13 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi %0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32> return %0, %1 : vector<[2]xi32>, vector<[2]xi32> } + +// ----- + +// CHECK-LABEL: func.func @vector_bitcast_2d +// CHECK: llvm.bitcast +// CHECK-NOT: vector.bitcast +func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> { + %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64> + return %0 : vector<2x2xi64> +} diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir new file mode 100644 index 0000000..23fece2 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s + +func.func @vector_bitcast_0d(%arg0: vector<i32>) -> vector<f32> { + %0 = vector.bitcast %arg0 : vector<i32> to vector<f32> + return %0 : vector<f32> +} +// CHECK-LABEL: func.func @vector_bitcast_0d +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] +// CHECK: %[[RES:.+]] = vector.bitcast %[[IN]] : vector<i32> to vector<f32> +// CHECK: return %[[RES]] + +func.func @vector_bitcast_1d(%arg0: vector<10xi64>) -> vector<20xi32> { + %0 = vector.bitcast %arg0 : vector<10xi64> to vector<20xi32> + return %0 : vector<20xi32> +} +// CHECK-LABEL: func.func @vector_bitcast_1d +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] +// CHECK: %[[RES:.+]] = vector.bitcast %[[IN]] : vector<10xi64> to vector<20xi32> +// CHECK: return %[[RES]] + +func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> { + %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64> + return %0 : vector<2x2xi64> +} +// CHECK-LABEL: func.func @vector_bitcast_2d +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] +// CHECK: %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64> +// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32> +// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64> +// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0] +// CHECK: %[[V2:.+]] = vector.extract %[[IN]][1] : vector<4xi32> from vector<2x4xi32> +// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<4xi32> to vector<2xi64> +// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1] +// CHECK: return %[[R2]] + +func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) -> vector<1x2x[3]x8xi32> { + %0 = vector.bitcast %arg0 : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32> + return %0 : vector<1x2x[3]x8xi32> +} +// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim +// CHECK: vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f { + transform.apply_patterns.vector.lower_bitcast + } : !transform.any_op + transform.yield + } +} |