aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp')
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp18
1 files changed, 11 insertions, 7 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 8295492..04e8836 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -310,26 +310,30 @@ LogicalResult BlockPrefetch2dOp::verify() {
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, BlockLoadOp, BlockStoreOp>::value>>
LogicalResult verify1DBlockArg(OpType op) {
- VectorType vTy;
+ Type srcOrDstTy;
if constexpr (std::is_same_v<OpType, BlockLoadOp>)
- vTy = op.getResult().getType();
+ srcOrDstTy = op.getResult().getType();
else
- vTy = op.getVal().getType();
+ srcOrDstTy = op.getVal().getType();
+ VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
+ // scalar case is always valid
+ if (!vTy)
+ return success();
int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
if (elemTySize == 1) {
- llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16};
+ llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
- "vector size must be 1, 2, 4, 8 or 16 for 8-bit element type");
+ "vector size must be 2, 4, 8 or 16 for 8-bit element type");
} else {
- llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8};
+ llvm::SmallSet<int, 3> validSizes{2, 4, 8};
if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
- "vector size must be 1, 2, 4 or 8 for element type > 8 bits");
+ "vector size must be 2, 4 or 8 for element type > 8 bits");
}
}