diff options
Diffstat (limited to 'mlir/lib/Dialect/AMDGPU')
| -rw-r--r-- | mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 68 | 
1 files changed, 38 insertions, 30 deletions
| diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 61166db..585b6da 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -360,45 +360,53 @@ LogicalResult ScaledExtPacked816Op::verify() {  //===----------------------------------------------------------------------===//  // WMMAOp  //===----------------------------------------------------------------------===// -LogicalResult WMMAOp::verify() { -  Type sourceAType = getSourceA().getType(); -  Type sourceBType = getSourceB().getType(); -  Type destType = getDestC().getType(); -  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType); -  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType); -  VectorType destVectorType = dyn_cast<VectorType>(destType); +ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser, +                                                IntegerAttr &m, IntegerAttr &n, +                                                IntegerAttr &k) { +  SmallVector<int64_t, 3> dimensions; +  if (parser.parseDimensionList(dimensions, false, false)) +    return failure(); +  if (dimensions.size() != 3) +    return parser.emitError(parser.getCurrentLocation()) +           << "expected 3 dimensions in MNK dimension list"; -  Type sourceAElemType = sourceVectorAType.getElementType(); -  Type sourceBElemType = sourceVectorBType.getElementType(); -  Type destElemType = destVectorType.getElementType(); +  m = parser.getBuilder().getI32IntegerAttr(dimensions[0]); +  n = parser.getBuilder().getI32IntegerAttr(dimensions[1]); +  k = parser.getBuilder().getI32IntegerAttr(dimensions[2]); +  return success(); +} -  if (sourceVectorAType.getNumElements() != -      sourceVectorBType.getNumElements()) { +LogicalResult WMMAOp::verify() { +  auto sourceAType = cast<VectorType>(getSourceA().getType()); +  auto sourceBType = cast<VectorType>(getSourceB().getType()); +  auto destType = cast<VectorType>(getDestC().getType()); + +  Type sourceAElemType = sourceAType.getElementType(); +  Type sourceBElemType = sourceBType.getElementType(); +  if (sourceAType.getNumElements() != sourceBType.getNumElements()) {      return emitOpError("source vectors have different lengths: ") -           << sourceVectorAType << " vs. " << sourceVectorBType; +           << sourceAType << " vs. " << sourceBType;    } -  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType); -  bool isSrcFloat = -      isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>( -          sourceAElemType); - -  if (isDestFloat && !isSrcFloat) { -    return emitOpError("Expected float sources with float destination"); -  } +  bool isDestFloat = destType.getElementType().isFloat(); +  bool isSrcFloat = sourceAElemType.isFloat(); -  if (!isDestFloat && isSrcFloat) { -    return emitOpError("Expected int sources with int destination"); -  } +  if (isDestFloat && !isSrcFloat) +    return emitOpError("expected float sources with float destination"); +  if (!isDestFloat && isSrcFloat) +    return emitOpError("expected int sources with int destination"); -  if (sourceAElemType != sourceBElemType && -      !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) && -        isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) { +  if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {      return emitOpError(                 "source element types much match (except for fp8) but have ")             << sourceAType << " and " << sourceBType;    } + +  if (!sourceAElemType.isInteger(4) && getK() != 16) { +    return emitOpError("K dimension must be 16 for source element type ") +           << sourceAElemType; +  }    return success();  } @@ -414,11 +422,11 @@ LogicalResult MFMAOp::verify() {    Type sourceElem = sourceType, destElem = destType;    uint32_t sourceLen = 1, destLen = 1; -  if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) { +  if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {      sourceLen = sourceVector.getNumElements();      sourceElem = sourceVector.getElementType();    } -  if (auto destVector = llvm::dyn_cast<VectorType>(destType)) { +  if (auto destVector = dyn_cast<VectorType>(destType)) {      destLen = destVector.getNumElements();      destElem = destVector.getElementType();    } @@ -443,7 +451,7 @@ LogicalResult MFMAOp::verify() {        return emitOpError("expected both non-small-float source operand types "                           "to match exactly");    } -  // Normalize the wider integer types the compiler expects to i8 +  // Normalize the wider integer types the compiler expects to i8.    if (sourceElem.isInteger(32)) {      sourceLen *= 4;      sourceElem = b.getI8Type(); | 
