diff options
-rw-r--r-- | mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 10 | ||||
-rw-r--r-- | mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp | 8 |
2 files changed, 8 insertions, 10 deletions
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 3b1fdb6..aa566c0 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -299,19 +299,17 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape, // this utility). if (numSourceDims <= numTargetDims) return std::nullopt; - // Early handling for scalar target types. + // Early handling for scalar target types. We should report an invalid + // reassociation for non-unit static dimensions - no chance to collapse these + // into a scalar. if (numTargetDims == 0) { - ReassociationIndices allSourceIndices; - allSourceIndices.reserve(numSourceDims); for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims; ++sourceDimIdx) { int64_t sourceSize = sourceShape[sourceDimIdx]; - // All source dimensions must be unit or dynamic. if (sourceSize != 1 && sourceSize != ShapedType::kDynamic) return std::nullopt; - allSourceIndices.push_back(sourceDimIdx); } - return SmallVector<ReassociationIndices>{allSourceIndices}; + return SmallVector<ReassociationIndices>{}; } // Collect source ranges by iterating over the target shape left-to-right. diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp index db1a87a..05f97e8 100644 --- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp +++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp @@ -23,16 +23,16 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) { TEST(ReassociationIndicesForCollapse, ScalarTest) { EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}), - makeOptionalIndices({{0}})); + makeOptionalIndices({})); EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}), - makeOptionalIndices({{0, 1}})); + makeOptionalIndices({})); EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}), - makeOptionalIndices({{0}})); + makeOptionalIndices({})); EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, ShapedType::kDynamic, 1, ShapedType::kDynamic}, {}), - makeOptionalIndices({{0, 1, 2, 3, 4}})); + makeOptionalIndices({})); } TEST(ReassociationIndicesForCollapse, ScalarTestFailure) { |