aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp')
-rw-r--r--mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp112
1 files changed, 57 insertions, 55 deletions
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 59db14e..a877ad2 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -36,12 +36,12 @@ static Value getScalarOrVectorI32Constant(Type type, int value,
if (!vectorType.getElementType().isInteger(32))
return nullptr;
SmallVector<int> values(vectorType.getNumElements(), value);
- return builder.create<spirv::ConstantOp>(loc, type,
- builder.getI32VectorAttr(values));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getI32VectorAttr(values));
}
if (type.isInteger(32))
- return builder.create<spirv::ConstantOp>(loc, type,
- builder.getI32IntegerAttr(value));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getI32IntegerAttr(value));
return nullptr;
}
@@ -144,10 +144,11 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
Type intType = rewriter.getIntegerType(bitwidth);
uint64_t intValue = uint64_t(1) << (bitwidth - 1);
- Value signMask = rewriter.create<spirv::ConstantOp>(
- loc, intType, rewriter.getIntegerAttr(intType, intValue));
- Value valueMask = rewriter.create<spirv::ConstantOp>(
- loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
+ Value signMask = spirv::ConstantOp::create(
+ rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue));
+ Value valueMask = spirv::ConstantOp::create(
+ rewriter, loc, intType,
+ rewriter.getIntegerAttr(intType, intValue - 1u));
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
@@ -155,26 +156,26 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
intType = VectorType::get(count, intType);
SmallVector<Value> signSplat(count, signMask);
- signMask =
- rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
+ signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
+ signSplat);
SmallVector<Value> valueSplat(count, valueMask);
- valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
- valueSplat);
+ valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
+ valueSplat);
}
Value lhsCast =
- rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
+ spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs());
Value rhsCast =
- rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
+ spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs());
- Value value = rewriter.create<spirv::BitwiseAndOp>(
- loc, intType, ValueRange{lhsCast, valueMask});
- Value sign = rewriter.create<spirv::BitwiseAndOp>(
- loc, intType, ValueRange{rhsCast, signMask});
+ Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType,
+ ValueRange{lhsCast, valueMask});
+ Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType,
+ ValueRange{rhsCast, signMask});
- Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
- ValueRange{value, sign});
+ Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType,
+ ValueRange{value, sign});
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
return success();
}
@@ -214,18 +215,18 @@ struct CountLeadingZerosPattern final
Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
- Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
+ Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input);
// We need to subtract from 31 given that the index returned by GLSL
// FindUMsb is counted from the least significant bit. Theoretically this
// also gives the correct result even if the integer has all zero bits, in
// which case GL FindUMsb would return -1.
- Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
+ Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb);
// However, certain Vulkan implementations have driver bugs for the corner
// case where the input is zero. And.. it can be smart to optimize a select
// only involving the corner case. So separately compute the result when the
// input is either zero or one.
- Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
- Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
+ Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input);
+ Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
subMsb);
return success();
@@ -253,7 +254,7 @@ struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
if (!type)
return failure();
- Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
+ Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand());
auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
return success();
@@ -283,7 +284,7 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
Value onePlus =
- rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
+ spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand());
rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
return success();
}
@@ -321,15 +322,15 @@ struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
auto getConstantValue = [&](double value) {
if (auto floatType = dyn_cast<FloatType>(type)) {
- return rewriter.create<spirv::ConstantOp>(
- loc, type, rewriter.getFloatAttr(floatType, value));
+ return spirv::ConstantOp::create(
+ rewriter, loc, type, rewriter.getFloatAttr(floatType, value));
}
if (auto vectorType = dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
if (isa<FloatType>(elemType)) {
- return rewriter.create<spirv::ConstantOp>(
- loc, type,
+ return spirv::ConstantOp::create(
+ rewriter, loc, type,
DenseFPElementsAttr::get(
vectorType, FloatAttr::get(elemType, value).getValue()));
}
@@ -341,7 +342,7 @@ struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
Value constantValue = getConstantValue(
std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
: log10Reciprocal);
- Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
+ Value log = SpirvLogOp::create(rewriter, loc, adaptor.getOperand());
rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
constantValue);
return success();
@@ -386,7 +387,7 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
Location loc = powfOp.getLoc();
Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
Value lessThan =
- rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
+ spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
// Per C/C++ spec:
// > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
@@ -394,11 +395,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
// Calculate the reminder from the exponent and check whether it is zero.
Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
Value expRem =
- rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
+ spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
Value expRemNonZero =
- rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
+ spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
Value cmpNegativeWithFractionalExp =
- rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
+ spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
// Create NaN result and replace base value if conditions are met.
const auto &floatSemantics = scalarFloatType.getFloatSemantics();
const auto nan = APFloat::getNaN(floatSemantics);
@@ -407,10 +408,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
nanAttr = DenseElementsAttr::get(vectorType, nan);
Value NanValue =
- rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr);
- Value lhs = rewriter.create<spirv::SelectOp>(
- loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
- Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);
+ spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
+ Value lhs =
+ spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
+ NanValue, adaptor.getLhs());
+ Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
// TODO: The following just forcefully casts y into an integer value in
// order to properly propagate the sign, assuming integer y cases. It
@@ -418,18 +420,18 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
// Cast exponent to integer and calculate exponent % 2 != 0.
Value intRhs =
- rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
+ spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs());
Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
Value bitwiseAndOne =
- rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
- Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
+ spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne);
+ Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne);
// calculate pow based on abs(lhs)^rhs.
- Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
- Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
+ Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs());
+ Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
// if the exponent is odd and lhs < 0, negate the result.
Value shouldNegate =
- rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
+ spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
pow);
return success();
@@ -455,22 +457,22 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
Value half;
if (VectorType vty = dyn_cast<VectorType>(ty)) {
- half = rewriter.create<spirv::ConstantOp>(
- loc, vty,
+ half = spirv::ConstantOp::create(
+ rewriter, loc, vty,
DenseElementsAttr::get(vty,
rewriter.getFloatAttr(ety, 0.5).getValue()));
} else {
- half = rewriter.create<spirv::ConstantOp>(
- loc, ty, rewriter.getFloatAttr(ety, 0.5));
+ half = spirv::ConstantOp::create(rewriter, loc, ty,
+ rewriter.getFloatAttr(ety, 0.5));
}
- auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
- auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
- auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
+ auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand);
+ auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
+ auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
auto greater =
- rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
- auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
- auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
+ spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
+ auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
+ auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
return success();
}