diff options
Diffstat (limited to 'mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp')
| -rw-r--r-- | mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 68 |
1 files changed, 38 insertions, 30 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 61166db..585b6da 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -360,45 +360,53 @@ LogicalResult ScaledExtPacked816Op::verify() { //===----------------------------------------------------------------------===// // WMMAOp //===----------------------------------------------------------------------===// -LogicalResult WMMAOp::verify() { - Type sourceAType = getSourceA().getType(); - Type sourceBType = getSourceB().getType(); - Type destType = getDestC().getType(); - VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType); - VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType); - VectorType destVectorType = dyn_cast<VectorType>(destType); +ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser, + IntegerAttr &m, IntegerAttr &n, + IntegerAttr &k) { + SmallVector<int64_t, 3> dimensions; + if (parser.parseDimensionList(dimensions, false, false)) + return failure(); + if (dimensions.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expected 3 dimensions in MNK dimension list"; - Type sourceAElemType = sourceVectorAType.getElementType(); - Type sourceBElemType = sourceVectorBType.getElementType(); - Type destElemType = destVectorType.getElementType(); + m = parser.getBuilder().getI32IntegerAttr(dimensions[0]); + n = parser.getBuilder().getI32IntegerAttr(dimensions[1]); + k = parser.getBuilder().getI32IntegerAttr(dimensions[2]); + return success(); +} - if (sourceVectorAType.getNumElements() != - sourceVectorBType.getNumElements()) { +LogicalResult WMMAOp::verify() { + auto sourceAType = cast<VectorType>(getSourceA().getType()); + auto sourceBType = cast<VectorType>(getSourceB().getType()); + auto destType = cast<VectorType>(getDestC().getType()); + + Type sourceAElemType = sourceAType.getElementType(); + Type sourceBElemType = sourceBType.getElementType(); + if (sourceAType.getNumElements() != sourceBType.getNumElements()) { return emitOpError("source vectors have different lengths: ") - << sourceVectorAType << " vs. " << sourceVectorBType; + << sourceAType << " vs. " << sourceBType; } - bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType); - bool isSrcFloat = - isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>( - sourceAElemType); - - if (isDestFloat && !isSrcFloat) { - return emitOpError("Expected float sources with float destination"); - } + bool isDestFloat = destType.getElementType().isFloat(); + bool isSrcFloat = sourceAElemType.isFloat(); - if (!isDestFloat && isSrcFloat) { - return emitOpError("Expected int sources with int destination"); - } + if (isDestFloat && !isSrcFloat) + return emitOpError("expected float sources with float destination"); + if (!isDestFloat && isSrcFloat) + return emitOpError("expected int sources with int destination"); - if (sourceAElemType != sourceBElemType && - !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) && - isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) { + if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) { return emitOpError( "source element types much match (except for fp8) but have ") << sourceAType << " and " << sourceBType; } + + if (!sourceAElemType.isInteger(4) && getK() != 16) { + return emitOpError("K dimension must be 16 for source element type ") + << sourceAElemType; + } return success(); } @@ -414,11 +422,11 @@ LogicalResult MFMAOp::verify() { Type sourceElem = sourceType, destElem = destType; uint32_t sourceLen = 1, destLen = 1; - if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) { + if (auto sourceVector = dyn_cast<VectorType>(sourceType)) { sourceLen = sourceVector.getNumElements(); sourceElem = sourceVector.getElementType(); } - if (auto destVector = llvm::dyn_cast<VectorType>(destType)) { + if (auto destVector = dyn_cast<VectorType>(destType)) { destLen = destVector.getNumElements(); destElem = destVector.getElementType(); } @@ -443,7 +451,7 @@ LogicalResult MFMAOp::verify() { return emitOpError("expected both non-small-float source operand types " "to match exactly"); } - // Normalize the wider integer types the compiler expects to i8 + // Normalize the wider integer types the compiler expects to i8. if (sourceElem.isInteger(32)) { sourceLen *= 4; sourceElem = b.getI8Type(); |
