aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.td1
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp16
2 files changed, 9 insertions, 8 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8e333de..fdf51f0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1121,7 +1121,6 @@ def Vector_ExtractStridedSliceOp :
attribute and extracts the n-D subvector at the proper offset.
At the moment strides must contain only 1s.
- // TODO: support non-1 strides.
Returns an n-D vector where the first k-D dimensions match the `sizes`
attribute. The returned subvector contains the elements starting at offset
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 749eb56..791924f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2789,9 +2789,11 @@ isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
return success();
}
-// Returns true if all integers in `arrayAttr` are in the interval [min, max}.
-// interval. If `halfOpen` is true then the admissible interval is [min, max).
-// Otherwise, the admissible interval is [min, max].
+// Returns true if, for all indices i = 0..shape.size()-1, val is in the
+// [min, max} interval:
+// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
+// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
+// the admissible interval is [min, max].
template <typename OpType>
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
@@ -2845,8 +2847,8 @@ LogicalResult InsertStridedSliceOp::verify() {
auto stridesName = InsertStridedSliceOp::getStridesAttrName();
if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
offName)) ||
- failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
- stridesName,
+ failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
+ /*max=*/1, stridesName,
/*halfOpen=*/false)) ||
failed(isSumOfIntegerArrayAttrConfinedToShape(
*this, offsets,
@@ -3250,8 +3252,8 @@ LogicalResult ExtractStridedSliceOp::verify() {
failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
/*halfOpen=*/false,
/*min=*/1)) ||
- failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
- stridesName,
+ failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
+ /*max=*/1, stridesName,
/*halfOpen=*/false)) ||
failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
shape, offName, sizesName,