diff options
-rw-r--r-- | mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 5 | ||||
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 17 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/int-range-interface.mlir | 8 |
3 files changed, 29 insertions, 1 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index e68a3c7..5d45508 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2877,7 +2877,10 @@ def Vector_ScanOp : // VectorStepOp //===----------------------------------------------------------------------===// -def Vector_StepOp : Vector_Op<"step", [Pure]> { +def Vector_StepOp : Vector_Op<"step", [ + Pure, + DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]> + ]> { let summary = "A linear sequence of values from 0 to N"; let description = [{ A `step` operation produces an index vector, i.e. a 1-D vector of values of diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 25ce292..86fbb76 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7203,6 +7203,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, } //===----------------------------------------------------------------------===// +// StepOp +//===----------------------------------------------------------------------===// + +void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, + SetIntRangeFn setResultRanges) { + auto resultType = cast<VectorType>(getType()); + if (resultType.isScalable()) { + return; + } + unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType); + APInt zero(bitwidth, 0); + APInt high(bitwidth, resultType.getDimSize(0) - 1); + ConstantIntRanges result = {zero, high, zero, high}; + setResultRanges(getResult(), result); +} + +//===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index f89d307..b2f16bb 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -108,3 +108,11 @@ func.func @test_vector_extsi() -> vector<2xi32> { %2 = test.reflect_bounds %1 : vector<2xi32> func.return %2 : vector<2xi32> } + +// CHECK-LABEL: func @vector_step +// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index} +func.func @vector_step() -> vector<8xindex> { + %0 = vector.step : vector<8xindex> + %1 = test.reflect_bounds %0 : vector<8xindex> + func.return %1 : vector<8xindex> +} |