aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorChao Chen <116223022+chencha3@users.noreply.github.com>2024-03-20 17:32:30 -0500
committerGitHub <noreply@github.com>2024-03-20 17:32:30 -0500
commit61b24c61a90802e06e40a7ab0aa5e2138486bd73 (patch)
treec25c945913bef58ee6f52280475cd1f8c647ff30 /mlir/lib
parentde0abc0983d355bbd971c5c571ba4c209a0c63ea (diff)
downloadllvm-61b24c61a90802e06e40a7ab0aa5e2138486bd73.zip
llvm-61b24c61a90802e06e40a7ab0aa5e2138486bd73.tar.gz
llvm-61b24c61a90802e06e40a7ab0aa5e2138486bd73.tar.bz2
[MLIR][XeGPU] Adding XeGPU 2d block operators (#85804)
This PR adds XeGPU 2D block operators. It contains: 1. TensorDescType and TensorDescAttr definitions 2. MemoryScopeAttr and CacheHintAttr definitions which are used by TensorDescAttr. 3. CreateNdDescOp, PrefetchNdOp, LoadNdOp, and StoreNdOp definitions, and their corresponding testcases for illustration. It cherry-picks daebe5c4f27ba140ac8d13abf41e3fe4db72b91a with asan fix. --------- Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp73
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp187
2 files changed, 254 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 4f839ee..0b3f4b9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -6,7 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#include <mlir/Dialect/XeGPU/IR/XeGPU.h>
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
namespace xegpu {
@@ -26,8 +29,72 @@ void XeGPUDialect::initialize() {
>();
}
-// this file is for position occupation,
-// we will add functions in following PRs.
+//===----------------------------------------------------------------------===//
+// XeGPU_TensorDescAttr
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// XeGPU_TensorDescType
+//===----------------------------------------------------------------------===//
+mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
+ llvm::SmallVector<int64_t> shape;
+ mlir::Type elementType;
+ mlir::FailureOr<mlir::Attribute> encoding;
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ auto shapeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseDimensionList(shape))) {
+ parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
+ return {};
+ }
+
+ auto elemTypeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseType(elementType))) {
+ parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
+ return {};
+ }
+
+ // parse optional attributes
+ if (mlir::succeeded(parser.parseOptionalComma())) {
+ encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
+ if (mlir::failed(encoding)) {
+ parser.emitError(
+ parser.getCurrentLocation(),
+ "Failed to parse the attribute field for TensorDescType.\n");
+ return {};
+ }
+ }
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ return TensorDescType::get(parser.getContext(), shape, elementType,
+ encoding.value_or(mlir::Attribute()));
+}
+
+void TensorDescType::print(::mlir::AsmPrinter &printer) const {
+ printer << "<";
+
+ auto shape = getShape();
+ for (int64_t dim : shape) {
+ if (mlir::ShapedType::isDynamic(dim))
+ printer << '?';
+ else
+ printer << dim;
+ printer << 'x';
+ }
+
+ printer << getElementType();
+
+ if (auto encoding = getEncoding())
+ printer << ", " << encoding;
+
+ printer << ">";
+}
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index b356c39..a0bed51 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -6,15 +6,196 @@
//
//===----------------------------------------------------------------------===//
-#include <mlir/Dialect/XeGPU/IR/XeGPU.h>
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Builders.h"
#define DEBUG_TYPE "xegpu"
namespace mlir {
namespace xegpu {
-// this file is for position occupation,
-// we will add functions in following PRs.
+static void transpose(llvm::ArrayRef<int64_t> trans,
+ std::vector<int64_t> &shape) {
+ std::vector<int64_t> old = shape;
+ for (size_t i = 0; i < trans.size(); i++)
+ shape[i] = old[trans[i]];
+}
+
+template <typename T>
+static std::string makeString(T array, bool breakline = false) {
+ std::string buf;
+ buf.clear();
+ llvm::raw_string_ostream os(buf);
+ os << "[";
+ for (size_t i = 1; i < array.size(); i++) {
+ os << array[i - 1] << ", ";
+ if (breakline)
+ os << "\n\t\t";
+ }
+ os << array.back() << "]";
+ os.flush();
+ return buf;
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_CreateNdDescOp
+//===----------------------------------------------------------------------===//
+void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
+ Type tdesc, TypedValue<MemRefType> source,
+ llvm::ArrayRef<OpFoldResult> offsets) {
+ auto ty = source.getType();
+ assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
+
+ llvm::SmallVector<int64_t> staticOffsets;
+ llvm::SmallVector<Value> dynamicOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
+ ValueRange({}) /* empty dynamic shape */,
+ ValueRange({}) /* empty dynamic strides */,
+ staticOffsets /* const offsets */, {} /* empty const shape*/,
+ {} /* empty const strides*/);
+}
+
+void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
+ Type tdesc, TypedValue<IntegerType> source,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ llvm::ArrayRef<OpFoldResult> shape,
+ llvm::ArrayRef<OpFoldResult> strides) {
+ assert(shape.size() && offsets.size() && strides.size() &&
+ shape.size() == strides.size() && shape.size() == offsets.size());
+
+ llvm::SmallVector<int64_t> staticOffsets;
+ llvm::SmallVector<int64_t> staticShape;
+ llvm::SmallVector<int64_t> staticStrides;
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<Value> dynamicShape;
+ llvm::SmallVector<Value> dynamicStrides;
+
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
+ dispatchIndexOpFoldResults(strides, dynamicStrides, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
+ auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
+
+ build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
+ dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
+}
+
+LogicalResult CreateNdDescOp::verify() {
+ auto rank = (int64_t)getMixedOffsets().size();
+ bool invalidRank = (rank != 2);
+ bool invalidElemTy = false;
+
+ // check source type matches the rank if it is a memref.
+ // It also should have the same ElementType as TensorDesc.
+ auto memrefTy = getSourceType().dyn_cast<MemRefType>();
+ if (memrefTy) {
+ invalidRank |= (memrefTy.getRank() != rank);
+ invalidElemTy |= memrefTy.getElementType() != getElementType();
+ }
+
+ // check result type matches the rank
+ invalidRank = (getType().getRank() != rank);
+
+ // mismatches among shape, strides, and offsets are
+ // already handeled by OffsetSizeAndStrideOpInterface.
+ // So they are not check here.
+ if (invalidRank)
+ return emitOpError(
+ "Expecting the rank of shape, strides, offsets, "
+ "source memref type (if source is a memref) and TensorDesc "
+ "should match with each other. They currenlty are 2D.");
+
+ if (invalidElemTy)
+ return emitOpError("TensorDesc should have the same element "
+ "type with the source if it is a memref.\n");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_LoadNdOp
+//===----------------------------------------------------------------------===//
+LogicalResult LoadNdOp::verify() {
+ auto tdescTy = getTensorDescType();
+ auto valueTy = getType();
+
+ if (tdescTy.getRank() != 2)
+ return emitOpError(
+ "The TensorDesc for LoadNdOp should be a 2D TensorDesc.");
+
+ if (!valueTy)
+ return emitOpError("Invalid result, it should be a VectorType.\n");
+
+ auto tdescElemTy = tdescTy.getElementType();
+ auto valueElemTy = valueTy.getElementType();
+
+ if (tdescElemTy != valueElemTy)
+ return emitOpError(
+ "Value should have the same element type as TensorDesc.");
+
+ auto array_len = tdescTy.getArrayLength();
+ auto tdescShape = tdescTy.getShape().vec();
+ auto valueShape = valueTy.getShape().vec();
+
+ if (getTranspose()) {
+ auto trans = getTranspose().value();
+ if (tdescShape.size() >= trans.size())
+ transpose(trans, tdescShape);
+ else
+ emitWarning("Invalid transpose attr. It is ignored.");
+ }
+
+ if (getVnniAxis()) {
+ auto axis = getVnniAxis().value();
+ auto vnni_factor = valueShape.back();
+ tdescShape[axis] /= vnni_factor;
+ tdescShape.push_back(vnni_factor);
+ }
+
+ if (array_len > 1) {
+ auto it = tdescShape.begin();
+ tdescShape.insert(it, array_len);
+ }
+
+ if (tdescShape != valueShape)
+ return emitOpError() << "Result shape doesn't match TensorDesc shape."
+ << "The expected shape is " << makeString(tdescShape)
+ << ". But the given shape is "
+ << makeString(valueShape) << ".\n";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_StoreNdOp
+//===----------------------------------------------------------------------===//
+LogicalResult StoreNdOp::verify() {
+ auto dstTy = getTensorDesc().getType(); // Tile
+ auto valTy = getValue().getType().cast<VectorType>(); // Vector
+
+ if (dstTy.getRank() != 2)
+ return emitOpError("Expecting a 2D TensorDesc shape.\n");
+
+ if (!valTy)
+ return emitOpError("Exepcting a VectorType result.\n");
+
+ auto dstElemTy = dstTy.getElementType();
+ auto valElemTy = valTy.getElementType();
+
+ if (dstElemTy != valElemTy) {
+ return emitOpError() << "The element type of the value should "
+ "match the elementtype of the TensorDesc.\n";
+ }
+
+ if (dstTy.getShape() != valTy.getShape())
+ return emitOpError()
+ << "The result shape should match the TensorDesc shape.\n";
+ return success();
+}
} // namespace xegpu
} // namespace mlir