sim_lib_numbers_cas_diff/implementation/registry.rs
1//! The extensible differentiation-rule registry: a process-global map from an
2//! operator symbol to a custom `diff` rule, letting other libraries teach the
3//! differentiator new functions.
4
5use std::{
6 collections::BTreeMap,
7 sync::{OnceLock, RwLock},
8};
9
10use sim_kernel::Symbol;
11use sim_lib_numbers_cas::CasExpr;
12
13/// A custom differentiation rule: maps an operator's arguments and the variable
14/// of differentiation to a derivative tree, or `None` to decline.
15pub type DiffRule = Box<dyn Fn(&[CasExpr], &Symbol) -> Option<CasExpr> + Send + Sync>;
16
17/// A registry of per-operator differentiation rules.
18///
19/// # Examples
20///
21/// ```
22/// use sim_kernel::Symbol;
23/// use sim_lib_numbers_cas::CasExpr;
24/// use sim_lib_numbers_cas_diff::CasDiffRegistry;
25///
26/// let mut registry = CasDiffRegistry::new();
27/// registry.register_rule(
28/// Symbol::new("id"),
29/// Box::new(|args: &[CasExpr], _var: &Symbol| args.first().cloned()),
30/// );
31///
32/// let out = registry.apply(
33/// &Symbol::new("id"),
34/// &[CasExpr::Var(Symbol::new("x"))],
35/// &Symbol::new("x"),
36/// );
37/// assert!(matches!(out, Some(CasExpr::Var(_))));
38/// // An unregistered operator yields `None`.
39/// assert!(registry.apply(&Symbol::new("nope"), &[], &Symbol::new("x")).is_none());
40/// ```
41#[derive(Default)]
42pub struct CasDiffRegistry {
43 /// The registered rules, keyed by operator symbol.
44 pub rules: BTreeMap<Symbol, DiffRule>,
45}
46
47impl CasDiffRegistry {
48 /// Construct an empty registry.
49 pub fn new() -> Self {
50 Self::default()
51 }
52
53 /// Register `rule` for `symbol`, returning any rule it replaced.
54 pub fn register_rule(&mut self, symbol: Symbol, rule: DiffRule) -> Option<DiffRule> {
55 self.rules.insert(symbol, rule)
56 }
57
58 /// Apply the rule registered for `symbol`, if any, to `args` and `var`.
59 pub fn apply(&self, symbol: &Symbol, args: &[CasExpr], var: &Symbol) -> Option<CasExpr> {
60 self.rules.get(symbol).and_then(|rule| rule(args, var))
61 }
62}
63
64static REGISTRY: OnceLock<RwLock<CasDiffRegistry>> = OnceLock::new();
65
66/// Access the process-global differentiation-rule registry.
67pub fn global_diff_registry() -> &'static RwLock<CasDiffRegistry> {
68 REGISTRY.get_or_init(|| RwLock::new(CasDiffRegistry::new()))
69}
70
71/// Register `rule` for `symbol` in the global registry, returning any rule it
72/// replaced.
73pub fn register_diff_rule(symbol: Symbol, rule: DiffRule) -> Option<DiffRule> {
74 global_diff_registry()
75 .write()
76 .expect("CAS diff registry should not be poisoned")
77 .register_rule(symbol, rule)
78}
79
80pub(crate) fn apply_registered_rule(
81 symbol: &Symbol,
82 args: &[CasExpr],
83 var: &Symbol,
84) -> Option<CasExpr> {
85 global_diff_registry()
86 .read()
87 .expect("CAS diff registry should not be poisoned")
88 .apply(symbol, args, var)
89}