diff options
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 4 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/linearize.mlir | 13 |
2 files changed, 15 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 7ca0353..38536de 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -22,9 +22,9 @@ using namespace mlir; static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { auto resultTypes = op->getResultTypes(); for (auto resType : resultTypes) { - VectorType vecType = cast<VectorType>(resType); + VectorType vecType = dyn_cast<VectorType>(resType); // Reject index since getElementTypeBitWidth will abort for Index types. - if (vecType.getElementType().isIndex()) + if (!vecType || vecType.getElementType().isIndex()) return false; unsigned trailingVecDimBitWidth = vecType.getShape().back() * vecType.getElementTypeBitWidth(); diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 2cbf9be..e865fcb 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -90,3 +90,16 @@ func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi %0 = arith.addi %arg0, %arg1 : vector<2x2xindex> return %0 : vector<2x2xindex> } + +// ----- + +// vectorizable operation (arith.mulf) with tensor result types. + +func.func @nonvec_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) { + // CHECK: %[[MULF:.*]] = arith.mulf %arg0, %arg1 : tensor<2x2xf32> + // CHECK128: %[[MULF:.*]] = arith.mulf %arg0, %arg1 : tensor<2x2xf32> + // CHECK0: %[[MULF:.*]] = arith.mulf %arg0, %arg1 : tensor<2x2xf32> + %0 = arith.mulf %arg0, %arg1 : tensor<2x2xf32> + + return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32> +} |