diff options
Diffstat (limited to 'mlir/lib/Conversion')
3 files changed, 188 insertions, 43 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 478b6aa..1eca43d 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { smfma.getN(), smfma.getK(), 1u, chipset); } -/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` -/// if one exists. This includes checking to ensure the intrinsic is supported -/// on the architecture you are compiling for. -static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, - Chipset chipset) { - auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType()); - auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType()); - auto destVectorType = cast<VectorType>(wmma.getDestC().getType()); - Type elemSourceType = sourceVectorType.getElementType(); - Type elemBSourceType = sourceBVectorType.getElementType(); - Type elemDestType = destVectorType.getElementType(); - - const uint32_t k = wmma.getK(); - +/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// for RDNA3/4 architectures. +static std::optional<StringRef> +wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, + Type elemDestType, uint32_t k, bool isRDNA3) { + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + + // Handle k == 16 for RDNA3/4. if (k == 16) { + // Common patterns for RDNA3 and RDNA4. if (elemSourceType.isF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); if (elemSourceType.isBF16() && elemDestType.isF32()) @@ -1014,39 +1010,160 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); - if (chipset.majorVersion == 11) { + + // RDNA3 specific patterns. + if (isRDNA3) { if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + return std::nullopt; } - } - if (chipset.majorVersion < 12) - return std::nullopt; - // gfx12+ - if (k == 16) { - if (isa<Float8E4M3FNType>(elemSourceType) && - isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) + // RDNA4 specific patterns (fp8/bf8). + if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); - if (isa<Float8E4M3FNType>(elemSourceType) && - isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) + if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName(); - if (isa<Float8E5M2Type>(elemSourceType) && - isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) + if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName(); - if (isa<Float8E5M2Type>(elemSourceType) && - isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) + if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); return std::nullopt; } - if (k == 32) { + + // Handle k == 32 for RDNA4. + if (k == 32 && !isRDNA3) { if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); + } + + llvm_unreachable("Unsupported k value"); +} + +/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// for the gfx1250 architecture. +static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType, + Type elemBSourceType, + Type elemDestType, + uint32_t k) { + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + + if (k == 4) { + if (elemSourceType.isF32() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x4_f32::getOperationName(); + return std::nullopt; } + if (k == 32) { + if (elemSourceType.isF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x32_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x32_bf16::getOperationName(); + if (elemSourceType.isF16() && elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x32_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isBF16()) + return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName(); + + return std::nullopt; + } + + if (k == 64) { + if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName(); + } + if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName(); + } + if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x64_iu8::getOperationName(); + + return std::nullopt; + } + + if (k == 128) { + if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName(); + } + if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName(); + } + + return std::nullopt; + } + + llvm_unreachable("Unsupported k value"); +} + +/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// if one exists. This includes checking to ensure the intrinsic is supported +/// on the architecture you are compiling for. +static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, + Chipset chipset) { + auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType()); + auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType()); + auto destVectorType = cast<VectorType>(wmma.getDestC().getType()); + Type elemSourceType = sourceVectorType.getElementType(); + Type elemBSourceType = sourceBVectorType.getElementType(); + Type elemDestType = destVectorType.getElementType(); + + const uint32_t k = wmma.getK(); + const bool isRDNA3 = chipset.majorVersion == 11; + const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0; + + // Handle RDNA3 and RDNA4. + if (isRDNA3 || isRDNA4) + return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType, + k, isRDNA3); + + // Handle gfx1250. + if (chipset == Chipset{12, 5, 0}) + return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, + elemDestType, k); + llvm_unreachable("unhandled WMMA case"); } diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 0fe7239..9e46b7d 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -313,25 +313,53 @@ private: struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { using OpConversionPattern<complex::ExpOp>::OpConversionPattern; + // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y)) + // Handle special cases as StableHLO implementation does: + // 1. When b == 0, set imag(exp(z)) = 0 + // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2) LogicalResult matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast<ComplexType>(adaptor.getComplex().getType()); - auto elementType = cast<FloatType>(type.getElementType()); - arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - - Value real = - complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); - Value imag = - complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); - Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue()); - Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue()); + auto ET = cast<FloatType>(type.getElementType()); + arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); + const auto &floatSemantics = ET.getFloatSemantics(); + ImplicitLocOpBuilder b(loc, rewriter); + + Value x = complex::ReOp::create(b, ET, adaptor.getComplex()); + Value y = complex::ImOp::create(b, ET, adaptor.getComplex()); + Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET)); + Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5)); + Value inf = arith::ConstantOp::create( + b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics))); + + Value exp = math::ExpOp::create(b, x, fmf); + Value xHalf = arith::MulFOp::create(b, x, half, fmf); + Value expHalf = math::ExpOp::create(b, xHalf, fmf); + Value cos = math::CosOp::create(b, y, fmf); + Value sin = math::SinOp::create(b, y, fmf); + + Value expIsInf = + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf); + Value yIsZero = + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero); + + // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2) + Value realNormal = arith::MulFOp::create(b, exp, cos, fmf); + Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf); + Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf); Value resultReal = - arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue()); - Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue()); - Value resultImag = - arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue()); + arith::SelectOp::create(b, expIsInf, realOverflow, realNormal); + + // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and + // exp(x/2)*sin(y)*exp(x/2) + Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf); + Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf); + Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf); + Value imagNonZero = + arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal); + Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index b711e33..a4c66e1 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -692,7 +692,7 @@ SymbolRefAttr PatternLowering::generateRewriter( llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue); args.append(mappedArgs.begin(), mappedArgs.end()); pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(), - /*resultTypes=*/TypeRange(), rewriteName, + /*results=*/TypeRange(), rewriteName, args); } else { // Otherwise this is a dag rewriter defined using PDL operations. |
