Program Listing for File roots.hpp
↰ Return to documentation for file (include/ruckig/roots.hpp)
#pragma once
#include <array>
#include <algorithm>
#include <cfloat>
#include <cmath>
namespace ruckig {
template<typename T>
inline T pow2(T v) {
return v * v;
}
namespace roots {
// Use own set class on stack for real-time capability
template<typename T, size_t N>
class Set {
protected:
using Container = typename std::array<T, N>;
using iterator = typename Container::iterator;
Container data;
size_t size {0};
public:
// Sort when accessing the elements
const iterator begin() {
std::sort(data.begin(), data.begin() + size);
return data.begin();
}
const iterator end() {
return data.begin() + size;
}
void insert(T value) {
data[size] = value;
++size;
}
};
// Set that only inserts positive values
template<typename T, size_t N>
class PositiveSet: public Set<T, N> {
public:
void insert(T value) {
if (value >= 0) {
Set<T, N>::insert(value);
}
}
};
inline PositiveSet<double, 3> solve_cubic(double a, double b, double c, double d) {
PositiveSet<double, 3> roots;
if (std::abs(d) < DBL_EPSILON) {
// First solution is x = 0
roots.insert(0.0);
// Converting to a quadratic equation
d = c;
c = b;
b = a;
a = 0.0;
}
if (std::abs(a) < DBL_EPSILON) {
if (std::abs(b) < DBL_EPSILON) {
// Linear equation
if (std::abs(c) > DBL_EPSILON) {
roots.insert(-d / c);
}
} else {
// Quadratic equation
const double discriminant = c * c - 4 * b * d;
if (discriminant >= 0) {
const double inv2b = 1.0 / (2 * b);
const double y = std::sqrt(discriminant);
roots.insert((-c + y) * inv2b);
roots.insert((-c - y) * inv2b);
}
}
} else {
// Cubic equation
const double inva = 1.0 / a;
const double invaa = inva * inva;
const double bb = b * b;
const double bover3a = b * inva / 3;
const double p = (a * c - bb / 3) * invaa;
const double halfq = (2 * bb * b - 9 * a * b * c + 27 * a * a * d) / 54 * invaa * inva;
const double yy = p * p * p / 27 + halfq * halfq;
constexpr double cos120 = -0.50;
constexpr double sin120 = 0.866025403784438646764;
if (yy > DBL_EPSILON) {
// Sqrt is positive: one real solution
const double y = std::sqrt(yy);
const double uuu = -halfq + y;
const double vvv = -halfq - y;
const double www = std::abs(uuu) > std::abs(vvv) ? uuu : vvv;
const double w = std::cbrt(www);
roots.insert(w - p / (3 * w) - bover3a);
} else if (yy < -DBL_EPSILON) {
// Sqrt is negative: three real solutions
const double x = -halfq;
const double y = std::sqrt(-yy);
double theta;
double r;
// Convert to polar form
if (std::abs(x) > DBL_EPSILON) {
theta = (x > 0.0) ? std::atan(y / x) : (std::atan(y / x) + M_PI);
r = std::sqrt(x * x - yy);
} else {
// Vertical line
theta = M_PI / 2;
r = y;
}
// Calculate cube root
theta /= 3;
r = 2 * std::cbrt(r);
// Convert to complex coordinate
const double ux = std::cos(theta) * r;
const double uyi = std::sin(theta) * r;
roots.insert(ux - bover3a);
roots.insert(ux * cos120 - uyi * sin120 - bover3a);
roots.insert(ux * cos120 + uyi * sin120 - bover3a);
} else {
// Sqrt is zero: two real solutions
const double www = -halfq;
const double w = 2 * std::cbrt(www);
roots.insert(w - bover3a);
roots.insert(w * cos120 - bover3a);
}
}
return roots;
}
// Solve resolvent eqaution of corresponding Quartic equation
// The input x must be of length 3
// Number of zeros are returned
inline int solve_resolvent(std::array<double, 3>& x, double a, double b, double c) {
constexpr double cos120 = -0.50;
constexpr double sin120 = 0.866025403784438646764;
a /= 3;
const double a2 = a * a;
double q = a2 - b / 3;
const double r = (a * (2 * a2 - b) + c) / 2;
const double r2 = r * r;
const double q3 = q * q * q;
if (r2 < q3) {
const double qsqrt = std::sqrt(q);
const double t = std::min(std::max(r / (q * qsqrt), -1.0), 1.0);
q = -2 * qsqrt;
const double theta = std::acos(t) / 3;
const double ux = std::cos(theta) * q;
const double uyi = std::sin(theta) * q;
x[0] = ux - a;
x[1] = ux * cos120 - uyi * sin120 - a;
x[2] = ux * cos120 + uyi * sin120 - a;
return 3;
} else {
double A = -std::cbrt(std::abs(r) + std::sqrt(r2 - q3));
if (r < 0.0) {
A = -A;
}
const double B = (0.0 == A ? 0.0 : q / A);
x[0] = (A + B) - a;
x[1] = -(A + B) / 2 - a;
x[2] = std::sqrt(3) * (A - B) / 2;
if (std::abs(x[2]) < DBL_EPSILON) {
x[2] = x[1];
return 2;
}
return 1;
}
}
inline PositiveSet<double, 4> solve_quart_monic(double a, double b, double c, double d) {
PositiveSet<double, 4> roots;
if (std::abs(d) < DBL_EPSILON) {
if (std::abs(c) < DBL_EPSILON) {
roots.insert(0.0);
const double D = a * a - 4 * b;
if (std::abs(D) < DBL_EPSILON) {
roots.insert(-a / 2);
} else if (D > 0.0) {
const double sqrtD = std::sqrt(D);
roots.insert((-a - sqrtD) / 2);
roots.insert((-a + sqrtD) / 2);
}
return roots;
}
if (std::abs(a) < DBL_EPSILON && std::abs(b) < DBL_EPSILON) {
roots.insert(0.0);
roots.insert(-std::cbrt(c));
return roots;
}
}
const double a3 = -b;
const double b3 = a * c - 4 * d;
const double c3 = -a * a * d - c * c + 4 * b * d;
std::array<double, 3> x3;
const int number_zeroes = solve_resolvent(x3, a3, b3, c3);
double y = x3[0];
// Choosing Y with maximal absolute value.
if (number_zeroes != 1) {
if (std::abs(x3[1]) > std::abs(y)) {
y = x3[1];
}
if (std::abs(x3[2]) > std::abs(y)) {
y = x3[2];
}
}
double q1, q2, p1, p2;
double D = y * y - 4 * d;
if (std::abs(D) < DBL_EPSILON) {
q1 = q2 = y / 2;
D = a * a - 4 * (b - y);
if (std::abs(D) < DBL_EPSILON) {
p1 = p2 = a / 2;
} else {
const double sqrtD = std::sqrt(D);
p1 = (a + sqrtD) / 2;
p2 = (a - sqrtD) / 2;
}
} else {
const double sqrtD = std::sqrt(D);
q1 = (y + sqrtD) / 2;
q2 = (y - sqrtD) / 2;
p1 = (a * q1 - c) / (q1 - q2);
p2 = (c - a * q2) / (q1 - q2);
}
constexpr double eps {16 * DBL_EPSILON};
D = p1 * p1 - 4 * q1;
if (std::abs(D) < eps) {
roots.insert(-p1 / 2);
} else if (D > 0.0) {
const double sqrtD = std::sqrt(D);
roots.insert((-p1 - sqrtD) / 2);
roots.insert((-p1 + sqrtD) / 2);
}
D = p2 * p2 - 4 * q2;
if (std::abs(D) < eps) {
roots.insert(-p2 / 2);
} else if (D > 0.0) {
const double sqrtD = std::sqrt(D);
roots.insert((-p2 - sqrtD) / 2);
roots.insert((-p2 + sqrtD) / 2);
}
return roots;
}
inline PositiveSet<double, 4> solve_quart_monic(const std::array<double, 4>& polynom) {
return solve_quart_monic(polynom[0], polynom[1], polynom[2], polynom[3]);
}
template<size_t N>
inline double poly_eval(const std::array<double, N>& p, double x) {
double retVal = 0.0;
if constexpr (N == 0) {
return retVal;
}
if (std::abs(x) < DBL_EPSILON) {
retVal = p[N - 1];
} else if (x == 1.0) {
for (int i = N - 1; i >= 0; i--) {
retVal += p[i];
}
} else {
double xn = 1.0;
for (int i = N - 1; i >= 0; i--) {
retVal += p[i] * xn;
xn *= x;
}
}
return retVal;
}
// Calculate the derivative poly coefficients of a given poly
template<size_t N>
inline std::array<double, N-1> poly_derivative(const std::array<double, N>& coeffs) {
std::array<double, N-1> deriv;
for (size_t i = 0; i < N - 1; ++i) {
deriv[i] = (N - 1 - i) * coeffs[i];
}
return deriv;
}
template<size_t N>
inline std::array<double, N-1> poly_monic_derivative(const std::array<double, N>& monic_coeffs) {
std::array<double, N-1> deriv;
deriv[0] = 1.0;
for (size_t i = 1; i < N - 1; ++i) {
deriv[i] = (N - 1 - i) * monic_coeffs[i] / (N - 1);
}
return deriv;
}
// Safe Newton Method
constexpr double tolerance {1e-14};
// Calculate a single zero of polynom p(x) inside [lbound, ubound]
// Requirements: p(lbound)*p(ubound) < 0, lbound < ubound
template<size_t N, size_t maxIts = 128>
inline double shrink_interval(const std::array<double, N>& p, double l, double h) {
const double fl = poly_eval(p, l);
const double fh = poly_eval(p, h);
if (fl == 0.0) {
return l;
}
if (fh == 0.0) {
return h;
}
if (fl > 0.0) {
std::swap(l, h);
}
double rts = (l + h) / 2;
double dxold = std::abs(h - l);
double dx = dxold;
const auto deriv = poly_derivative(p);
double f = poly_eval(p, rts);
double df = poly_eval(deriv, rts);
double temp;
for (size_t j = 0; j < maxIts; j++) {
if ((((rts - h) * df - f) * ((rts - l) * df - f) > 0.0) || (std::abs(2 * f) > std::abs(dxold * df))) {
dxold = dx;
dx = (h - l) / 2;
rts = l + dx;
if (l == rts) {
break;
}
} else {
dxold = dx;
dx = f / df;
temp = rts;
rts -= dx;
if (temp == rts) {
break;
}
}
if (std::abs(dx) < tolerance) {
break;
}
f = poly_eval(p, rts);
df = poly_eval(deriv, rts);
if (f < 0.0) {
l = rts;
} else {
h = rts;
}
}
return rts;
}
} // namespace roots
} // namespace ruckig