aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.td5
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp17
-rw-r--r--mlir/test/Dialect/Vector/int-range-interface.mlir8
3 files changed, 29 insertions, 1 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e68a3c7..5d45508 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2877,7 +2877,10 @@ def Vector_ScanOp :
// VectorStepOp
//===----------------------------------------------------------------------===//
-def Vector_StepOp : Vector_Op<"step", [Pure]> {
+def Vector_StepOp : Vector_Op<"step", [
+ Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
+ ]> {
let summary = "A linear sequence of values from 0 to N";
let description = [{
A `step` operation produces an index vector, i.e. a 1-D vector of values of
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 25ce292..86fbb76 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7203,6 +7203,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
}
//===----------------------------------------------------------------------===//
+// StepOp
+//===----------------------------------------------------------------------===//
+
+void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ auto resultType = cast<VectorType>(getType());
+ if (resultType.isScalable()) {
+ return;
+ }
+ unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
+ APInt zero(bitwidth, 0);
+ APInt high(bitwidth, resultType.getDimSize(0) - 1);
+ ConstantIntRanges result = {zero, high, zero, high};
+ setResultRanges(getResult(), result);
+}
+
+//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index f89d307..b2f16bb 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -108,3 +108,11 @@ func.func @test_vector_extsi() -> vector<2xi32> {
%2 = test.reflect_bounds %1 : vector<2xi32>
func.return %2 : vector<2xi32>
}
+
+// CHECK-LABEL: func @vector_step
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+func.func @vector_step() -> vector<8xindex> {
+ %0 = vector.step : vector<8xindex>
+ %1 = test.reflect_bounds %0 : vector<8xindex>
+ func.return %1 : vector<8xindex>
+}