diff options
Diffstat (limited to 'libphobos/src/std/numeric.d')
-rw-r--r-- | libphobos/src/std/numeric.d | 3467 |
1 files changed, 3467 insertions, 0 deletions
diff --git a/libphobos/src/std/numeric.d b/libphobos/src/std/numeric.d new file mode 100644 index 0000000..307406e --- /dev/null +++ b/libphobos/src/std/numeric.d @@ -0,0 +1,3467 @@ +// Written in the D programming language. + +/** +This module is a port of a growing fragment of the $(D_PARAM numeric) +header in Alexander Stepanov's $(LINK2 http://sgi.com/tech/stl, +Standard Template Library), with a few additions. + +Macros: +Copyright: Copyright Andrei Alexandrescu 2008 - 2009. +License: $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0). +Authors: $(HTTP erdani.org, Andrei Alexandrescu), + Don Clugston, Robert Jacques, Ilya Yaroshenko +Source: $(PHOBOSSRC std/_numeric.d) +*/ +/* + Copyright Andrei Alexandrescu 2008 - 2009. +Distributed under the Boost Software License, Version 1.0. + (See accompanying file LICENSE_1_0.txt or copy at + http://www.boost.org/LICENSE_1_0.txt) +*/ +module std.numeric; + +import std.complex; +import std.math; +import std.range.primitives; +import std.traits; +import std.typecons; + +version (unittest) +{ + import std.stdio; +} +/// Format flags for CustomFloat. +public enum CustomFloatFlags +{ + /// Adds a sign bit to allow for signed numbers. + signed = 1, + + /** + * Store values in normalized form by default. The actual precision of the + * significand is extended by 1 bit by assuming an implicit leading bit of 1 + * instead of 0. i.e. $(D 1.nnnn) instead of $(D 0.nnnn). + * True for all $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEE754) types + */ + storeNormalized = 2, + + /** + * Stores the significand in $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, + * IEEE754 denormalized) form when the exponent is 0. Required to express the value 0. + */ + allowDenorm = 4, + + /** + * Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Positive_and_negative_infinity, + * IEEE754 _infinity) values. + */ + infinity = 8, + + /// Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/NaN, IEEE754 Not a Number) values. + nan = 16, + + /** + * If set, select an exponent bias such that max_exp = 1. + * i.e. so that the maximum value is >= 1.0 and < 2.0. + * Ignored if the exponent bias is manually specified. + */ + probability = 32, + + /// If set, unsigned custom floats are assumed to be negative. + negativeUnsigned = 64, + + /**If set, 0 is the only allowed $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, + * IEEE754 denormalized) number. + * Requires allowDenorm and storeNormalized. + */ + allowDenormZeroOnly = 128 | allowDenorm | storeNormalized, + + /// Include _all of the $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEEE754) options. + ieee = signed | storeNormalized | allowDenorm | infinity | nan , + + /// Include none of the above options. + none = 0 +} + +private template CustomFloatParams(uint bits) +{ + enum CustomFloatFlags flags = CustomFloatFlags.ieee + ^ ((bits == 80) ? CustomFloatFlags.storeNormalized : CustomFloatFlags.none); + static if (bits == 8) alias CustomFloatParams = CustomFloatParams!( 4, 3, flags); + static if (bits == 16) alias CustomFloatParams = CustomFloatParams!(10, 5, flags); + static if (bits == 32) alias CustomFloatParams = CustomFloatParams!(23, 8, flags); + static if (bits == 64) alias CustomFloatParams = CustomFloatParams!(52, 11, flags); + static if (bits == 80) alias CustomFloatParams = CustomFloatParams!(64, 15, flags); +} + +private template CustomFloatParams(uint precision, uint exponentWidth, CustomFloatFlags flags) +{ + import std.meta : AliasSeq; + alias CustomFloatParams = + AliasSeq!( + precision, + exponentWidth, + flags, + (1 << (exponentWidth - ((flags & flags.probability) == 0))) + - ((flags & (flags.nan | flags.infinity)) != 0) - ((flags & flags.probability) != 0) + ); // ((flags & CustomFloatFlags.probability) == 0) +} + +/** + * Allows user code to define custom floating-point formats. These formats are + * for storage only; all operations on them are performed by first implicitly + * extracting them to $(D real) first. After the operation is completed the + * result can be stored in a custom floating-point value via assignment. + */ +template CustomFloat(uint bits) +if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 80) +{ + alias CustomFloat = CustomFloat!(CustomFloatParams!(bits)); +} + +/// ditto +template CustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags = CustomFloatFlags.ieee) +if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && precision + exponentWidth > 0) +{ + alias CustomFloat = CustomFloat!(CustomFloatParams!(precision, exponentWidth, flags)); +} + +/// +@safe unittest +{ + import std.math : sin, cos; + + // Define a 16-bit floating point values + CustomFloat!16 x; // Using the number of bits + CustomFloat!(10, 5) y; // Using the precision and exponent width + CustomFloat!(10, 5,CustomFloatFlags.ieee) z; // Using the precision, exponent width and format flags + CustomFloat!(10, 5,CustomFloatFlags.ieee, 15) w; // Using the precision, exponent width, format flags and exponent offset bias + + // Use the 16-bit floats mostly like normal numbers + w = x*y - 1; + + // Functions calls require conversion + z = sin(+x) + cos(+y); // Use unary plus to concisely convert to a real + z = sin(x.get!float) + cos(y.get!float); // Or use get!T + z = sin(cast(float) x) + cos(cast(float) y); // Or use cast(T) to explicitly convert + + // Define a 8-bit custom float for storing probabilities + alias Probability = CustomFloat!(4, 4, CustomFloatFlags.ieee^CustomFloatFlags.probability^CustomFloatFlags.signed ); + auto p = Probability(0.5); +} + +/// ditto +struct CustomFloat(uint precision, // fraction bits (23 for float) + uint exponentWidth, // exponent bits (8 for float) Exponent width + CustomFloatFlags flags, + uint bias) +if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && + precision + exponentWidth > 0) +{ + import std.bitmanip : bitfields; + import std.meta : staticIndexOf; +private: + // get the correct unsigned bitfield type to support > 32 bits + template uType(uint bits) + { + static if (bits <= size_t.sizeof*8) alias uType = size_t; + else alias uType = ulong ; + } + + // get the correct signed bitfield type to support > 32 bits + template sType(uint bits) + { + static if (bits <= ptrdiff_t.sizeof*8-1) alias sType = ptrdiff_t; + else alias sType = long; + } + + alias T_sig = uType!precision; + alias T_exp = uType!exponentWidth; + alias T_signed_exp = sType!exponentWidth; + + alias Flags = CustomFloatFlags; + + // Facilitate converting numeric types to custom float + union ToBinary(F) + if (is(typeof(CustomFloatParams!(F.sizeof*8))) || is(F == real)) + { + F set; + + // If on Linux or Mac, where 80-bit reals are padded, ignore the + // padding. + import std.algorithm.comparison : min; + CustomFloat!(CustomFloatParams!(min(F.sizeof*8, 80))) get; + + // Convert F to the correct binary type. + static typeof(get) opCall(F value) + { + ToBinary r; + r.set = value; + return r.get; + } + alias get this; + } + + // Perform IEEE rounding with round to nearest detection + void roundedShift(T,U)(ref T sig, U shift) + { + if (sig << (T.sizeof*8 - shift) == cast(T) 1uL << (T.sizeof*8 - 1)) + { + // round to even + sig >>= shift; + sig += sig & 1; + } + else + { + sig >>= shift - 1; + sig += sig & 1; + // Perform standard rounding + sig >>= 1; + } + } + + // Convert the current value to signed exponent, normalized form + void toNormalized(T,U)(ref T sig, ref U exp) + { + sig = significand; + auto shift = (T.sizeof*8) - precision; + exp = exponent; + static if (flags&(Flags.infinity|Flags.nan)) + { + // Handle inf or nan + if (exp == exponent_max) + { + exp = exp.max; + sig <<= shift; + static if (flags&Flags.storeNormalized) + { + // Save inf/nan in denormalized format + sig >>= 1; + sig += cast(T) 1uL << (T.sizeof*8 - 1); + } + return; + } + } + if ((~flags&Flags.storeNormalized) || + // Convert denormalized form to normalized form + ((flags&Flags.allowDenorm) && exp == 0)) + { + if (sig > 0) + { + import core.bitop : bsr; + auto shift2 = precision - bsr(sig); + exp -= shift2-1; + shift += shift2; + } + else // value = 0.0 + { + exp = exp.min; + return; + } + } + sig <<= shift; + exp -= bias; + } + + // Set the current value from signed exponent, normalized form + void fromNormalized(T,U)(ref T sig, ref U exp) + { + auto shift = (T.sizeof*8) - precision; + if (exp == exp.max) + { + // infinity or nan + exp = exponent_max; + static if (flags & Flags.storeNormalized) + sig <<= 1; + + // convert back to normalized form + static if (~flags & Flags.infinity) + // No infinity support? + assert(sig != 0, "Infinity floating point value assigned to a " + ~ typeof(this).stringof ~ " (no infinity support)."); + + static if (~flags & Flags.nan) // No NaN support? + assert(sig == 0, "NaN floating point value assigned to a " ~ + typeof(this).stringof ~ " (no nan support)."); + sig >>= shift; + return; + } + if (exp == exp.min) // 0.0 + { + exp = 0; + sig = 0; + return; + } + + exp += bias; + if (exp <= 0) + { + static if ((flags&Flags.allowDenorm) || + // Convert from normalized form to denormalized + (~flags&Flags.storeNormalized)) + { + shift += -exp; + roundedShift(sig,1); + sig += cast(T) 1uL << (T.sizeof*8 - 1); + // Add the leading 1 + exp = 0; + } + else + assert((flags&Flags.storeNormalized) && exp == 0, + "Underflow occured assigning to a " ~ + typeof(this).stringof ~ " (no denormal support)."); + } + else + { + static if (~flags&Flags.storeNormalized) + { + // Convert from normalized form to denormalized + roundedShift(sig,1); + sig += cast(T) 1uL << (T.sizeof*8 - 1); + // Add the leading 1 + } + } + + if (shift > 0) + roundedShift(sig,shift); + if (sig > significand_max) + { + // handle significand overflow (should only be 1 bit) + static if (~flags&Flags.storeNormalized) + { + sig >>= 1; + } + else + sig &= significand_max; + exp++; + } + static if ((flags&Flags.allowDenormZeroOnly)==Flags.allowDenormZeroOnly) + { + // disallow non-zero denormals + if (exp == 0) + { + sig <<= 1; + if (sig > significand_max && (sig&significand_max) > 0) + // Check and round to even + exp++; + sig = 0; + } + } + + if (exp >= exponent_max) + { + static if (flags&(Flags.infinity|Flags.nan)) + { + sig = 0; + exp = exponent_max; + static if (~flags&(Flags.infinity)) + assert(0, "Overflow occured assigning to a " ~ + typeof(this).stringof ~ " (no infinity support)."); + } + else + assert(exp == exponent_max, "Overflow occured assigning to a " + ~ typeof(this).stringof ~ " (no infinity support)."); + } + } + +public: + static if (precision == 64) // CustomFloat!80 support hack + { + ulong significand; + enum ulong significand_max = ulong.max; + mixin(bitfields!( + T_exp , "exponent", exponentWidth, + bool , "sign" , flags & flags.signed )); + } + else + { + mixin(bitfields!( + T_sig, "significand", precision, + T_exp, "exponent" , exponentWidth, + bool , "sign" , flags & flags.signed )); + } + + /// Returns: infinity value + static if (flags & Flags.infinity) + static @property CustomFloat infinity() + { + CustomFloat value; + static if (flags & Flags.signed) + value.sign = 0; + value.significand = 0; + value.exponent = exponent_max; + return value; + } + + /// Returns: NaN value + static if (flags & Flags.nan) + static @property CustomFloat nan() + { + CustomFloat value; + static if (flags & Flags.signed) + value.sign = 0; + value.significand = cast(typeof(significand_max)) 1L << (precision-1); + value.exponent = exponent_max; + return value; + } + + /// Returns: number of decimal digits of precision + static @property size_t dig() + { + auto shiftcnt = precision - ((flags&Flags.storeNormalized) != 0); + immutable x = (shiftcnt == 64) ? 0 : 1uL << shiftcnt; + return cast(size_t) log10(x); + } + + /// Returns: smallest increment to the value 1 + static @property CustomFloat epsilon() + { + CustomFloat value; + static if (flags & Flags.signed) + value.sign = 0; + T_signed_exp exp = -precision; + T_sig sig = 0; + + value.fromNormalized(sig,exp); + if (exp == 0 && sig == 0) // underflowed to zero + { + static if ((flags&Flags.allowDenorm) || + (~flags&Flags.storeNormalized)) + sig = 1; + else + sig = cast(T) 1uL << (precision - 1); + } + value.exponent = cast(value.T_exp) exp; + value.significand = cast(value.T_sig) sig; + return value; + } + + /// the number of bits in mantissa + enum mant_dig = precision + ((flags&Flags.storeNormalized) != 0); + + /// Returns: maximum int value such that 10<sup>max_10_exp</sup> is representable + static @property int max_10_exp(){ return cast(int) log10( +max ); } + + /// maximum int value such that 2<sup>max_exp-1</sup> is representable + enum max_exp = exponent_max-bias+((~flags&(Flags.infinity|flags.nan))!=0); + + /// Returns: minimum int value such that 10<sup>min_10_exp</sup> is representable + static @property int min_10_exp(){ return cast(int) log10( +min_normal ); } + + /// minimum int value such that 2<sup>min_exp-1</sup> is representable as a normalized value + enum min_exp = cast(T_signed_exp)-bias +1+ ((flags&Flags.allowDenorm)!=0); + + /// Returns: largest representable value that's not infinity + static @property CustomFloat max() + { + CustomFloat value; + static if (flags & Flags.signed) + value.sign = 0; + value.exponent = exponent_max - ((flags&(flags.infinity|flags.nan)) != 0); + value.significand = significand_max; + return value; + } + + /// Returns: smallest representable normalized value that's not 0 + static @property CustomFloat min_normal() { + CustomFloat value; + static if (flags & Flags.signed) + value.sign = 0; + value.exponent = 1; + static if (flags&Flags.storeNormalized) + value.significand = 0; + else + value.significand = cast(T_sig) 1uL << (precision - 1); + return value; + } + + /// Returns: real part + @property CustomFloat re() { return this; } + + /// Returns: imaginary part + static @property CustomFloat im() { return CustomFloat(0.0f); } + + /// Initialize from any $(D real) compatible type. + this(F)(F input) if (__traits(compiles, cast(real) input )) + { + this = input; + } + + /// Self assignment + void opAssign(F:CustomFloat)(F input) + { + static if (flags & Flags.signed) + sign = input.sign; + exponent = input.exponent; + significand = input.significand; + } + + /// Assigns from any $(D real) compatible type. + void opAssign(F)(F input) + if (__traits(compiles, cast(real) input)) + { + import std.conv : text; + + static if (staticIndexOf!(Unqual!F, float, double, real) >= 0) + auto value = ToBinary!(Unqual!F)(input); + else + auto value = ToBinary!(real )(input); + + // Assign the sign bit + static if (~flags & Flags.signed) + assert((!value.sign) ^ ((flags&flags.negativeUnsigned) > 0), + "Incorrectly signed floating point value assigned to a " ~ + typeof(this).stringof ~ " (no sign support)."); + else + sign = value.sign; + + CommonType!(T_signed_exp ,value.T_signed_exp) exp = value.exponent; + CommonType!(T_sig, value.T_sig ) sig = value.significand; + + value.toNormalized(sig,exp); + fromNormalized(sig,exp); + + assert(exp <= exponent_max, text(typeof(this).stringof ~ + " exponent too large: " ,exp," > ",exponent_max, "\t",input,"\t",sig)); + assert(sig <= significand_max, text(typeof(this).stringof ~ + " significand too large: ",sig," > ",significand_max, + "\t",input,"\t",exp," ",exponent_max)); + exponent = cast(T_exp) exp; + significand = cast(T_sig) sig; + } + + /// Fetches the stored value either as a $(D float), $(D double) or $(D real). + @property F get(F)() + if (staticIndexOf!(Unqual!F, float, double, real) >= 0) + { + import std.conv : text; + + ToBinary!F result; + + static if (flags&Flags.signed) + result.sign = sign; + else + result.sign = (flags&flags.negativeUnsigned) > 0; + + CommonType!(T_signed_exp ,result.get.T_signed_exp ) exp = exponent; // Assign the exponent and fraction + CommonType!(T_sig, result.get.T_sig ) sig = significand; + + toNormalized(sig,exp); + result.fromNormalized(sig,exp); + assert(exp <= result.exponent_max, text("get exponent too large: " ,exp," > ",result.exponent_max) ); + assert(sig <= result.significand_max, text("get significand too large: ",sig," > ",result.significand_max) ); + result.exponent = cast(result.get.T_exp) exp; + result.significand = cast(result.get.T_sig) sig; + return result.set; + } + + ///ditto + T opCast(T)() if (__traits(compiles, get!T )) { return get!T; } + + /// Convert the CustomFloat to a real and perform the relavent operator on the result + real opUnary(string op)() + if (__traits(compiles, mixin(op~`(get!real)`)) || op=="++" || op=="--") + { + static if (op=="++" || op=="--") + { + auto result = get!real; + this = mixin(op~`result`); + return result; + } + else + return mixin(op~`get!real`); + } + + /// ditto + real opBinary(string op,T)(T b) + if (__traits(compiles, mixin(`get!real`~op~`b`))) + { + return mixin(`get!real`~op~`b`); + } + + /// ditto + real opBinaryRight(string op,T)(T a) + if ( __traits(compiles, mixin(`a`~op~`get!real`)) && + !__traits(compiles, mixin(`get!real`~op~`b`))) + { + return mixin(`a`~op~`get!real`); + } + + /// ditto + int opCmp(T)(auto ref T b) + if (__traits(compiles, cast(real) b)) + { + auto x = get!real; + auto y = cast(real) b; + return (x >= y)-(x <= y); + } + + /// ditto + void opOpAssign(string op, T)(auto ref T b) + if (__traits(compiles, mixin(`get!real`~op~`cast(real) b`))) + { + return mixin(`this = this `~op~` cast(real) b`); + } + + /// ditto + template toString() + { + import std.format : FormatSpec, formatValue; + // Needs to be a template because of DMD @@BUG@@ 13737. + void toString()(scope void delegate(const(char)[]) sink, FormatSpec!char fmt) + { + sink.formatValue(get!real, fmt); + } + } +} + +@safe unittest +{ + import std.meta; + alias FPTypes = + AliasSeq!( + CustomFloat!(5, 10), + CustomFloat!(5, 11, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), + CustomFloat!(1, 15, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), + CustomFloat!(4, 3, CustomFloatFlags.ieee | CustomFloatFlags.probability ^ CustomFloatFlags.signed) + ); + + foreach (F; FPTypes) + { + auto x = F(0.125); + assert(x.get!float == 0.125F); + assert(x.get!double == 0.125); + + x -= 0.0625; + assert(x.get!float == 0.0625F); + assert(x.get!double == 0.0625); + + x *= 2; + assert(x.get!float == 0.125F); + assert(x.get!double == 0.125); + + x /= 4; + assert(x.get!float == 0.03125); + assert(x.get!double == 0.03125); + + x = 0.5; + x ^^= 4; + assert(x.get!float == 1 / 16.0F); + assert(x.get!double == 1 / 16.0); + } +} + +@system unittest +{ + // @system due to to!string(CustomFloat) + import std.conv; + CustomFloat!(5, 10) y = CustomFloat!(5, 10)(0.125); + assert(y.to!string == "0.125"); +} + +/** +Defines the fastest type to use when storing temporaries of a +calculation intended to ultimately yield a result of type $(D F) +(where $(D F) must be one of $(D float), $(D double), or $(D +real)). When doing a multi-step computation, you may want to store +intermediate results as $(D FPTemporary!F). + +The necessity of $(D FPTemporary) stems from the optimized +floating-point operations and registers present in virtually all +processors. When adding numbers in the example above, the addition may +in fact be done in $(D real) precision internally. In that case, +storing the intermediate $(D result) in $(D double format) is not only +less precise, it is also (surprisingly) slower, because a conversion +from $(D real) to $(D double) is performed every pass through the +loop. This being a lose-lose situation, $(D FPTemporary!F) has been +defined as the $(I fastest) type to use for calculations at precision +$(D F). There is no need to define a type for the $(I most accurate) +calculations, as that is always $(D real). + +Finally, there is no guarantee that using $(D FPTemporary!F) will +always be fastest, as the speed of floating-point calculations depends +on very many factors. + */ +template FPTemporary(F) +if (isFloatingPoint!F) +{ + version (X86) + alias FPTemporary = real; + else + alias FPTemporary = Unqual!F; +} + +/// +@safe unittest +{ + import std.math : approxEqual; + + // Average numbers in an array + double avg(in double[] a) + { + if (a.length == 0) return 0; + FPTemporary!double result = 0; + foreach (e; a) result += e; + return result / a.length; + } + + auto a = [1.0, 2.0, 3.0]; + assert(approxEqual(avg(a), 2)); +} + +/** +Implements the $(HTTP tinyurl.com/2zb9yr, secant method) for finding a +root of the function $(D fun) starting from points $(D [xn_1, x_n]) +(ideally close to the root). $(D Num) may be $(D float), $(D double), +or $(D real). +*/ +template secantMethod(alias fun) +{ + import std.functional : unaryFun; + Num secantMethod(Num)(Num xn_1, Num xn) + { + auto fxn = unaryFun!(fun)(xn_1), d = xn_1 - xn; + typeof(fxn) fxn_1; + + xn = xn_1; + while (!approxEqual(d, 0) && isFinite(d)) + { + xn_1 = xn; + xn -= d; + fxn_1 = fxn; + fxn = unaryFun!(fun)(xn); + d *= -fxn / (fxn - fxn_1); + } + return xn; + } +} + +/// +@safe unittest +{ + import std.math : approxEqual, cos; + + float f(float x) + { + return cos(x) - x*x*x; + } + auto x = secantMethod!(f)(0f, 1f); + assert(approxEqual(x, 0.865474)); +} + +@system unittest +{ + // @system because of __gshared stderr + scope(failure) stderr.writeln("Failure testing secantMethod"); + float f(float x) + { + return cos(x) - x*x*x; + } + immutable x = secantMethod!(f)(0f, 1f); + assert(approxEqual(x, 0.865474)); + auto d = &f; + immutable y = secantMethod!(d)(0f, 1f); + assert(approxEqual(y, 0.865474)); +} + + +/** + * Return true if a and b have opposite sign. + */ +private bool oppositeSigns(T1, T2)(T1 a, T2 b) +{ + return signbit(a) != signbit(b); +} + +public: + +/** Find a real root of a real function f(x) via bracketing. + * + * Given a function `f` and a range `[a .. b]` such that `f(a)` + * and `f(b)` have opposite signs or at least one of them equals ±0, + * returns the value of `x` in + * the range which is closest to a root of `f(x)`. If `f(x)` + * has more than one root in the range, one will be chosen + * arbitrarily. If `f(x)` returns NaN, NaN will be returned; + * otherwise, this algorithm is guaranteed to succeed. + * + * Uses an algorithm based on TOMS748, which uses inverse cubic + * interpolation whenever possible, otherwise reverting to parabolic + * or secant interpolation. Compared to TOMS748, this implementation + * improves worst-case performance by a factor of more than 100, and + * typical performance by a factor of 2. For 80-bit reals, most + * problems require 8 to 15 calls to `f(x)` to achieve full machine + * precision. The worst-case performance (pathological cases) is + * approximately twice the number of bits. + * + * References: "On Enclosing Simple Roots of Nonlinear Equations", + * G. Alefeld, F.A. Potra, Yixun Shi, Mathematics of Computation 61, + * pp733-744 (1993). Fortran code available from $(HTTP + * www.netlib.org,www.netlib.org) as algorithm TOMS478. + * + */ +T findRoot(T, DF, DT)(scope DF f, in T a, in T b, + scope DT tolerance) //= (T a, T b) => false) +if ( + isFloatingPoint!T && + is(typeof(tolerance(T.init, T.init)) : bool) && + is(typeof(f(T.init)) == R, R) && isFloatingPoint!R + ) +{ + immutable fa = f(a); + if (fa == 0) + return a; + immutable fb = f(b); + if (fb == 0) + return b; + immutable r = findRoot(f, a, b, fa, fb, tolerance); + // Return the first value if it is smaller or NaN + return !(fabs(r[2]) > fabs(r[3])) ? r[0] : r[1]; +} + +///ditto +T findRoot(T, DF)(scope DF f, in T a, in T b) +{ + return findRoot(f, a, b, (T a, T b) => false); +} + +/** Find root of a real function f(x) by bracketing, allowing the + * termination condition to be specified. + * + * Params: + * + * f = Function to be analyzed + * + * ax = Left bound of initial range of `f` known to contain the + * root. + * + * bx = Right bound of initial range of `f` known to contain the + * root. + * + * fax = Value of $(D f(ax)). + * + * fbx = Value of $(D f(bx)). $(D fax) and $(D fbx) should have opposite signs. + * ($(D f(ax)) and $(D f(bx)) are commonly known in advance.) + * + * + * tolerance = Defines an early termination condition. Receives the + * current upper and lower bounds on the root. The + * delegate must return $(D true) when these bounds are + * acceptable. If this function always returns $(D false), + * full machine precision will be achieved. + * + * Returns: + * + * A tuple consisting of two ranges. The first two elements are the + * range (in `x`) of the root, while the second pair of elements + * are the corresponding function values at those points. If an exact + * root was found, both of the first two elements will contain the + * root, and the second pair of elements will be 0. + */ +Tuple!(T, T, R, R) findRoot(T, R, DF, DT)(scope DF f, in T ax, in T bx, in R fax, in R fbx, + scope DT tolerance) // = (T a, T b) => false) +if ( + isFloatingPoint!T && + is(typeof(tolerance(T.init, T.init)) : bool) && + is(typeof(f(T.init)) == R) && isFloatingPoint!R + ) +in +{ + assert(!ax.isNaN() && !bx.isNaN(), "Limits must not be NaN"); + assert(signbit(fax) != signbit(fbx), "Parameters must bracket the root."); +} +body +{ + // Author: Don Clugston. This code is (heavily) modified from TOMS748 + // (www.netlib.org). The changes to improve the worst-cast performance are + // entirely original. + + T a, b, d; // [a .. b] is our current bracket. d is the third best guess. + R fa, fb, fd; // Values of f at a, b, d. + bool done = false; // Has a root been found? + + // Allow ax and bx to be provided in reverse order + if (ax <= bx) + { + a = ax; fa = fax; + b = bx; fb = fbx; + } + else + { + a = bx; fa = fbx; + b = ax; fb = fax; + } + + // Test the function at point c; update brackets accordingly + void bracket(T c) + { + R fc = f(c); + if (fc == 0 || fc.isNaN()) // Exact solution, or NaN + { + a = c; + fa = fc; + d = c; + fd = fc; + done = true; + return; + } + + // Determine new enclosing interval + if (signbit(fa) != signbit(fc)) + { + d = b; + fd = fb; + b = c; + fb = fc; + } + else + { + d = a; + fd = fa; + a = c; + fa = fc; + } + } + + /* Perform a secant interpolation. If the result would lie on a or b, or if + a and b differ so wildly in magnitude that the result would be meaningless, + perform a bisection instead. + */ + static T secant_interpolate(T a, T b, R fa, R fb) + { + if (( ((a - b) == a) && b != 0) || (a != 0 && ((b - a) == b))) + { + // Catastrophic cancellation + if (a == 0) + a = copysign(T(0), b); + else if (b == 0) + b = copysign(T(0), a); + else if (signbit(a) != signbit(b)) + return 0; + T c = ieeeMean(a, b); + return c; + } + // avoid overflow + if (b - a > T.max) + return b / 2 + a / 2; + if (fb - fa > R.max) + return a - (b - a) / 2; + T c = a - (fa / (fb - fa)) * (b - a); + if (c == a || c == b) + return (a + b) / 2; + return c; + } + + /* Uses 'numsteps' newton steps to approximate the zero in [a .. b] of the + quadratic polynomial interpolating f(x) at a, b, and d. + Returns: + The approximate zero in [a .. b] of the quadratic polynomial. + */ + T newtonQuadratic(int numsteps) + { + // Find the coefficients of the quadratic polynomial. + immutable T a0 = fa; + immutable T a1 = (fb - fa)/(b - a); + immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a); + + // Determine the starting point of newton steps. + T c = oppositeSigns(a2, fa) ? a : b; + + // start the safeguarded newton steps. + foreach (int i; 0 .. numsteps) + { + immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a); + immutable T pdc = a1 + a2*((2 * c) - (a + b)); + if (pdc == 0) + return a - a0 / a1; + else + c = c - pc / pdc; + } + return c; + } + + // On the first iteration we take a secant step: + if (fa == 0 || fa.isNaN()) + { + done = true; + b = a; + fb = fa; + } + else if (fb == 0 || fb.isNaN()) + { + done = true; + a = b; + fa = fb; + } + else + { + bracket(secant_interpolate(a, b, fa, fb)); + } + + // Starting with the second iteration, higher-order interpolation can + // be used. + int itnum = 1; // Iteration number + int baditer = 1; // Num bisections to take if an iteration is bad. + T c, e; // e is our fourth best guess + R fe; + +whileloop: + while (!done && (b != nextUp(a)) && !tolerance(a, b)) + { + T a0 = a, b0 = b; // record the brackets + + // Do two higher-order (cubic or parabolic) interpolation steps. + foreach (int QQ; 0 .. 2) + { + // Cubic inverse interpolation requires that + // all four function values fa, fb, fd, and fe are distinct; + // otherwise use quadratic interpolation. + bool distinct = (fa != fb) && (fa != fd) && (fa != fe) + && (fb != fd) && (fb != fe) && (fd != fe); + // The first time, cubic interpolation is impossible. + if (itnum<2) distinct = false; + bool ok = distinct; + if (distinct) + { + // Cubic inverse interpolation of f(x) at a, b, d, and e + immutable q11 = (d - e) * fd / (fe - fd); + immutable q21 = (b - d) * fb / (fd - fb); + immutable q31 = (a - b) * fa / (fb - fa); + immutable d21 = (b - d) * fd / (fd - fb); + immutable d31 = (a - b) * fb / (fb - fa); + + immutable q22 = (d21 - q11) * fb / (fe - fb); + immutable q32 = (d31 - q21) * fa / (fd - fa); + immutable d32 = (d31 - q21) * fd / (fd - fa); + immutable q33 = (d32 - q22) * fa / (fe - fa); + c = a + (q31 + q32 + q33); + if (c.isNaN() || (c <= a) || (c >= b)) + { + // DAC: If the interpolation predicts a or b, it's + // probable that it's the actual root. Only allow this if + // we're already close to the root. + if (c == a && a - b != a) + { + c = nextUp(a); + } + else if (c == b && a - b != -b) + { + c = nextDown(b); + } + else + { + ok = false; + } + } + } + if (!ok) + { + // DAC: Alefeld doesn't explain why the number of newton steps + // should vary. + c = newtonQuadratic(distinct ? 3 : 2); + if (c.isNaN() || (c <= a) || (c >= b)) + { + // Failure, try a secant step: + c = secant_interpolate(a, b, fa, fb); + } + } + ++itnum; + e = d; + fe = fd; + bracket(c); + if (done || ( b == nextUp(a)) || tolerance(a, b)) + break whileloop; + if (itnum == 2) + continue whileloop; + } + + // Now we take a double-length secant step: + T u; + R fu; + if (fabs(fa) < fabs(fb)) + { + u = a; + fu = fa; + } + else + { + u = b; + fu = fb; + } + c = u - 2 * (fu / (fb - fa)) * (b - a); + + // DAC: If the secant predicts a value equal to an endpoint, it's + // probably false. + if (c == a || c == b || c.isNaN() || fabs(c - u) > (b - a) / 2) + { + if ((a-b) == a || (b-a) == b) + { + if ((a>0 && b<0) || (a<0 && b>0)) + c = 0; + else + { + if (a == 0) + c = ieeeMean(copysign(T(0), b), b); + else if (b == 0) + c = ieeeMean(copysign(T(0), a), a); + else + c = ieeeMean(a, b); + } + } + else + { + c = a + (b - a) / 2; + } + } + e = d; + fe = fd; + bracket(c); + if (done || (b == nextUp(a)) || tolerance(a, b)) + break; + + // IMPROVE THE WORST-CASE PERFORMANCE + // We must ensure that the bounds reduce by a factor of 2 + // in binary space! every iteration. If we haven't achieved this + // yet, or if we don't yet know what the exponent is, + // perform a binary chop. + + if ((a == 0 || b == 0 || + (fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a))) + && (b - a) < T(0.25) * (b0 - a0)) + { + baditer = 1; + continue; + } + + // DAC: If this happens on consecutive iterations, we probably have a + // pathological function. Perform a number of bisections equal to the + // total number of consecutive bad iterations. + + if ((b - a) < T(0.25) * (b0 - a0)) + baditer = 1; + foreach (int QQ; 0 .. baditer) + { + e = d; + fe = fd; + + T w; + if ((a>0 && b<0) || (a<0 && b>0)) + w = 0; + else + { + T usea = a; + T useb = b; + if (a == 0) + usea = copysign(T(0), b); + else if (b == 0) + useb = copysign(T(0), a); + w = ieeeMean(usea, useb); + } + bracket(w); + } + ++baditer; + } + return Tuple!(T, T, R, R)(a, b, fa, fb); +} + +///ditto +Tuple!(T, T, R, R) findRoot(T, R, DF)(scope DF f, in T ax, in T bx, in R fax, in R fbx) +{ + return findRoot(f, ax, bx, fax, fbx, (T a, T b) => false); +} + +///ditto +T findRoot(T, R)(scope R delegate(T) f, in T a, in T b, + scope bool delegate(T lo, T hi) tolerance = (T a, T b) => false) +{ + return findRoot!(T, R delegate(T), bool delegate(T lo, T hi))(f, a, b, tolerance); +} + +@safe nothrow unittest +{ + int numProblems = 0; + int numCalls; + + void testFindRoot(real delegate(real) @nogc @safe nothrow pure f , real x1, real x2) @nogc @safe nothrow pure + { + //numCalls=0; + //++numProblems; + assert(!x1.isNaN() && !x2.isNaN()); + assert(signbit(x1) != signbit(x2)); + auto result = findRoot(f, x1, x2, f(x1), f(x2), + (real lo, real hi) { return false; }); + + auto flo = f(result[0]); + auto fhi = f(result[1]); + if (flo != 0) + { + assert(oppositeSigns(flo, fhi)); + } + } + + // Test functions + real cubicfn(real x) @nogc @safe nothrow pure + { + //++numCalls; + if (x>float.max) + x = float.max; + if (x<-double.max) + x = -double.max; + // This has a single real root at -59.286543284815 + return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2; + } + // Test a function with more than one root. + real multisine(real x) { ++numCalls; return sin(x); } + //testFindRoot( &multisine, 6, 90); + //testFindRoot(&cubicfn, -100, 100); + //testFindRoot( &cubicfn, -double.max, real.max); + + +/* Tests from the paper: + * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra, + * Yixun Shi, Mathematics of Computation 61, pp733-744 (1993). + */ + // Parameters common to many alefeld tests. + int n; + real ale_a, ale_b; + + int powercalls = 0; + + real power(real x) + { + ++powercalls; + ++numCalls; + return pow(x, n) + double.min_normal; + } + int [] power_nvals = [3, 5, 7, 9, 19, 25]; + // Alefeld paper states that pow(x,n) is a very poor case, where bisection + // outperforms his method, and gives total numcalls = + // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit), + // 2624 for brent (6.8/bit) + // ... but that is for double, not real80. + // This poor performance seems mainly due to catastrophic cancellation, + // which is avoided here by the use of ieeeMean(). + // I get: 231 (0.48/bit). + // IE this is 10X faster in Alefeld's worst case + numProblems=0; + foreach (k; power_nvals) + { + n = k; + //testFindRoot(&power, -1, 10); + } + + int powerProblems = numProblems; + + // Tests from Alefeld paper + + int [9] alefeldSums; + real alefeld0(real x) + { + ++alefeldSums[0]; + ++numCalls; + real q = sin(x) - x/2; + for (int i=1; i<20; ++i) + q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i)); + return q; + } + real alefeld1(real x) + { + ++numCalls; + ++alefeldSums[1]; + return ale_a*x + exp(ale_b * x); + } + real alefeld2(real x) + { + ++numCalls; + ++alefeldSums[2]; + return pow(x, n) - ale_a; + } + real alefeld3(real x) + { + ++numCalls; + ++alefeldSums[3]; + return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2); + } + real alefeld4(real x) + { + ++numCalls; + ++alefeldSums[4]; + return x*x - pow(1-x, n); + } + real alefeld5(real x) + { + ++numCalls; + ++alefeldSums[5]; + return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4); + } + real alefeld6(real x) + { + ++numCalls; + ++alefeldSums[6]; + return exp(-n*x)*(x-1.01L) + pow(x, n); + } + real alefeld7(real x) + { + ++numCalls; + ++alefeldSums[7]; + return (n*x-1)/((n-1)*x); + } + + numProblems=0; + //testFindRoot(&alefeld0, PI_2, PI); + for (n=1; n <= 10; ++n) + { + //testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L); + } + ale_a = -40; ale_b = -1; + //testFindRoot(&alefeld1, -9, 31); + ale_a = -100; ale_b = -2; + //testFindRoot(&alefeld1, -9, 31); + ale_a = -200; ale_b = -3; + //testFindRoot(&alefeld1, -9, 31); + int [] nvals_3 = [1, 2, 5, 10, 15, 20]; + int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20]; + int [] nvals_6 = [1, 5, 10, 15, 20]; + int [] nvals_7 = [2, 5, 15, 20]; + + for (int i=4; i<12; i+=2) + { + n = i; + ale_a = 0.2; + //testFindRoot(&alefeld2, 0, 5); + ale_a=1; + //testFindRoot(&alefeld2, 0.95, 4.05); + //testFindRoot(&alefeld2, 0, 1.5); + } + foreach (i; nvals_3) + { + n=i; + //testFindRoot(&alefeld3, 0, 1); + } + foreach (i; nvals_3) + { + n=i; + //testFindRoot(&alefeld4, 0, 1); + } + foreach (i; nvals_5) + { + n=i; + //testFindRoot(&alefeld5, 0, 1); + } + foreach (i; nvals_6) + { + n=i; + //testFindRoot(&alefeld6, 0, 1); + } + foreach (i; nvals_7) + { + n=i; + //testFindRoot(&alefeld7, 0.01L, 1); + } + real worstcase(real x) + { + ++numCalls; + return x<0.3*real.max? -0.999e-3 : 1.0; + } + //testFindRoot(&worstcase, -real.max, real.max); + + // just check that the double + float cases compile + //findRoot((double x){ return 0.0; }, -double.max, double.max); + //findRoot((float x){ return 0.0f; }, -float.max, float.max); + +/* + int grandtotal=0; + foreach (calls; alefeldSums) + { + grandtotal+=calls; + } + grandtotal-=2*numProblems; + printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n", + grandtotal, (1.0*grandtotal)/numProblems); + powercalls -= 2*powerProblems; + printf("POWER TOTAL = %d avg = %f ", powercalls, + (1.0*powercalls)/powerProblems); +*/ + //Issue 14231 + auto xp = findRoot((float x) => x, 0f, 1f); + auto xn = findRoot((float x) => x, -1f, -0f); +} + +//regression control +@system unittest +{ + // @system due to the case in the 2nd line + static assert(__traits(compiles, findRoot((float x)=>cast(real) x, float.init, float.init))); + static assert(__traits(compiles, findRoot!real((x)=>cast(double) x, real.init, real.init))); + static assert(__traits(compiles, findRoot((real x)=>cast(double) x, real.init, real.init))); +} + +/++ +Find a real minimum of a real function `f(x)` via bracketing. +Given a function `f` and a range `(ax .. bx)`, +returns the value of `x` in the range which is closest to a minimum of `f(x)`. +`f` is never evaluted at the endpoints of `ax` and `bx`. +If `f(x)` has more than one minimum in the range, one will be chosen arbitrarily. +If `f(x)` returns NaN or -Infinity, `(x, f(x), NaN)` will be returned; +otherwise, this algorithm is guaranteed to succeed. + +Params: + f = Function to be analyzed + ax = Left bound of initial range of f known to contain the minimum. + bx = Right bound of initial range of f known to contain the minimum. + relTolerance = Relative tolerance. + absTolerance = Absolute tolerance. + +Preconditions: + `ax` and `bx` shall be finite reals. $(BR) + $(D relTolerance) shall be normal positive real. $(BR) + $(D absTolerance) shall be normal positive real no less then $(D T.epsilon*2). + +Returns: + A tuple consisting of `x`, `y = f(x)` and `error = 3 * (absTolerance * fabs(x) + relTolerance)`. + + The method used is a combination of golden section search and +successive parabolic interpolation. Convergence is never much slower +than that for a Fibonacci search. + +References: + "Algorithms for Minimization without Derivatives", Richard Brent, Prentice-Hall, Inc. (1973) + +See_Also: $(LREF findRoot), $(REF isNormal, std,math) ++/ +Tuple!(T, "x", Unqual!(ReturnType!DF), "y", T, "error") +findLocalMin(T, DF)( + scope DF f, + in T ax, + in T bx, + in T relTolerance = sqrt(T.epsilon), + in T absTolerance = sqrt(T.epsilon), + ) +if (isFloatingPoint!T + && __traits(compiles, {T _ = DF.init(T.init);})) +in +{ + assert(isFinite(ax), "ax is not finite"); + assert(isFinite(bx), "bx is not finite"); + assert(isNormal(relTolerance), "relTolerance is not normal floating point number"); + assert(isNormal(absTolerance), "absTolerance is not normal floating point number"); + assert(relTolerance >= 0, "absTolerance is not positive"); + assert(absTolerance >= T.epsilon*2, "absTolerance is not greater then `2*T.epsilon`"); +} +out (result) +{ + assert(isFinite(result.x)); +} +body +{ + alias R = Unqual!(CommonType!(ReturnType!DF, T)); + // c is the squared inverse of the golden ratio + // (3 - sqrt(5))/2 + // Value obtained from Wolfram Alpha. + enum T c = 0x0.61c8864680b583ea0c633f9fa31237p+0L; + enum T cm1 = 0x0.9e3779b97f4a7c15f39cc0605cedc8p+0L; + R tolerance; + T a = ax > bx ? bx : ax; + T b = ax > bx ? ax : bx; + // sequence of declarations suitable for SIMD instructions + T v = a * cm1 + b * c; + assert(isFinite(v)); + R fv = f(v); + if (isNaN(fv) || fv == -T.infinity) + { + return typeof(return)(v, fv, T.init); + } + T w = v; + R fw = fv; + T x = v; + R fx = fv; + size_t i; + for (R d = 0, e = 0;;) + { + i++; + T m = (a + b) / 2; + // This fix is not part of the original algorithm + if (!isFinite(m)) // fix infinity loop. Issue can be reproduced in R. + { + m = a / 2 + b / 2; + if (!isFinite(m)) // fast-math compiler switch is enabled + { + //SIMD instructions can be used by compiler, do not reduce declarations + int a_exp = void; + int b_exp = void; + immutable an = frexp(a, a_exp); + immutable bn = frexp(b, b_exp); + immutable am = ldexp(an, a_exp-1); + immutable bm = ldexp(bn, b_exp-1); + m = am + bm; + if (!isFinite(m)) // wrong input: constraints are disabled in release mode + { + return typeof(return).init; + } + } + } + tolerance = absTolerance * fabs(x) + relTolerance; + immutable t2 = tolerance * 2; + // check stopping criterion + if (!(fabs(x - m) > t2 - (b - a) / 2)) + { + break; + } + R p = 0; + R q = 0; + R r = 0; + // fit parabola + if (fabs(e) > tolerance) + { + immutable xw = x - w; + immutable fxw = fx - fw; + immutable xv = x - v; + immutable fxv = fx - fv; + immutable xwfxv = xw * fxv; + immutable xvfxw = xv * fxw; + p = xv * xvfxw - xw * xwfxv; + q = (xvfxw - xwfxv) * 2; + if (q > 0) + p = -p; + else + q = -q; + r = e; + e = d; + } + T u; + // a parabolic-interpolation step + if (fabs(p) < fabs(q * r / 2) && p > q * (a - x) && p < q * (b - x)) + { + d = p / q; + u = x + d; + // f must not be evaluated too close to a or b + if (u - a < t2 || b - u < t2) + d = x < m ? tolerance : -tolerance; + } + // a golden-section step + else + { + e = (x < m ? b : a) - x; + d = c * e; + } + // f must not be evaluated too close to x + u = x + (fabs(d) >= tolerance ? d : d > 0 ? tolerance : -tolerance); + immutable fu = f(u); + if (isNaN(fu) || fu == -T.infinity) + { + return typeof(return)(u, fu, T.init); + } + // update a, b, v, w, and x + if (fu <= fx) + { + u < x ? b : a = x; + v = w; fv = fw; + w = x; fw = fx; + x = u; fx = fu; + } + else + { + u < x ? a : b = u; + if (fu <= fw || w == x) + { + v = w; fv = fw; + w = u; fw = fu; + } + else if (fu <= fv || v == x || v == w) + { // do not remove this braces + v = u; fv = fu; + } + } + } + return typeof(return)(x, fx, tolerance * 3); +} + +/// +@safe unittest +{ + import std.math : approxEqual; + + auto ret = findLocalMin((double x) => (x-4)^^2, -1e7, 1e7); + assert(ret.x.approxEqual(4.0)); + assert(ret.y.approxEqual(0.0)); +} + +@safe unittest +{ + import std.meta : AliasSeq; + foreach (T; AliasSeq!(double, float, real)) + { + { + auto ret = findLocalMin!T((T x) => (x-4)^^2, T.min_normal, 1e7); + assert(ret.x.approxEqual(T(4))); + assert(ret.y.approxEqual(T(0))); + } + { + auto ret = findLocalMin!T((T x) => fabs(x-1), -T.max/4, T.max/4, T.min_normal, 2*T.epsilon); + assert(approxEqual(ret.x, T(1))); + assert(approxEqual(ret.y, T(0))); + assert(ret.error <= 10 * T.epsilon); + } + { + auto ret = findLocalMin!T((T x) => T.init, 0, 1, T.min_normal, 2*T.epsilon); + assert(!ret.x.isNaN); + assert(ret.y.isNaN); + assert(ret.error.isNaN); + } + { + auto ret = findLocalMin!T((T x) => log(x), 0, 1, T.min_normal, 2*T.epsilon); + assert(ret.error < 3.00001 * ((2*T.epsilon)*fabs(ret.x)+ T.min_normal)); + assert(ret.x >= 0 && ret.x <= ret.error); + } + { + auto ret = findLocalMin!T((T x) => log(x), 0, T.max, T.min_normal, 2*T.epsilon); + assert(ret.y < -18); + assert(ret.error < 5e-08); + assert(ret.x >= 0 && ret.x <= ret.error); + } + { + auto ret = findLocalMin!T((T x) => -fabs(x), -1, 1, T.min_normal, 2*T.epsilon); + assert(ret.x.fabs.approxEqual(T(1))); + assert(ret.y.fabs.approxEqual(T(1))); + assert(ret.error.approxEqual(T(0))); + } + } +} + +/** +Computes $(LINK2 https://en.wikipedia.org/wiki/Euclidean_distance, +Euclidean distance) between input ranges $(D a) and +$(D b). The two ranges must have the same length. The three-parameter +version stops computation as soon as the distance is greater than or +equal to $(D limit) (this is useful to save computation if a small +distance is sought). + */ +CommonType!(ElementType!(Range1), ElementType!(Range2)) +euclideanDistance(Range1, Range2)(Range1 a, Range2 b) +if (isInputRange!(Range1) && isInputRange!(Range2)) +{ + enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); + static if (haveLen) assert(a.length == b.length); + Unqual!(typeof(return)) result = 0; + for (; !a.empty; a.popFront(), b.popFront()) + { + immutable t = a.front - b.front; + result += t * t; + } + static if (!haveLen) assert(b.empty); + return sqrt(result); +} + +/// Ditto +CommonType!(ElementType!(Range1), ElementType!(Range2)) +euclideanDistance(Range1, Range2, F)(Range1 a, Range2 b, F limit) +if (isInputRange!(Range1) && isInputRange!(Range2)) +{ + limit *= limit; + enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); + static if (haveLen) assert(a.length == b.length); + Unqual!(typeof(return)) result = 0; + for (; ; a.popFront(), b.popFront()) + { + if (a.empty) + { + static if (!haveLen) assert(b.empty); + break; + } + immutable t = a.front - b.front; + result += t * t; + if (result >= limit) break; + } + return sqrt(result); +} + +@safe unittest +{ + import std.meta : AliasSeq; + foreach (T; AliasSeq!(double, const double, immutable double)) + { + T[] a = [ 1.0, 2.0, ]; + T[] b = [ 4.0, 6.0, ]; + assert(euclideanDistance(a, b) == 5); + assert(euclideanDistance(a, b, 5) == 5); + assert(euclideanDistance(a, b, 4) == 5); + assert(euclideanDistance(a, b, 2) == 3); + } +} + +/** +Computes the $(LINK2 https://en.wikipedia.org/wiki/Dot_product, +dot product) of input ranges $(D a) and $(D +b). The two ranges must have the same length. If both ranges define +length, the check is done once; otherwise, it is done at each +iteration. + */ +CommonType!(ElementType!(Range1), ElementType!(Range2)) +dotProduct(Range1, Range2)(Range1 a, Range2 b) +if (isInputRange!(Range1) && isInputRange!(Range2) && + !(isArray!(Range1) && isArray!(Range2))) +{ + enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); + static if (haveLen) assert(a.length == b.length); + Unqual!(typeof(return)) result = 0; + for (; !a.empty; a.popFront(), b.popFront()) + { + result += a.front * b.front; + } + static if (!haveLen) assert(b.empty); + return result; +} + +/// Ditto +CommonType!(F1, F2) +dotProduct(F1, F2)(in F1[] avector, in F2[] bvector) +{ + immutable n = avector.length; + assert(n == bvector.length); + auto avec = avector.ptr, bvec = bvector.ptr; + Unqual!(typeof(return)) sum0 = 0, sum1 = 0; + + const all_endp = avec + n; + const smallblock_endp = avec + (n & ~3); + const bigblock_endp = avec + (n & ~15); + + for (; avec != bigblock_endp; avec += 16, bvec += 16) + { + sum0 += avec[0] * bvec[0]; + sum1 += avec[1] * bvec[1]; + sum0 += avec[2] * bvec[2]; + sum1 += avec[3] * bvec[3]; + sum0 += avec[4] * bvec[4]; + sum1 += avec[5] * bvec[5]; + sum0 += avec[6] * bvec[6]; + sum1 += avec[7] * bvec[7]; + sum0 += avec[8] * bvec[8]; + sum1 += avec[9] * bvec[9]; + sum0 += avec[10] * bvec[10]; + sum1 += avec[11] * bvec[11]; + sum0 += avec[12] * bvec[12]; + sum1 += avec[13] * bvec[13]; + sum0 += avec[14] * bvec[14]; + sum1 += avec[15] * bvec[15]; + } + + for (; avec != smallblock_endp; avec += 4, bvec += 4) + { + sum0 += avec[0] * bvec[0]; + sum1 += avec[1] * bvec[1]; + sum0 += avec[2] * bvec[2]; + sum1 += avec[3] * bvec[3]; + } + + sum0 += sum1; + + /* Do trailing portion in naive loop. */ + while (avec != all_endp) + { + sum0 += *avec * *bvec; + ++avec; + ++bvec; + } + + return sum0; +} + +@system unittest +{ + // @system due to dotProduct and assertCTFEable + import std.exception : assertCTFEable; + import std.meta : AliasSeq; + foreach (T; AliasSeq!(double, const double, immutable double)) + { + T[] a = [ 1.0, 2.0, ]; + T[] b = [ 4.0, 6.0, ]; + assert(dotProduct(a, b) == 16); + assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3); + } + + // Make sure the unrolled loop codepath gets tested. + static const x = + [1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]; + static const y = + [2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]; + assertCTFEable!({ assert(dotProduct(x, y) == 2280); }); +} + +/** +Computes the $(LINK2 https://en.wikipedia.org/wiki/Cosine_similarity, +cosine similarity) of input ranges $(D a) and $(D +b). The two ranges must have the same length. If both ranges define +length, the check is done once; otherwise, it is done at each +iteration. If either range has all-zero elements, return 0. + */ +CommonType!(ElementType!(Range1), ElementType!(Range2)) +cosineSimilarity(Range1, Range2)(Range1 a, Range2 b) +if (isInputRange!(Range1) && isInputRange!(Range2)) +{ + enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); + static if (haveLen) assert(a.length == b.length); + Unqual!(typeof(return)) norma = 0, normb = 0, dotprod = 0; + for (; !a.empty; a.popFront(), b.popFront()) + { + immutable t1 = a.front, t2 = b.front; + norma += t1 * t1; + normb += t2 * t2; + dotprod += t1 * t2; + } + static if (!haveLen) assert(b.empty); + if (norma == 0 || normb == 0) return 0; + return dotprod / sqrt(norma * normb); +} + +@safe unittest +{ + import std.meta : AliasSeq; + foreach (T; AliasSeq!(double, const double, immutable double)) + { + T[] a = [ 1.0, 2.0, ]; + T[] b = [ 4.0, 3.0, ]; + assert(approxEqual( + cosineSimilarity(a, b), 10.0 / sqrt(5.0 * 25), + 0.01)); + } +} + +/** +Normalizes values in $(D range) by multiplying each element with a +number chosen such that values sum up to $(D sum). If elements in $(D +range) sum to zero, assigns $(D sum / range.length) to +all. Normalization makes sense only if all elements in $(D range) are +positive. $(D normalize) assumes that is the case without checking it. + +Returns: $(D true) if normalization completed normally, $(D false) if +all elements in $(D range) were zero or if $(D range) is empty. + */ +bool normalize(R)(R range, ElementType!(R) sum = 1) +if (isForwardRange!(R)) +{ + ElementType!(R) s = 0; + // Step 1: Compute sum and length of the range + static if (hasLength!(R)) + { + const length = range.length; + foreach (e; range) + { + s += e; + } + } + else + { + uint length = 0; + foreach (e; range) + { + s += e; + ++length; + } + } + // Step 2: perform normalization + if (s == 0) + { + if (length) + { + immutable f = sum / range.length; + foreach (ref e; range) e = f; + } + return false; + } + // The path most traveled + assert(s >= 0); + immutable f = sum / s; + foreach (ref e; range) + e *= f; + return true; +} + +/// +@safe unittest +{ + double[] a = []; + assert(!normalize(a)); + a = [ 1.0, 3.0 ]; + assert(normalize(a)); + assert(a == [ 0.25, 0.75 ]); + a = [ 0.0, 0.0 ]; + assert(!normalize(a)); + assert(a == [ 0.5, 0.5 ]); +} + +/** +Compute the sum of binary logarithms of the input range $(D r). +The error of this method is much smaller than with a naive sum of log2. + */ +ElementType!Range sumOfLog2s(Range)(Range r) +if (isInputRange!Range && isFloatingPoint!(ElementType!Range)) +{ + long exp = 0; + Unqual!(typeof(return)) x = 1; + foreach (e; r) + { + if (e < 0) + return typeof(return).nan; + int lexp = void; + x *= frexp(e, lexp); + exp += lexp; + if (x < 0.5) + { + x *= 2; + exp--; + } + } + return exp + log2(x); +} + +/// +@safe unittest +{ + import std.math : isNaN; + + assert(sumOfLog2s(new double[0]) == 0); + assert(sumOfLog2s([0.0L]) == -real.infinity); + assert(sumOfLog2s([-0.0L]) == -real.infinity); + assert(sumOfLog2s([2.0L]) == 1); + assert(sumOfLog2s([-2.0L]).isNaN()); + assert(sumOfLog2s([real.nan]).isNaN()); + assert(sumOfLog2s([-real.nan]).isNaN()); + assert(sumOfLog2s([real.infinity]) == real.infinity); + assert(sumOfLog2s([-real.infinity]).isNaN()); + assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9); +} + +/** +Computes $(LINK2 https://en.wikipedia.org/wiki/Entropy_(information_theory), +_entropy) of input range $(D r) in bits. This +function assumes (without checking) that the values in $(D r) are all +in $(D [0, 1]). For the entropy to be meaningful, often $(D r) should +be normalized too (i.e., its values should sum to 1). The +two-parameter version stops evaluating as soon as the intermediate +result is greater than or equal to $(D max). + */ +ElementType!Range entropy(Range)(Range r) +if (isInputRange!Range) +{ + Unqual!(typeof(return)) result = 0.0; + for (;!r.empty; r.popFront) + { + if (!r.front) continue; + result -= r.front * log2(r.front); + } + return result; +} + +/// Ditto +ElementType!Range entropy(Range, F)(Range r, F max) +if (isInputRange!Range && + !is(CommonType!(ElementType!Range, F) == void)) +{ + Unqual!(typeof(return)) result = 0.0; + for (;!r.empty; r.popFront) + { + if (!r.front) continue; + result -= r.front * log2(r.front); + if (result >= max) break; + } + return result; +} + +@safe unittest +{ + import std.meta : AliasSeq; + foreach (T; AliasSeq!(double, const double, immutable double)) + { + T[] p = [ 0.0, 0, 0, 1 ]; + assert(entropy(p) == 0); + p = [ 0.25, 0.25, 0.25, 0.25 ]; + assert(entropy(p) == 2); + assert(entropy(p, 1) == 1); + } +} + +/** +Computes the $(LINK2 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence, +Kullback-Leibler divergence) between input ranges +$(D a) and $(D b), which is the sum $(D ai * log(ai / bi)). The base +of logarithm is 2. The ranges are assumed to contain elements in $(D +[0, 1]). Usually the ranges are normalized probability distributions, +but this is not required or checked by $(D +kullbackLeiblerDivergence). If any element $(D bi) is zero and the +corresponding element $(D ai) nonzero, returns infinity. (Otherwise, +if $(D ai == 0 && bi == 0), the term $(D ai * log(ai / bi)) is +considered zero.) If the inputs are normalized, the result is +positive. + */ +CommonType!(ElementType!Range1, ElementType!Range2) +kullbackLeiblerDivergence(Range1, Range2)(Range1 a, Range2 b) +if (isInputRange!(Range1) && isInputRange!(Range2)) +{ + enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); + static if (haveLen) assert(a.length == b.length); + Unqual!(typeof(return)) result = 0; + for (; !a.empty; a.popFront(), b.popFront()) + { + immutable t1 = a.front; + if (t1 == 0) continue; + immutable t2 = b.front; + if (t2 == 0) return result.infinity; + assert(t1 > 0 && t2 > 0); + result += t1 * log2(t1 / t2); + } + static if (!haveLen) assert(b.empty); + return result; +} + +/// +@safe unittest +{ + import std.math : approxEqual; + + double[] p = [ 0.0, 0, 0, 1 ]; + assert(kullbackLeiblerDivergence(p, p) == 0); + double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; + assert(kullbackLeiblerDivergence(p1, p1) == 0); + assert(kullbackLeiblerDivergence(p, p1) == 2); + assert(kullbackLeiblerDivergence(p1, p) == double.infinity); + double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; + assert(approxEqual(kullbackLeiblerDivergence(p1, p2), 0.0719281)); + assert(approxEqual(kullbackLeiblerDivergence(p2, p1), 0.0780719)); +} + +/** +Computes the $(LINK2 https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence, +Jensen-Shannon divergence) between $(D a) and $(D +b), which is the sum $(D (ai * log(2 * ai / (ai + bi)) + bi * log(2 * +bi / (ai + bi))) / 2). The base of logarithm is 2. The ranges are +assumed to contain elements in $(D [0, 1]). Usually the ranges are +normalized probability distributions, but this is not required or +checked by $(D jensenShannonDivergence). If the inputs are normalized, +the result is bounded within $(D [0, 1]). The three-parameter version +stops evaluations as soon as the intermediate result is greater than +or equal to $(D limit). + */ +CommonType!(ElementType!Range1, ElementType!Range2) +jensenShannonDivergence(Range1, Range2)(Range1 a, Range2 b) +if (isInputRange!Range1 && isInputRange!Range2 && + is(CommonType!(ElementType!Range1, ElementType!Range2))) +{ + enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); + static if (haveLen) assert(a.length == b.length); + Unqual!(typeof(return)) result = 0; + for (; !a.empty; a.popFront(), b.popFront()) + { + immutable t1 = a.front; + immutable t2 = b.front; + immutable avg = (t1 + t2) / 2; + if (t1 != 0) + { + result += t1 * log2(t1 / avg); + } + if (t2 != 0) + { + result += t2 * log2(t2 / avg); + } + } + static if (!haveLen) assert(b.empty); + return result / 2; +} + +/// Ditto +CommonType!(ElementType!Range1, ElementType!Range2) +jensenShannonDivergence(Range1, Range2, F)(Range1 a, Range2 b, F limit) +if (isInputRange!Range1 && isInputRange!Range2 && + is(typeof(CommonType!(ElementType!Range1, ElementType!Range2).init + >= F.init) : bool)) +{ + enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); + static if (haveLen) assert(a.length == b.length); + Unqual!(typeof(return)) result = 0; + limit *= 2; + for (; !a.empty; a.popFront(), b.popFront()) + { + immutable t1 = a.front; + immutable t2 = b.front; + immutable avg = (t1 + t2) / 2; + if (t1 != 0) + { + result += t1 * log2(t1 / avg); + } + if (t2 != 0) + { + result += t2 * log2(t2 / avg); + } + if (result >= limit) break; + } + static if (!haveLen) assert(b.empty); + return result / 2; +} + +/// +@safe unittest +{ + import std.math : approxEqual; + + double[] p = [ 0.0, 0, 0, 1 ]; + assert(jensenShannonDivergence(p, p) == 0); + double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; + assert(jensenShannonDivergence(p1, p1) == 0); + assert(approxEqual(jensenShannonDivergence(p1, p), 0.548795)); + double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; + assert(approxEqual(jensenShannonDivergence(p1, p2), 0.0186218)); + assert(approxEqual(jensenShannonDivergence(p2, p1), 0.0186218)); + assert(approxEqual(jensenShannonDivergence(p2, p1, 0.005), 0.00602366)); +} + +/** +The so-called "all-lengths gap-weighted string kernel" computes a +similarity measure between $(D s) and $(D t) based on all of their +common subsequences of all lengths. Gapped subsequences are also +included. + +To understand what $(D gapWeightedSimilarity(s, t, lambda)) computes, +consider first the case $(D lambda = 1) and the strings $(D s = +["Hello", "brave", "new", "world"]) and $(D t = ["Hello", "new", +"world"]). In that case, $(D gapWeightedSimilarity) counts the +following matches: + +$(OL $(LI three matches of length 1, namely $(D "Hello"), $(D "new"), +and $(D "world");) $(LI three matches of length 2, namely ($(D +"Hello", "new")), ($(D "Hello", "world")), and ($(D "new", "world"));) +$(LI one match of length 3, namely ($(D "Hello", "new", "world")).)) + +The call $(D gapWeightedSimilarity(s, t, 1)) simply counts all of +these matches and adds them up, returning 7. + +---- +string[] s = ["Hello", "brave", "new", "world"]; +string[] t = ["Hello", "new", "world"]; +assert(gapWeightedSimilarity(s, t, 1) == 7); +---- + +Note how the gaps in matching are simply ignored, for example ($(D +"Hello", "new")) is deemed as good a match as ($(D "new", +"world")). This may be too permissive for some applications. To +eliminate gapped matches entirely, use $(D lambda = 0): + +---- +string[] s = ["Hello", "brave", "new", "world"]; +string[] t = ["Hello", "new", "world"]; +assert(gapWeightedSimilarity(s, t, 0) == 4); +---- + +The call above eliminated the gapped matches ($(D "Hello", "new")), +($(D "Hello", "world")), and ($(D "Hello", "new", "world")) from the +tally. That leaves only 4 matches. + +The most interesting case is when gapped matches still participate in +the result, but not as strongly as ungapped matches. The result will +be a smooth, fine-grained similarity measure between the input +strings. This is where values of $(D lambda) between 0 and 1 enter +into play: gapped matches are $(I exponentially penalized with the +number of gaps) with base $(D lambda). This means that an ungapped +match adds 1 to the return value; a match with one gap in either +string adds $(D lambda) to the return value; ...; a match with a total +of $(D n) gaps in both strings adds $(D pow(lambda, n)) to the return +value. In the example above, we have 4 matches without gaps, 2 matches +with one gap, and 1 match with three gaps. The latter match is ($(D +"Hello", "world")), which has two gaps in the first string and one gap +in the second string, totaling to three gaps. Summing these up we get +$(D 4 + 2 * lambda + pow(lambda, 3)). + +---- +string[] s = ["Hello", "brave", "new", "world"]; +string[] t = ["Hello", "new", "world"]; +assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 0.5 * 2 + 0.125); +---- + +$(D gapWeightedSimilarity) is useful wherever a smooth similarity +measure between sequences allowing for approximate matches is +needed. The examples above are given with words, but any sequences +with elements comparable for equality are allowed, e.g. characters or +numbers. $(D gapWeightedSimilarity) uses a highly optimized dynamic +programming implementation that needs $(D 16 * min(s.length, +t.length)) extra bytes of memory and $(BIGOH s.length * t.length) time +to complete. + */ +F gapWeightedSimilarity(alias comp = "a == b", R1, R2, F)(R1 s, R2 t, F lambda) +if (isRandomAccessRange!(R1) && hasLength!(R1) && + isRandomAccessRange!(R2) && hasLength!(R2)) +{ + import core.exception : onOutOfMemoryError; + import core.stdc.stdlib : malloc, free; + import std.algorithm.mutation : swap; + import std.functional : binaryFun; + + if (s.length < t.length) return gapWeightedSimilarity(t, s, lambda); + if (!t.length) return 0; + + auto dpvi = cast(F*) malloc(F.sizeof * 2 * t.length); + if (!dpvi) + onOutOfMemoryError(); + + auto dpvi1 = dpvi + t.length; + scope(exit) free(dpvi < dpvi1 ? dpvi : dpvi1); + dpvi[0 .. t.length] = 0; + dpvi1[0] = 0; + immutable lambda2 = lambda * lambda; + + F result = 0; + foreach (i; 0 .. s.length) + { + const si = s[i]; + for (size_t j = 0;;) + { + F dpsij = void; + if (binaryFun!(comp)(si, t[j])) + { + dpsij = 1 + dpvi[j]; + result += dpsij; + } + else + { + dpsij = 0; + } + immutable j1 = j + 1; + if (j1 == t.length) break; + dpvi1[j1] = dpsij + lambda * (dpvi1[j] + dpvi[j1]) - + lambda2 * dpvi[j]; + j = j1; + } + swap(dpvi, dpvi1); + } + return result; +} + +@system unittest +{ + string[] s = ["Hello", "brave", "new", "world"]; + string[] t = ["Hello", "new", "world"]; + assert(gapWeightedSimilarity(s, t, 1) == 7); + assert(gapWeightedSimilarity(s, t, 0) == 4); + assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 2 * 0.5 + 0.125); +} + +/** +The similarity per $(D gapWeightedSimilarity) has an issue in that it +grows with the lengths of the two strings, even though the strings are +not actually very similar. For example, the range $(D ["Hello", +"world"]) is increasingly similar with the range $(D ["Hello", +"world", "world", "world",...]) as more instances of $(D "world") are +appended. To prevent that, $(D gapWeightedSimilarityNormalized) +computes a normalized version of the similarity that is computed as +$(D gapWeightedSimilarity(s, t, lambda) / +sqrt(gapWeightedSimilarity(s, t, lambda) * gapWeightedSimilarity(s, t, +lambda))). The function $(D gapWeightedSimilarityNormalized) (a +so-called normalized kernel) is bounded in $(D [0, 1]), reaches $(D 0) +only for ranges that don't match in any position, and $(D 1) only for +identical ranges. + +The optional parameters $(D sSelfSim) and $(D tSelfSim) are meant for +avoiding duplicate computation. Many applications may have already +computed $(D gapWeightedSimilarity(s, s, lambda)) and/or $(D +gapWeightedSimilarity(t, t, lambda)). In that case, they can be passed +as $(D sSelfSim) and $(D tSelfSim), respectively. + */ +Select!(isFloatingPoint!(F), F, double) +gapWeightedSimilarityNormalized(alias comp = "a == b", R1, R2, F) + (R1 s, R2 t, F lambda, F sSelfSim = F.init, F tSelfSim = F.init) +if (isRandomAccessRange!(R1) && hasLength!(R1) && + isRandomAccessRange!(R2) && hasLength!(R2)) +{ + static bool uncomputed(F n) + { + static if (isFloatingPoint!(F)) + return isNaN(n); + else + return n == n.init; + } + if (uncomputed(sSelfSim)) + sSelfSim = gapWeightedSimilarity!(comp)(s, s, lambda); + if (sSelfSim == 0) return 0; + if (uncomputed(tSelfSim)) + tSelfSim = gapWeightedSimilarity!(comp)(t, t, lambda); + if (tSelfSim == 0) return 0; + + return gapWeightedSimilarity!(comp)(s, t, lambda) / + sqrt(cast(typeof(return)) sSelfSim * tSelfSim); +} + +/// +@system unittest +{ + import std.math : approxEqual, sqrt; + + string[] s = ["Hello", "brave", "new", "world"]; + string[] t = ["Hello", "new", "world"]; + assert(gapWeightedSimilarity(s, s, 1) == 15); + assert(gapWeightedSimilarity(t, t, 1) == 7); + assert(gapWeightedSimilarity(s, t, 1) == 7); + assert(approxEqual(gapWeightedSimilarityNormalized(s, t, 1), + 7.0 / sqrt(15.0 * 7), 0.01)); +} + +/** +Similar to $(D gapWeightedSimilarity), just works in an incremental +manner by first revealing the matches of length 1, then gapped matches +of length 2, and so on. The memory requirement is $(BIGOH s.length * +t.length). The time complexity is $(BIGOH s.length * t.length) time +for computing each step. Continuing on the previous example: + +The implementation is based on the pseudocode in Fig. 4 of the paper +$(HTTP jmlr.csail.mit.edu/papers/volume6/rousu05a/rousu05a.pdf, +"Efficient Computation of Gapped Substring Kernels on Large Alphabets") +by Rousu et al., with additional algorithmic and systems-level +optimizations. + */ +struct GapWeightedSimilarityIncremental(Range, F = double) +if (isRandomAccessRange!(Range) && hasLength!(Range)) +{ + import core.stdc.stdlib : malloc, realloc, alloca, free; + +private: + Range s, t; + F currentValue = 0; + F* kl; + size_t gram = void; + F lambda = void, lambda2 = void; + +public: +/** +Constructs an object given two ranges $(D s) and $(D t) and a penalty +$(D lambda). Constructor completes in $(BIGOH s.length * t.length) +time and computes all matches of length 1. + */ + this(Range s, Range t, F lambda) + { + import core.exception : onOutOfMemoryError; + + assert(lambda > 0); + this.gram = 0; + this.lambda = lambda; + this.lambda2 = lambda * lambda; // for efficiency only + + size_t iMin = size_t.max, jMin = size_t.max, + iMax = 0, jMax = 0; + /* initialize */ + Tuple!(size_t, size_t) * k0; + size_t k0len; + scope(exit) free(k0); + currentValue = 0; + foreach (i, si; s) + { + foreach (j; 0 .. t.length) + { + if (si != t[j]) continue; + k0 = cast(typeof(k0)) realloc(k0, ++k0len * (*k0).sizeof); + with (k0[k0len - 1]) + { + field[0] = i; + field[1] = j; + } + // Maintain the minimum and maximum i and j + if (iMin > i) iMin = i; + if (iMax < i) iMax = i; + if (jMin > j) jMin = j; + if (jMax < j) jMax = j; + } + } + + if (iMin > iMax) return; + assert(k0len); + + currentValue = k0len; + // Chop strings down to the useful sizes + s = s[iMin .. iMax + 1]; + t = t[jMin .. jMax + 1]; + this.s = s; + this.t = t; + + kl = cast(F*) malloc(s.length * t.length * F.sizeof); + if (!kl) + onOutOfMemoryError(); + + kl[0 .. s.length * t.length] = 0; + foreach (pos; 0 .. k0len) + { + with (k0[pos]) + { + kl[(field[0] - iMin) * t.length + field[1] -jMin] = lambda2; + } + } + } + + /** + Returns: $(D this). + */ + ref GapWeightedSimilarityIncremental opSlice() + { + return this; + } + + /** + Computes the match of the popFront length. Completes in $(BIGOH s.length * + t.length) time. + */ + void popFront() + { + import std.algorithm.mutation : swap; + + // This is a large source of optimization: if similarity at + // the gram-1 level was 0, then we can safely assume + // similarity at the gram level is 0 as well. + if (empty) return; + + // Now attempt to match gapped substrings of length `gram' + ++gram; + currentValue = 0; + + auto Si = cast(F*) alloca(t.length * F.sizeof); + Si[0 .. t.length] = 0; + foreach (i; 0 .. s.length) + { + const si = s[i]; + F Sij_1 = 0; + F Si_1j_1 = 0; + auto kli = kl + i * t.length; + for (size_t j = 0;;) + { + const klij = kli[j]; + const Si_1j = Si[j]; + const tmp = klij + lambda * (Si_1j + Sij_1) - lambda2 * Si_1j_1; + // now update kl and currentValue + if (si == t[j]) + currentValue += kli[j] = lambda2 * Si_1j_1; + else + kli[j] = 0; + // commit to Si + Si[j] = tmp; + if (++j == t.length) break; + // get ready for the popFront step; virtually increment j, + // so essentially stuffj_1 <-- stuffj + Si_1j_1 = Si_1j; + Sij_1 = tmp; + } + } + currentValue /= pow(lambda, 2 * (gram + 1)); + + version (none) + { + Si_1[0 .. t.length] = 0; + kl[0 .. min(t.length, maxPerimeter + 1)] = 0; + foreach (i; 1 .. min(s.length, maxPerimeter + 1)) + { + auto kli = kl + i * t.length; + assert(s.length > i); + const si = s[i]; + auto kl_1i_1 = kl_1 + (i - 1) * t.length; + kli[0] = 0; + F lastS = 0; + foreach (j; 1 .. min(maxPerimeter - i + 1, t.length)) + { + immutable j_1 = j - 1; + immutable tmp = kl_1i_1[j_1] + + lambda * (Si_1[j] + lastS) + - lambda2 * Si_1[j_1]; + kl_1i_1[j_1] = float.nan; + Si_1[j_1] = lastS; + lastS = tmp; + if (si == t[j]) + { + currentValue += kli[j] = lambda2 * lastS; + } + else + { + kli[j] = 0; + } + } + Si_1[t.length - 1] = lastS; + } + currentValue /= pow(lambda, 2 * (gram + 1)); + // get ready for the popFront computation + swap(kl, kl_1); + } + } + + /** + Returns: The gapped similarity at the current match length (initially + 1, grows with each call to $(D popFront)). + */ + @property F front() { return currentValue; } + + /** + Returns: Whether there are more matches. + */ + @property bool empty() + { + if (currentValue) return false; + if (kl) + { + free(kl); + kl = null; + } + return true; + } +} + +/** +Ditto + */ +GapWeightedSimilarityIncremental!(R, F) gapWeightedSimilarityIncremental(R, F) +(R r1, R r2, F penalty) +{ + return typeof(return)(r1, r2, penalty); +} + +/// +@system unittest +{ + string[] s = ["Hello", "brave", "new", "world"]; + string[] t = ["Hello", "new", "world"]; + auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); + assert(simIter.front == 3); // three 1-length matches + simIter.popFront(); + assert(simIter.front == 3); // three 2-length matches + simIter.popFront(); + assert(simIter.front == 1); // one 3-length match + simIter.popFront(); + assert(simIter.empty); // no more match +} + +@system unittest +{ + import std.conv : text; + string[] s = ["Hello", "brave", "new", "world"]; + string[] t = ["Hello", "new", "world"]; + auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); + //foreach (e; simIter) writeln(e); + assert(simIter.front == 3); // three 1-length matches + simIter.popFront(); + assert(simIter.front == 3, text(simIter.front)); // three 2-length matches + simIter.popFront(); + assert(simIter.front == 1); // one 3-length matches + simIter.popFront(); + assert(simIter.empty); // no more match + + s = ["Hello"]; + t = ["bye"]; + simIter = gapWeightedSimilarityIncremental(s, t, 0.5); + assert(simIter.empty); + + s = ["Hello"]; + t = ["Hello"]; + simIter = gapWeightedSimilarityIncremental(s, t, 0.5); + assert(simIter.front == 1); // one match + simIter.popFront(); + assert(simIter.empty); + + s = ["Hello", "world"]; + t = ["Hello"]; + simIter = gapWeightedSimilarityIncremental(s, t, 0.5); + assert(simIter.front == 1); // one match + simIter.popFront(); + assert(simIter.empty); + + s = ["Hello", "world"]; + t = ["Hello", "yah", "world"]; + simIter = gapWeightedSimilarityIncremental(s, t, 0.5); + assert(simIter.front == 2); // two 1-gram matches + simIter.popFront(); + assert(simIter.front == 0.5, text(simIter.front)); // one 2-gram match, 1 gap +} + +@system unittest +{ + GapWeightedSimilarityIncremental!(string[]) sim = + GapWeightedSimilarityIncremental!(string[])( + ["nyuk", "I", "have", "no", "chocolate", "giba"], + ["wyda", "I", "have", "I", "have", "have", "I", "have", "hehe"], + 0.5); + double[] witness = [ 7.0, 4.03125, 0, 0 ]; + foreach (e; sim) + { + //writeln(e); + assert(e == witness.front); + witness.popFront(); + } + witness = [ 3.0, 1.3125, 0.25 ]; + sim = GapWeightedSimilarityIncremental!(string[])( + ["I", "have", "no", "chocolate"], + ["I", "have", "some", "chocolate"], + 0.5); + foreach (e; sim) + { + //writeln(e); + assert(e == witness.front); + witness.popFront(); + } + assert(witness.empty); +} + +/** +Computes the greatest common divisor of $(D a) and $(D b) by using +an efficient algorithm such as $(HTTPS en.wikipedia.org/wiki/Euclidean_algorithm, Euclid's) +or $(HTTPS en.wikipedia.org/wiki/Binary_GCD_algorithm, Stein's) algorithm. + +Params: + T = Any numerical type that supports the modulo operator `%`. If + bit-shifting `<<` and `>>` are also supported, Stein's algorithm will + be used; otherwise, Euclid's algorithm is used as _a fallback. +Returns: + The greatest common divisor of the given arguments. + */ +T gcd(T)(T a, T b) + if (isIntegral!T) +{ + static if (is(T == const) || is(T == immutable)) + { + return gcd!(Unqual!T)(a, b); + } + else version (DigitalMars) + { + static if (T.min < 0) + { + assert(a >= 0 && b >= 0); + } + while (b) + { + immutable t = b; + b = a % b; + a = t; + } + return a; + } + else + { + if (a == 0) + return b; + if (b == 0) + return a; + + import core.bitop : bsf; + import std.algorithm.mutation : swap; + + immutable uint shift = bsf(a | b); + a >>= a.bsf; + + do + { + b >>= b.bsf; + if (a > b) + swap(a, b); + b -= a; + } while (b); + + return a << shift; + } +} + +/// +@safe unittest +{ + assert(gcd(2 * 5 * 7 * 7, 5 * 7 * 11) == 5 * 7); + const int a = 5 * 13 * 23 * 23, b = 13 * 59; + assert(gcd(a, b) == 13); +} + +// This overload is for non-builtin numerical types like BigInt or +// user-defined types. +/// ditto +T gcd(T)(T a, T b) + if (!isIntegral!T && + is(typeof(T.init % T.init)) && + is(typeof(T.init == 0 || T.init > 0))) +{ + import std.algorithm.mutation : swap; + + enum canUseBinaryGcd = is(typeof(() { + T t, u; + t <<= 1; + t >>= 1; + t -= u; + bool b = (t & 1) == 0; + swap(t, u); + })); + + assert(a >= 0 && b >= 0); + + static if (canUseBinaryGcd) + { + uint shift = 0; + while ((a & 1) == 0 && (b & 1) == 0) + { + a >>= 1; + b >>= 1; + shift++; + } + + do + { + assert((a & 1) != 0); + while ((b & 1) == 0) + b >>= 1; + if (a > b) + swap(a, b); + b -= a; + } while (b); + + return a << shift; + } + else + { + // The only thing we have is %; fallback to Euclidean algorithm. + while (b != 0) + { + auto t = b; + b = a % b; + a = t; + } + return a; + } +} + +// Issue 7102 +@system pure unittest +{ + import std.bigint : BigInt; + assert(gcd(BigInt("71_000_000_000_000_000_000"), + BigInt("31_000_000_000_000_000_000")) == + BigInt("1_000_000_000_000_000_000")); +} + +@safe pure nothrow unittest +{ + // A numerical type that only supports % and - (to force gcd implementation + // to use Euclidean algorithm). + struct CrippledInt + { + int impl; + CrippledInt opBinary(string op : "%")(CrippledInt i) + { + return CrippledInt(impl % i.impl); + } + int opEquals(CrippledInt i) { return impl == i.impl; } + int opEquals(int i) { return impl == i; } + int opCmp(int i) { return (impl < i) ? -1 : (impl > i) ? 1 : 0; } + } + assert(gcd(CrippledInt(2310), CrippledInt(1309)) == CrippledInt(77)); +} + +// This is to make tweaking the speed/size vs. accuracy tradeoff easy, +// though floats seem accurate enough for all practical purposes, since +// they pass the "approxEqual(inverseFft(fft(arr)), arr)" test even for +// size 2 ^^ 22. +private alias lookup_t = float; + +/**A class for performing fast Fourier transforms of power of two sizes. + * This class encapsulates a large amount of state that is reusable when + * performing multiple FFTs of sizes smaller than or equal to that specified + * in the constructor. This results in substantial speedups when performing + * multiple FFTs with a known maximum size. However, + * a free function API is provided for convenience if you need to perform a + * one-off FFT. + * + * References: + * $(HTTP en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm) + */ +final class Fft +{ + import core.bitop : bsf; + import std.algorithm.iteration : map; + import std.array : uninitializedArray; + +private: + immutable lookup_t[][] negSinLookup; + + void enforceSize(R)(R range) const + { + import std.conv : text; + assert(range.length <= size, text( + "FFT size mismatch. Expected ", size, ", got ", range.length)); + } + + void fftImpl(Ret, R)(Stride!R range, Ret buf) const + in + { + assert(range.length >= 4); + assert(isPowerOf2(range.length)); + } + body + { + auto recurseRange = range; + recurseRange.doubleSteps(); + + if (buf.length > 4) + { + fftImpl(recurseRange, buf[0..$ / 2]); + recurseRange.popHalf(); + fftImpl(recurseRange, buf[$ / 2..$]); + } + else + { + // Do this here instead of in another recursion to save on + // recursion overhead. + slowFourier2(recurseRange, buf[0..$ / 2]); + recurseRange.popHalf(); + slowFourier2(recurseRange, buf[$ / 2..$]); + } + + butterfly(buf); + } + + // This algorithm works by performing the even and odd parts of our FFT + // using the "two for the price of one" method mentioned at + // http://www.engineeringproductivitytools.com/stuff/T0001/PT10.HTM#Head521 + // by making the odd terms into the imaginary components of our new FFT, + // and then using symmetry to recombine them. + void fftImplPureReal(Ret, R)(R range, Ret buf) const + in + { + assert(range.length >= 4); + assert(isPowerOf2(range.length)); + } + body + { + alias E = ElementType!R; + + // Converts odd indices of range to the imaginary components of + // a range half the size. The even indices become the real components. + static if (isArray!R && isFloatingPoint!E) + { + // Then the memory layout of complex numbers provides a dirt + // cheap way to convert. This is a common case, so take advantage. + auto oddsImag = cast(Complex!E[]) range; + } + else + { + // General case: Use a higher order range. We can assume + // source.length is even because it has to be a power of 2. + static struct OddToImaginary + { + R source; + alias C = Complex!(CommonType!(E, typeof(buf[0].re))); + + @property + { + C front() + { + return C(source[0], source[1]); + } + + C back() + { + immutable n = source.length; + return C(source[n - 2], source[n - 1]); + } + + typeof(this) save() + { + return typeof(this)(source.save); + } + + bool empty() + { + return source.empty; + } + + size_t length() + { + return source.length / 2; + } + } + + void popFront() + { + source.popFront(); + source.popFront(); + } + + void popBack() + { + source.popBack(); + source.popBack(); + } + + C opIndex(size_t index) + { + return C(source[index * 2], source[index * 2 + 1]); + } + + typeof(this) opSlice(size_t lower, size_t upper) + { + return typeof(this)(source[lower * 2 .. upper * 2]); + } + } + + auto oddsImag = OddToImaginary(range); + } + + fft(oddsImag, buf[0..$ / 2]); + auto evenFft = buf[0..$ / 2]; + auto oddFft = buf[$ / 2..$]; + immutable halfN = evenFft.length; + oddFft[0].re = buf[0].im; + oddFft[0].im = 0; + evenFft[0].im = 0; + // evenFft[0].re is already right b/c it's aliased with buf[0].re. + + foreach (k; 1 .. halfN / 2 + 1) + { + immutable bufk = buf[k]; + immutable bufnk = buf[buf.length / 2 - k]; + evenFft[k].re = 0.5 * (bufk.re + bufnk.re); + evenFft[halfN - k].re = evenFft[k].re; + evenFft[k].im = 0.5 * (bufk.im - bufnk.im); + evenFft[halfN - k].im = -evenFft[k].im; + + oddFft[k].re = 0.5 * (bufk.im + bufnk.im); + oddFft[halfN - k].re = oddFft[k].re; + oddFft[k].im = 0.5 * (bufnk.re - bufk.re); + oddFft[halfN - k].im = -oddFft[k].im; + } + + butterfly(buf); + } + + void butterfly(R)(R buf) const + in + { + assert(isPowerOf2(buf.length)); + } + body + { + immutable n = buf.length; + immutable localLookup = negSinLookup[bsf(n)]; + assert(localLookup.length == n); + + immutable cosMask = n - 1; + immutable cosAdd = n / 4 * 3; + + lookup_t negSinFromLookup(size_t index) pure nothrow + { + return localLookup[index]; + } + + lookup_t cosFromLookup(size_t index) pure nothrow + { + // cos is just -sin shifted by PI * 3 / 2. + return localLookup[(index + cosAdd) & cosMask]; + } + + immutable halfLen = n / 2; + + // This loop is unrolled and the two iterations are interleaved + // relative to the textbook FFT to increase ILP. This gives roughly 5% + // speedups on DMD. + for (size_t k = 0; k < halfLen; k += 2) + { + immutable cosTwiddle1 = cosFromLookup(k); + immutable sinTwiddle1 = negSinFromLookup(k); + immutable cosTwiddle2 = cosFromLookup(k + 1); + immutable sinTwiddle2 = negSinFromLookup(k + 1); + + immutable realLower1 = buf[k].re; + immutable imagLower1 = buf[k].im; + immutable realLower2 = buf[k + 1].re; + immutable imagLower2 = buf[k + 1].im; + + immutable upperIndex1 = k + halfLen; + immutable upperIndex2 = upperIndex1 + 1; + immutable realUpper1 = buf[upperIndex1].re; + immutable imagUpper1 = buf[upperIndex1].im; + immutable realUpper2 = buf[upperIndex2].re; + immutable imagUpper2 = buf[upperIndex2].im; + + immutable realAdd1 = cosTwiddle1 * realUpper1 + - sinTwiddle1 * imagUpper1; + immutable imagAdd1 = sinTwiddle1 * realUpper1 + + cosTwiddle1 * imagUpper1; + immutable realAdd2 = cosTwiddle2 * realUpper2 + - sinTwiddle2 * imagUpper2; + immutable imagAdd2 = sinTwiddle2 * realUpper2 + + cosTwiddle2 * imagUpper2; + + buf[k].re += realAdd1; + buf[k].im += imagAdd1; + buf[k + 1].re += realAdd2; + buf[k + 1].im += imagAdd2; + + buf[upperIndex1].re = realLower1 - realAdd1; + buf[upperIndex1].im = imagLower1 - imagAdd1; + buf[upperIndex2].re = realLower2 - realAdd2; + buf[upperIndex2].im = imagLower2 - imagAdd2; + } + } + + // This constructor is used within this module for allocating the + // buffer space elsewhere besides the GC heap. It's definitely **NOT** + // part of the public API and definitely **IS** subject to change. + // + // Also, this is unsafe because the memSpace buffer will be cast + // to immutable. + public this(lookup_t[] memSpace) // Public b/c of bug 4636. + { + immutable size = memSpace.length / 2; + + /* Create a lookup table of all negative sine values at a resolution of + * size and all smaller power of two resolutions. This may seem + * inefficient, but having all the lookups be next to each other in + * memory at every level of iteration is a huge win performance-wise. + */ + if (size == 0) + { + return; + } + + assert(isPowerOf2(size), + "Can only do FFTs on ranges with a size that is a power of two."); + + auto table = new lookup_t[][bsf(size) + 1]; + + table[$ - 1] = memSpace[$ - size..$]; + memSpace = memSpace[0 .. size]; + + auto lastRow = table[$ - 1]; + lastRow[0] = 0; // -sin(0) == 0. + foreach (ptrdiff_t i; 1 .. size) + { + // The hard coded cases are for improved accuracy and to prevent + // annoying non-zeroness when stuff should be zero. + + if (i == size / 4) + lastRow[i] = -1; // -sin(pi / 2) == -1. + else if (i == size / 2) + lastRow[i] = 0; // -sin(pi) == 0. + else if (i == size * 3 / 4) + lastRow[i] = 1; // -sin(pi * 3 / 2) == 1 + else + lastRow[i] = -sin(i * 2.0L * PI / size); + } + + // Fill in all the other rows with strided versions. + foreach (i; 1 .. table.length - 1) + { + immutable strideLength = size / (2 ^^ i); + auto strided = Stride!(lookup_t[])(lastRow, strideLength); + table[i] = memSpace[$ - strided.length..$]; + memSpace = memSpace[0..$ - strided.length]; + + size_t copyIndex; + foreach (elem; strided) + { + table[i][copyIndex++] = elem; + } + } + + negSinLookup = cast(immutable) table; + } + +public: + /**Create an $(D Fft) object for computing fast Fourier transforms of + * power of two sizes of $(D size) or smaller. $(D size) must be a + * power of two. + */ + this(size_t size) + { + // Allocate all twiddle factor buffers in one contiguous block so that, + // when one is done being used, the next one is next in cache. + auto memSpace = uninitializedArray!(lookup_t[])(2 * size); + this(memSpace); + } + + @property size_t size() const + { + return (negSinLookup is null) ? 0 : negSinLookup[$ - 1].length; + } + + /**Compute the Fourier transform of range using the $(BIGOH N log N) + * Cooley-Tukey Algorithm. $(D range) must be a random-access range with + * slicing and a length equal to $(D size) as provided at the construction of + * this object. The contents of range can be either numeric types, + * which will be interpreted as pure real values, or complex types with + * properties or members $(D .re) and $(D .im) that can be read. + * + * Note: Pure real FFTs are automatically detected and the relevant + * optimizations are performed. + * + * Returns: An array of complex numbers representing the transformed data in + * the frequency domain. + * + * Conventions: The exponent is negative and the factor is one, + * i.e., output[j] := sum[ exp(-2 PI i j k / N) input[k] ]. + */ + Complex!F[] fft(F = double, R)(R range) const + if (isFloatingPoint!F && isRandomAccessRange!R) + { + enforceSize(range); + Complex!F[] ret; + if (range.length == 0) + { + return ret; + } + + // Don't waste time initializing the memory for ret. + ret = uninitializedArray!(Complex!F[])(range.length); + + fft(range, ret); + return ret; + } + + /**Same as the overload, but allows for the results to be stored in a user- + * provided buffer. The buffer must be of the same length as range, must be + * a random-access range, must have slicing, and must contain elements that are + * complex-like. This means that they must have a .re and a .im member or + * property that can be both read and written and are floating point numbers. + */ + void fft(Ret, R)(R range, Ret buf) const + if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) + { + assert(buf.length == range.length); + enforceSize(range); + + if (range.length == 0) + { + return; + } + else if (range.length == 1) + { + buf[0] = range[0]; + return; + } + else if (range.length == 2) + { + slowFourier2(range, buf); + return; + } + else + { + alias E = ElementType!R; + static if (is(E : real)) + { + return fftImplPureReal(range, buf); + } + else + { + static if (is(R : Stride!R)) + return fftImpl(range, buf); + else + return fftImpl(Stride!R(range, 1), buf); + } + } + } + + /** + * Computes the inverse Fourier transform of a range. The range must be a + * random access range with slicing, have a length equal to the size + * provided at construction of this object, and contain elements that are + * either of type std.complex.Complex or have essentially + * the same compile-time interface. + * + * Returns: The time-domain signal. + * + * Conventions: The exponent is positive and the factor is 1/N, i.e., + * output[j] := (1 / N) sum[ exp(+2 PI i j k / N) input[k] ]. + */ + Complex!F[] inverseFft(F = double, R)(R range) const + if (isRandomAccessRange!R && isComplexLike!(ElementType!R) && isFloatingPoint!F) + { + enforceSize(range); + Complex!F[] ret; + if (range.length == 0) + { + return ret; + } + + // Don't waste time initializing the memory for ret. + ret = uninitializedArray!(Complex!F[])(range.length); + + inverseFft(range, ret); + return ret; + } + + /** + * Inverse FFT that allows a user-supplied buffer to be provided. The buffer + * must be a random access range with slicing, and its elements + * must be some complex-like type. + */ + void inverseFft(Ret, R)(R range, Ret buf) const + if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) + { + enforceSize(range); + + auto swapped = map!swapRealImag(range); + fft(swapped, buf); + + immutable lenNeg1 = 1.0 / buf.length; + foreach (ref elem; buf) + { + immutable temp = elem.re * lenNeg1; + elem.re = elem.im * lenNeg1; + elem.im = temp; + } + } +} + +// This mixin creates an Fft object in the scope it's mixed into such that all +// memory owned by the object is deterministically destroyed at the end of that +// scope. +private enum string MakeLocalFft = q{ + import core.stdc.stdlib; + import core.exception : onOutOfMemoryError; + + auto lookupBuf = (cast(lookup_t*) malloc(range.length * 2 * lookup_t.sizeof)) + [0 .. 2 * range.length]; + if (!lookupBuf.ptr) + onOutOfMemoryError(); + + scope(exit) free(cast(void*) lookupBuf.ptr); + auto fftObj = scoped!Fft(lookupBuf); +}; + +/**Convenience functions that create an $(D Fft) object, run the FFT or inverse + * FFT and return the result. Useful for one-off FFTs. + * + * Note: In addition to convenience, these functions are slightly more + * efficient than manually creating an Fft object for a single use, + * as the Fft object is deterministically destroyed before these + * functions return. + */ +Complex!F[] fft(F = double, R)(R range) +{ + mixin(MakeLocalFft); + return fftObj.fft!(F, R)(range); +} + +/// ditto +void fft(Ret, R)(R range, Ret buf) +{ + mixin(MakeLocalFft); + return fftObj.fft!(Ret, R)(range, buf); +} + +/// ditto +Complex!F[] inverseFft(F = double, R)(R range) +{ + mixin(MakeLocalFft); + return fftObj.inverseFft!(F, R)(range); +} + +/// ditto +void inverseFft(Ret, R)(R range, Ret buf) +{ + mixin(MakeLocalFft); + return fftObj.inverseFft!(Ret, R)(range, buf); +} + +@system unittest +{ + import std.algorithm; + import std.conv; + import std.range; + // Test values from R and Octave. + auto arr = [1,2,3,4,5,6,7,8]; + auto fft1 = fft(arr); + assert(approxEqual(map!"a.re"(fft1), + [36.0, -4, -4, -4, -4, -4, -4, -4])); + assert(approxEqual(map!"a.im"(fft1), + [0, 9.6568, 4, 1.6568, 0, -1.6568, -4, -9.6568])); + + auto fft1Retro = fft(retro(arr)); + assert(approxEqual(map!"a.re"(fft1Retro), + [36.0, 4, 4, 4, 4, 4, 4, 4])); + assert(approxEqual(map!"a.im"(fft1Retro), + [0, -9.6568, -4, -1.6568, 0, 1.6568, 4, 9.6568])); + + auto fft1Float = fft(to!(float[])(arr)); + assert(approxEqual(map!"a.re"(fft1), map!"a.re"(fft1Float))); + assert(approxEqual(map!"a.im"(fft1), map!"a.im"(fft1Float))); + + alias C = Complex!float; + auto arr2 = [C(1,2), C(3,4), C(5,6), C(7,8), C(9,10), + C(11,12), C(13,14), C(15,16)]; + auto fft2 = fft(arr2); + assert(approxEqual(map!"a.re"(fft2), + [64.0, -27.3137, -16, -11.3137, -8, -4.6862, 0, 11.3137])); + assert(approxEqual(map!"a.im"(fft2), + [72, 11.3137, 0, -4.686, -8, -11.3137, -16, -27.3137])); + + auto inv1 = inverseFft(fft1); + assert(approxEqual(map!"a.re"(inv1), arr)); + assert(reduce!max(map!"a.im"(inv1)) < 1e-10); + + auto inv2 = inverseFft(fft2); + assert(approxEqual(map!"a.re"(inv2), map!"a.re"(arr2))); + assert(approxEqual(map!"a.im"(inv2), map!"a.im"(arr2))); + + // FFTs of size 0, 1 and 2 are handled as special cases. Test them here. + ushort[] empty; + assert(fft(empty) == null); + assert(inverseFft(fft(empty)) == null); + + real[] oneElem = [4.5L]; + auto oneFft = fft(oneElem); + assert(oneFft.length == 1); + assert(oneFft[0].re == 4.5L); + assert(oneFft[0].im == 0); + + auto oneInv = inverseFft(oneFft); + assert(oneInv.length == 1); + assert(approxEqual(oneInv[0].re, 4.5)); + assert(approxEqual(oneInv[0].im, 0)); + + long[2] twoElems = [8, 4]; + auto twoFft = fft(twoElems[]); + assert(twoFft.length == 2); + assert(approxEqual(twoFft[0].re, 12)); + assert(approxEqual(twoFft[0].im, 0)); + assert(approxEqual(twoFft[1].re, 4)); + assert(approxEqual(twoFft[1].im, 0)); + auto twoInv = inverseFft(twoFft); + assert(approxEqual(twoInv[0].re, 8)); + assert(approxEqual(twoInv[0].im, 0)); + assert(approxEqual(twoInv[1].re, 4)); + assert(approxEqual(twoInv[1].im, 0)); +} + +// Swaps the real and imaginary parts of a complex number. This is useful +// for inverse FFTs. +C swapRealImag(C)(C input) +{ + return C(input.im, input.re); +} + +private: +// The reasons I couldn't use std.algorithm were b/c its stride length isn't +// modifiable on the fly and because range has grown some performance hacks +// for powers of 2. +struct Stride(R) +{ + import core.bitop : bsf; + Unqual!R range; + size_t _nSteps; + size_t _length; + alias E = ElementType!(R); + + this(R range, size_t nStepsIn) + { + this.range = range; + _nSteps = nStepsIn; + _length = (range.length + _nSteps - 1) / nSteps; + } + + size_t length() const @property + { + return _length; + } + + typeof(this) save() @property + { + auto ret = this; + ret.range = ret.range.save; + return ret; + } + + E opIndex(size_t index) + { + return range[index * _nSteps]; + } + + E front() @property + { + return range[0]; + } + + void popFront() + { + if (range.length >= _nSteps) + { + range = range[_nSteps .. range.length]; + _length--; + } + else + { + range = range[0 .. 0]; + _length = 0; + } + } + + // Pops half the range's stride. + void popHalf() + { + range = range[_nSteps / 2 .. range.length]; + } + + bool empty() const @property + { + return length == 0; + } + + size_t nSteps() const @property + { + return _nSteps; + } + + void doubleSteps() + { + _nSteps *= 2; + _length /= 2; + } + + size_t nSteps(size_t newVal) @property + { + _nSteps = newVal; + + // Using >> bsf(nSteps) is a few cycles faster than / nSteps. + _length = (range.length + _nSteps - 1) >> bsf(nSteps); + return newVal; + } +} + +// Hard-coded base case for FFT of size 2. This is actually a TON faster than +// using a generic slow DFT. This seems to be the best base case. (Size 1 +// can be coded inline as buf[0] = range[0]). +void slowFourier2(Ret, R)(R range, Ret buf) +{ + assert(range.length == 2); + assert(buf.length == 2); + buf[0] = range[0] + range[1]; + buf[1] = range[0] - range[1]; +} + +// Hard-coded base case for FFT of size 4. Doesn't work as well as the size +// 2 case. +void slowFourier4(Ret, R)(R range, Ret buf) +{ + alias C = ElementType!Ret; + + assert(range.length == 4); + assert(buf.length == 4); + buf[0] = range[0] + range[1] + range[2] + range[3]; + buf[1] = range[0] - range[1] * C(0, 1) - range[2] + range[3] * C(0, 1); + buf[2] = range[0] - range[1] + range[2] - range[3]; + buf[3] = range[0] + range[1] * C(0, 1) - range[2] - range[3] * C(0, 1); +} + +N roundDownToPowerOf2(N)(N num) +if (isScalarType!N && !isFloatingPoint!N) +{ + import core.bitop : bsr; + return num & (cast(N) 1 << bsr(num)); +} + +@safe unittest +{ + assert(roundDownToPowerOf2(7) == 4); + assert(roundDownToPowerOf2(4) == 4); +} + +template isComplexLike(T) +{ + enum bool isComplexLike = is(typeof(T.init.re)) && + is(typeof(T.init.im)); +} + +@safe unittest +{ + static assert(isComplexLike!(Complex!double)); + static assert(!isComplexLike!(uint)); +} |