aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp80
1 files changed, 42 insertions, 38 deletions
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index f4d69ce..853f454 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -107,8 +107,8 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
- return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
- false);
+ return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
+ false);
},
rewriter);
}
@@ -145,15 +145,16 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one;
if (LLVM::isCompatibleVectorType(llvmOperandType)) {
- one = rewriter.create<LLVM::ConstantOp>(
- loc, llvmOperandType,
+ one = LLVM::ConstantOp::create(
+ rewriter, loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne));
} else {
- one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
+ one =
+ LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
}
- auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
- expAttrs.getAttrs());
+ auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
+ expAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
return success();
@@ -170,12 +171,13 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
{numElements.isScalable()}),
floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto exp = rewriter.create<LLVM::ExpOp>(
- loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
- return rewriter.create<LLVM::FSubOp>(
- loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
+ auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
+ splatAttr);
+ auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
+ operands[0], expAttrs.getAttrs());
+ return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{exp, one},
+ subAttrs.getAttrs());
},
rewriter);
}
@@ -205,16 +207,16 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one =
isa<VectorType>(llvmOperandType)
- ? rewriter.create<LLVM::ConstantOp>(
- loc, llvmOperandType,
+ ? LLVM::ConstantOp::create(
+ rewriter, loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne))
- : rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType,
- floatOne);
+ : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
+ floatOne);
- auto add = rewriter.create<LLVM::FAddOp>(
- loc, llvmOperandType, ValueRange{one, adaptor.getOperand()},
- addAttrs.getAttrs());
+ auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
+ ValueRange{one, adaptor.getOperand()},
+ addAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::LogOp>(
op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
return success();
@@ -231,13 +233,13 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
{numElements.isScalable()}),
floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
- ValueRange{one, operands[0]},
- addAttrs.getAttrs());
- return rewriter.create<LLVM::LogOp>(
- loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
+ auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
+ splatAttr);
+ auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{one, operands[0]},
+ addAttrs.getAttrs());
+ return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{add}, logAttrs.getAttrs());
},
rewriter);
}
@@ -267,15 +269,16 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one;
if (isa<VectorType>(llvmOperandType)) {
- one = rewriter.create<LLVM::ConstantOp>(
- loc, llvmOperandType,
+ one = LLVM::ConstantOp::create(
+ rewriter, loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne));
} else {
- one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
+ one =
+ LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
}
- auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
- sqrtAttrs.getAttrs());
+ auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
+ sqrtAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
return success();
@@ -292,12 +295,13 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
{numElements.isScalable()}),
floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto sqrt = rewriter.create<LLVM::SqrtOp>(
- loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
- return rewriter.create<LLVM::FDivOp>(
- loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
+ auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
+ splatAttr);
+ auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
+ operands[0], sqrtAttrs.getAttrs());
+ return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{one, sqrt},
+ divAttrs.getAttrs());
},
rewriter);
}