From 2755c69098c9d0cf33bbbd3ff90f63ab819acfe1 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Fri, 3 May 2024 21:38:02 +0530 Subject: [mlir][linalg] Vectorize unpack op without masking (#89067) Enables vectorization of unpack op in the case of unknown vector size. The vector sizes are determined by the result's shape. --- .../Dialect/Linalg/Transforms/Vectorization.cpp | 107 +++++++++++++++------ mlir/test/Dialect/Linalg/vectorization.mlir | 70 ++++++++++++++ 2 files changed, 145 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index ef9a30b..7b4507c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1414,27 +1414,39 @@ static SmallVector getTiledPackShape(tensor::PackOp packOp, /// create an empty destination tensor and create a TransferWriteOp from the /// input to the empty tensor. If the destination shape is not the same as the /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a -/// mask for the write. +/// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the +/// inBounds attribute of the transfer write op instead of masking. static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector destSizes, - ArrayRef inputVectorSizes) { + ArrayRef inputVectorSizes, + bool useInBoundsInsteadOfMasking) { + auto inputType = cast(input.getType()); Value dest = builder.create(loc, destSizes, inputType.getElementType()); int64_t rank = cast(dest.getType()).getRank(); auto zero = builder.create(loc, 0); + auto destShape = cast(dest.getType()).getShape(); + SmallVector inBoundsVal(rank, true); + if (useInBoundsInsteadOfMasking) { + // Update the inBounds attribute. + for (unsigned i = 0; i < rank; i++) + inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) && + !ShapedType::isDynamic(destShape[i]); + } Operation *write = builder.create( loc, /*vector=*/input, /*source=*/dest, /*indices=*/SmallVector(rank, zero), - /*inBounds=*/SmallVector(rank, true)); - auto destShape = cast(dest.getType()).getShape(); + /*inBounds=*/inBoundsVal); assert(llvm::none_of( destShape.drop_front(inputVectorSizes.size()), [](int64_t size) { return size == ShapedType::kDynamic; }) && "Only dims aligned with inputVectorSizes may be dynamic"); + if (useInBoundsInsteadOfMasking) + return write; bool needMaskForWrite = !llvm::equal( inputVectorSizes, destShape.take_front(inputVectorSizes.size())); if (needMaskForWrite) { @@ -1535,9 +1547,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, loc, shapeCastOp.getResult(), destPermutation); // Create TransferWriteOp. - Operation *write = - createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), - reifiedReturnShapes[0], inputVectorSizes); + Operation *write = createWriteOrMaskedWrite( + rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0], + inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false); newResults.push_back(write->getResult(0)); return success(); } @@ -1547,7 +1559,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, /// vector::TransposeOp - Transpose the Source tensor /// ShapeCastOp - Reshape the data based on the target. /// vector::TransferWriteOp. - Write the result vector back to the destination -/// tensor +/// tensor. +/// If the vector sizes are not provided: +/// * the vector sizes are determined by the input operand and attributes, +/// * update the inBounds attribute instead of masking. static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef inputVectorSizes, @@ -1560,40 +1575,65 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef innerDimPos = unpackOp.getInnerDimsPos(); ArrayRef innerTiles = unpackOp.getStaticInnerTiles(); - - SmallVector readMaskShape(inputVectorSizes.begin(), - inputVectorSizes.end()); - ArrayRef outerDimsPerm = unpackOp.getOuterDimsPerm(); ArrayRef sourceShape = unpackTensorType.getShape(); + bool useInBoundsInsteadOfMasking = false; + ArrayRef outerDimsPerm = unpackOp.getOuterDimsPerm(); + + auto destSize = unpackOp.getDestRank(); + + if (!inputVectorSizes.empty()) + assert(inputVectorSizes.size() == destSize && + "Incorrect number of input vector sizes"); - // ReadMask is the size of tensor used to read and apply mask. It is + // vectorSizes is the shape of the vector that will be used to do final + // write on the destination tensor. It is set like this: Let's say the + // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M. + // Thus: + // 1. vectorSizes = sourceShape.take_front(N) + // 2. if outer_dims_perms is present: do that permutation on vectorSizes. + // 3. multiply all the locations in vectorSize pointed by innerDimPos by the + // innerTiles attribute value. + SmallVector vectorSizes(inputVectorSizes); + if (vectorSizes.empty()) { + llvm::append_range(vectorSizes, sourceShape.take_front(destSize)); + if (!outerDimsPerm.empty()) + applyPermutationToVector(vectorSizes, outerDimsPerm); + for (auto [i, pos] : llvm::enumerate(innerDimPos)) + vectorSizes[pos] *= innerTiles[i]; + + useInBoundsInsteadOfMasking = true; + } + + // readVectorSizes is the size of tensor used to read and apply mask. It is // set like this: Let's say the vectorSize (VS) array is size 'N' and // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of // size M-N // Thus: - // - initially: ReadMaskShape = vectorInputSizes + // - initially: readVectorSizes = vectorInputSizes // - Divide all the readMaskShape locations pointed by innerDimPos // by the innerTileSize attribute value. - // - if outer_dims_perms is present: do that permutation on readMaskShape. + // - if outer_dims_perms is present: do that permutation on readVectorSizes. // - Append the remaining shape from SS // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16> // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512, // 128] and outer_dims_perm is [1, 0] then read shape is: - // ReadMaskShape(initial): [512, 128] + // ReadVectorSizes(initial): [512, 128] // Final Value(after innerDim Adjustment): [512/32, 128/16] // = [16, 8] // After applying outer_dims_perm: [8, 16] // After appending the rest of the sourceShape: [8, 16, 32, 16] + SmallVector readVectorSizes(vectorSizes.begin(), vectorSizes.end()); + for (auto [index, size] : enumerate(innerTiles)) { - readMaskShape[innerDimPos[index]] = - llvm::divideCeil(readMaskShape[innerDimPos[index]], size); + readVectorSizes[innerDimPos[index]] = + llvm::divideCeil(readVectorSizes[innerDimPos[index]], size); } if (!outerDimsPerm.empty()) { - applyPermutationToVector(readMaskShape, outerDimsPerm); + applyPermutationToVector(readVectorSizes, outerDimsPerm); } - readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(), - sourceShape.end()); + readVectorSizes.append(sourceShape.begin() + vectorSizes.size(), + sourceShape.end()); ReifiedRankedShapedTypeDims reifiedRetShapes; LogicalResult status = @@ -1611,8 +1651,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, // Read result, mask if necessary. If transferReadOp shape is not equal // to shape of source, then a mask is necessary. Value readResult = vector::createReadOrMaskedRead( - rewriter, loc, unpackOp.getSource(), - ArrayRef(readMaskShape.begin(), readMaskShape.end()), padValue, + rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue, /*useInBoundsInsteadOfMasking=*/false); PackingMetadata packMetadata; @@ -1636,15 +1675,15 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, vector::ShapeCastOp shapeCastOp = rewriter.create( loc, vecCollapsedType, transposeOp->getResult(0)); - // WriteMaskShape had to match the shapecast shape for dynamic sizes, + // writeVectorSizes had to match the shapecast shape for dynamic sizes, // otherwise the validator complains that the mask size is invalid. - SmallVector writeMaskShape( + SmallVector writeVectorSizes( unpackOp.getDestType().hasStaticShape() - ? inputVectorSizes + ? vectorSizes : shapeCastOp.getResultVectorType().getShape()); - Operation *write = - createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), - reifiedRetShapes[0], writeMaskShape); + Operation *write = createWriteOrMaskedWrite( + rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0], + writeVectorSizes, useInBoundsInsteadOfMasking); newResults.push_back(write->getResult(0)); return success(); } @@ -1673,7 +1712,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, rewriter, loc, padOp.getSource(), inputVectorSizes, padValue, /*useInBoundsInsteadOfMasking=*/false); Operation *write = createWriteOrMaskedWrite( - rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes); + rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes, + /*useInBoundsInsteadOfMasking=*/false); newResults.push_back(write->getResult(0)); return success(); } @@ -1755,8 +1795,11 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, LDBG("Inner-tiles must be constant: " << unpackOp << "\n"); return failure(); } - llvm::ArrayRef resultShape = unpackOp.getDestType().getShape(); - if (!inputVectorSizes.empty() && + ArrayRef resultShape = unpackOp.getDestType().getShape(); + bool satisfyEmptyCond = inputVectorSizes.empty() && + unpackOp.getDestType().hasStaticShape() && + unpackOp.getSourceType().hasStaticShape(); + if (!satisfyEmptyCond && failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes))) return failure(); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 80a5a4c..bbeccc7 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -985,3 +985,73 @@ module attributes {transform.with_named_sequence} { transform.yield } } + + // ----- + +func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32> + // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32> + // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32> + // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32> + // CHECK: %[[C00:.*]] = arith.constant 0 : index + // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32> + // CHECK: return %[[WRIT]] : tensor<256x128xf32> + %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32> + return %0 : tensor<256x128xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } + } + + // ----- + +func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x4x16x16xf32>, %dest: tensor<64x127xf32>) -> tensor<64x127xf32> { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32> + // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32> + // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x16x8x16xf32> to vector<64x128xf32> + // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x127xf32> + // CHECK: %[[C00:.*]] = arith.constant 0 : index + // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[EMPT]]{{\[}}%[[C00]], %[[C00]]] + // CHECK-SAME: {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32> + // CHECK: return %[[WRIT]] : tensor<64x127xf32> + %0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %dest : tensor<8x4x16x16xf32> -> tensor<64x127xf32> + return %0 : tensor<64x127xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } + } + + // ----- + +func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf32>, %dest: tensor<7x16xf32>) -> tensor<7x16xf32> { + %0 = tensor.unpack %source outer_dims_perm=[1, 0] inner_dims_pos = [1] inner_tiles = [4] into %dest : tensor<4x7x4xf32> -> tensor<7x16xf32> + return %0 : tensor<7x16xf32> + } + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<4x7x4xf32>, vector<4x7x4xf32> + // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 0, 2] : vector<4x7x4xf32> to vector<7x4x4xf32> + // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<7x4x4xf32> to vector<7x16xf32> + // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<7x16xf32> + // CHECK: %[[C00:.*]] = arith.constant 0 : index + // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<7x16xf32>, tensor<7x16xf32> + // CHECK: return %[[WRIT]] : tensor<7x16xf32> + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } + } -- cgit v1.1