From dc5d541081381e3dca80b982097596546e0619fc Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 21 Jun 2024 14:38:19 +0100 Subject: [mlir][vector] Support scalable vectors when unrolling vector.bitcast (#94197) Follow up to #94064. --- mlir/include/mlir/Dialect/Utils/IndexingUtils.h | 5 ++++ .../Vector/Transforms/LowerVectorBitCast.cpp | 17 ++++------- .../Vector/vector-bitcast-lowering-transforms.mlir | 34 +++++++++++++++++++++- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h index 9892253..b774359 100644 --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -287,6 +287,8 @@ public: return getDynamicTileOffsets(linearIndex); } + size_t getRank() const { return tileShape.size(); } + private: /// The sub-shape that divides the larger outer shape (which is provided to /// the constructor). @@ -388,6 +390,9 @@ public: /// Returns the total number of tiles that fit in the larger shape. size_t size() const { return params.getMaxLinearIndex(); } + /// Returns rank of the iterator's shape. + size_t getRank() const { return params.getRank(); } + private: const ParamsTy params; IteratorTy beginValue; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp index 092ec92..e5f11d8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp @@ -56,17 +56,12 @@ public: if (!unrollIterator) return failure(); - // TODO: Support the scalable vector cases. It is not supported because - // the final rank could be values other than `targetRank`. It makes creating - // the result type of new vector.bitcast ops much harder. - if (resultType.isScalable()) { - return rewriter.notifyMatchFailure(op, - "unrolling vector.bitcast on scalable " - "vectors is not yet implemented"); - } - - ArrayRef shape = resultType.getShape().take_back(targetRank); - auto bitcastResType = VectorType::get(shape, resultType.getElementType()); + auto unrollRank = unrollIterator->getRank(); + ArrayRef shape = resultType.getShape().drop_front(unrollRank); + ArrayRef scalableDims = + resultType.getScalableDims().drop_front(unrollRank); + auto bitcastResType = + VectorType::get(shape, resultType.getElementType(), scalableDims); Location loc = op.getLoc(); Value result = rewriter.create( diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir index 23fece2..3462910 100644 --- a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir @@ -38,7 +38,39 @@ func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) -> return %0 : vector<1x2x[3]x8xi32> } // CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim -// CHECK: vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32> +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] +// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<1x2x[3]x8xi32> +// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0, 0] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64> +// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[3]x4xi64> to vector<[3]x8xi32> +// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0, 0] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32> +// CHECK: %[[V2:.+]] = vector.extract %[[IN]][0, 1] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64> +// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<[3]x4xi64> to vector<[3]x8xi32> +// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [0, 1] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32> +// CHECK: return %[[R2]] : vector<1x2x[3]x8xi32> + +func.func @vector_bitcast_2d_trailing_scalable_dim(%arg0: vector<2x[2]xi64>) -> vector<2x[4]xi32> { + %0 = vector.bitcast %arg0 : vector<2x[2]xi64> to vector<2x[4]xi32> + return %0 : vector<2x[4]xi32> +} +// CHECK-LABEL: func.func @vector_bitcast_2d_trailing_scalable_dim +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] +// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<2x[4]xi32> +// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<[2]xi64> from vector<2x[2]xi64> +// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0] : vector<[4]xi32> into vector<2x[4]xi32> +// CHECK: %[[V2:.+]] = vector.extract %[[IN]][1] : vector<[2]xi64> from vector<2x[2]xi64> +// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1] : vector<[4]xi32> into vector<2x[4]xi32> +// CHECK: return %[[R2]] : vector<2x[4]xi32> + +func.func @negative_vector_bitcast_2d_leading_scalable_dim(%arg0: vector<[2]x2xi64>) -> vector<[2]x4xi32> +{ + %0 = vector.bitcast %arg0 : vector<[2]x2xi64> to vector<[2]x4xi32> + return %0 : vector<[2]x4xi32> +} +// CHECK-LABEL: func.func @negative_vector_bitcast_2d_leading_scalable_dim +// CHECK-NOT: vector.extract +// CHECK-NOT: vector.insert module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { -- cgit v1.1