//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "xegpu" namespace mlir { namespace xegpu { static bool isSharedMemory(const MemRefType &memrefTy) { Attribute attr = memrefTy.getMemorySpace(); if (auto intAttr = llvm::dyn_cast(attr)) return intAttr.getInt() == 3; if (auto memrefSpace = llvm::dyn_cast(attr)) return memrefSpace.getValue() == MemorySpace::SLM; if (auto xevmSpace = llvm::dyn_cast(attr)) return xevmSpace.getValue() == xevm::AddrSpace::SHARED; return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr); } template static std::string makeString(T array, bool breakline = false) { std::string buf; buf.clear(); llvm::raw_string_ostream os(buf); os << "["; for (size_t i = 1; i < array.size(); i++) { os << array[i - 1] << ", "; if (breakline) os << "\n\t\t"; } os << array.back() << "]"; return buf; } static SmallVector getShapeOf(Type type) { SmallVector shape; if (auto ty = llvm::dyn_cast(type)) shape = SmallVector(ty.getShape()); else shape.push_back(1); return shape; } static bool isReadHintOrNone(const CachePolicyAttr &attr) { if (!attr) return true; auto kind = attr.getValue(); return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE; } static bool isWriteHintOrNone(const CachePolicyAttr &attr) { if (!attr) return true; auto kind = attr.getValue(); return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; } static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref emitError) { if (!tdescTy.isScattered()) return emitError() << "Expects a scattered TensorDesc."; auto chunkSize = tdescTy.getChunkSizeAsInt(); if (!valueTy) { if (chunkSize > 1) return emitError() << "Expecting chunk size == 1 for scalar result"; if (dyn_cast(maskTy)) return emitError() << "Expecting a vector type result."; return success(); } auto maskShape = getShapeOf(maskTy); auto valueShape = getShapeOf(valueTy); auto tdescShape = getShapeOf(tdescTy); if (valueTy.getElementType() != tdescTy.getElementType()) return emitError() << "Value should have the same element type as TensorDesc."; llvm::SmallVector expectedMaskShape(tdescShape); if (chunkSize > 1) expectedMaskShape.pop_back(); if (expectedMaskShape != maskShape) return emitError() << "Mask should match TensorDesc except the chunk size dim."; // a valid shape for SIMT case if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { if (tdescTy.getLayoutAttr()) return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code"; return success(); } if (tdescShape != valueShape) return emitError() << "Value shape " << makeString(valueShape) << " is neither a valid distribution for SIMT nor " "consistent with the tensor descriptor for SIMD " << tdescTy; return success(); } static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref emitError) { auto maskVecTy = dyn_cast(maskTy); auto offsetsVecTy = dyn_cast(offsetsTy); if (!valueTy) { if (chunkSize > 1) return emitError() << "Expecting chunk size == 1 for scalar result"; if (maskVecTy || offsetsVecTy) return emitError() << "Expecting scalar mask and offsets."; else if (maskVecTy && offsetsVecTy) return emitError() << "Expecting a vector type result."; return success(); } auto valueSize = valueTy.getNumElements(); // SIMT mode with scalar mask and offsets. if (!maskVecTy && !offsetsVecTy) { if (valueSize != chunkSize) return emitError() << "value elements must match chunk size " << chunkSize; return success(); } auto maskShape = getShapeOf(maskTy); auto valueShape = getShapeOf(valueTy); if (!maskVecTy) return emitError() << "Expecting a vector type mask."; int64_t maskSize = maskVecTy.getNumElements(); if (chunkSize > 1) { if ((valueTy.getRank() == 1) && (valueSize != chunkSize)) return emitError() << "value elements must match chunk size " << chunkSize; } else { if (valueSize != maskSize) return emitError() << "Mask should match value except the chunk size dim."; } llvm::SmallVector expectedMaskShape(valueShape); if (maskSize == 1) return success(); if (chunkSize > 1) expectedMaskShape.pop_back(); if (expectedMaskShape != maskShape) return emitError() << "Mask should match value except the chunk size dim."; return success(); } //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, TypedValue source) { [[maybe_unused]] auto ty = source.getType(); assert(ty.hasStaticShape() && "expecting a memref with static shape"); build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */, ValueRange({}) /* empty dynamic shape */, ValueRange({}) /* empty dynamic strides */, DenseI64ArrayAttr({}) /* const offsets */, DenseI64ArrayAttr({}) /* empty const shape*/, DenseI64ArrayAttr({}) /* empty const strides*/); } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, Value source, llvm::ArrayRef shape, llvm::ArrayRef strides) { Type srcTy = source.getType(); assert((isa(srcTy)) && "Source has to be either int or memref."); llvm::SmallVector dynamicShape; llvm::SmallVector dynamicStrides; llvm::SmallVector staticShape; llvm::SmallVector staticStrides; dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); if (auto memrefTy = dyn_cast(srcTy)) { auto memrefShape = memrefTy.getShape(); auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); // if shape and strides are from Memref, we don't need attributes for them // to keep the IR print clean. if (staticShape == memrefShape && staticStrides == memrefStrides) { staticShapeAttr = DenseI64ArrayAttr(); staticStridesAttr = DenseI64ArrayAttr(); } } build(builder, state, tdesc, source, ValueRange({}), dynamicShape, dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, staticStridesAttr); } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, TypedValue source, llvm::ArrayRef offsets) { [[maybe_unused]] auto ty = source.getType(); assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank()); llvm::SmallVector staticOffsets; llvm::SmallVector dynamicOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, ValueRange({}) /* empty dynamic shape */, ValueRange({}) /* empty dynamic strides */, builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */, {} /* empty const shape*/, {} /* empty const strides*/); } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, Value source, llvm::ArrayRef offsets, llvm::ArrayRef shape, llvm::ArrayRef strides) { assert(!shape.empty() && !offsets.empty() && !strides.empty() && shape.size() == strides.size() && shape.size() == offsets.size()); Type srcTy = source.getType(); assert((isa(srcTy)) && "Source has to be either int or memref."); llvm::SmallVector dynamicOffsets; llvm::SmallVector dynamicShape; llvm::SmallVector dynamicStrides; llvm::SmallVector staticOffsets; llvm::SmallVector staticShape; llvm::SmallVector staticStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); if (auto memrefTy = dyn_cast(srcTy)) { auto memrefShape = memrefTy.getShape(); auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); // if shape and strides are from Memref, we don't need attributes for them // to keep the IR print clean. if (staticShape == memrefShape && staticStrides == memrefStrides) { staticShapeAttr = DenseI64ArrayAttr(); staticStridesAttr = DenseI64ArrayAttr(); } } build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); } LogicalResult CreateNdDescOp::verify() { size_t rank = getMixedSizes().size(); bool invalidRank = rank != getMixedStrides().size(); bool invalidElemTy = false; // Memory space of created TensorDesc should match with the source. // Both source and TensorDesc are considered for global memory by default, // if the memory scope attr is not specified. If source is an integer, // it is considered as ptr to global memory. auto srcMemorySpace = getSourceMemorySpace(); auto tdescMemorySpace = static_cast(getType().getMemorySpace()); if (srcMemorySpace != tdescMemorySpace) return emitOpError("Memory space mismatch.") << " Source: " << srcMemorySpace << ", TensorDesc: " << tdescMemorySpace; if (size_t offsetRank = getMixedOffsets().size()) invalidRank |= (offsetRank != rank); // check source type matches the rank if it is a memref. // It also should have the same ElementType as TensorDesc. if (auto memrefTy = dyn_cast(getSourceType())) invalidElemTy |= memrefTy.getElementType() != getElementType(); if (llvm::isa(getSourceType())) { // strides and shape must present for integer source. if (getMixedStrides().empty() || getMixedSizes().empty()) return emitOpError("expecting strides and shape to be present for " "integer source."); } if (invalidRank) return emitOpError( "Expecting the rank of shape, strides, offsets, and source (if source " "is a memref) should match with each other."); // check result TensorDesc rank if (getType().getRank() > (int64_t)rank) return emitOpError( "Expecting the TensorDesc rank is not greater than the " "ranks of shape, strides, offsets or the memref source."); if (invalidElemTy) return emitOpError("TensorDesc should have the same element " "type with the source if it is a memref.\n"); if (getType().isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); return success(); } static ParseResult parseOptionalDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, DenseI64ArrayAttr &integers, SmallVectorImpl *valueTypes = nullptr, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { SmallVector integerVals; auto parseIntegerOrValue = [&]() { OpAsmParser::UnresolvedOperand operand; auto res = parser.parseOptionalOperand(operand); if (res.has_value() && succeeded(res.value())) { values.push_back(operand); integerVals.push_back(ShapedType::kDynamic); if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) return failure(); } else { int64_t integer; if (failed(parser.parseInteger(integer))) return failure(); integerVals.push_back(integer); } return success(); }; // If the optional values are given there must be left bracket if (parser.parseOptionalLSquare().succeeded()) { if (parser.parseCommaSeparatedList(parseIntegerOrValue) || parser.parseRSquare()) return parser.emitError(parser.getNameLoc()) << "expected a list of SSA values or integers"; integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); return success(); } return success(); } static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers) { if (!integers || integers.empty()) return; printDynamicIndexList(printer, op, values, integers, /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square); } //===----------------------------------------------------------------------===// // XeGPU_PrefetchNdOp //===----------------------------------------------------------------------===// void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); } void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, ArrayRef offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { SmallVector dynamicOffsets; SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, l2_hint, l3_hint); } LogicalResult PrefetchNdOp::verify() { auto tdescTy = getTensorDescType(); if (tdescTy.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); int64_t tDescRank = tdescTy.getRank(); int64_t offsetSize = static_cast(getOffsets().size()); int64_t constOffsetSize = getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; if (((offsetSize != 0) && (offsetSize != tDescRank)) || ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); return success(); } //===----------------------------------------------------------------------===// // XeGPU_LoadNdOp //===----------------------------------------------------------------------===// void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, Value tensorDesc, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, l3_hint); } void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, Value tensorDesc, ArrayRef offsets, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { SmallVector dynamicOffsets; SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr, packed, transpose, l1_hint, l2_hint, l3_hint); } LogicalResult LoadNdOp::verify() { auto tdescTy = getTensorDescType(); auto valueTy = getType(); if (tdescTy.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); if (tdescTy.getRank() > 2) return emitOpError("Expects a 1D or 2D TensorDesc.\n"); if (!valueTy) return emitOpError("Invalid result, it should be a VectorType.\n"); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength(); int valueElems = valueTy.getNumElements(); // If the result vector is 1D and has less elements than the tensor // descriptor, it is supposed to be a SIMT op. The layout attribute in // tensor_desc is not needed. if (valueElems < tdescElems && valueTy.getRank() == 1) { // SIMT mode doesn't need LayoutAttr. if (tdescTy.getLayoutAttr()) return emitOpError() << "TensorDesc doesn't need LayoutAttr for SIMT code"; // For SIMT code, the load is evenly distributed across all lanes in a // subgroup. Since subgroup size is arch dependent, we only check even // distribution here. if (tdescElems % valueElems) return emitOpError() << "Result shape " << makeString(getShapeOf(valueTy)) << " is not a valid distribution for tensor descriptor " << tdescTy; return success(); } // Check SIMD mode. auto tdescShape = getShapeOf(tdescTy); auto valueShape = getShapeOf(valueTy); if (getTranspose()) { auto trans = getTranspose().value(); // Make sure the transpose value is valid, and apply it if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); })) tdescShape = applyPermutation(tdescShape, trans); else mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored."; } if (getPacked()) { if (tdescTy.getRank() == 2) { const int axis = 0; auto vnni_factor = valueShape.back(); tdescShape[axis] /= vnni_factor; tdescShape.push_back(vnni_factor); } else { mlir::emitWarning(getLoc()) << "Invalid Packed Attr. It is ignored (available for 2D " "TensorDesc only)."; } } auto array_len = tdescTy.getArrayLength(); if (array_len > 1) tdescShape.insert(tdescShape.begin(), array_len); if (tdescShape != valueShape) return emitOpError() << "Result shape " << makeString(valueShape) << " is not consistent with tensor descriptor " << tdescTy; int64_t tDescRank = tdescTy.getRank(); int64_t offsetSize = static_cast(getOffsets().size()); int64_t constOffsetSize = getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; if (((offsetSize != 0) && (offsetSize != tDescRank)) || ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); return success(); } //===----------------------------------------------------------------------===// // XeGPU_StoreNdOp //===----------------------------------------------------------------------===// void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, value, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); } void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, ArrayRef offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { SmallVector dynamicOffsets; SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, l2_hint, l3_hint); } LogicalResult StoreNdOp::verify() { auto dstTy = getTensorDescType(); // Tile auto valTy = getValueType(); // Vector if (dstTy.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); if (dstTy.getRank() > 2) return emitOpError("Expects a 1D or 2D TensorDesc.\n"); if (!valTy) return emitOpError("Expecting a VectorType result.\n"); if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isWriteHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); auto array_len = dstTy.getArrayLength(); if (array_len > 1) return emitOpError("array length is not supported by store_nd.\n"); auto tdescElems = dstTy.getNumElements(); auto valueElems = valTy.getNumElements(); // Similar to LoadNdOp, if the value vector is 1D and has less elements than // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute // in tensor_desc is not needed. if (valTy.getRank() == 1 && valueElems < tdescElems) { // SIMT mode doesn't need LayoutAttr. if (dstTy.getLayoutAttr()) return emitOpError() << "TensorDesc doesn't need LayoutAttr for SIMT code"; if (tdescElems % valueElems) return emitOpError() << "Value shape " << makeString(getShapeOf(valTy)) << " is not a valid distribution for tensor descriptor " << dstTy; return success(); } // SIMD code should have the same shape as the tensor descriptor. auto tdescShape = getShapeOf(dstTy); auto valueShape = getShapeOf(valTy); if (tdescShape != valueShape) return emitOpError() << "Value shape " << makeString(valueShape) << " is not consistent with tensor descriptor " << dstTy; int64_t tDescRank = dstTy.getRank(); int64_t offsetSize = static_cast(getOffsets().size()); int64_t constOffsetSize = getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; if (((offsetSize != 0) && (offsetSize != tDescRank)) || ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); return success(); } //===----------------------------------------------------------------------===// // XeGPU_UpdateNDOffsetOp //===----------------------------------------------------------------------===// LogicalResult UpdateNdOffsetOp::verify() { auto ty = getTensorDescType(); if (ty.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); // number of offsets specified must match the rank of the tensor descriptor if (ty.getRank() != (int64_t)getNumOffsets()) { return emitOpError("Invalid number of offsets."); } return success(); } //===----------------------------------------------------------------------===// // XeGPU_CreateDescOp //===----------------------------------------------------------------------===// void CreateDescOp::build(OpBuilder &builder, OperationState &state, TensorDescType TensorDesc, Value source, llvm::ArrayRef offsets) { auto loc = source.getLoc(); int64_t size = static_cast(offsets.size()); auto type = VectorType::get(size, builder.getIndexType()); auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); auto offset = vector::FromElementsOp::create(builder, loc, type, values); build(builder, state, TensorDesc, source, offset); } void CreateDescOp::build(OpBuilder &builder, OperationState &state, TensorDescType TensorDesc, Value source, llvm::ArrayRef offsets) { auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); build(builder, state, TensorDesc, source, ofrs); } LogicalResult CreateDescOp::verify() { auto tdescTy = getTensorDescType(); if (!tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); // Memory space of created TensorDesc should match with the source. // Both source and TensorDesc are considered for global memory by default, // if the memory scope attr is not specified. If source is an integer, // it is considered as ptr to global memory. auto srcMemorySpace = getSourceMemorySpace(); auto tdescMemorySpace = static_cast(tdescTy.getMemorySpace()); if (srcMemorySpace != tdescMemorySpace) return emitOpError("Memory space mismatch.") << " Source: " << srcMemorySpace << ", TensorDesc: " << tdescMemorySpace; // check total size auto chunkSize = tdescTy.getChunkSizeAsInt(); SmallVector shape(getOffsetsType().getShape()); if (chunkSize != 1) shape.push_back(chunkSize); auto tdescShape = getShapeOf(tdescTy); if (shape != tdescShape) return emitOpError("Incorrect TensorDesc shape. ") << "Expected is " << makeString(shape) << "\n"; return success(); } //===----------------------------------------------------------------------===// // XeGPU_PrefetchOp //===----------------------------------------------------------------------===// LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); if (!tdescTy && !getOffsets()) return emitOpError("Expects offsets."); if (tdescTy && getOffsets()) return emitOpError("offsets not allowed."); if (tdescTy && !tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc."); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); auto srcTy = getSourceType(); if (srcTy.isInteger() && !getOffsetAlignByteAttr()) return emitOpError("offset_align_byte is required with integer source."); if (getOffsetAlignByteAttr() && !srcTy.isInteger()) return emitOpError("offset_align_byte only allowed with integer source."); return success(); } void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint, IntegerAttr{}); } //===----------------------------------------------------------------------===// // XeGPU_LoadGatherOp //===----------------------------------------------------------------------===// LogicalResult LoadGatherOp::verify() { auto tdescTy = getTensorDescType(); auto maskTy = getMaskType(); auto valueTy = getValueType(); if (!tdescTy && !getOffsets()) return emitOpError("Expects offsets."); if (tdescTy && getOffsets()) return emitOpError("offsets not allowed."); if (tdescTy && !tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc."); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); if (tdescTy) return isValidGatherScatterParams(maskTy, valueTy, tdescTy, [&]() { return emitOpError(); }); auto srcTy = getSourceType(); uint64_t chunkSize = static_cast(getChunkSize().value_or(1)); auto memTy = dyn_cast(srcTy); if (memTy && (getElementType() != memTy.getElementType())) return emitError() << "Value should have the same element type as MemRef."; auto offsetsTy = getOffsets().getType(); return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize, [&]() { return emitOpError(); }); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type valueType, Value source, Value mask, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, valueType, source, Value(), mask, IntegerAttr(), l1_hint, l2_hint, l3_hint); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type valueType, Value source, ArrayRef offsets, Value mask, IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { auto loc = source.getLoc(); int64_t size = static_cast(offsets.size()); auto type = VectorType::get(size, builder.getIndexType()); auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); auto offset = vector::FromElementsOp::create(builder, loc, type, values); build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, l2_hint, l3_hint); } //===----------------------------------------------------------------------===// // XeGPU_StoreScatterOp //===----------------------------------------------------------------------===// LogicalResult StoreScatterOp::verify() { auto tdescTy = getTensorDescType(); auto maskTy = getMaskType(); auto valueTy = getValueType(); if (!tdescTy && !getOffsets()) return emitOpError("Expects offsets."); if (tdescTy && getOffsets()) return emitOpError("offsets not allowed."); if (tdescTy && !tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc."); if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isWriteHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); if (tdescTy) return isValidGatherScatterParams(maskTy, valueTy, tdescTy, [&]() { return emitOpError(); }); auto destTy = getDestType(); uint64_t chunkSize = static_cast(getChunkSize().value_or(1)); auto memTy = dyn_cast(destTy); if (memTy && (getElementType() != memTy.getElementType())) return emitError() << "Value should have the same element type as MemRef."; auto offsetsTy = getOffsets().getType(); return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize, [&]() { return emitOpError(); }); } void StoreScatterOp::build(OpBuilder &builder, OperationState &state, Value value, Value dest, Value mask, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, l2_hint, l3_hint); } void StoreScatterOp::build(OpBuilder &builder, OperationState &state, Value value, Value dest, ArrayRef offsets, Value mask, IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { auto loc = dest.getLoc(); int64_t size = static_cast(offsets.size()); auto type = VectorType::get(size, builder.getIndexType()); auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); auto offset = vector::FromElementsOp::create(builder, loc, type, values); // Call the correct builder overload that does not expect result types. build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, l3_hint); } //===----------------------------------------------------------------------===// // XeGPU_UpdateOffsetOp //===----------------------------------------------------------------------===// void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, mlir::Value tensorDesc, llvm::ArrayRef offsets) { auto tdescTy = mlir::dyn_cast(tensorDesc.getType()); assert(tdescTy && "Expecting the source is a TensorDescType value."); auto loc = tensorDesc.getLoc(); int64_t size = static_cast(offsets.size()); auto type = VectorType::get({size}, builder.getIndexType()); auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); auto offset = vector::FromElementsOp::create(builder, loc, type, values); build(builder, state, tdescTy, tensorDesc, offset); } void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, llvm::ArrayRef offsets) { auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); build(builder, state, tensorDesc, ofrs); } LogicalResult UpdateOffsetOp::verify() { auto tdescTy = getTensorDescType(); if (!tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); SmallVector expectedOffsetShape = getShapeOf(tdescTy); SmallVector offsetShape = getShapeOf(getOffsetsType()); if (tdescTy.getChunkSizeAsInt() > 1) expectedOffsetShape.pop_back(); if (expectedOffsetShape != offsetShape) return emitOpError( "Offsets should match TensorDesc except the chunk size dim."); return success(); } //===----------------------------------------------------------------------===// // XeGPU_DpasOp //===----------------------------------------------------------------------===// LogicalResult DpasOp::verify() { int64_t lhsRank = getLhsType().getRank(); int64_t rhsRank = getRhsType().getRank(); int64_t resRank = getResultType().getRank(); auto lhsShape = getLhsType().getShape(); auto rhsShape = getRhsType().getShape(); auto resShape = getResultType().getShape(); if (getAcc() && getAcc().getType() != getResultType()) return emitOpError("Expecting the acc type to be the same as result."); // SIMT code: the size of the B operand has to be a multiple of 32 bits. // It skips the semantic check since lack of architecture information. // Users need to ensure the correctness. if (lhsRank == 1 && rhsRank == 1 && resRank == 1) { auto numElems = getRhsType().getNumElements(); auto elemTy = getRhsType().getElementType(); auto factor = 32 / elemTy.getIntOrFloatBitWidth(); if (numElems % factor != 0) return emitOpError("Expecting B operand to be a multiple of 32 bits."); return success(); } // SIMD code if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2) return emitOpError( "expecting lhs and result to be a 2D vector, and rhs to be either " "2D or 3D (packed) vector."); auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0]; if (bK != lhsShape[1]) return emitOpError("K-dimension mismatch."); if (lhsShape[0] != resShape[0]) return emitOpError("M-dimension mismatch."); if (rhsShape[1] != resShape[1]) return emitOpError("N-dimension mismatch."); return success(); } //===----------------------------------------------------------------------===// // XeGPU_ConvertLayoutOp //===----------------------------------------------------------------------===// LogicalResult ConvertLayoutOp::verify() { auto srcLayout = getInputLayout(); auto resLayout = getTargetLayout(); if (!srcLayout) return emitOpError("expected input layout."); if (!resLayout) return emitOpError("expected target layout."); // both input and target layouts should be WgLayout or SgLayout at the same // time. if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) && (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup())) return emitOpError("expected input layout and target layout be WgLayout or " "SgLayout at the same time."); auto shape = getSource().getType().getShape(); if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout)) return emitOpError( "invalid input layout, data cannot be evenly distributed."); if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout)) return emitOpError( "invalid target layout, data cannot be evenly distributed."); return mlir::success(); } OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) { if (getInputLayout() == getTargetLayout()) return getSource(); return {}; } struct FoldConvertLayoutOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override { if (op.getInputLayout() == op.getTargetLayout()) { rewriter.replaceOp(op, op.getSource()); return success(); } return failure(); } }; void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // XeGPU_LoadMatrixOp //===----------------------------------------------------------------------===// void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, TypedValue memDesc, llvm::ArrayRef offsets, DistributeLayoutAttr layout) { llvm::SmallVector dynamicOffsets; llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, layout); } LogicalResult LoadMatrixOp::verify() { VectorType resTy = getRes().getType(); MemDescType mdescTy = getMemDesc().getType(); if (mdescTy.getRank() != 2) return emitOpError("mem_desc must be 2D."); ArrayRef valueShape = resTy.getShape(); ArrayRef mdescShape = mdescTy.getShape(); if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), [](auto p) { return std::get<0>(p) > std::get<1>(p); })) return emitOpError("result shape must not exceed mem_desc shape."); return success(); } //===----------------------------------------------------------------------===// // XeGPU_StoreMatrixOp //===----------------------------------------------------------------------===// void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, TypedValue memDesc, llvm::ArrayRef offsets, DistributeLayoutAttr layout) { llvm::SmallVector dynamicOffsets; llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, layout); } LogicalResult StoreMatrixOp::verify() { VectorType dataTy = getData().getType(); MemDescType mdescTy = getMemDesc().getType(); if (mdescTy.getRank() != 2) return emitOpError("mem_desc must be 2D."); ArrayRef dataShape = dataTy.getShape(); ArrayRef mdescShape = mdescTy.getShape(); if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), [](auto p) { return std::get<0>(p) > std::get<1>(p); })) return emitOpError("data shape must not exceed mem_desc shape."); return success(); } //===----------------------------------------------------------------------===// // XeGPU_MemDescSubviewOp //===----------------------------------------------------------------------===// void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, Type resTy, Value src, llvm::ArrayRef offsets) { llvm::SmallVector dynamicOffsets; llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); } LogicalResult MemDescSubviewOp::verify() { MemDescType srcTy = getSrc().getType(); MemDescType resTy = getRes().getType(); ArrayRef srcShape = srcTy.getShape(); ArrayRef resShape = resTy.getShape(); if (srcTy.getRank() < resTy.getRank()) return emitOpError("result rank must not exceed source rank."); if (llvm::any_of( llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), [](auto p) { return std::get<0>(p) > std::get<1>(p); })) return emitOpError("result shape must not exceed source shape."); if (srcTy.getStrides() != resTy.getStrides()) return emitOpError("result must inherit the source strides."); return success(); } } // namespace xegpu } // namespace mlir namespace mlir { #include } // namespace mlir #include #define GET_OP_CLASSES #include