aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp4
-rw-r--r--mlir/test/Dialect/Vector/linearize.mlir13
2 files changed, 15 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7ca0353..38536de 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -22,9 +22,9 @@ using namespace mlir;
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
- VectorType vecType = cast<VectorType>(resType);
+ VectorType vecType = dyn_cast<VectorType>(resType);
// Reject index since getElementTypeBitWidth will abort for Index types.
- if (vecType.getElementType().isIndex())
+ if (!vecType || vecType.getElementType().isIndex())
return false;
unsigned trailingVecDimBitWidth =
vecType.getShape().back() * vecType.getElementTypeBitWidth();
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 2cbf9be..e865fcb 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -90,3 +90,16 @@ func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
return %0 : vector<2x2xindex>
}
+
+// -----
+
+// vectorizable operation (arith.mulf) with tensor result types.
+
+func.func @nonvec_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) {
+ // CHECK: %[[MULF:.*]] = arith.mulf %arg0, %arg1 : tensor<2x2xf32>
+ // CHECK128: %[[MULF:.*]] = arith.mulf %arg0, %arg1 : tensor<2x2xf32>
+ // CHECK0: %[[MULF:.*]] = arith.mulf %arg0, %arg1 : tensor<2x2xf32>
+ %0 = arith.mulf %arg0, %arg1 : tensor<2x2xf32>
+
+ return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32>
+}