diff options
Diffstat (limited to 'mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 80 |
1 files changed, 42 insertions, 38 deletions
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index f4d69ce..853f454 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -107,8 +107,8 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> { return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), typeConverter, [&](Type llvm1DVectorTy, ValueRange operands) { - return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0], - false); + return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0], + false); }, rewriter); } @@ -145,15 +145,16 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(llvmOperandType)) { - one = rewriter.create<LLVM::ConstantOp>( - loc, llvmOperandType, + one = LLVM::ConstantOp::create( + rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast<ShapedType>(llvmOperandType), floatOne)); } else { - one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne); + one = + LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne); } - auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(), - expAttrs.getAttrs()); + auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(), + expAttrs.getAttrs()); rewriter.replaceOpWithNewOp<LLVM::FSubOp>( op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs()); return success(); @@ -170,12 +171,13 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, {numElements.isScalable()}), floatOne); - auto one = - rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); - auto exp = rewriter.create<LLVM::ExpOp>( - loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs()); - return rewriter.create<LLVM::FSubOp>( - loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs()); + auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, + splatAttr); + auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy, + operands[0], expAttrs.getAttrs()); + return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{exp, one}, + subAttrs.getAttrs()); }, rewriter); } @@ -205,16 +207,16 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) { LLVM::ConstantOp one = isa<VectorType>(llvmOperandType) - ? rewriter.create<LLVM::ConstantOp>( - loc, llvmOperandType, + ? LLVM::ConstantOp::create( + rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast<ShapedType>(llvmOperandType), floatOne)) - : rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, - floatOne); + : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, + floatOne); - auto add = rewriter.create<LLVM::FAddOp>( - loc, llvmOperandType, ValueRange{one, adaptor.getOperand()}, - addAttrs.getAttrs()); + auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType, + ValueRange{one, adaptor.getOperand()}, + addAttrs.getAttrs()); rewriter.replaceOpWithNewOp<LLVM::LogOp>( op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs()); return success(); @@ -231,13 +233,13 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, {numElements.isScalable()}), floatOne); - auto one = - rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); - auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, - ValueRange{one, operands[0]}, - addAttrs.getAttrs()); - return rewriter.create<LLVM::LogOp>( - loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs()); + auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, + splatAttr); + auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{one, operands[0]}, + addAttrs.getAttrs()); + return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{add}, logAttrs.getAttrs()); }, rewriter); } @@ -267,15 +269,16 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) { LLVM::ConstantOp one; if (isa<VectorType>(llvmOperandType)) { - one = rewriter.create<LLVM::ConstantOp>( - loc, llvmOperandType, + one = LLVM::ConstantOp::create( + rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast<ShapedType>(llvmOperandType), floatOne)); } else { - one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne); + one = + LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne); } - auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(), - sqrtAttrs.getAttrs()); + auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(), + sqrtAttrs.getAttrs()); rewriter.replaceOpWithNewOp<LLVM::FDivOp>( op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs()); return success(); @@ -292,12 +295,13 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, {numElements.isScalable()}), floatOne); - auto one = - rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); - auto sqrt = rewriter.create<LLVM::SqrtOp>( - loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs()); - return rewriter.create<LLVM::FDivOp>( - loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs()); + auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, + splatAttr); + auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy, + operands[0], sqrtAttrs.getAttrs()); + return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{one, sqrt}, + divAttrs.getAttrs()); }, rewriter); } |