aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/XeGPU/IR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/IR')
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp146
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp122
2 files changed, 205 insertions, 63 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9beb22d..1599ae9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -727,6 +727,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const {
}
printer << ">";
}
+// a helper utility to perform binary operation on OpFoldResult.
+// If both a and b are attributes, it will simply return the result.
+// Otherwise, the corresponding arith op will be generated, and an
+// contant op will be created if one of them is an attribute.
+template <typename ArithOp>
+OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
+ OpBuilder &builder) {
+ auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
+ auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
+ return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+}
+
+// a helper utility to perform division operation on OpFoldResult and int64_t.
+#define div(a, b) \
+ genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform reminder operation on OpFoldResult and int64_t.
+#define rem(a, b) \
+ genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform multiply operation on OpFoldResult and int64_t.
+#define mul(a, b) \
+ genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform addition operation on two OpFoldResult.
+#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
+
+// block the given offsets according to the block shape
+// say the original offset is [y, x], and the block shape is [By, Bx],
+// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<int64_t> blockShape) {
+
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ SmallVector<OpFoldResult> blockedOffsets;
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+
+ return blockedOffsets;
+}
+
+// Get strides as vector of integer for MemDesc.
+SmallVector<int64_t> MemDescType::getStrideShape() {
+
+ SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
+
+ ArrayAttr strideAttr = getStrideAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+
+ SmallVector<int64_t> innerBlkShape = getBlockShape();
+
+ // get perm from FCD to LCD
+ // perm[i] = the dim with i-th smallest stride
+ SmallVector<int, 4> perm =
+ llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
+ llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
+
+ assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
+
+ SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
+ innerBlkStride[perm[0]] = 1;
+ for (size_t i = 1; i < perm.size(); ++i)
+ innerBlkStride[perm[i]] =
+ innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
+
+ // compute the original matrix shape using the stride info
+ // and compute the number of blocks in each dimension
+ // The shape of highest dim can't be derived from stride info,
+ // but doesn't impact the stride computation for blocked layout.
+ SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
+ SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
+ BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
+ }
+
+ int64_t innerBlkSize = 1;
+ for (auto s : innerBlkShape)
+ innerBlkSize *= s;
+
+ SmallVector<int64_t> outerBlkStride(matrixShape.size());
+ outerBlkStride[perm[0]] = innerBlkSize;
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ outerBlkStride[perm[i + 1]] =
+ outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
+ }
+
+ // combine the inner and outer strides
+ SmallVector<int64_t> blockedStrides;
+ blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
+ blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
+
+ return blockedStrides;
+}
+
+// Calculate the linear offset using the blocked offsets and stride
+Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets) {
+
+ SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
+ SmallVector<int64_t> blockShape = getBlockShape();
+ SmallVector<int64_t> strides = getStrideShape();
+ SmallVector<OpFoldResult> blockedOffsets;
+
+ // blockshape equal to matrixshape means no blocking
+ if (llvm::equal(blockShape, matrixShape)) {
+ // remove the outer dims from strides
+ strides.erase(strides.begin(), strides.begin() + matrixShape.size());
+ } else {
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ // say the original offset is [y, x], and the block shape is [By, Bx],
+ // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+ offsets = blockedOffsets;
+ }
+
+ // Start with initial value as matrix descriptor's base offset.
+ Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
+ for (size_t i = 0; i < offsets.size(); ++i) {
+ OpFoldResult mulResult = mul(offsets[i], strides[i]);
+ Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
+ linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
+ }
+
+ return linearOffset;
+}
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81b5788..abd12e2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -20,8 +20,8 @@
#define DEBUG_TYPE "xegpu"
-namespace mlir {
-namespace xegpu {
+using namespace mlir;
+using namespace mlir::xegpu;
static bool isSharedMemory(const MemRefType &memrefTy) {
Attribute attr = memrefTy.getMemorySpace();
@@ -173,6 +173,49 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
return success();
}
+LogicalResult
+IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
+ UnitAttr subgroup_block_io,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ if (!dataTy) {
+ if (subgroup_block_io)
+ return emitError() << "subgroup_block_io "
+ "are only allowed when result is a 1D VectorType.";
+ else
+ return success();
+ }
+
+ if (mdescTy.getRank() != 2)
+ return emitError() << "mem_desc must be 2D.";
+
+ ArrayRef<int64_t> dataShape = dataTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+
+ if (dataShape.size() == 2) {
+ if (subgroup_block_io)
+ return emitError() << "subgroup_block_io "
+ "are only allowed when result is a 1D VectorType.";
+ if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitError() << "data shape must not exceed mem_desc shape.";
+ } else {
+ SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
+ // if the subgroup_block_io attribute is set, mdescTy must have block
+ // attribute
+ if (subgroup_block_io && !blockShape.size())
+ return emitError() << "mem_desc must have block attribute when "
+ "subgroup_block_io is set.";
+ // if the subgroup_block_io attribute is set, the memdesc should be row
+ // major
+ if (subgroup_block_io && mdescTy.isColMajor())
+ return emitError() << "mem_desc should be row major when "
+ "subgroup_block_io is set.";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -1049,23 +1092,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ // Call the generated builder with all parameters (including optional ones as
+ // nullptr/empty)
build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
- layout);
+ /*subgroup_block_io=*/nullptr, layout);
}
LogicalResult LoadMatrixOp::verify() {
- VectorType resTy = getRes().getType();
- MemDescType mdescTy = getMemDesc().getType();
- if (mdescTy.getRank() != 2)
- return emitOpError("mem_desc must be 2D.");
+ auto resTy = dyn_cast<VectorType>(getRes().getType());
+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+ MemDescType mdescTy = getMemDesc().getType();
- ArrayRef<int64_t> valueShape = resTy.getShape();
- ArrayRef<int64_t> mdescShape = mdescTy.getShape();
- if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("result shape must not exceed mem_desc shape.");
- return success();
+ return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
+ [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
@@ -1080,62 +1120,18 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
- layout);
+ /*subgroup_block_io=*/nullptr, layout);
}
LogicalResult StoreMatrixOp::verify() {
- VectorType dataTy = getData().getType();
- MemDescType mdescTy = getMemDesc().getType();
-
- if (mdescTy.getRank() != 2)
- return emitOpError("mem_desc must be 2D.");
-
- ArrayRef<int64_t> dataShape = dataTy.getShape();
- ArrayRef<int64_t> mdescShape = mdescTy.getShape();
- if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("data shape must not exceed mem_desc shape.");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// XeGPU_MemDescSubviewOp
-//===----------------------------------------------------------------------===//
-
-void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
- Type resTy, Value src,
- llvm::ArrayRef<OpFoldResult> offsets) {
- llvm::SmallVector<Value> dynamicOffsets;
- llvm::SmallVector<int64_t> staticOffsets;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
- auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
- build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
-}
-
-LogicalResult MemDescSubviewOp::verify() {
- MemDescType srcTy = getSrc().getType();
- MemDescType resTy = getRes().getType();
- ArrayRef<int64_t> srcShape = srcTy.getShape();
- ArrayRef<int64_t> resShape = resTy.getShape();
-
- if (srcTy.getRank() < resTy.getRank())
- return emitOpError("result rank must not exceed source rank.");
- if (llvm::any_of(
- llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("result shape must not exceed source shape.");
-
- if (srcTy.getStrides() != resTy.getStrides())
- return emitOpError("result must inherit the source strides.");
-
- return success();
+ auto dataTy = dyn_cast<VectorType>(getData().getType());
+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+ MemDescType mdescTy = getMemDesc().getType();
+ return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
+ [&]() { return emitError(); });
}
-} // namespace xegpu
-} // namespace mlir
-
namespace mlir {
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
} // namespace mlir