1use std::sync::Arc;
5
6use sim_kernel::{
7 AbiVersion, Cx, Dependency, Export, Lib, LibManifest, LibTarget, Linker, Result, Symbol, Value,
8 Version,
9};
10use sim_lib_numbers_codec::{numeric_plugin_descriptor_symbol, numeric_plugin_descriptor_value};
11use sim_lib_numbers_core::domains;
12use sim_lib_numbers_func::Func;
13use sim_lib_numbers_numeric::{
14 NumericKind, NumericPlugin, QuadOpts, Quadrature, register_differentiator, register_quadrature,
15};
16
17use super::{
18 diff::differentiators,
19 support::{
20 abs_error, add, add_scaled, call_unary_func, f64_value, scale, sub, value_to_f64, zero_like,
21 },
22};
23
24pub struct QuadNumbersLib;
47
48impl QuadNumbersLib {
49 pub fn new() -> Self {
52 Self
53 }
54}
55
56impl Default for QuadNumbersLib {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl Lib for QuadNumbersLib {
63 fn manifest(&self) -> LibManifest {
64 LibManifest {
65 id: domains::quad(),
66 version: Version(env!("CARGO_PKG_VERSION").to_owned()),
67 abi: AbiVersion { major: 0, minor: 1 },
68 target: LibTarget::HostRegistered,
69 requires: Vec::<Dependency>::new(),
70 capabilities: Vec::new(),
71 exports: descriptor_exports(),
72 }
73 }
74
75 fn load(&self, cx: &mut sim_kernel::LoadCx, linker: &mut Linker<'_>) -> Result<()> {
76 for plugin in differentiators() {
77 register_differentiator(plugin)?;
78 }
79 for plugin in quadratures() {
80 register_quadrature(plugin)?;
81 }
82 install_descriptors(cx, linker)?;
83 Ok(())
84 }
85}
86
87fn descriptor_exports() -> Vec<Export> {
88 descriptor_specs()
89 .into_iter()
90 .map(|(name, _, _adaptive)| Export::Value {
91 symbol: numeric_plugin_descriptor_symbol("numbers/quad", name),
92 })
93 .collect()
94}
95
96fn install_descriptors(cx: &sim_kernel::LoadCx, linker: &mut Linker<'_>) -> Result<()> {
97 for (name, kind, adaptive) in descriptor_specs() {
98 linker.value(
99 numeric_plugin_descriptor_symbol("numbers/quad", name),
100 numeric_plugin_descriptor_value(
101 cx.factory(),
102 Symbol::new(name),
103 kind,
104 adaptive,
105 domains::quad(),
106 )?,
107 )?;
108 }
109 Ok(())
110}
111
112fn descriptor_specs() -> Vec<(&'static str, &'static str, bool)> {
113 vec![
114 ("forward", "differentiator", false),
115 ("backward", "differentiator", false),
116 ("central-3", "differentiator", false),
117 ("central-5", "differentiator", false),
118 ("richardson", "differentiator", false),
119 ("trapezoid", "quadrature", false),
120 ("simpson", "quadrature", false),
121 ("romberg", "quadrature", true),
122 ("gauss-legendre-8", "quadrature", false),
123 ("gauss-legendre-16", "quadrature", false),
124 ("gauss-legendre-32", "quadrature", false),
125 ("adaptive-gauss-kronrod", "quadrature", true),
126 ]
127}
128
129#[derive(Clone, Copy)]
130enum Method {
131 Trapezoid,
132 Simpson,
133 Romberg,
134 GaussLegendre(usize),
135 AdaptiveGaussKronrod,
136}
137
138fn quadratures() -> Vec<Arc<dyn Quadrature>> {
139 vec![
140 Arc::new(QuadPlugin::new(
141 "trapezoid",
142 NumericKind::QuadratureFixed,
143 Method::Trapezoid,
144 )),
145 Arc::new(QuadPlugin::new(
146 "simpson",
147 NumericKind::QuadratureFixed,
148 Method::Simpson,
149 )),
150 Arc::new(QuadPlugin::new(
151 "romberg",
152 NumericKind::QuadratureFixed,
153 Method::Romberg,
154 )),
155 Arc::new(QuadPlugin::new(
156 "romberg",
157 NumericKind::QuadratureAdaptive,
158 Method::Romberg,
159 )),
160 Arc::new(QuadPlugin::new(
161 "gauss-legendre-8",
162 NumericKind::QuadratureFixed,
163 Method::GaussLegendre(8),
164 )),
165 Arc::new(QuadPlugin::new(
166 "gauss-legendre-16",
167 NumericKind::QuadratureFixed,
168 Method::GaussLegendre(16),
169 )),
170 Arc::new(QuadPlugin::new(
171 "gauss-legendre-32",
172 NumericKind::QuadratureFixed,
173 Method::GaussLegendre(32),
174 )),
175 Arc::new(QuadPlugin::new(
176 "adaptive-gauss-kronrod",
177 NumericKind::QuadratureAdaptive,
178 Method::AdaptiveGaussKronrod,
179 )),
180 ]
181}
182
183struct QuadPlugin {
184 name: Symbol,
185 kind: NumericKind,
186 method: Method,
187}
188
189impl QuadPlugin {
190 fn new(name: &str, kind: NumericKind, method: Method) -> Self {
191 Self {
192 name: Symbol::new(name),
193 kind,
194 method,
195 }
196 }
197}
198
199impl NumericPlugin for QuadPlugin {
200 fn name(&self) -> Symbol {
201 self.name.clone()
202 }
203
204 fn kind(&self) -> NumericKind {
205 self.kind
206 }
207}
208
209impl Quadrature for QuadPlugin {
210 fn integrate(
211 &self,
212 cx: &mut Cx,
213 f: &Func,
214 _var: &Symbol,
215 lo: &Value,
216 hi: &Value,
217 opt: QuadOpts,
218 ) -> Result<Value> {
219 let a = value_to_f64(cx, lo, "quadrature lower bound")?;
220 let b = value_to_f64(cx, hi, "quadrature upper bound")?;
221 match self.method {
222 Method::Trapezoid => trapezoid(cx, f, a, b, opt.n.unwrap_or(256)),
223 Method::Simpson => simpson(cx, f, a, b, opt.n.unwrap_or(128)),
224 Method::Romberg => romberg(cx, f, a, b, opt.n.unwrap_or(6), opt.tol.unwrap_or(1.0e-10)),
225 Method::GaussLegendre(n) => gauss_legendre(cx, f, a, b, n),
226 Method::AdaptiveGaussKronrod => {
227 adaptive_gauss_kronrod(cx, f, a, b, opt.tol.unwrap_or(1.0e-10), 10)
228 }
229 }
230 }
231}
232
233fn trapezoid(cx: &mut Cx, f: &Func, a: f64, b: f64, n: usize) -> Result<Value> {
234 let n = n.max(1);
235 let h = (b - a) / n as f64;
236 let fa = sample_at(cx, f, a)?;
237 let fb = sample_at(cx, f, b)?;
238 let fa = scale(cx, fa, 0.5)?;
239 let fb = scale(cx, fb, 0.5)?;
240 let mut acc = add(cx, fa, fb)?;
241 for i in 1..n {
242 let x = a + i as f64 * h;
243 let sample = sample_at(cx, f, x)?;
244 acc = add(cx, acc, sample)?;
245 }
246 scale(cx, acc, h)
247}
248
249fn simpson(cx: &mut Cx, f: &Func, a: f64, b: f64, n: usize) -> Result<Value> {
250 let n = if n < 2 {
251 2
252 } else if n.is_multiple_of(2) {
253 n
254 } else {
255 n + 1
256 };
257 let h = (b - a) / n as f64;
258 let fa = sample_at(cx, f, a)?;
259 let fb = sample_at(cx, f, b)?;
260 let mut acc = add(cx, fa, fb)?;
261 for i in 1..n {
262 let x = a + i as f64 * h;
263 let coeff = if i.is_multiple_of(2) { 2.0 } else { 4.0 };
264 let sample = sample_at(cx, f, x)?;
265 acc = add_scaled(cx, acc, sample, coeff)?;
266 }
267 scale(cx, acc, h / 3.0)
268}
269
270fn romberg(cx: &mut Cx, f: &Func, a: f64, b: f64, levels: usize, tol: f64) -> Result<Value> {
271 let levels = levels.max(1);
272 let mut table: Vec<Vec<Value>> = Vec::with_capacity(levels);
273 for k in 0..levels {
274 let panels = 1usize << k;
275 let trap = trapezoid(cx, f, a, b, panels)?;
276 let mut row = vec![trap];
277 for j in 1..=k {
278 let factor = 4_f64.powi(j as i32);
279 let prev = row[j - 1].clone();
280 let coarse = table[k - 1][j - 1].clone();
281 let prev = scale(cx, prev, factor)?;
282 let numerator = sub(cx, prev, coarse)?;
283 row.push(scale(cx, numerator, 1.0 / (factor - 1.0))?);
284 }
285 if let Some(previous_row) = table.last()
286 && let (Some(last), Some(prev)) = (row.last(), previous_row.last())
287 && abs_error(cx, last.clone(), prev.clone())? <= tol
288 {
289 return Ok(last.clone());
290 }
291 table.push(row);
292 }
293 Ok(table
294 .last()
295 .and_then(|row| row.last())
296 .cloned()
297 .expect("romberg table should contain at least one result"))
298}
299
300fn gauss_legendre(cx: &mut Cx, f: &Func, a: f64, b: f64, n: usize) -> Result<Value> {
301 let nodes = gauss_legendre_nodes(n);
302 let mid = 0.5 * (a + b);
303 let half = 0.5 * (b - a);
304 let seed = sample_at(cx, f, mid + half * nodes[0].0)?;
305 let mut acc = zero_like(cx, seed)?;
306 for (x, w) in nodes {
307 let sample = sample_at(cx, f, mid + half * x)?;
308 acc = add_scaled(cx, acc, sample, w)?;
309 }
310 scale(cx, acc, half)
311}
312
313fn adaptive_gauss_kronrod(
314 cx: &mut Cx,
315 f: &Func,
316 a: f64,
317 b: f64,
318 tol: f64,
319 depth: usize,
320) -> Result<Value> {
321 let (kronrod, gauss) = gauss_kronrod_15(cx, f, a, b)?;
322 if depth == 0 || abs_error(cx, kronrod.clone(), gauss)? <= tol {
323 return Ok(kronrod);
324 }
325 let mid = 0.5 * (a + b);
326 let left = adaptive_gauss_kronrod(cx, f, a, mid, tol * 0.5, depth - 1)?;
327 let right = adaptive_gauss_kronrod(cx, f, mid, b, tol * 0.5, depth - 1)?;
328 add(cx, left, right)
329}
330
331fn gauss_kronrod_15(cx: &mut Cx, f: &Func, a: f64, b: f64) -> Result<(Value, Value)> {
332 const XGK: [f64; 8] = [
333 0.991_455_371_120_812_6,
334 0.949_107_912_342_758_5,
335 0.864_864_423_359_769_1,
336 0.741_531_185_599_394_5,
337 0.586_087_235_467_691_1,
338 0.405_845_151_377_397_2,
339 0.207_784_955_007_898_47,
340 0.0,
341 ];
342 const WGK: [f64; 8] = [
343 0.022_935_322_010_529_224,
344 0.063_092_092_629_978_56,
345 0.104_790_010_322_250_18,
346 0.140_653_259_715_525_92,
347 0.169_004_726_639_267_9,
348 0.190_350_578_064_785_42,
349 0.204_432_940_075_298_89,
350 0.209_482_141_084_727_82,
351 ];
352 const WG: [f64; 4] = [
353 0.129_484_966_168_869_7,
354 0.279_705_391_489_276_64,
355 0.381_830_050_505_118_9,
356 0.417_959_183_673_469_4,
357 ];
358 let mid = 0.5 * (a + b);
359 let half = 0.5 * (b - a);
360 let seed = sample_at(cx, f, mid)?;
361 let mut kronrod = zero_like(cx, seed.clone())?;
362 let mut gauss = zero_like(cx, seed)?;
363 for (i, x) in XGK.iter().copied().enumerate() {
364 let sample = if x == 0.0 {
365 sample_at(cx, f, mid)?
366 } else {
367 let plus = sample_at(cx, f, mid + half * x)?;
368 let minus = sample_at(cx, f, mid - half * x)?;
369 add(cx, plus, minus)?
370 };
371 kronrod = add_scaled(cx, kronrod, sample.clone(), WGK[i])?;
372 if i == 1 {
373 gauss = add_scaled(cx, gauss, sample, WG[0])?;
374 } else if i == 3 {
375 gauss = add_scaled(cx, gauss, sample, WG[1])?;
376 } else if i == 5 {
377 gauss = add_scaled(cx, gauss, sample, WG[2])?;
378 } else if i == 7 {
379 gauss = add_scaled(cx, gauss, sample, WG[3])?;
380 }
381 }
382 Ok((scale(cx, kronrod, half)?, scale(cx, gauss, half)?))
383}
384
385fn sample_at(cx: &mut Cx, f: &Func, x: f64) -> Result<Value> {
386 let x = f64_value(cx, x)?;
387 call_unary_func(cx, f, x)
388}
389
390fn gauss_legendre_nodes(n: usize) -> Vec<(f64, f64)> {
391 let m = n.div_ceil(2);
392 let mut nodes = vec![(0.0, 0.0); n];
393 for i in 0..m {
394 let mut z = (std::f64::consts::PI * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
395 loop {
396 let (pn, pnm1) = legendre(n, z);
397 let derivative = (n as f64) * (z * pn - pnm1) / (z * z - 1.0);
398 let next = z - pn / derivative;
399 if (next - z).abs() < 1.0e-15 {
400 z = next;
401 break;
402 }
403 z = next;
404 }
405 let (pn, pnm1) = legendre(n, z);
406 let derivative = (n as f64) * (z * pn - pnm1) / (z * z - 1.0);
407 let weight = 2.0 / ((1.0 - z * z) * derivative * derivative);
408 nodes[i] = (-z, weight);
409 nodes[n - 1 - i] = (z, weight);
410 }
411 nodes
412}
413
414fn legendre(n: usize, x: f64) -> (f64, f64) {
415 let mut p0 = 1.0;
416 let mut p1 = x;
417 if n == 0 {
418 return (p0, 0.0);
419 }
420 if n == 1 {
421 return (p1, p0);
422 }
423 for k in 2..=n {
424 let pk = ((2 * k - 1) as f64 * x * p1 - (k - 1) as f64 * p0) / k as f64;
425 p0 = p1;
426 p1 = pk;
427 }
428 (p1, p0)
429}