diff options
Diffstat (limited to 'mlir/lib/Conversion/TosaToArith/TosaToArith.cpp')
-rw-r--r-- | mlir/lib/Conversion/TosaToArith/TosaToArith.cpp | 123 |
1 files changed, 62 insertions, 61 deletions
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp index 40ad6361..044b725 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -51,8 +51,8 @@ TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { Value getConstantValue(Location loc, Type type, int64_t value, PatternRewriter &rewriter) { - return rewriter.create<arith::ConstantOp>( - loc, getConstantAttr(type, value, rewriter)); + return arith::ConstantOp::create(rewriter, loc, + getConstantAttr(type, value, rewriter)); } // This converts the TOSA ApplyScale operator to a set of arithmetic ops, @@ -82,41 +82,41 @@ public: Value one64 = getConstantValue(loc, i64Ty, 1, rewriter); Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter); - Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift()); + Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift()); // Compute the multiplication in 64-bits then select the high / low parts. Value value64 = value; if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type()) - value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value); + value64 = arith::ExtSIOp::create(rewriter, loc, i64Ty, value); Value multiplier64 = - rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32); + arith::ExtSIOp::create(rewriter, loc, i64Ty, multiplier32); Value multiply64 = - rewriter.create<arith::MulIOp>(loc, value64, multiplier64); + arith::MulIOp::create(rewriter, loc, value64, multiplier64); // Apply normal rounding. - Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32); - Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64); - round = rewriter.create<arith::ShRUIOp>(loc, round, one64); - multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round); + Value shift64 = arith::ExtUIOp::create(rewriter, loc, i64Ty, shift32); + Value round = arith::ShLIOp::create(rewriter, loc, one64, shift64); + round = arith::ShRUIOp::create(rewriter, loc, round, one64); + multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round); // Apply double rounding if necessary. if (op.getRoundingMode() == "DOUBLE_ROUND") { int64_t roundInt = 1 << 30; Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter); Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter); - Value positive = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, value, zero); + Value positive = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, value, zero); Value dir = - rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown); - Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64); - Value valid = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); + arith::SelectOp::create(rewriter, loc, positive, roundUp, roundDown); + Value val = arith::AddIOp::create(rewriter, loc, dir, multiply64); + Value valid = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); multiply64 = - rewriter.create<arith::SelectOp>(loc, valid, val, multiply64); + arith::SelectOp::create(rewriter, loc, valid, val, multiply64); } - Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64); - Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64); + Value result64 = arith::ShRSIOp::create(rewriter, loc, multiply64, shift64); + Value result32 = arith::TruncIOp::create(rewriter, loc, i32Ty, result64); rewriter.replaceOp(op, result32); return success(); @@ -146,7 +146,7 @@ public: Value value32 = op.getValue(); Value multiplier32 = op.getMultiplier(); - Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift()); + Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift()); // Constants used during the scaling operation. Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter); @@ -158,86 +158,87 @@ public: // Compute the multiplication in 64-bits then select the high / low parts. // Grab out the high/low of the computation auto value64 = - rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32); + arith::MulSIExtendedOp::create(rewriter, loc, value32, multiplier32); Value low32 = value64.getLow(); Value high32 = value64.getHigh(); // Determine the direction and amount to shift the high bits. - Value shiftOver32 = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); - Value roundHighBits = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); + Value shiftOver32 = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); + Value roundHighBits = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); Value shiftHighL = - rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32); + arith::SubIOp::create(rewriter, loc, thirtyTwo32, shift32); Value shiftHighR = - rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32); + arith::SubIOp::create(rewriter, loc, shift32, thirtyTwo32); shiftHighL = - rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL); + arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, shiftHighL); shiftHighR = - rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32); + arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32); // Conditionally perform our double round. if (op.getRoundingMode() == "DOUBLE_ROUND") { Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter); - Value valuePositive = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, value32, zero32); + Value valuePositive = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, value32, zero32); - Value roundDir = - rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32); + Value roundDir = arith::SelectOp::create(rewriter, loc, valuePositive, + one32, negOne32); roundDir = - rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32); + arith::SelectOp::create(rewriter, loc, shiftOver32, roundDir, zero32); - Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32); - Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir); - Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32); + Value shiftLow = arith::ShRUIOp::create(rewriter, loc, low32, thirty32); + Value rounded = arith::AddIOp::create(rewriter, loc, shiftLow, roundDir); + Value carry = arith::ShRSIOp::create(rewriter, loc, rounded, two32); Value shiftRound = - rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32); + arith::ShLIOp::create(rewriter, loc, roundDir, thirty32); - low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound); - high32 = rewriter.create<arith::AddIOp>(loc, high32, carry); + low32 = arith::AddIOp::create(rewriter, loc, low32, shiftRound); + high32 = arith::AddIOp::create(rewriter, loc, high32, carry); } // Conditionally apply rounding in the low bits. { - Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32); - Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne); - roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32, - roundBit); - - Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit); - Value wasRounded = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ugt, low32, newLow32); + Value shiftSubOne = arith::SubIOp::create(rewriter, loc, shift32, one32); + Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne); + roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, zero32, + roundBit); + + Value newLow32 = arith::AddIOp::create(rewriter, loc, low32, roundBit); + Value wasRounded = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ugt, low32, newLow32); low32 = newLow32; - Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded); - high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32); + Value rounded32 = + arith::ExtUIOp::create(rewriter, loc, i32Ty, wasRounded); + high32 = arith::AddIOp::create(rewriter, loc, high32, rounded32); } // Conditionally apply rounding in the high bits. { Value shiftSubOne = - rewriter.create<arith::SubIOp>(loc, shiftHighR, one32); - Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne); - roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit, - zero32); - high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit); + arith::SubIOp::create(rewriter, loc, shiftHighR, one32); + Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne); + roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, roundBit, + zero32); + high32 = arith::AddIOp::create(rewriter, loc, high32, roundBit); } // Combine the correct high/low bits into the final rescale result. - high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL); - high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR); - low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32); - low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32); + high32 = arith::ShLIOp::create(rewriter, loc, high32, shiftHighL); + high32 = arith::ShRSIOp::create(rewriter, loc, high32, shiftHighR); + low32 = arith::ShRUIOp::create(rewriter, loc, low32, shift32); + low32 = arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, low32); // Apply the rounding behavior and shift to the final alignment. - Value result = rewriter.create<arith::AddIOp>(loc, low32, high32); + Value result = arith::AddIOp::create(rewriter, loc, low32, high32); // Truncate if necessary. if (!getElementTypeOrSelf(resultTy).isInteger(32)) { - result = rewriter.create<arith::TruncIOp>(loc, resultTy, result); + result = arith::TruncIOp::create(rewriter, loc, resultTy, result); } rewriter.replaceOp(op, result); |