aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp')
-rw-r--r--mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp20
1 files changed, 15 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 0b85462..42629e1 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -216,20 +216,30 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
// Convert `math.fpowi` to a series of `arith.mulf` operations.
// If the power is negative, we divide one by the result.
// If both the base and power are zero, the result is 1.
-static LogicalResult convertFPowICstOp(math::FPowIOp op,
- PatternRewriter &rewriter) {
+// In the case of non constant power, we convert the operation to `math.powf`.
+static LogicalResult convertFPowIOp(math::FPowIOp op,
+ PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value base = op.getOperand(0);
Value power = op.getOperand(1);
Type baseType = base.getType();
+ auto convertFPowItoPowf = [&]() -> LogicalResult {
+ Value castPowerToFp =
+ rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
+ Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
+ castPowerToFp);
+ rewriter.replaceOp(op, res);
+ return success();
+ };
+
Attribute cstAttr;
if (!matchPattern(power, m_Constant(&cstAttr)))
- return failure();
+ return convertFPowItoPowf();
APInt value;
if (!matchPattern(cstAttr, m_ConstantInt(&value)))
- return failure();
+ return convertFPowItoPowf();
int64_t powerInt = value.getSExtValue();
bool isNegative = powerInt < 0;
@@ -591,7 +601,7 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
}
void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
- patterns.add(convertFPowICstOp);
+ patterns.add(convertFPowIOp);
}
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {