diff options
Diffstat (limited to 'mlir/lib/Conversion/AMDGPUToROCDL')
| -rw-r--r-- | mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 177 | 
1 files changed, 147 insertions, 30 deletions
| diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 478b6aa..3a307a0 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -935,7 +935,7 @@ static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {        .Case([](Float6E2M3FNType) { return 2u; })        .Case([](Float6E3M2FNType) { return 3u; })        .Case([](Float4E2M1FNType) { return 4u; }) -      .Default([](Type) { return std::nullopt; }); +      .Default(std::nullopt);  }  /// If there is a scaled MFMA instruction for the input element types `aType` @@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {                                   smfma.getN(), smfma.getK(), 1u, chipset);  } -/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` -/// if one exists. This includes checking to ensure the intrinsic is supported -/// on the architecture you are compiling for. -static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, -                                                  Chipset chipset) { -  auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType()); -  auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType()); -  auto destVectorType = cast<VectorType>(wmma.getDestC().getType()); -  Type elemSourceType = sourceVectorType.getElementType(); -  Type elemBSourceType = sourceBVectorType.getElementType(); -  Type elemDestType = destVectorType.getElementType(); - -  const uint32_t k = wmma.getK(); +/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// for RDNA3/4 architectures. +static std::optional<StringRef> +wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, +                      Type elemDestType, uint32_t k, bool isRDNA3) { +  using fp8 = Float8E4M3FNType; +  using bf8 = Float8E5M2Type; +  // Handle k == 16 for RDNA3/4.    if (k == 16) { +    // Common patterns for RDNA3 and RDNA4.      if (elemSourceType.isF16() && elemDestType.isF32())        return ROCDL::wmma_f32_16x16x16_f16::getOperationName();      if (elemSourceType.isBF16() && elemDestType.isF32()) @@ -1014,40 +1010,161 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,        return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();      if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))        return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); -    if (chipset.majorVersion == 11) { + +    // RDNA3 specific patterns. +    if (isRDNA3) {        if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))          return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); +      return std::nullopt;      } -  } -  if (chipset.majorVersion < 12) -    return std::nullopt; -  // gfx12+ -  if (k == 16) { -    if (isa<Float8E4M3FNType>(elemSourceType) && -        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) +    // RDNA4 specific patterns (fp8/bf8). +    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) && +        elemDestType.isF32())        return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); -    if (isa<Float8E4M3FNType>(elemSourceType) && -        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) +    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) && +        elemDestType.isF32())        return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName(); -    if (isa<Float8E5M2Type>(elemSourceType) && -        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) +    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) && +        elemDestType.isF32())        return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName(); -    if (isa<Float8E5M2Type>(elemSourceType) && -        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) +    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) && +        elemDestType.isF32())        return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))        return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();      return std::nullopt;    } -  if (k == 32) { + +  // Handle k == 32 for RDNA4. +  if (k == 32 && !isRDNA3) {      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); +  } + +  return std::nullopt; +} + +/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// for the gfx1250 architecture. +static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType, +                                                         Type elemBSourceType, +                                                         Type elemDestType, +                                                         uint32_t k) { +  using fp8 = Float8E4M3FNType; +  using bf8 = Float8E5M2Type; + +  if (k == 4) { +    if (elemSourceType.isF32() && elemDestType.isF32()) +      return ROCDL::wmma_f32_16x16x4_f32::getOperationName(); + +    return std::nullopt; +  } + +  if (k == 32) { +    if (elemSourceType.isF16() && elemDestType.isF32()) +      return ROCDL::wmma_f32_16x16x32_f16::getOperationName(); +    if (elemSourceType.isBF16() && elemDestType.isF32()) +      return ROCDL::wmma_f32_16x16x32_bf16::getOperationName(); +    if (elemSourceType.isF16() && elemDestType.isF16()) +      return ROCDL::wmma_f16_16x16x32_f16::getOperationName(); +    if (elemSourceType.isBF16() && elemDestType.isBF16()) +      return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName(); +      return std::nullopt;    } -  llvm_unreachable("unhandled WMMA case"); +  if (k == 64) { +    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) { +      if (elemDestType.isF32()) +        return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName(); +      if (elemDestType.isF16()) +        return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName(); +    } +    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) { +      if (elemDestType.isF32()) +        return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName(); +      if (elemDestType.isF16()) +        return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName(); +    } +    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) { +      if (elemDestType.isF32()) +        return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName(); +      if (elemDestType.isF16()) +        return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName(); +    } +    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) { +      if (elemDestType.isF32()) +        return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName(); +      if (elemDestType.isF16()) +        return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName(); +    } +    if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) +      return ROCDL::wmma_i32_16x16x64_iu8::getOperationName(); + +    return std::nullopt; +  } + +  if (k == 128) { +    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) { +      if (elemDestType.isF32()) +        return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName(); +      if (elemDestType.isF16()) +        return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName(); +    } +    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) { +      if (elemDestType.isF32()) +        return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName(); +      if (elemDestType.isF16()) +        return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName(); +    } +    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) { +      if (elemDestType.isF32()) +        return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName(); +      if (elemDestType.isF16()) +        return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName(); +    } +    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) { +      if (elemDestType.isF32()) +        return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName(); +      if (elemDestType.isF16()) +        return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName(); +    } + +    return std::nullopt; +  } + +  return std::nullopt; +} + +/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// if one exists. This includes checking to ensure the intrinsic is supported +/// on the architecture you are compiling for. +static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, +                                                  Chipset chipset) { +  auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType()); +  auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType()); +  auto destVectorType = cast<VectorType>(wmma.getDestC().getType()); +  Type elemSourceType = sourceVectorType.getElementType(); +  Type elemBSourceType = sourceBVectorType.getElementType(); +  Type elemDestType = destVectorType.getElementType(); + +  const uint32_t k = wmma.getK(); +  const bool isRDNA3 = chipset.majorVersion == 11; +  const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0; + +  // Handle RDNA3 and RDNA4. +  if (isRDNA3 || isRDNA4) +    return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType, +                                 k, isRDNA3); + +  // Handle gfx1250. +  if (chipset == Chipset{12, 5, 0}) +    return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, +                                    elemDestType, k); + +  return std::nullopt;  }  namespace { | 
