diff options
Diffstat (limited to 'mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp')
-rw-r--r-- | mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp | 112 |
1 files changed, 57 insertions, 55 deletions
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index 59db14e..a877ad2 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -36,12 +36,12 @@ static Value getScalarOrVectorI32Constant(Type type, int value, if (!vectorType.getElementType().isInteger(32)) return nullptr; SmallVector<int> values(vectorType.getNumElements(), value); - return builder.create<spirv::ConstantOp>(loc, type, - builder.getI32VectorAttr(values)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getI32VectorAttr(values)); } if (type.isInteger(32)) - return builder.create<spirv::ConstantOp>(loc, type, - builder.getI32IntegerAttr(value)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getI32IntegerAttr(value)); return nullptr; } @@ -144,10 +144,11 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> { Type intType = rewriter.getIntegerType(bitwidth); uint64_t intValue = uint64_t(1) << (bitwidth - 1); - Value signMask = rewriter.create<spirv::ConstantOp>( - loc, intType, rewriter.getIntegerAttr(intType, intValue)); - Value valueMask = rewriter.create<spirv::ConstantOp>( - loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); + Value signMask = spirv::ConstantOp::create( + rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue)); + Value valueMask = spirv::ConstantOp::create( + rewriter, loc, intType, + rewriter.getIntegerAttr(intType, intValue - 1u)); if (auto vectorType = dyn_cast<VectorType>(type)) { assert(vectorType.getRank() == 1); @@ -155,26 +156,26 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> { intType = VectorType::get(count, intType); SmallVector<Value> signSplat(count, signMask); - signMask = - rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat); + signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType, + signSplat); SmallVector<Value> valueSplat(count, valueMask); - valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType, - valueSplat); + valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType, + valueSplat); } Value lhsCast = - rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs()); + spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs()); Value rhsCast = - rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs()); + spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs()); - Value value = rewriter.create<spirv::BitwiseAndOp>( - loc, intType, ValueRange{lhsCast, valueMask}); - Value sign = rewriter.create<spirv::BitwiseAndOp>( - loc, intType, ValueRange{rhsCast, signMask}); + Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType, + ValueRange{lhsCast, valueMask}); + Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType, + ValueRange{rhsCast, signMask}); - Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType, - ValueRange{value, sign}); + Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType, + ValueRange{value, sign}); rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result); return success(); } @@ -214,18 +215,18 @@ struct CountLeadingZerosPattern final Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); - Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input); + Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input); // We need to subtract from 31 given that the index returned by GLSL // FindUMsb is counted from the least significant bit. Theoretically this // also gives the correct result even if the integer has all zero bits, in // which case GL FindUMsb would return -1. - Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb); + Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb); // However, certain Vulkan implementations have driver bugs for the corner // case where the input is zero. And.. it can be smart to optimize a select // only involving the corner case. So separately compute the result when the // input is either zero or one. - Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input); - Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1); + Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input); + Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1); rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput, subMsb); return success(); @@ -253,7 +254,7 @@ struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> { if (!type) return failure(); - Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand()); + Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand()); auto one = spirv::ConstantOp::getOne(type, loc, rewriter); rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one); return success(); @@ -283,7 +284,7 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); Value onePlus = - rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand()); + spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand()); rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus); return success(); } @@ -321,15 +322,15 @@ struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> { auto getConstantValue = [&](double value) { if (auto floatType = dyn_cast<FloatType>(type)) { - return rewriter.create<spirv::ConstantOp>( - loc, type, rewriter.getFloatAttr(floatType, value)); + return spirv::ConstantOp::create( + rewriter, loc, type, rewriter.getFloatAttr(floatType, value)); } if (auto vectorType = dyn_cast<VectorType>(type)) { Type elemType = vectorType.getElementType(); if (isa<FloatType>(elemType)) { - return rewriter.create<spirv::ConstantOp>( - loc, type, + return spirv::ConstantOp::create( + rewriter, loc, type, DenseFPElementsAttr::get( vectorType, FloatAttr::get(elemType, value).getValue())); } @@ -341,7 +342,7 @@ struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> { Value constantValue = getConstantValue( std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal : log10Reciprocal); - Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand()); + Value log = SpirvLogOp::create(rewriter, loc, adaptor.getOperand()); rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log, constantValue); return success(); @@ -386,7 +387,7 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { Location loc = powfOp.getLoc(); Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter); Value lessThan = - rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero); + spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero); // Per C/C++ spec: // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is @@ -394,11 +395,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { // Calculate the reminder from the exponent and check whether it is zero. Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter); Value expRem = - rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne); + spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne); Value expRemNonZero = - rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero); + spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero); Value cmpNegativeWithFractionalExp = - rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan); + spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan); // Create NaN result and replace base value if conditions are met. const auto &floatSemantics = scalarFloatType.getFloatSemantics(); const auto nan = APFloat::getNaN(floatSemantics); @@ -407,10 +408,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { nanAttr = DenseElementsAttr::get(vectorType, nan); Value NanValue = - rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr); - Value lhs = rewriter.create<spirv::SelectOp>( - loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs()); - Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs); + spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr); + Value lhs = + spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp, + NanValue, adaptor.getLhs()); + Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs); // TODO: The following just forcefully casts y into an integer value in // order to properly propagate the sign, assuming integer y cases. It @@ -418,18 +420,18 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { // Cast exponent to integer and calculate exponent % 2 != 0. Value intRhs = - rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs()); + spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs()); Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter); Value bitwiseAndOne = - rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne); - Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne); + spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne); + Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne); // calculate pow based on abs(lhs)^rhs. - Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs()); - Value negate = rewriter.create<spirv::FNegateOp>(loc, pow); + Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs()); + Value negate = spirv::FNegateOp::create(rewriter, loc, pow); // if the exponent is odd and lhs < 0, negate the result. Value shouldNegate = - rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd); + spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd); rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate, pow); return success(); @@ -455,22 +457,22 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> { auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); Value half; if (VectorType vty = dyn_cast<VectorType>(ty)) { - half = rewriter.create<spirv::ConstantOp>( - loc, vty, + half = spirv::ConstantOp::create( + rewriter, loc, vty, DenseElementsAttr::get(vty, rewriter.getFloatAttr(ety, 0.5).getValue())); } else { - half = rewriter.create<spirv::ConstantOp>( - loc, ty, rewriter.getFloatAttr(ety, 0.5)); + half = spirv::ConstantOp::create(rewriter, loc, ty, + rewriter.getFloatAttr(ety, 0.5)); } - auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand); - auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs); - auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor); + auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand); + auto floor = spirv::GLFloorOp::create(rewriter, loc, abs); + auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor); auto greater = - rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half); - auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero); - auto add = rewriter.create<spirv::FAddOp>(loc, floor, select); + spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half); + auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero); + auto add = spirv::FAddOp::create(rewriter, loc, floor, select); rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand); return success(); } |