aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChao Chen <chao.chen@intel.com>2024-03-26 15:43:56 +0000
committerChao Chen <chao.chen@intel.com>2024-03-26 17:21:36 +0000
commit6486c994d496b8291220e77e2442eb59bf21d4f1 (patch)
tree0c92317311ef79683ebe81bef14a568d2d771c6f
parent2c3bd1384f119a753953774ccd297a7c4cad8cb1 (diff)
downloadllvm-6486c994d496b8291220e77e2442eb59bf21d4f1.zip
llvm-6486c994d496b8291220e77e2442eb59bf21d4f1.tar.gz
llvm-6486c994d496b8291220e77e2442eb59bf21d4f1.tar.bz2
refine getShapeOf implementation
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp22
1 files changed, 11 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index dc18d8c..972cee6 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -19,8 +19,8 @@ namespace mlir {
namespace xegpu {
static void transpose(llvm::ArrayRef<int64_t> trans,
- std::vector<int64_t> &shape) {
- std::vector<int64_t> old = shape;
+ SmallVector<int64_t> &shape) {
+ SmallVector<int64_t> old = shape;
for (size_t i = 0; i < trans.size(); i++)
shape[i] = old[trans[i]];
}
@@ -42,9 +42,9 @@ static std::string makeString(T array, bool breakline = false) {
}
static SmallVector<int64_t> getShapeOf(Type type) {
- std::vector<int64_t> shape;
+ SmallVector<int64_t> shape;
if (auto ty = llvm::dyn_cast<ShapedType>(type))
- shape = ty.getShape().vec();
+ shape = SmallVector<int64_t>(ty.getShape());
else
shape.push_back(1);
return shape;
@@ -201,8 +201,8 @@ LogicalResult LoadNdOp::verify() {
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
auto array_len = tdescTy.getArrayLength();
- auto tdescShape = tdescTy.getShape().vec();
- auto valueShape = valueTy.getShape().vec();
+ auto tdescShape = getShapeOf(tdescTy);
+ auto valueShape = getShapeOf(valueTy);
if (getTranspose()) {
auto trans = getTranspose().value();
@@ -353,9 +353,9 @@ LogicalResult LoadGatherOp::verify() {
return emitOpError(
"Value should have the same element type as TensorDesc.");
- std::vector<int64_t> maskShape = getShapeOf(maskTy);
- std::vector<int64_t> valueShape = getShapeOf(valueTy);
- std::vector<int64_t> tdescShape = getShapeOf(tdescTy);
+ auto maskShape = getShapeOf(maskTy);
+ auto valueShape = getShapeOf(valueTy);
+ auto tdescShape = getShapeOf(tdescTy);
if (tdescShape[0] != maskShape[0])
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
@@ -394,8 +394,8 @@ LogicalResult StoreScatterOp::verify() {
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
auto maskTy = getMaskType();
- std::vector<int64_t> maskShape = getShapeOf(maskTy);
- std::vector<int64_t> tdescShape = getShapeOf(tdescTy);
+ auto maskShape = getShapeOf(maskTy);
+ auto tdescShape = getShapeOf(tdescTy);
if (tdescShape[0] != maskShape[0])
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");