diff options
author | Chao Chen <116223022+chencha3@users.noreply.github.com> | 2024-03-20 17:32:30 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-20 17:32:30 -0500 |
commit | 61b24c61a90802e06e40a7ab0aa5e2138486bd73 (patch) | |
tree | c25c945913bef58ee6f52280475cd1f8c647ff30 /mlir/lib | |
parent | de0abc0983d355bbd971c5c571ba4c209a0c63ea (diff) | |
download | llvm-61b24c61a90802e06e40a7ab0aa5e2138486bd73.zip llvm-61b24c61a90802e06e40a7ab0aa5e2138486bd73.tar.gz llvm-61b24c61a90802e06e40a7ab0aa5e2138486bd73.tar.bz2 |
[MLIR][XeGPU] Adding XeGPU 2d block operators (#85804)
This PR adds XeGPU 2D block operators. It contains:
1. TensorDescType and TensorDescAttr definitions
2. MemoryScopeAttr and CacheHintAttr definitions which are used by
TensorDescAttr.
3. CreateNdDescOp, PrefetchNdOp, LoadNdOp, and StoreNdOp definitions,
and their corresponding testcases for illustration.
It cherry-picks daebe5c4f27ba140ac8d13abf41e3fe4db72b91a with asan fix.
---------
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 73 | ||||
-rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 187 |
2 files changed, 254 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 4f839ee..0b3f4b9 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -6,7 +6,10 @@ // //===----------------------------------------------------------------------===// -#include <mlir/Dialect/XeGPU/IR/XeGPU.h> +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" namespace mlir { namespace xegpu { @@ -26,8 +29,72 @@ void XeGPUDialect::initialize() { >(); } -// this file is for position occupation, -// we will add functions in following PRs. +//===----------------------------------------------------------------------===// +// XeGPU_TensorDescAttr +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// XeGPU_TensorDescType +//===----------------------------------------------------------------------===// +mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { + llvm::SmallVector<int64_t> shape; + mlir::Type elementType; + mlir::FailureOr<mlir::Attribute> encoding; + + // Parse literal '<' + if (parser.parseLess()) + return {}; + + auto shapeLoc = parser.getCurrentLocation(); + if (mlir::failed(parser.parseDimensionList(shape))) { + parser.emitError(shapeLoc, "failed to parse parameter 'shape'"); + return {}; + } + + auto elemTypeLoc = parser.getCurrentLocation(); + if (mlir::failed(parser.parseType(elementType))) { + parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'"); + return {}; + } + + // parse optional attributes + if (mlir::succeeded(parser.parseOptionalComma())) { + encoding = mlir::FieldParser<mlir::Attribute>::parse(parser); + if (mlir::failed(encoding)) { + parser.emitError( + parser.getCurrentLocation(), + "Failed to parse the attribute field for TensorDescType.\n"); + return {}; + } + } + + // Parse literal '>' + if (parser.parseGreater()) + return {}; + + return TensorDescType::get(parser.getContext(), shape, elementType, + encoding.value_or(mlir::Attribute())); +} + +void TensorDescType::print(::mlir::AsmPrinter &printer) const { + printer << "<"; + + auto shape = getShape(); + for (int64_t dim : shape) { + if (mlir::ShapedType::isDynamic(dim)) + printer << '?'; + else + printer << dim; + printer << 'x'; + } + + printer << getElementType(); + + if (auto encoding = getEncoding()) + printer << ", " << encoding; + + printer << ">"; +} } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index b356c39..a0bed51 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -6,15 +6,196 @@ // //===----------------------------------------------------------------------===// -#include <mlir/Dialect/XeGPU/IR/XeGPU.h> +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/IR/Builders.h" #define DEBUG_TYPE "xegpu" namespace mlir { namespace xegpu { -// this file is for position occupation, -// we will add functions in following PRs. +static void transpose(llvm::ArrayRef<int64_t> trans, + std::vector<int64_t> &shape) { + std::vector<int64_t> old = shape; + for (size_t i = 0; i < trans.size(); i++) + shape[i] = old[trans[i]]; +} + +template <typename T> +static std::string makeString(T array, bool breakline = false) { + std::string buf; + buf.clear(); + llvm::raw_string_ostream os(buf); + os << "["; + for (size_t i = 1; i < array.size(); i++) { + os << array[i - 1] << ", "; + if (breakline) + os << "\n\t\t"; + } + os << array.back() << "]"; + os.flush(); + return buf; +} + +//===----------------------------------------------------------------------===// +// XeGPU_CreateNdDescOp +//===----------------------------------------------------------------------===// +void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, + Type tdesc, TypedValue<MemRefType> source, + llvm::ArrayRef<OpFoldResult> offsets) { + auto ty = source.getType(); + assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank()); + + llvm::SmallVector<int64_t> staticOffsets; + llvm::SmallVector<Value> dynamicOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, + ValueRange({}) /* empty dynamic shape */, + ValueRange({}) /* empty dynamic strides */, + staticOffsets /* const offsets */, {} /* empty const shape*/, + {} /* empty const strides*/); +} + +void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, + Type tdesc, TypedValue<IntegerType> source, + llvm::ArrayRef<OpFoldResult> offsets, + llvm::ArrayRef<OpFoldResult> shape, + llvm::ArrayRef<OpFoldResult> strides) { + assert(shape.size() && offsets.size() && strides.size() && + shape.size() == strides.size() && shape.size() == offsets.size()); + + llvm::SmallVector<int64_t> staticOffsets; + llvm::SmallVector<int64_t> staticShape; + llvm::SmallVector<int64_t> staticStrides; + llvm::SmallVector<Value> dynamicOffsets; + llvm::SmallVector<Value> dynamicShape; + llvm::SmallVector<Value> dynamicStrides; + + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); + dispatchIndexOpFoldResults(strides, dynamicStrides, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); + auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); + + build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, + dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); +} + +LogicalResult CreateNdDescOp::verify() { + auto rank = (int64_t)getMixedOffsets().size(); + bool invalidRank = (rank != 2); + bool invalidElemTy = false; + + // check source type matches the rank if it is a memref. + // It also should have the same ElementType as TensorDesc. + auto memrefTy = getSourceType().dyn_cast<MemRefType>(); + if (memrefTy) { + invalidRank |= (memrefTy.getRank() != rank); + invalidElemTy |= memrefTy.getElementType() != getElementType(); + } + + // check result type matches the rank + invalidRank = (getType().getRank() != rank); + + // mismatches among shape, strides, and offsets are + // already handeled by OffsetSizeAndStrideOpInterface. + // So they are not check here. + if (invalidRank) + return emitOpError( + "Expecting the rank of shape, strides, offsets, " + "source memref type (if source is a memref) and TensorDesc " + "should match with each other. They currenlty are 2D."); + + if (invalidElemTy) + return emitOpError("TensorDesc should have the same element " + "type with the source if it is a memref.\n"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_LoadNdOp +//===----------------------------------------------------------------------===// +LogicalResult LoadNdOp::verify() { + auto tdescTy = getTensorDescType(); + auto valueTy = getType(); + + if (tdescTy.getRank() != 2) + return emitOpError( + "The TensorDesc for LoadNdOp should be a 2D TensorDesc."); + + if (!valueTy) + return emitOpError("Invalid result, it should be a VectorType.\n"); + + auto tdescElemTy = tdescTy.getElementType(); + auto valueElemTy = valueTy.getElementType(); + + if (tdescElemTy != valueElemTy) + return emitOpError( + "Value should have the same element type as TensorDesc."); + + auto array_len = tdescTy.getArrayLength(); + auto tdescShape = tdescTy.getShape().vec(); + auto valueShape = valueTy.getShape().vec(); + + if (getTranspose()) { + auto trans = getTranspose().value(); + if (tdescShape.size() >= trans.size()) + transpose(trans, tdescShape); + else + emitWarning("Invalid transpose attr. It is ignored."); + } + + if (getVnniAxis()) { + auto axis = getVnniAxis().value(); + auto vnni_factor = valueShape.back(); + tdescShape[axis] /= vnni_factor; + tdescShape.push_back(vnni_factor); + } + + if (array_len > 1) { + auto it = tdescShape.begin(); + tdescShape.insert(it, array_len); + } + + if (tdescShape != valueShape) + return emitOpError() << "Result shape doesn't match TensorDesc shape." + << "The expected shape is " << makeString(tdescShape) + << ". But the given shape is " + << makeString(valueShape) << ".\n"; + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_StoreNdOp +//===----------------------------------------------------------------------===// +LogicalResult StoreNdOp::verify() { + auto dstTy = getTensorDesc().getType(); // Tile + auto valTy = getValue().getType().cast<VectorType>(); // Vector + + if (dstTy.getRank() != 2) + return emitOpError("Expecting a 2D TensorDesc shape.\n"); + + if (!valTy) + return emitOpError("Exepcting a VectorType result.\n"); + + auto dstElemTy = dstTy.getElementType(); + auto valElemTy = valTy.getElementType(); + + if (dstElemTy != valElemTy) { + return emitOpError() << "The element type of the value should " + "match the elementtype of the TensorDesc.\n"; + } + + if (dstTy.getShape() != valTy.getShape()) + return emitOpError() + << "The result shape should match the TensorDesc shape.\n"; + return success(); +} } // namespace xegpu } // namespace mlir |