Skip to main content

sim_lib_numbers_quad/
quad.rs

1//! The quadrature library and its integration backends, registering fixed and
2//! adaptive quadrature rules and the finite-difference differentiators.
3
4use std::sync::Arc;
5
6use sim_kernel::{
7    AbiVersion, Cx, Dependency, Export, Lib, LibManifest, LibTarget, Linker, Result, Symbol, Value,
8    Version,
9};
10use sim_lib_numbers_codec::{numeric_plugin_descriptor_symbol, numeric_plugin_descriptor_value};
11use sim_lib_numbers_core::domains;
12use sim_lib_numbers_func::Func;
13use sim_lib_numbers_numeric::{
14    NumericKind, NumericPlugin, QuadOpts, Quadrature, register_differentiator, register_quadrature,
15};
16
17use super::{
18    diff::differentiators,
19    support::{
20        abs_error, add, add_scaled, call_unary_func, f64_value, scale, sub, value_to_f64, zero_like,
21    },
22};
23
24/// Registered numeric plugin library that installs this crate's quadrature
25/// rules and finite-difference differentiators.
26///
27/// Loading this [`Lib`] registers the fixed and adaptive integration backends
28/// (trapezoid, Simpson, Romberg, Gauss-Legendre, adaptive Gauss-Kronrod) used
29/// by the numeric `integrate`/`integrate-adapt` surface, and the
30/// finite-difference differentiators (forward, backward, central-3, central-5,
31/// Richardson) used by `numeric-diff`. It also installs the plugin descriptor
32/// values that advertise each backend to the registry.
33///
34/// # Examples
35///
36/// ```
37/// use sim_kernel::Lib;
38/// use sim_lib_numbers_quad::QuadNumbersLib;
39///
40/// let lib = QuadNumbersLib::new();
41/// let manifest = lib.manifest();
42/// // One descriptor export per registered backend (5 differentiators plus
43/// // 7 quadrature rules).
44/// assert_eq!(manifest.exports.len(), 12);
45/// ```
46pub struct QuadNumbersLib;
47
48impl QuadNumbersLib {
49    /// Creates the quadrature/differentiator library. The value is stateless;
50    /// all behavior is installed when it is loaded into a [`Cx`].
51    pub fn new() -> Self {
52        Self
53    }
54}
55
56impl Default for QuadNumbersLib {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl Lib for QuadNumbersLib {
63    fn manifest(&self) -> LibManifest {
64        LibManifest {
65            id: domains::quad(),
66            version: Version(env!("CARGO_PKG_VERSION").to_owned()),
67            abi: AbiVersion { major: 0, minor: 1 },
68            target: LibTarget::HostRegistered,
69            requires: Vec::<Dependency>::new(),
70            capabilities: Vec::new(),
71            exports: descriptor_exports(),
72        }
73    }
74
75    fn load(&self, cx: &mut sim_kernel::LoadCx, linker: &mut Linker<'_>) -> Result<()> {
76        for plugin in differentiators() {
77            register_differentiator(plugin)?;
78        }
79        for plugin in quadratures() {
80            register_quadrature(plugin)?;
81        }
82        install_descriptors(cx, linker)?;
83        Ok(())
84    }
85}
86
87fn descriptor_exports() -> Vec<Export> {
88    descriptor_specs()
89        .into_iter()
90        .map(|(name, _, _adaptive)| Export::Value {
91            symbol: numeric_plugin_descriptor_symbol("numbers/quad", name),
92        })
93        .collect()
94}
95
96fn install_descriptors(cx: &sim_kernel::LoadCx, linker: &mut Linker<'_>) -> Result<()> {
97    for (name, kind, adaptive) in descriptor_specs() {
98        linker.value(
99            numeric_plugin_descriptor_symbol("numbers/quad", name),
100            numeric_plugin_descriptor_value(
101                cx.factory(),
102                Symbol::new(name),
103                kind,
104                adaptive,
105                domains::quad(),
106            )?,
107        )?;
108    }
109    Ok(())
110}
111
112fn descriptor_specs() -> Vec<(&'static str, &'static str, bool)> {
113    vec![
114        ("forward", "differentiator", false),
115        ("backward", "differentiator", false),
116        ("central-3", "differentiator", false),
117        ("central-5", "differentiator", false),
118        ("richardson", "differentiator", false),
119        ("trapezoid", "quadrature", false),
120        ("simpson", "quadrature", false),
121        ("romberg", "quadrature", true),
122        ("gauss-legendre-8", "quadrature", false),
123        ("gauss-legendre-16", "quadrature", false),
124        ("gauss-legendre-32", "quadrature", false),
125        ("adaptive-gauss-kronrod", "quadrature", true),
126    ]
127}
128
129#[derive(Clone, Copy)]
130enum Method {
131    Trapezoid,
132    Simpson,
133    Romberg,
134    GaussLegendre(usize),
135    AdaptiveGaussKronrod,
136}
137
138fn quadratures() -> Vec<Arc<dyn Quadrature>> {
139    vec![
140        Arc::new(QuadPlugin::new(
141            "trapezoid",
142            NumericKind::QuadratureFixed,
143            Method::Trapezoid,
144        )),
145        Arc::new(QuadPlugin::new(
146            "simpson",
147            NumericKind::QuadratureFixed,
148            Method::Simpson,
149        )),
150        Arc::new(QuadPlugin::new(
151            "romberg",
152            NumericKind::QuadratureFixed,
153            Method::Romberg,
154        )),
155        Arc::new(QuadPlugin::new(
156            "romberg",
157            NumericKind::QuadratureAdaptive,
158            Method::Romberg,
159        )),
160        Arc::new(QuadPlugin::new(
161            "gauss-legendre-8",
162            NumericKind::QuadratureFixed,
163            Method::GaussLegendre(8),
164        )),
165        Arc::new(QuadPlugin::new(
166            "gauss-legendre-16",
167            NumericKind::QuadratureFixed,
168            Method::GaussLegendre(16),
169        )),
170        Arc::new(QuadPlugin::new(
171            "gauss-legendre-32",
172            NumericKind::QuadratureFixed,
173            Method::GaussLegendre(32),
174        )),
175        Arc::new(QuadPlugin::new(
176            "adaptive-gauss-kronrod",
177            NumericKind::QuadratureAdaptive,
178            Method::AdaptiveGaussKronrod,
179        )),
180    ]
181}
182
183struct QuadPlugin {
184    name: Symbol,
185    kind: NumericKind,
186    method: Method,
187}
188
189impl QuadPlugin {
190    fn new(name: &str, kind: NumericKind, method: Method) -> Self {
191        Self {
192            name: Symbol::new(name),
193            kind,
194            method,
195        }
196    }
197}
198
199impl NumericPlugin for QuadPlugin {
200    fn name(&self) -> Symbol {
201        self.name.clone()
202    }
203
204    fn kind(&self) -> NumericKind {
205        self.kind
206    }
207}
208
209impl Quadrature for QuadPlugin {
210    fn integrate(
211        &self,
212        cx: &mut Cx,
213        f: &Func,
214        _var: &Symbol,
215        lo: &Value,
216        hi: &Value,
217        opt: QuadOpts,
218    ) -> Result<Value> {
219        let a = value_to_f64(cx, lo, "quadrature lower bound")?;
220        let b = value_to_f64(cx, hi, "quadrature upper bound")?;
221        match self.method {
222            Method::Trapezoid => trapezoid(cx, f, a, b, opt.n.unwrap_or(256)),
223            Method::Simpson => simpson(cx, f, a, b, opt.n.unwrap_or(128)),
224            Method::Romberg => romberg(cx, f, a, b, opt.n.unwrap_or(6), opt.tol.unwrap_or(1.0e-10)),
225            Method::GaussLegendre(n) => gauss_legendre(cx, f, a, b, n),
226            Method::AdaptiveGaussKronrod => {
227                adaptive_gauss_kronrod(cx, f, a, b, opt.tol.unwrap_or(1.0e-10), 10)
228            }
229        }
230    }
231}
232
233fn trapezoid(cx: &mut Cx, f: &Func, a: f64, b: f64, n: usize) -> Result<Value> {
234    let n = n.max(1);
235    let h = (b - a) / n as f64;
236    let fa = sample_at(cx, f, a)?;
237    let fb = sample_at(cx, f, b)?;
238    let fa = scale(cx, fa, 0.5)?;
239    let fb = scale(cx, fb, 0.5)?;
240    let mut acc = add(cx, fa, fb)?;
241    for i in 1..n {
242        let x = a + i as f64 * h;
243        let sample = sample_at(cx, f, x)?;
244        acc = add(cx, acc, sample)?;
245    }
246    scale(cx, acc, h)
247}
248
249fn simpson(cx: &mut Cx, f: &Func, a: f64, b: f64, n: usize) -> Result<Value> {
250    let n = if n < 2 {
251        2
252    } else if n.is_multiple_of(2) {
253        n
254    } else {
255        n + 1
256    };
257    let h = (b - a) / n as f64;
258    let fa = sample_at(cx, f, a)?;
259    let fb = sample_at(cx, f, b)?;
260    let mut acc = add(cx, fa, fb)?;
261    for i in 1..n {
262        let x = a + i as f64 * h;
263        let coeff = if i.is_multiple_of(2) { 2.0 } else { 4.0 };
264        let sample = sample_at(cx, f, x)?;
265        acc = add_scaled(cx, acc, sample, coeff)?;
266    }
267    scale(cx, acc, h / 3.0)
268}
269
270fn romberg(cx: &mut Cx, f: &Func, a: f64, b: f64, levels: usize, tol: f64) -> Result<Value> {
271    let levels = levels.max(1);
272    let mut table: Vec<Vec<Value>> = Vec::with_capacity(levels);
273    for k in 0..levels {
274        let panels = 1usize << k;
275        let trap = trapezoid(cx, f, a, b, panels)?;
276        let mut row = vec![trap];
277        for j in 1..=k {
278            let factor = 4_f64.powi(j as i32);
279            let prev = row[j - 1].clone();
280            let coarse = table[k - 1][j - 1].clone();
281            let prev = scale(cx, prev, factor)?;
282            let numerator = sub(cx, prev, coarse)?;
283            row.push(scale(cx, numerator, 1.0 / (factor - 1.0))?);
284        }
285        if let Some(previous_row) = table.last()
286            && let (Some(last), Some(prev)) = (row.last(), previous_row.last())
287            && abs_error(cx, last.clone(), prev.clone())? <= tol
288        {
289            return Ok(last.clone());
290        }
291        table.push(row);
292    }
293    Ok(table
294        .last()
295        .and_then(|row| row.last())
296        .cloned()
297        .expect("romberg table should contain at least one result"))
298}
299
300fn gauss_legendre(cx: &mut Cx, f: &Func, a: f64, b: f64, n: usize) -> Result<Value> {
301    let nodes = gauss_legendre_nodes(n);
302    let mid = 0.5 * (a + b);
303    let half = 0.5 * (b - a);
304    let seed = sample_at(cx, f, mid + half * nodes[0].0)?;
305    let mut acc = zero_like(cx, seed)?;
306    for (x, w) in nodes {
307        let sample = sample_at(cx, f, mid + half * x)?;
308        acc = add_scaled(cx, acc, sample, w)?;
309    }
310    scale(cx, acc, half)
311}
312
313fn adaptive_gauss_kronrod(
314    cx: &mut Cx,
315    f: &Func,
316    a: f64,
317    b: f64,
318    tol: f64,
319    depth: usize,
320) -> Result<Value> {
321    let (kronrod, gauss) = gauss_kronrod_15(cx, f, a, b)?;
322    if depth == 0 || abs_error(cx, kronrod.clone(), gauss)? <= tol {
323        return Ok(kronrod);
324    }
325    let mid = 0.5 * (a + b);
326    let left = adaptive_gauss_kronrod(cx, f, a, mid, tol * 0.5, depth - 1)?;
327    let right = adaptive_gauss_kronrod(cx, f, mid, b, tol * 0.5, depth - 1)?;
328    add(cx, left, right)
329}
330
331fn gauss_kronrod_15(cx: &mut Cx, f: &Func, a: f64, b: f64) -> Result<(Value, Value)> {
332    const XGK: [f64; 8] = [
333        0.991_455_371_120_812_6,
334        0.949_107_912_342_758_5,
335        0.864_864_423_359_769_1,
336        0.741_531_185_599_394_5,
337        0.586_087_235_467_691_1,
338        0.405_845_151_377_397_2,
339        0.207_784_955_007_898_47,
340        0.0,
341    ];
342    const WGK: [f64; 8] = [
343        0.022_935_322_010_529_224,
344        0.063_092_092_629_978_56,
345        0.104_790_010_322_250_18,
346        0.140_653_259_715_525_92,
347        0.169_004_726_639_267_9,
348        0.190_350_578_064_785_42,
349        0.204_432_940_075_298_89,
350        0.209_482_141_084_727_82,
351    ];
352    const WG: [f64; 4] = [
353        0.129_484_966_168_869_7,
354        0.279_705_391_489_276_64,
355        0.381_830_050_505_118_9,
356        0.417_959_183_673_469_4,
357    ];
358    let mid = 0.5 * (a + b);
359    let half = 0.5 * (b - a);
360    let seed = sample_at(cx, f, mid)?;
361    let mut kronrod = zero_like(cx, seed.clone())?;
362    let mut gauss = zero_like(cx, seed)?;
363    for (i, x) in XGK.iter().copied().enumerate() {
364        let sample = if x == 0.0 {
365            sample_at(cx, f, mid)?
366        } else {
367            let plus = sample_at(cx, f, mid + half * x)?;
368            let minus = sample_at(cx, f, mid - half * x)?;
369            add(cx, plus, minus)?
370        };
371        kronrod = add_scaled(cx, kronrod, sample.clone(), WGK[i])?;
372        if i == 1 {
373            gauss = add_scaled(cx, gauss, sample, WG[0])?;
374        } else if i == 3 {
375            gauss = add_scaled(cx, gauss, sample, WG[1])?;
376        } else if i == 5 {
377            gauss = add_scaled(cx, gauss, sample, WG[2])?;
378        } else if i == 7 {
379            gauss = add_scaled(cx, gauss, sample, WG[3])?;
380        }
381    }
382    Ok((scale(cx, kronrod, half)?, scale(cx, gauss, half)?))
383}
384
385fn sample_at(cx: &mut Cx, f: &Func, x: f64) -> Result<Value> {
386    let x = f64_value(cx, x)?;
387    call_unary_func(cx, f, x)
388}
389
390fn gauss_legendre_nodes(n: usize) -> Vec<(f64, f64)> {
391    let m = n.div_ceil(2);
392    let mut nodes = vec![(0.0, 0.0); n];
393    for i in 0..m {
394        let mut z = (std::f64::consts::PI * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
395        loop {
396            let (pn, pnm1) = legendre(n, z);
397            let derivative = (n as f64) * (z * pn - pnm1) / (z * z - 1.0);
398            let next = z - pn / derivative;
399            if (next - z).abs() < 1.0e-15 {
400                z = next;
401                break;
402            }
403            z = next;
404        }
405        let (pn, pnm1) = legendre(n, z);
406        let derivative = (n as f64) * (z * pn - pnm1) / (z * z - 1.0);
407        let weight = 2.0 / ((1.0 - z * z) * derivative * derivative);
408        nodes[i] = (-z, weight);
409        nodes[n - 1 - i] = (z, weight);
410    }
411    nodes
412}
413
414fn legendre(n: usize, x: f64) -> (f64, f64) {
415    let mut p0 = 1.0;
416    let mut p1 = x;
417    if n == 0 {
418        return (p0, 0.0);
419    }
420    if n == 1 {
421        return (p1, p0);
422    }
423    for k in 2..=n {
424        let pk = ((2 * k - 1) as f64 * x * p1 - (k - 1) as f64 * p0) / k as f64;
425        p0 = p1;
426        p1 = pk;
427    }
428    (p1, p0)
429}