diff options
-rw-r--r-- | mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 16 |
2 files changed, 9 insertions, 8 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 8e333de..fdf51f0 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1121,7 +1121,6 @@ def Vector_ExtractStridedSliceOp : attribute and extracts the n-D subvector at the proper offset. At the moment strides must contain only 1s. - // TODO: support non-1 strides. Returns an n-D vector where the first k-D dimensions match the `sizes` attribute. The returned subvector contains the elements starting at offset diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 749eb56..791924f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2789,9 +2789,11 @@ isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, return success(); } -// Returns true if all integers in `arrayAttr` are in the interval [min, max}. -// interval. If `halfOpen` is true then the admissible interval is [min, max). -// Otherwise, the admissible interval is [min, max]. +// Returns true if, for all indices i = 0..shape.size()-1, val is in the +// [min, max} interval: +// val = `arrayAttr1[i]` + `arrayAttr2[i]`, +// If `halfOpen` is true then the admissible interval is [min, max). Otherwise, +// the admissible interval is [min, max]. template <typename OpType> static LogicalResult isSumOfIntegerArrayAttrConfinedToShape( OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, @@ -2845,8 +2847,8 @@ LogicalResult InsertStridedSliceOp::verify() { auto stridesName = InsertStridedSliceOp::getStridesAttrName(); if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape, offName)) || - failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, - stridesName, + failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1, + /*max=*/1, stridesName, /*halfOpen=*/false)) || failed(isSumOfIntegerArrayAttrConfinedToShape( *this, offsets, @@ -3250,8 +3252,8 @@ LogicalResult ExtractStridedSliceOp::verify() { failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, /*halfOpen=*/false, /*min=*/1)) || - failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, - stridesName, + failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1, + /*max=*/1, stridesName, /*halfOpen=*/false)) || failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape, offName, sizesName, |