diff options
-rw-r--r-- | mlir/include/mlir/Dialect/Math/Transforms/Passes.h | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 87 | ||||
-rw-r--r-- | mlir/test/Dialect/Math/expand-math.mlir | 99 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Math/TestExpandMath.cpp | 1 |
4 files changed, 183 insertions, 5 deletions
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index 11b2c7a..e2c5130 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns); void populateExpandCeilFPattern(RewritePatternSet &patterns); void populateExpandExp2FPattern(RewritePatternSet &patterns); void populateExpandPowFPattern(RewritePatternSet &patterns); +void populateExpandFPowIPattern(RewritePatternSet &patterns); void populateExpandRoundFPattern(RewritePatternSet &patterns); void populateExpandRoundEvenPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index e1ab9c9..0b85462 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -1,4 +1,4 @@ -//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===// +//===- ExpandPatterns.cpp - Code to expand various math operations. -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This file implements expansion of tanh op. +// This file implements expansion of various math operations. // //===----------------------------------------------------------------------===// @@ -23,9 +23,14 @@ using namespace mlir; /// Create a float constant. -static Value createFloatConst(Location loc, Type type, double value, +static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b) { - auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value); + bool losesInfo = false; + auto eltType = getElementTypeOrSelf(type); + // Convert double to the given `FloatType` with round-to-nearest-ties-to-even. + value.convert(cast<FloatType>(eltType).getFloatSemantics(), + APFloat::rmNearestTiesToEven, &losesInfo); + auto attr = b.getFloatAttr(eltType, value); if (auto shapedTy = dyn_cast<ShapedType>(type)) { return b.create<arith::ConstantOp>(loc, DenseElementsAttr::get(shapedTy, attr)); @@ -34,7 +39,12 @@ static Value createFloatConst(Location loc, Type type, double value, return b.create<arith::ConstantOp>(loc, attr); } -/// Create a float constant. +static Value createFloatConst(Location loc, Type type, double value, + OpBuilder &b) { + return createFloatConst(loc, type, APFloat(value), b); +} + +/// Create an integer constant. static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b) { auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); @@ -202,6 +212,69 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { rewriter.replaceOp(op, ret); return success(); } + +// Convert `math.fpowi` to a series of `arith.mulf` operations. +// If the power is negative, we divide one by the result. +// If both the base and power are zero, the result is 1. +static LogicalResult convertFPowICstOp(math::FPowIOp op, + PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value base = op.getOperand(0); + Value power = op.getOperand(1); + Type baseType = base.getType(); + + Attribute cstAttr; + if (!matchPattern(power, m_Constant(&cstAttr))) + return failure(); + + APInt value; + if (!matchPattern(cstAttr, m_ConstantInt(&value))) + return failure(); + + int64_t powerInt = value.getSExtValue(); + bool isNegative = powerInt < 0; + int64_t absPower = std::abs(powerInt); + Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter); + Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter); + + while (absPower > 0) { + if (absPower & 1) + res = b.create<arith::MulFOp>(baseType, base, res); + absPower >>= 1; + base = b.create<arith::MulFOp>(baseType, base, base); + } + + // Make sure not to introduce UB in case of negative power. + if (isNegative) { + auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType)) + .getFloatSemantics(); + Value zero = + createFloatConst(op->getLoc(), baseType, + APFloat::getZero(sem, /*Negative=*/false), rewriter); + Value negZero = + createFloatConst(op->getLoc(), baseType, + APFloat::getZero(sem, /*Negative=*/true), rewriter); + Value posInfinity = + createFloatConst(op->getLoc(), baseType, + APFloat::getInf(sem, /*Negative=*/false), rewriter); + Value negInfinity = + createFloatConst(op->getLoc(), baseType, + APFloat::getInf(sem, /*Negative=*/true), rewriter); + Value zeroEqCheck = + b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero); + Value negZeroEqCheck = + b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero); + res = b.create<arith::DivFOp>(baseType, one, res); + res = + b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res); + res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity, + res); + } + + rewriter.replaceOp(op, res); + return success(); +} + // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); @@ -517,6 +590,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { patterns.add(convertPowfOp); } +void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { + patterns.add(convertFPowICstOp); +} + void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { patterns.add(convertRoundOp); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 6326d3a..bfcff27 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -511,3 +511,102 @@ func.func @roundeven16(%arg: f16) -> f16 { // CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16 // CHECK: return %[[COPYSIGN]] : f16 + +// ----- + +// CHECK-LABEL: func.func @math_fpowi_neg_odd_power +func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> { + %1 = arith.constant dense<-3> : tensor<8xi64> + %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64> + return %2 : tensor<8xf32> +} +// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> { +// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32> +// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> +// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32> +// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32> +// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32> +// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32> +// CHECK: %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32> +// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[CUBE]], %[[CST0]] : tensor<8xf32> +// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[CUBE]], %[[CSTNEG0]] : tensor<8xf32> +// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32> +// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32> +// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32> +// CHECK: return %[[UB2]] : tensor<8xf32> + +// ----- + +// CHECK-LABEL: func.func @math_fpowi_neg_even_power +func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> { + %1 = arith.constant dense<-4> : tensor<8xi64> + %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64> + return %2 : tensor<8xf32> +} +// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> { +// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32> +// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> +// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32> +// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32> +// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32> +// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32> +// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32> +// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[PW4]], %[[CST0]] : tensor<8xf32> +// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[PW4]], %[[CSTNEG0]] : tensor<8xf32> +// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32> +// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32> +// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32> +// CHECK: return %[[UB2]] : tensor<8xf32> + +// ----- + +// CHECK-LABEL: func.func @math_fpowi_pos_odd_power +func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> { + %1 = arith.constant dense<5> : tensor<8xi64> + %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64> + return %2 : tensor<8xf32> +} +// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> { +// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32> +// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32> +// CHECK: %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32> +// CHECK: return %[[PW5]] : tensor<8xf32> + +// ----- + +// CHECK-LABEL: func.func @math_fpowi_pos_even_power +func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> { + %1 = arith.constant dense<4> : tensor<8xi64> + %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64> + return %2 : tensor<8xf32> +} +// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> { +// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32> +// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32> +// CHECK: return %[[PW4]] : tensor<8xf32> + +// ----- + +// CHECK-LABEL: func.func @math_fpowi_even_scalar +func.func @math_fpowi_even_scalar(%0 : f32) -> f32 { + %pow = arith.constant 2 : i64 + %2 = math.fpowi %0, %pow : f32, i64 + return %2 : f32 +} +// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 { +// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32 +// CHECK: return %[[SQ]] : f32 + +// ----- + +// CHECK-LABEL: func.func @math_fpowi_scalar_zero +func.func @math_fpowi_scalar_zero(%0 : f32) -> f32 { + %pow = arith.constant 0 : i64 + %2 = math.fpowi %0, %pow : f32, i64 + return %2 : f32 +} +// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 { +// CHECK: %[[RET:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: return %[[RET]] : f32 + +// ----- diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp index 7ce8b5a..97600ad 100644 --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() { populateExpandFloorFPattern(patterns); populateExpandCeilFPattern(patterns); populateExpandPowFPattern(patterns); + populateExpandFPowIPattern(patterns); populateExpandRoundFPattern(patterns); populateExpandRoundEvenPattern(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |