diff options
author | Benjamin Maxwell <benjamin.maxwell@arm.com> | 2024-06-21 14:38:19 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-21 14:38:19 +0100 |
commit | dc5d541081381e3dca80b982097596546e0619fc (patch) | |
tree | 0973e64f2910679324a9a26255010a7b5001f4f5 /mlir | |
parent | 747f9dacfe30114b492553e0c69a29328d246e4f (diff) | |
download | llvm-dc5d541081381e3dca80b982097596546e0619fc.zip llvm-dc5d541081381e3dca80b982097596546e0619fc.tar.gz llvm-dc5d541081381e3dca80b982097596546e0619fc.tar.bz2 |
[mlir][vector] Support scalable vectors when unrolling vector.bitcast (#94197)
Follow up to #94064.
Diffstat (limited to 'mlir')
3 files changed, 44 insertions, 12 deletions
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h index 9892253..b774359 100644 --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -287,6 +287,8 @@ public: return getDynamicTileOffsets(linearIndex); } + size_t getRank() const { return tileShape.size(); } + private: /// The sub-shape that divides the larger outer shape (which is provided to /// the constructor). @@ -388,6 +390,9 @@ public: /// Returns the total number of tiles that fit in the larger shape. size_t size() const { return params.getMaxLinearIndex(); } + /// Returns rank of the iterator's shape. + size_t getRank() const { return params.getRank(); } + private: const ParamsTy params; IteratorTy beginValue; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp index 092ec92..e5f11d8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp @@ -56,17 +56,12 @@ public: if (!unrollIterator) return failure(); - // TODO: Support the scalable vector cases. It is not supported because - // the final rank could be values other than `targetRank`. It makes creating - // the result type of new vector.bitcast ops much harder. - if (resultType.isScalable()) { - return rewriter.notifyMatchFailure(op, - "unrolling vector.bitcast on scalable " - "vectors is not yet implemented"); - } - - ArrayRef<int64_t> shape = resultType.getShape().take_back(targetRank); - auto bitcastResType = VectorType::get(shape, resultType.getElementType()); + auto unrollRank = unrollIterator->getRank(); + ArrayRef<int64_t> shape = resultType.getShape().drop_front(unrollRank); + ArrayRef<bool> scalableDims = + resultType.getScalableDims().drop_front(unrollRank); + auto bitcastResType = + VectorType::get(shape, resultType.getElementType(), scalableDims); Location loc = op.getLoc(); Value result = rewriter.create<arith::ConstantOp>( diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir index 23fece2..3462910 100644 --- a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir @@ -38,7 +38,39 @@ func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) -> return %0 : vector<1x2x[3]x8xi32> } // CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim -// CHECK: vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32> +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] +// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<1x2x[3]x8xi32> +// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0, 0] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64> +// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[3]x4xi64> to vector<[3]x8xi32> +// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0, 0] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32> +// CHECK: %[[V2:.+]] = vector.extract %[[IN]][0, 1] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64> +// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<[3]x4xi64> to vector<[3]x8xi32> +// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [0, 1] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32> +// CHECK: return %[[R2]] : vector<1x2x[3]x8xi32> + +func.func @vector_bitcast_2d_trailing_scalable_dim(%arg0: vector<2x[2]xi64>) -> vector<2x[4]xi32> { + %0 = vector.bitcast %arg0 : vector<2x[2]xi64> to vector<2x[4]xi32> + return %0 : vector<2x[4]xi32> +} +// CHECK-LABEL: func.func @vector_bitcast_2d_trailing_scalable_dim +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] +// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<2x[4]xi32> +// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<[2]xi64> from vector<2x[2]xi64> +// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0] : vector<[4]xi32> into vector<2x[4]xi32> +// CHECK: %[[V2:.+]] = vector.extract %[[IN]][1] : vector<[2]xi64> from vector<2x[2]xi64> +// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1] : vector<[4]xi32> into vector<2x[4]xi32> +// CHECK: return %[[R2]] : vector<2x[4]xi32> + +func.func @negative_vector_bitcast_2d_leading_scalable_dim(%arg0: vector<[2]x2xi64>) -> vector<[2]x4xi32> +{ + %0 = vector.bitcast %arg0 : vector<[2]x2xi64> to vector<[2]x4xi32> + return %0 : vector<[2]x4xi32> +} +// CHECK-LABEL: func.func @negative_vector_bitcast_2d_leading_scalable_dim +// CHECK-NOT: vector.extract +// CHECK-NOT: vector.insert module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { |