diff options
Diffstat (limited to 'mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp')
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 2a8c330..f0de4db 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -320,6 +320,51 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() { return success(); } +LogicalResult ConvertF8x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType())) + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) + << " types are supported for conversions from f8x2 to f16x2."; + + return success(); +} + +LogicalResult ConvertF8x2ToBF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + if (!llvm::isa<Float8E8M0FNUType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float8E8M0FNUType::get(ctx) + << " type is supported for conversions from f8x2 to bf16x2."; + + return success(); +} + +LogicalResult ConvertF6x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) + << " types are supported for conversions from f6x2 to f16x2."; + + return success(); +} + +LogicalResult ConvertF4x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float4E2M1FNType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from f4x2 to f16x2."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -2187,6 +2232,98 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn; + }) + .Case<Float8E5M2Type>([&](Float8E5M2Type type) { + return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op); + + llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2; + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn; + }) + .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *extendedI16 = + builder.CreateZExt(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {extendedI16}}; +} + llvm::Intrinsic::ID Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, |
