diff options
Diffstat (limited to 'mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 38 |
1 files changed, 19 insertions, 19 deletions
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index f7bf581..18e857c 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -293,11 +293,11 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite( [&](Type llvm1DVectorTy, ValueRange operands) -> Value { typename OpTy::Adaptor adaptor(operands); if (targetBits < sourceBits) { - return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy, - adaptor.getIn()); + return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getIn()); } - return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy, - adaptor.getIn()); + return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getIn()); }, rewriter); } @@ -324,12 +324,12 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite( Type newOverflowType = typeConverter->convertType(overflowResultType); Type structType = LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); - Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>( - loc, structType, adaptor.getLhs(), adaptor.getRhs()); + Value addOverflow = LLVM::UAddWithOverflowOp::create( + rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs()); Value sumExtracted = - rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0); + LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0); Value overflowExtracted = - rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1); + LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1); rewriter.replaceOp(op, {sumExtracted, overflowExtracted}); return success(); } @@ -381,15 +381,15 @@ LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite( "LLVM dialect should support all signless integer types"); using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>; - Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs()); - Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs()); - Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt); + Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs()); + Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs()); + Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt); // Split the 2*N-bit wide result into two N-bit values. - Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt); - Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr); - Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal); - Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt); + Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt); + Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr); + Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal); + Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt); rewriter.replaceOp(op, {low, high}); return success(); @@ -435,8 +435,8 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); - return rewriter.create<LLVM::ICmpOp>( - op.getLoc(), llvm1DVectorTy, + return LLVM::ICmpOp::create( + rewriter, op.getLoc(), llvm1DVectorTy, convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); }, @@ -471,8 +471,8 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); - return rewriter.create<LLVM::FCmpOp>( - op.getLoc(), llvm1DVectorTy, + return LLVM::FCmpOp::create( + rewriter, op.getLoc(), llvm1DVectorTy, convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs(), fmf); }, |