aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorCullen Rhodes <cullen.rhodes@arm.com>2024-07-04 08:57:02 +0100
committerGitHub <noreply@github.com>2024-07-04 08:57:02 +0100
commit67b302c52f79db2ab5c46e5e8c600f1c2af57a83 (patch)
tree2a6e90d72ae7533bcb072bbd13d12447f4fc054b /mlir/lib
parent927def49728371d746476e79a6570cd93a4d335c (diff)
downloadllvm-67b302c52f79db2ab5c46e5e8c600f1c2af57a83.zip
llvm-67b302c52f79db2ab5c46e5e8c600f1c2af57a83.tar.gz
llvm-67b302c52f79db2ab5c46e5e8c600f1c2af57a83.tar.bz2
[mlir][vector] Add vector.step operation (#96776)
This patch adds a new vector.step operation to the Vector dialect. It produces a linear sequence of index values from 0 to N, where N is the number of elements in the result vector, and can be used to create vectors of indices. It supports both fixed-width and scalable vectors. For fixed the canonical representation is `arith.constant dense<[0, .., N]>`. A scalable step cannot be represented as a constant and is lowered to the `llvm.experimental.stepvector` intrinsic [1]. This op enables scalable vectorization of linalg.index ops, see #96778. It can also be used in the SparseVectorizer in-place of lower-level stepvector intrinsic, see [2] (patch to follow). [1] https://llvm.org/docs/LangRef.html#llvm-experimental-stepvector-intrinsic [2] https://github.com/llvm/llvm-project/blob/acf675b63f9426e61aac2155e29280f7d21f9421/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp#L385-L388
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp17
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp14
2 files changed, 29 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 0eac552..6a8a9d8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1860,6 +1860,19 @@ struct VectorFromElementsLowering
}
};
+/// Conversion pattern for vector.step.
+struct VectorStepOpLowering : public ConvertOpToLLVMPattern<vector::StepOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type llvmType = typeConverter->convertType(stepOp.getType());
+ rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
+ return success();
+ }
+};
+
} // namespace
/// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1885,8 +1898,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
- VectorDeinterleaveOpLowering, VectorFromElementsLowering>(
- converter);
+ VectorDeinterleaveOpLowering, VectorFromElementsLowering,
+ VectorStepOpLowering>(converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 149723f..53a6648 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6313,6 +6313,20 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
}
//===----------------------------------------------------------------------===//
+// StepOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult StepOp::fold(FoldAdaptor adaptor) {
+ auto resultType = cast<VectorType>(getType());
+ if (resultType.isScalable())
+ return nullptr;
+ SmallVector<APInt> indices;
+ for (unsigned i = 0; i < resultType.getNumElements(); i++)
+ indices.push_back(APInt(/*width=*/64, i));
+ return DenseElementsAttr::get(resultType, indices);
+}
+
+//===----------------------------------------------------------------------===//
// WarpExecuteOnLane0Op
//===----------------------------------------------------------------------===//