aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/TosaToArith/TosaToArith.cpp')
-rw-r--r--mlir/lib/Conversion/TosaToArith/TosaToArith.cpp123
1 files changed, 62 insertions, 61 deletions
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 40ad6361..044b725 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -51,8 +51,8 @@ TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
Value getConstantValue(Location loc, Type type, int64_t value,
PatternRewriter &rewriter) {
- return rewriter.create<arith::ConstantOp>(
- loc, getConstantAttr(type, value, rewriter));
+ return arith::ConstantOp::create(rewriter, loc,
+ getConstantAttr(type, value, rewriter));
}
// This converts the TOSA ApplyScale operator to a set of arithmetic ops,
@@ -82,41 +82,41 @@ public:
Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
- Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
+ Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
// Compute the multiplication in 64-bits then select the high / low parts.
Value value64 = value;
if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
- value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
+ value64 = arith::ExtSIOp::create(rewriter, loc, i64Ty, value);
Value multiplier64 =
- rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
+ arith::ExtSIOp::create(rewriter, loc, i64Ty, multiplier32);
Value multiply64 =
- rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
+ arith::MulIOp::create(rewriter, loc, value64, multiplier64);
// Apply normal rounding.
- Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
- Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
- round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
- multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
+ Value shift64 = arith::ExtUIOp::create(rewriter, loc, i64Ty, shift32);
+ Value round = arith::ShLIOp::create(rewriter, loc, one64, shift64);
+ round = arith::ShRUIOp::create(rewriter, loc, round, one64);
+ multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
// Apply double rounding if necessary.
if (op.getRoundingMode() == "DOUBLE_ROUND") {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
- Value positive = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, value, zero);
+ Value positive = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, value, zero);
Value dir =
- rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
- Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
- Value valid = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
+ arith::SelectOp::create(rewriter, loc, positive, roundUp, roundDown);
+ Value val = arith::AddIOp::create(rewriter, loc, dir, multiply64);
+ Value valid = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
multiply64 =
- rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
+ arith::SelectOp::create(rewriter, loc, valid, val, multiply64);
}
- Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
- Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
+ Value result64 = arith::ShRSIOp::create(rewriter, loc, multiply64, shift64);
+ Value result32 = arith::TruncIOp::create(rewriter, loc, i32Ty, result64);
rewriter.replaceOp(op, result32);
return success();
@@ -146,7 +146,7 @@ public:
Value value32 = op.getValue();
Value multiplier32 = op.getMultiplier();
- Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
+ Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
// Constants used during the scaling operation.
Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
@@ -158,86 +158,87 @@ public:
// Compute the multiplication in 64-bits then select the high / low parts.
// Grab out the high/low of the computation
auto value64 =
- rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
+ arith::MulSIExtendedOp::create(rewriter, loc, value32, multiplier32);
Value low32 = value64.getLow();
Value high32 = value64.getHigh();
// Determine the direction and amount to shift the high bits.
- Value shiftOver32 = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
- Value roundHighBits = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
+ Value shiftOver32 = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
+ Value roundHighBits = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
Value shiftHighL =
- rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
+ arith::SubIOp::create(rewriter, loc, thirtyTwo32, shift32);
Value shiftHighR =
- rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
+ arith::SubIOp::create(rewriter, loc, shift32, thirtyTwo32);
shiftHighL =
- rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
+ arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, shiftHighL);
shiftHighR =
- rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
+ arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
// Conditionally perform our double round.
if (op.getRoundingMode() == "DOUBLE_ROUND") {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
- Value valuePositive = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, value32, zero32);
+ Value valuePositive = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
- Value roundDir =
- rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
+ Value roundDir = arith::SelectOp::create(rewriter, loc, valuePositive,
+ one32, negOne32);
roundDir =
- rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
+ arith::SelectOp::create(rewriter, loc, shiftOver32, roundDir, zero32);
- Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
- Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
- Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
+ Value shiftLow = arith::ShRUIOp::create(rewriter, loc, low32, thirty32);
+ Value rounded = arith::AddIOp::create(rewriter, loc, shiftLow, roundDir);
+ Value carry = arith::ShRSIOp::create(rewriter, loc, rounded, two32);
Value shiftRound =
- rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
+ arith::ShLIOp::create(rewriter, loc, roundDir, thirty32);
- low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
- high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
+ low32 = arith::AddIOp::create(rewriter, loc, low32, shiftRound);
+ high32 = arith::AddIOp::create(rewriter, loc, high32, carry);
}
// Conditionally apply rounding in the low bits.
{
- Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
- Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
- roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
- roundBit);
-
- Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
- Value wasRounded = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ugt, low32, newLow32);
+ Value shiftSubOne = arith::SubIOp::create(rewriter, loc, shift32, one32);
+ Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
+ roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, zero32,
+ roundBit);
+
+ Value newLow32 = arith::AddIOp::create(rewriter, loc, low32, roundBit);
+ Value wasRounded = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ugt, low32, newLow32);
low32 = newLow32;
- Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
- high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
+ Value rounded32 =
+ arith::ExtUIOp::create(rewriter, loc, i32Ty, wasRounded);
+ high32 = arith::AddIOp::create(rewriter, loc, high32, rounded32);
}
// Conditionally apply rounding in the high bits.
{
Value shiftSubOne =
- rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
- Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
- roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
- zero32);
- high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
+ arith::SubIOp::create(rewriter, loc, shiftHighR, one32);
+ Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
+ roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, roundBit,
+ zero32);
+ high32 = arith::AddIOp::create(rewriter, loc, high32, roundBit);
}
// Combine the correct high/low bits into the final rescale result.
- high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
- high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
- low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
- low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
+ high32 = arith::ShLIOp::create(rewriter, loc, high32, shiftHighL);
+ high32 = arith::ShRSIOp::create(rewriter, loc, high32, shiftHighR);
+ low32 = arith::ShRUIOp::create(rewriter, loc, low32, shift32);
+ low32 = arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, low32);
// Apply the rounding behavior and shift to the final alignment.
- Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
+ Value result = arith::AddIOp::create(rewriter, loc, low32, high32);
// Truncate if necessary.
if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
- result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
+ result = arith::TruncIOp::create(rewriter, loc, resultTy, result);
}
rewriter.replaceOp(op, result);