Skip to main content

sim_lib_numbers_cas_diff/implementation/
registry.rs

1//! The extensible differentiation-rule registry: a process-global map from an
2//! operator symbol to a custom `diff` rule, letting other libraries teach the
3//! differentiator new functions.
4
5use std::{
6    collections::BTreeMap,
7    sync::{OnceLock, RwLock},
8};
9
10use sim_kernel::Symbol;
11use sim_lib_numbers_cas::CasExpr;
12
13/// A custom differentiation rule: maps an operator's arguments and the variable
14/// of differentiation to a derivative tree, or `None` to decline.
15pub type DiffRule = Box<dyn Fn(&[CasExpr], &Symbol) -> Option<CasExpr> + Send + Sync>;
16
17/// A registry of per-operator differentiation rules.
18///
19/// # Examples
20///
21/// ```
22/// use sim_kernel::Symbol;
23/// use sim_lib_numbers_cas::CasExpr;
24/// use sim_lib_numbers_cas_diff::CasDiffRegistry;
25///
26/// let mut registry = CasDiffRegistry::new();
27/// registry.register_rule(
28///     Symbol::new("id"),
29///     Box::new(|args: &[CasExpr], _var: &Symbol| args.first().cloned()),
30/// );
31///
32/// let out = registry.apply(
33///     &Symbol::new("id"),
34///     &[CasExpr::Var(Symbol::new("x"))],
35///     &Symbol::new("x"),
36/// );
37/// assert!(matches!(out, Some(CasExpr::Var(_))));
38/// // An unregistered operator yields `None`.
39/// assert!(registry.apply(&Symbol::new("nope"), &[], &Symbol::new("x")).is_none());
40/// ```
41#[derive(Default)]
42pub struct CasDiffRegistry {
43    /// The registered rules, keyed by operator symbol.
44    pub rules: BTreeMap<Symbol, DiffRule>,
45}
46
47impl CasDiffRegistry {
48    /// Construct an empty registry.
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    /// Register `rule` for `symbol`, returning any rule it replaced.
54    pub fn register_rule(&mut self, symbol: Symbol, rule: DiffRule) -> Option<DiffRule> {
55        self.rules.insert(symbol, rule)
56    }
57
58    /// Apply the rule registered for `symbol`, if any, to `args` and `var`.
59    pub fn apply(&self, symbol: &Symbol, args: &[CasExpr], var: &Symbol) -> Option<CasExpr> {
60        self.rules.get(symbol).and_then(|rule| rule(args, var))
61    }
62}
63
64static REGISTRY: OnceLock<RwLock<CasDiffRegistry>> = OnceLock::new();
65
66/// Access the process-global differentiation-rule registry.
67pub fn global_diff_registry() -> &'static RwLock<CasDiffRegistry> {
68    REGISTRY.get_or_init(|| RwLock::new(CasDiffRegistry::new()))
69}
70
71/// Register `rule` for `symbol` in the global registry, returning any rule it
72/// replaced.
73pub fn register_diff_rule(symbol: Symbol, rule: DiffRule) -> Option<DiffRule> {
74    global_diff_registry()
75        .write()
76        .expect("CAS diff registry should not be poisoned")
77        .register_rule(symbol, rule)
78}
79
80pub(crate) fn apply_registered_rule(
81    symbol: &Symbol,
82    args: &[CasExpr],
83    var: &Symbol,
84) -> Option<CasExpr> {
85    global_diff_registry()
86        .read()
87        .expect("CAS diff registry should not be poisoned")
88        .apply(symbol, args, var)
89}