diff options
11 files changed, 158 insertions, 149 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index cd8b68e..caba614 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -324,8 +324,8 @@ def LLVM_NoAliasScopeDeclOp return success(); if (scopeAttrs->size() != 1) return failure(); - $_op = $_builder.create<LLVM::NoAliasScopeDeclOp>( - $_location, (*scopeAttrs)[0]); + $_op = LLVM::NoAliasScopeDeclOp::create( + $_builder, $_location, (*scopeAttrs)[0]); }]; let assemblyFormat = "$scope attr-dict"; } @@ -468,7 +468,7 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs, $_builder.getNamedAttr($_qualCppClassName::getRoundingModeAttrName(), roundingModeAttr)); }], true : "") # [{ - $res = $_builder.create<$_qualCppClassName>($_location, + $res = $_qualCppClassName::create($_builder, $_location, $_resultType, mlirOperands, mlirAttrs); }]; } @@ -743,7 +743,7 @@ def LLVM_DbgLabelOp : LLVM_IntrOp<"dbg.label", [], [], [], 0> { // Drop the intrinsic if the label translation fails due to cylic metadata. if (!labelAttr) return success(); - $_op = $_builder.create<$_qualCppClassName>($_location, labelAttr); + $_op = $_qualCppClassName::create($_builder, $_location, labelAttr); }]; let assemblyFormat = "$label attr-dict"; } @@ -883,7 +883,7 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_OneResultIntrOp<"matrix.column.major.loa $columns); }]; string mlirBuilder = [{ - $res = $_builder.create<LLVM::MatrixColumnMajorLoadOp>( + $res = LLVM::MatrixColumnMajorLoadOp::create($_builder, $_location, $_resultType, $data, $stride, $_int_attr($isVolatile), $_int_attr($rows), $_int_attr($columns)); }]; @@ -917,7 +917,7 @@ def LLVM_MatrixColumnMajorStoreOp : LLVM_ZeroResultIntrOp<"matrix.column.major.s $rows, $columns); }]; string mlirBuilder = [{ - $_op = $_builder.create<LLVM::MatrixColumnMajorStoreOp>( + $_op = LLVM::MatrixColumnMajorStoreOp::create($_builder, $_location, $matrix, $data, $stride, $_int_attr($isVolatile), $_int_attr($rows), $_int_attr($columns)); }]; @@ -940,7 +940,7 @@ def LLVM_MatrixMultiplyOp : LLVM_OneResultIntrOp<"matrix.multiply"> { $rhs_columns); }]; string mlirBuilder = [{ - $res = $_builder.create<LLVM::MatrixMultiplyOp>( + $res = LLVM::MatrixMultiplyOp::create($_builder, $_location, $_resultType, $lhs, $rhs, $_int_attr($lhs_rows), $_int_attr($lhs_columns), $_int_attr($rhs_columns)); }]; @@ -960,7 +960,7 @@ def LLVM_MatrixTransposeOp : LLVM_OneResultIntrOp<"matrix.transpose"> { $matrix, $rows, $columns); }]; string mlirBuilder = [{ - $res = $_builder.create<LLVM::MatrixTransposeOp>( + $res = LLVM::MatrixTransposeOp::create($_builder, $_location, $_resultType, $matrix, $_int_attr($rows), $_int_attr($columns)); }]; @@ -997,7 +997,7 @@ def LLVM_MaskedLoadOp : LLVM_OneResultIntrOp<"masked.load"> { string mlirBuilder = [{ auto *intrinInst = dyn_cast<llvm::IntrinsicInst>(inst); bool nontemporal = intrinInst->hasMetadata(llvm::LLVMContext::MD_nontemporal); - $res = $_builder.create<LLVM::MaskedLoadOp>($_location, + $res = LLVM::MaskedLoadOp::create($_builder, $_location, $_resultType, $data, $mask, $pass_thru, $_int_attr($alignment), nontemporal ? $_builder.getUnitAttr() : nullptr); }]; @@ -1017,7 +1017,7 @@ def LLVM_MaskedStoreOp : LLVM_ZeroResultIntrOp<"masked.store"> { $value, $data, llvm::Align($alignment), $mask); }]; string mlirBuilder = [{ - $_op = $_builder.create<LLVM::MaskedStoreOp>($_location, + $_op = LLVM::MaskedStoreOp::create($_builder, $_location, $value, $data, $mask, $_int_attr($alignment)); }]; list<int> llvmArgIndices = [0, 1, 3, 2]; @@ -1040,7 +1040,7 @@ def LLVM_masked_gather : LLVM_OneResultIntrOp<"masked.gather"> { $_resultType, $ptrs, llvm::Align($alignment), $mask, $pass_thru[0]); }]; string mlirBuilder = [{ - $res = $_builder.create<LLVM::masked_gather>($_location, + $res = LLVM::masked_gather::create($_builder, $_location, $_resultType, $ptrs, $mask, $pass_thru, $_int_attr($alignment)); }]; list<int> llvmArgIndices = [0, 2, 3, 1]; @@ -1061,7 +1061,7 @@ def LLVM_masked_scatter : LLVM_ZeroResultIntrOp<"masked.scatter"> { $value, $ptrs, llvm::Align($alignment), $mask); }]; string mlirBuilder = [{ - $_op = $_builder.create<LLVM::masked_scatter>($_location, + $_op = LLVM::masked_scatter::create($_builder, $_location, $value, $ptrs, $mask, $_int_attr($alignment)); }]; list<int> llvmArgIndices = [0, 1, 3, 2]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index e08c7b7..e845ea9f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -363,7 +363,7 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, } SmallVector<Type> resultTypes = }] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{ - auto op = $_builder.create<$_qualCppClassName>( + auto op = $_qualCppClassName::create($_builder, $_location, resultTypes, mlirOperands, mlirAttrs); }]; string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;"); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 4a9bc90..51004f5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -53,7 +53,7 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName, LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName, traits> { let arguments = commonArgs; string mlirBuilder = [{ - $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + $res = $_qualCppClassName::create($_builder, $_location, $lhs, $rhs); }]; } class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName, @@ -64,7 +64,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName, let arguments = !con(commonArgs, iofArg); string mlirBuilder = [{ - auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + auto op = $_qualCppClassName::create($_builder, $_location, $lhs, $rhs); moduleImport.setIntegerOverflowFlags(inst, op); $res = op; }]; @@ -82,7 +82,7 @@ class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName, let arguments = !con(commonArgs, (ins UnitAttr:$isExact)); string mlirBuilder = [{ - auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + auto op = $_qualCppClassName::create($_builder, $_location, $lhs, $rhs); moduleImport.setExactFlag(inst, op); $res = op; }]; @@ -100,7 +100,7 @@ class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName, let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint)); string mlirBuilder = [{ - auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + auto op = $_qualCppClassName::create($_builder, $_location, $lhs, $rhs); moduleImport.setDisjointFlag(inst, op); $res = op; }]; @@ -121,7 +121,7 @@ class LLVM_FloatArithmeticOp<string mnemonic, string instName, ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags); let arguments = !con(commonArgs, fmfArg); string mlirBuilder = [{ - auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + auto op = $_qualCppClassName::create($_builder, $_location, $lhs, $rhs); moduleImport.setFastmathFlagsAttr(inst, op); $res = op; }]; @@ -141,7 +141,7 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic, let assemblyFormat = "$operand attr-dict `:` type($res)"; string llvmInstName = instName; string mlirBuilder = [{ - auto op = $_builder.create<$_qualCppClassName>($_location, $operand); + auto op = $_qualCppClassName::create($_builder, $_location, $operand); moduleImport.setFastmathFlagsAttr(inst, op); $res = op; }]; @@ -193,7 +193,7 @@ def LLVM_ICmpOp : LLVM_ArithmeticCmpOp<"icmp", [Pure]> { }]; string mlirBuilder = [{ auto *iCmpInst = cast<llvm::ICmpInst>(inst); - $res = $_builder.create<$_qualCppClassName>($_location, + $res = $_qualCppClassName::create($_builder, $_location, convertICmpPredicateFromLLVM(iCmpInst->getPredicate()), $lhs, $rhs); }]; // Set the $predicate index to -1 to indicate there is no matching operand @@ -217,7 +217,7 @@ def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [ }]; string mlirBuilder = [{ auto *fCmpInst = cast<llvm::FCmpInst>(inst); - auto op = $_builder.create<$_qualCppClassName>( + auto op = $_qualCppClassName::create($_builder, $_location, convertFCmpPredicateFromLLVM(fCmpInst->getPredicate()), $lhs, $rhs); moduleImport.setFastmathFlagsAttr(inst, op); $res = op; @@ -261,7 +261,7 @@ def LLVM_AllocaOp : LLVM_Op<"alloca", Type allocatedType = moduleImport.convertType(allocaInst->getAllocatedType()); unsigned alignment = allocaInst->getAlign().value(); - $res = $_builder.create<LLVM::AllocaOp>( + $res = LLVM::AllocaOp::create($_builder, $_location, $_resultType, $arraySize, alignment == 0 ? IntegerAttr() : $_builder.getI64IntegerAttr(alignment), allocatedType, allocaInst->isUsedWithInAlloca()); @@ -440,7 +440,7 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load", string mlirBuilder = [{ auto *loadInst = cast<llvm::LoadInst>(inst); unsigned alignment = loadInst->getAlign().value(); - $res = $_builder.create<LLVM::LoadOp>($_location, $_resultType, $addr, + $res = LLVM::LoadOp::create($_builder, $_location, $_resultType, $addr, alignment, loadInst->isVolatile(), loadInst->hasMetadata(llvm::LLVMContext::MD_nontemporal), loadInst->hasMetadata(llvm::LLVMContext::MD_invariant_load), @@ -518,7 +518,7 @@ def LLVM_StoreOp : LLVM_MemAccessOpBase<"store", string mlirBuilder = [{ auto *storeInst = cast<llvm::StoreInst>(inst); unsigned alignment = storeInst->getAlign().value(); - $_op = $_builder.create<LLVM::StoreOp>($_location, $value, $addr, + $_op = LLVM::StoreOp::create($_builder, $_location, $value, $addr, alignment, storeInst->isVolatile(), storeInst->hasMetadata(llvm::LLVMContext::MD_nontemporal), storeInst->hasMetadata(llvm::LLVMContext::MD_invariant_group), @@ -547,7 +547,7 @@ class LLVM_CastOp<string mnemonic, string instName, Type type, let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)"; string llvmInstName = instName; string mlirBuilder = [{ - $res = $_builder.create<$_qualCppClassName>( + $res = $_qualCppClassName::create($_builder, $_location, $_resultType, $arg); }]; } @@ -561,7 +561,7 @@ class LLVM_CastOpWithNNegFlag<string mnemonic, string instName, Type type, let assemblyFormat = "(`nneg` $nonNeg^)? $arg attr-dict `:` type($arg) `to` type($res)"; string llvmInstName = instName; string mlirBuilder = [{ - auto op = $_builder.create<$_qualCppClassName>( + auto op = $_qualCppClassName::create($_builder, $_location, $_resultType, $arg); moduleImport.setNonNegFlag(inst, op); $res = op; @@ -578,7 +578,7 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type, let assemblyFormat = "$arg ($overflowFlags^)? attr-dict `:` type($arg) `to` type($res)"; string llvmInstName = instName; string mlirBuilder = [{ - auto op = $_builder.create<$_qualCppClassName>( + auto op = $_qualCppClassName::create($_builder, $_location, $_resultType, $arg); moduleImport.setIntegerOverflowFlags(inst, op); $res = op; @@ -602,7 +602,7 @@ class LLVM_DereferenceableCastOp<string mnemonic, string instName, Type type, } }]; string mlirBuilder = [{ - auto op = $_builder.create<$_qualCppClassName>( + auto op = $_qualCppClassName::create($_builder, $_location, $_resultType, $arg); $res = op; }]; @@ -725,7 +725,7 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> { string llvmInstName = "VAArg"; string mlirBuilder = [{ - $res = $_builder.create<mlir::LLVM::VaArgOp>( + $res = mlir::LLVM::VaArgOp::create($_builder, $_location, $_resultType, $arg); }]; } @@ -847,7 +847,7 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [Pure, $res = builder.CreateExtractElement($vector, $position); }]; string mlirBuilder = [{ - $res = $_builder.create<LLVM::ExtractElementOp>( + $res = LLVM::ExtractElementOp::create($_builder, $_location, $vector, $position); }]; } @@ -881,7 +881,7 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [Pure]> { }]; string mlirBuilder = [{ auto *evInst = cast<llvm::ExtractValueInst>(inst); - $res = $_builder.create<LLVM::ExtractValueOp>($_location, + $res = LLVM::ExtractValueOp::create($_builder, $_location, $container, getPositionFromIndices(evInst->getIndices())); }]; } @@ -913,7 +913,7 @@ def LLVM_InsertElementOp : LLVM_Op<"insertelement", [Pure, $res = builder.CreateInsertElement($vector, $value, $position); }]; string mlirBuilder = [{ - $res = $_builder.create<LLVM::InsertElementOp>( + $res = LLVM::InsertElementOp::create($_builder, $_location, $vector, $value, $position); }]; } @@ -945,7 +945,7 @@ def LLVM_InsertValueOp : LLVM_Op< }]; string mlirBuilder = [{ auto *ivInst = cast<llvm::InsertValueInst>(inst); - $res = $_builder.create<LLVM::InsertValueOp>($_location, + $res = LLVM::InsertValueOp::create($_builder, $_location, $container, $value, getPositionFromIndices(ivInst->getIndices())); }]; } @@ -982,7 +982,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", string mlirBuilder = [{ auto *svInst = cast<llvm::ShuffleVectorInst>(inst); SmallVector<int32_t> mask(svInst->getShuffleMask()); - $res = $_builder.create<LLVM::ShuffleVectorOp>( + $res = LLVM::ShuffleVectorOp::create($_builder, $_location, $v1, $v2, mask); }]; } @@ -1003,7 +1003,7 @@ def LLVM_SelectOp let assemblyFormat = "operands attr-dict `:` type($condition) `,` type($res)"; string llvmInstName = "Select"; string mlirBuilder = [{ - auto op = $_builder.create<LLVM::SelectOp>( + auto op = LLVM::SelectOp::create($_builder, $_location, $_resultType, $condition, $trueValue, $falseValue); moduleImport.setFastmathFlagsAttr(inst, op); $res = op; @@ -1017,7 +1017,7 @@ def LLVM_FreezeOp : LLVM_Op<"freeze", [Pure, SameOperandsAndResultType]> { string llvmInstName = "Freeze"; string llvmBuilder = "$res = builder.CreateFreeze($val);"; string mlirBuilder = [{ - $res = $_builder.create<LLVM::FreezeOp>($_location, $val); + $res = LLVM::FreezeOp::create($_builder, $_location, $val); }]; } @@ -1108,7 +1108,7 @@ def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [Pure, ReturnLike]> { moduleImport.convertValues(llvmOperands); if (failed(mlirOperands)) return failure(); - $_op = $_builder.create<LLVM::ReturnOp>($_location, *mlirOperands); + $_op = LLVM::ReturnOp::create($_builder, $_location, *mlirOperands); }]; } @@ -1120,7 +1120,7 @@ def LLVM_ResumeOp : LLVM_TerminatorOp<"resume"> { string llvmInstName = "Resume"; string llvmBuilder = [{ builder.CreateResume($value); }]; string mlirBuilder = [{ - $_op = $_builder.create<LLVM::ResumeOp>($_location, $value); + $_op = LLVM::ResumeOp::create($_builder, $_location, $value); }]; } def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> { @@ -1128,7 +1128,7 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> { string llvmInstName = "Unreachable"; string llvmBuilder = [{ builder.CreateUnreachable(); }]; string mlirBuilder = [{ - $_op = $_builder.create<LLVM::UnreachableOp>($_location); + $_op = LLVM::UnreachableOp::create($_builder, $_location); }]; } @@ -2256,7 +2256,7 @@ def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [ string mlirBuilder = [{ auto *atomicInst = cast<llvm::AtomicRMWInst>(inst); unsigned alignment = atomicInst->getAlign().value(); - $res = $_builder.create<LLVM::AtomicRMWOp>($_location, + $res = LLVM::AtomicRMWOp::create($_builder, $_location, convertAtomicBinOpFromLLVM(atomicInst->getOperation()), $ptr, $val, convertAtomicOrderingFromLLVM(atomicInst->getOrdering()), getLLVMSyncScope(atomicInst), alignment, atomicInst->isVolatile()); @@ -2311,7 +2311,7 @@ def LLVM_AtomicCmpXchgOp : LLVM_MemAccessOpBase<"cmpxchg", [ string mlirBuilder = [{ auto *cmpXchgInst = cast<llvm::AtomicCmpXchgInst>(inst); unsigned alignment = cmpXchgInst->getAlign().value(); - $res = $_builder.create<LLVM::AtomicCmpXchgOp>( + $res = LLVM::AtomicCmpXchgOp::create($_builder, $_location, $ptr, $cmp, $val, convertAtomicOrderingFromLLVM(cmpXchgInst->getSuccessOrdering()), convertAtomicOrderingFromLLVM(cmpXchgInst->getFailureOrdering()), @@ -2340,7 +2340,7 @@ def LLVM_FenceOp : LLVM_Op<"fence">, LLVM_MemOpPatterns { }] # setSyncScopeCode; string mlirBuilder = [{ llvm::FenceInst *fenceInst = cast<llvm::FenceInst>(inst); - $_op = $_builder.create<LLVM::FenceOp>( + $_op = LLVM::FenceOp::create($_builder, $_location, convertAtomicOrderingFromLLVM(fenceInst->getOrdering()), getLLVMSyncScope(fenceInst)); diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp index c17ef10..894de44 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp @@ -90,8 +90,8 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { } for (auto [idx, t] : llvm::enumerate(stype.getBody())) { if (itype != PTXRegisterMod::Write) { - Value extractValue = rewriter.create<LLVM::ExtractValueOp>( - interfaceOp->getLoc(), v, idx); + Value extractValue = LLVM::ExtractValueOp::create( + rewriter, interfaceOp->getLoc(), v, idx); addValue(extractValue); } if (itype == PTXRegisterMod::ReadWrite) { @@ -132,8 +132,8 @@ LLVM::InlineAsmOp PtxBuilder::build() { // Replace all % with $ llvm::replace(ptxInstruction, '%', '$'); - return rewriter.create<LLVM::InlineAsmOp>( - interfaceOp->getLoc(), + return LLVM::InlineAsmOp::create( + rewriter, interfaceOp->getLoc(), /*result types=*/resultTypes, /*operands=*/ptxOperands, /*asm_string=*/llvm::StringRef(ptxInstruction), diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 89f765d..feaffa3 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -89,8 +89,8 @@ mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, OpBuilder::InsertionGuard g(b); assert(!moduleOp->getRegion(0).empty() && "expected non-empty region"); b.setInsertionPointToStart(&moduleOp->getRegion(0).front()); - auto funcOp = b.create<LLVM::LLVMFuncOp>( - moduleOp->getLoc(), name, + auto funcOp = LLVM::LLVMFuncOp::create( + b, moduleOp->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg)); if (symbolTables) { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 4a1527c..34ffd1e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -3384,7 +3384,7 @@ bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) { ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value, Type type, Location loc) { if (isBuildableWith(value, type)) - return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value)); + return LLVM::ConstantOp::create(builder, loc, cast<TypedAttr>(value)); return nullptr; } @@ -4133,9 +4133,11 @@ void LLVMDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" + , #define GET_OP_LIST #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" + >(); // Support unknown operations because not all LLVM operations are registered. @@ -4350,13 +4352,13 @@ Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value, // a builtin zero attribute and thus will materialize as a llvm.mlir.constant. if (auto symbol = dyn_cast<FlatSymbolRefAttr>(value)) if (isa<LLVM::LLVMPointerType>(type)) - return builder.create<LLVM::AddressOfOp>(loc, type, symbol); + return LLVM::AddressOfOp::create(builder, loc, type, symbol); if (isa<LLVM::UndefAttr>(value)) - return builder.create<LLVM::UndefOp>(loc, type); + return LLVM::UndefOp::create(builder, loc, type); if (isa<LLVM::PoisonAttr>(value)) - return builder.create<LLVM::PoisonOp>(loc, type); + return LLVM::PoisonOp::create(builder, loc, type); if (isa<LLVM::ZeroAttr>(value)) - return builder.create<LLVM::ZeroOp>(loc, type); + return LLVM::ZeroOp::create(builder, loc, type); // Otherwise try materializing it as a regular llvm.mlir.constant op. return LLVM::ConstantOp::materialize(builder, value, type, loc); } @@ -4379,16 +4381,16 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); MLIRContext *ctx = builder.getContext(); auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); - auto global = moduleBuilder.create<LLVM::GlobalOp>( - loc, type, /*isConstant=*/true, linkage, name, + auto global = LLVM::GlobalOp::create( + moduleBuilder, loc, type, /*isConstant=*/true, linkage, name, builder.getStringAttr(value), /*alignment=*/0); LLVMPointerType ptrType = LLVMPointerType::get(ctx); // Get the pointer to the first character in the global string. Value globalPtr = - builder.create<LLVM::AddressOfOp>(loc, ptrType, global.getSymNameAttr()); - return builder.create<LLVM::GEPOp>(loc, ptrType, type, globalPtr, - ArrayRef<GEPArg>{0, 0}); + LLVM::AddressOfOp::create(builder, loc, ptrType, global.getSymNameAttr()); + return LLVM::GEPOp::create(builder, loc, ptrType, type, globalPtr, + ArrayRef<GEPArg>{0, 0}); } bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index bc451f8..e7d5dad 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -37,7 +37,7 @@ llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() { Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType); + return LLVM::UndefOp::create(builder, getLoc(), slot.elemType); } void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot, @@ -45,9 +45,9 @@ void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot, OpBuilder &builder) { for (Operation *user : getOperation()->getUsers()) if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user)) - builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument, - declareOp.getVarInfo(), - declareOp.getLocationExpr()); + LLVM::DbgValueOp::create(builder, declareOp.getLoc(), argument, + declareOp.getVarInfo(), + declareOp.getLocationExpr()); } std::optional<PromotableAllocationOpInterface> @@ -89,8 +89,8 @@ DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure( for (Attribute index : usedIndices) { Type elemType = destructurableType.getTypeAtIndex(index); assert(elemType && "used index must exist"); - auto subAlloca = builder.create<LLVM::AllocaOp>( - getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType, + auto subAlloca = LLVM::AllocaOp::create( + builder, getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType, getArraySize()); newAllocators.push_back(subAlloca); slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType}); @@ -260,14 +260,14 @@ static Value createExtractAndCast(OpBuilder &builder, Location loc, // Truncate the integer if the size of the target is less than the value. if (isBigEndian(dataLayout)) { uint64_t shiftAmount = srcTypeSize - targetTypeSize; - auto shiftConstant = builder.create<LLVM::ConstantOp>( - loc, builder.getIntegerAttr(srcType, shiftAmount)); + auto shiftConstant = LLVM::ConstantOp::create( + builder, loc, builder.getIntegerAttr(srcType, shiftAmount)); replacement = builder.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant); } - replacement = builder.create<LLVM::TruncOp>( - loc, builder.getIntegerType(targetTypeSize), replacement); + replacement = LLVM::TruncOp::create( + builder, loc, builder.getIntegerType(targetTypeSize), replacement); // Now cast the integer to the actual target type if required. return castIntValueToSameSizedType(builder, loc, replacement, targetType); @@ -304,8 +304,9 @@ static Value createInsertAndCast(OpBuilder &builder, Location loc, // On big endian systems, a store to the base pointer overwrites the most // significant bits. To accomodate for this, the stored value needs to be // shifted into the according position. - Value bigEndianShift = builder.create<LLVM::ConstantOp>( - loc, builder.getIntegerAttr(defAsInt.getType(), sizeDifference)); + Value bigEndianShift = LLVM::ConstantOp::create( + builder, loc, + builder.getIntegerAttr(defAsInt.getType(), sizeDifference)); valueAsInt = builder.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift); } @@ -325,8 +326,8 @@ static Value createInsertAndCast(OpBuilder &builder, Location loc, } // Mask out the affected bits ... - Value mask = builder.create<LLVM::ConstantOp>( - loc, builder.getIntegerAttr(defAsInt.getType(), maskValue)); + Value mask = LLVM::ConstantOp::create( + builder, loc, builder.getIntegerAttr(defAsInt.getType(), maskValue)); Value masked = builder.createOrFold<LLVM::AndOp>(loc, defAsInt, mask); // ... and combine the result with the new value. @@ -644,7 +645,7 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses( // debug local variable info. This allows the debugger to inform the user that // the variable has been optimized out. auto undef = - builder.create<UndefOp>(getValue().getLoc(), getValue().getType()); + UndefOp::create(builder, getValue().getLoc(), getValue().getType()); getValueMutable().assign(undef); return DeletionKind::Keep; } @@ -655,8 +656,8 @@ void LLVM::DbgDeclareOp::visitReplacedValues( ArrayRef<std::pair<Operation *, Value>> definitions, OpBuilder &builder) { for (auto [op, value] : definitions) { builder.setInsertionPointAfter(op); - builder.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(), - getLocationExpr()); + LLVM::DbgValueOp::create(builder, getLoc(), value, getVarInfo(), + getLocationExpr()); } } @@ -972,15 +973,14 @@ void createMemsetIntr(OpBuilder &builder, LLVM::MemsetOp toReplace, DenseMap<Attribute, MemorySlot> &subslots, Attribute index) { Value newMemsetSizeValue = - builder - .create<LLVM::ConstantOp>( - toReplace.getLen().getLoc(), - IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize)) + LLVM::ConstantOp::create( + builder, toReplace.getLen().getLoc(), + IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize)) .getResult(); - builder.create<LLVM::MemsetOp>(toReplace.getLoc(), subslots.at(index).ptr, - toReplace.getVal(), newMemsetSizeValue, - toReplace.getIsVolatile()); + LLVM::MemsetOp::create(builder, toReplace.getLoc(), subslots.at(index).ptr, + toReplace.getVal(), newMemsetSizeValue, + toReplace.getIsVolatile()); } template <> @@ -991,9 +991,9 @@ void createMemsetIntr(OpBuilder &builder, LLVM::MemsetInlineOp toReplace, auto newMemsetSizeValue = IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize); - builder.create<LLVM::MemsetInlineOp>( - toReplace.getLoc(), subslots.at(index).ptr, toReplace.getVal(), - newMemsetSizeValue, toReplace.getIsVolatile()); + LLVM::MemsetInlineOp::create(builder, toReplace.getLoc(), + subslots.at(index).ptr, toReplace.getVal(), + newMemsetSizeValue, toReplace.getIsVolatile()); } } // namespace @@ -1063,8 +1063,8 @@ static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot, APInt memsetVal(/*numBits=*/width, /*val=*/0); for (unsigned loBit = 0; loBit < width; loBit += 8) memsetVal.insertBits(constantPattern.getValue(), loBit); - return builder.create<LLVM::ConstantOp>( - op.getLoc(), IntegerAttr::get(intType, memsetVal)); + return LLVM::ConstantOp::create(builder, op.getLoc(), + IntegerAttr::get(intType, memsetVal)); } // If the output is a single byte, we can return the pattern directly. @@ -1075,14 +1075,14 @@ static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot, // value and or-ing it with the previous value. uint64_t coveredBits = 8; Value currentValue = - builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal()); + LLVM::ZExtOp::create(builder, op.getLoc(), intType, op.getVal()); while (coveredBits < width) { Value shiftBy = - builder.create<LLVM::ConstantOp>(op.getLoc(), intType, coveredBits); + LLVM::ConstantOp::create(builder, op.getLoc(), intType, coveredBits); Value shifted = - builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy); + LLVM::ShlOp::create(builder, op.getLoc(), currentValue, shiftBy); currentValue = - builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted); + LLVM::OrOp::create(builder, op.getLoc(), currentValue, shifted); coveredBits *= 2; } @@ -1094,7 +1094,7 @@ static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot, }) .Case([&](FloatType type) -> Value { Value intVal = buildMemsetValue(type.getWidth()); - return builder.create<LLVM::BitcastOp>(op.getLoc(), type, intVal); + return LLVM::BitcastOp::create(builder, op.getLoc(), type, intVal); }) .Default([](Type) -> Value { llvm_unreachable( @@ -1282,7 +1282,7 @@ static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) { template <class MemcpyLike> static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot, OpBuilder &builder) { - return builder.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc()); + return LLVM::LoadOp::create(builder, op.getLoc(), slot.elemType, op.getSrc()); } template <class MemcpyLike> @@ -1309,7 +1309,8 @@ memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder, Value reachingDefinition) { if (op.loadsFrom(slot)) - builder.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition, op.getDst()); + LLVM::StoreOp::create(builder, op.getLoc(), reachingDefinition, + op.getDst()); return DeletionKind::Delete; } @@ -1354,11 +1355,12 @@ template <class MemcpyLike> void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout, MemcpyLike toReplace, Value dst, Value src, Type toCpy, bool isVolatile) { - Value memcpySize = builder.create<LLVM::ConstantOp>( - toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(), - layout.getTypeSize(toCpy))); - builder.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize, - isVolatile); + Value memcpySize = + LLVM::ConstantOp::create(builder, toReplace.getLoc(), + IntegerAttr::get(toReplace.getLen().getType(), + layout.getTypeSize(toCpy))); + MemcpyLike::create(builder, toReplace.getLoc(), dst, src, memcpySize, + isVolatile); } template <> @@ -1367,8 +1369,8 @@ void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout, Value src, Type toCpy, bool isVolatile) { Type lenType = IntegerType::get(toReplace->getContext(), toReplace.getLen().getBitWidth()); - builder.create<LLVM::MemcpyInlineOp>( - toReplace.getLoc(), dst, src, + LLVM::MemcpyInlineOp::create( + builder, toReplace.getLoc(), dst, src, IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile); } @@ -1409,9 +1411,9 @@ memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot, SmallVector<LLVM::GEPArg> gepIndices{ 0, static_cast<int32_t>( cast<IntegerAttr>(index).getValue().getZExtValue())}; - Value subslotPtrInOther = builder.create<LLVM::GEPOp>( - op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), slot.elemType, - isDst ? op.getSrc() : op.getDst(), gepIndices); + Value subslotPtrInOther = LLVM::GEPOp::create( + builder, op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), + slot.elemType, isDst ? op.getSrc() : op.getDst(), gepIndices); // Then create a new memcpy out of this source pointer. createMemcpyLikeToReplace(builder, dataLayout, op, diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp index 6fbb0d2..1fb482b 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp @@ -28,14 +28,15 @@ static void addComdat(LLVM::LLVMFuncOp &op, OpBuilder &builder, PatternRewriter::InsertionGuard guard(builder); builder.setInsertionPointToStart(module.getBody()); comdatOp = - builder.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName); + mlir::LLVM::ComdatOp::create(builder, module.getLoc(), comdatName); symbolTable.insert(comdatOp); } PatternRewriter::InsertionGuard guard(builder); builder.setInsertionPointToStart(&comdatOp.getBody().back()); - auto selectorOp = builder.create<mlir::LLVM::ComdatSelectorOp>( - comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any); + auto selectorOp = mlir::LLVM::ComdatSelectorOp::create( + builder, comdatOp.getLoc(), op.getSymName(), + mlir::LLVM::comdat::Comdat::Any); op.setComdatAttr(mlir::SymbolRefAttr::get( builder.getContext(), comdatName, mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()))); diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp index 7f3afff..935aa3c 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp @@ -113,20 +113,22 @@ handleInlinedAllocas(Operation *call, // scope if some are already present in the body of the caller. This is not // invalid IR, but LLVM cleans these up in InstCombineCalls.cpp, along with // other cases where the stacksave/stackrestore is redundant. - stackPtr = builder.create<LLVM::StackSaveOp>( - call->getLoc(), LLVM::LLVMPointerType::get(call->getContext())); + stackPtr = LLVM::StackSaveOp::create( + builder, call->getLoc(), + LLVM::LLVMPointerType::get(call->getContext())); } builder.setInsertionPointToStart(callerEntryBlock); for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) { - auto newConstant = builder.create<LLVM::ConstantOp>( - allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize); + auto newConstant = + LLVM::ConstantOp::create(builder, allocaOp->getLoc(), + allocaOp.getArraySize().getType(), arraySize); // Insert a lifetime start intrinsic where the alloca was before moving it. if (shouldInsertLifetime) { OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPoint(allocaOp); - builder.create<LLVM::LifetimeStartOp>( - allocaOp.getLoc(), arraySize.getValue().getLimitedValue(), - allocaOp.getResult()); + LLVM::LifetimeStartOp::create(builder, allocaOp.getLoc(), + arraySize.getValue().getLimitedValue(), + allocaOp.getResult()); } allocaOp->moveAfter(newConstant); allocaOp.getArraySizeMutable().assign(newConstant.getResult()); @@ -139,12 +141,12 @@ handleInlinedAllocas(Operation *call, continue; builder.setInsertionPoint(block.getTerminator()); if (hasDynamicAlloca) - builder.create<LLVM::StackRestoreOp>(call->getLoc(), stackPtr); + LLVM::StackRestoreOp::create(builder, call->getLoc(), stackPtr); for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) { if (shouldInsertLifetime) - builder.create<LLVM::LifetimeEndOp>( - allocaOp.getLoc(), arraySize.getValue().getLimitedValue(), - allocaOp.getResult()); + LLVM::LifetimeEndOp::create(builder, allocaOp.getLoc(), + arraySize.getValue().getLimitedValue(), + allocaOp.getResult()); } } } @@ -311,7 +313,8 @@ static void createNewAliasScopesFromNoAliasParameter( auto scope = LLVM::AliasScopeAttr::get(functionDomain); pointerScopes[copyOp] = scope; - OpBuilder(call).create<LLVM::NoAliasScopeDeclOp>(call->getLoc(), scope); + auto builder = OpBuilder(call); + LLVM::NoAliasScopeDeclOp::create(builder, call->getLoc(), scope); } // Go through every instruction and attempt to find which noalias parameters @@ -603,16 +606,17 @@ static Value handleByValArgumentInit(OpBuilder &builder, Location loc, OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = &(*argument.getParentRegion()->begin()); builder.setInsertionPointToStart(entryBlock); - Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), - builder.getI64IntegerAttr(1)); - allocaOp = builder.create<LLVM::AllocaOp>( - loc, argument.getType(), elementType, one, targetAlignment); + Value one = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), + builder.getI64IntegerAttr(1)); + allocaOp = LLVM::AllocaOp::create(builder, loc, argument.getType(), + elementType, one, targetAlignment); } // Copy the pointee to the newly allocated value. - Value copySize = builder.create<LLVM::ConstantOp>( - loc, builder.getI64Type(), builder.getI64IntegerAttr(elementTypeSize)); - builder.create<LLVM::MemcpyOp>(loc, allocaOp, argument, copySize, - /*isVolatile=*/false); + Value copySize = + LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), + builder.getI64IntegerAttr(elementTypeSize)); + LLVM::MemcpyOp::create(builder, loc, allocaOp, argument, copySize, + /*isVolatile=*/false); return allocaOp; } @@ -747,7 +751,7 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { // Replace the return with a branch to the dest. OpBuilder builder(op); - builder.create<LLVM::BrOp>(op->getLoc(), returnOp.getOperands(), newDest); + LLVM::BrOp::create(builder, op->getLoc(), returnOp.getOperands(), newDest); op->erase(); } @@ -801,7 +805,7 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { // and is extremely unlikely to exist in the code prior to inlining, using // this to communicate between this method and `processInlinedCallBlocks`. // TODO: Fix this by refactoring the inliner interface. - auto copyOp = builder.create<LLVM::SSACopyOp>(call->getLoc(), argument); + auto copyOp = LLVM::SSACopyOp::create(builder, call->getLoc(), argument); if (argumentAttrs.contains(LLVM::LLVMDialect::getNoAliasAttrName())) copyOp->setDiscardableAttr( builder.getStringAttr(LLVM::LLVMDialect::getNoAliasAttrName()), diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp index 1a5a6e4..38a4bc8 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp @@ -58,8 +58,8 @@ static void ensureDistinctSuccessors(Block &bb) { terminator->setSuccessor(dummyBlock, position); for (BlockArgument arg : successor.first->getArguments()) dummyBlock->addArgument(arg.getType(), arg.getLoc()); - builder.create<LLVM::BrOp>(terminator->getLoc(), - dummyBlock->getArguments(), successor.first); + LLVM::BrOp::create(builder, terminator->getLoc(), + dummyBlock->getArguments(), successor.first); } } } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp index 8db32ec..7f34f7d 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp @@ -59,32 +59,32 @@ LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op, Type i32Type = rewriter.getI32Type(); // Extend lhs and rhs to fp32. - Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs()); - Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs()); + Value lhs = LLVM::FPExtOp::create(rewriter, loc, f32Type, op.getLhs()); + Value rhs = LLVM::FPExtOp::create(rewriter, loc, f32Type, op.getRhs()); // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. - Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs); - Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp); + Value rcp = NVVM::RcpApproxFtzF32Op::create(rewriter, loc, f32Type, rhs); + Value approx = LLVM::FMulOp::create(rewriter, loc, lhs, rcp); // Refine the approximation with one Newton iteration: // float refined = approx + (lhs - approx * rhs) * rcp; - Value err = rewriter.create<LLVM::FMAOp>( - loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs); - Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx); + Value err = LLVM::FMAOp::create( + rewriter, loc, approx, LLVM::FNegOp::create(rewriter, loc, rhs), lhs); + Value refined = LLVM::FMAOp::create(rewriter, loc, err, rcp, approx); // Use refined value if approx is normal (exponent neither all 0 or all 1). - Value mask = rewriter.create<LLVM::ConstantOp>( - loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); - Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx); - Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask); - Value zero = rewriter.create<LLVM::ConstantOp>( - loc, i32Type, rewriter.getUI32IntegerAttr(0)); - Value pred = rewriter.create<LLVM::OrOp>( - loc, - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero), - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask)); + Value mask = LLVM::ConstantOp::create( + rewriter, loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); + Value cast = LLVM::BitcastOp::create(rewriter, loc, i32Type, approx); + Value exp = LLVM::AndOp::create(rewriter, loc, i32Type, cast, mask); + Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, + rewriter.getUI32IntegerAttr(0)); + Value pred = LLVM::OrOp::create( + rewriter, loc, + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, exp, zero), + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, exp, mask)); Value result = - rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined); + LLVM::SelectOp::create(rewriter, loc, f32Type, pred, approx, refined); // Replace with trucation back to fp16. rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result); |