Skip to main content

sim_lib_femm_function/
implementation.rs

1#![forbid(unsafe_code)]
2//! Callable wrapper that evaluates a model to a quantity, field, or solution.
3//!
4//! Defines the call request, output query, evaluation result, and the callable
5//! payload that turns a FEMM model into a runtime function of its parameters.
6
7use std::{any::Any, sync::Arc};
8
9use sim_kernel::{
10    ClassId, ClassRef, Cx, DefaultFactory, Expr, Factory, Object, ObjectEncode, ObjectEncoding,
11    Result as KernelResult, Symbol, Value,
12};
13use sim_lib_femm_core::{FemmError, FemmLimits, FemmResult, ParamSet, StableId, value_as_f64};
14use sim_lib_femm_field::{Field, Projection, field_as_func};
15use sim_lib_femm_mesh::FemmModel;
16use sim_lib_femm_post::{FemmSolution, QuantitySpec, quantity};
17use sim_lib_femm_solve::{GradientTrust, SolveCertificate, SteadySolve, solve_steady};
18use sim_lib_numbers_func::{Func, FuncMetadata};
19
20/// One evaluation request against a model: which parameters, output, and limits.
21///
22/// Couples a parameter binding with the [`OutputQuery`] to compute and the
23/// solver [`FemmLimits`]; `want_grad` names parameters whose sensitivities the
24/// caller also wants. See the [crate README](index.html).
25#[derive(Clone, Debug)]
26pub struct FemmCall {
27    /// The parameter binding the model is evaluated at.
28    pub params: ParamSet,
29    /// The output to compute from the solved model.
30    pub query: OutputQuery,
31    /// Parameters to also report sensitivities for, if any.
32    pub want_grad: Option<Vec<Symbol>>,
33    /// Solver budget and tolerances for this evaluation.
34    pub limits: FemmLimits,
35}
36
37/// The kind of output an evaluation produces from a solved model.
38#[derive(Clone, Debug)]
39pub enum OutputQuery {
40    /// A scalar quantity reduced from the solution (energy, flux, capacitance).
41    Quantity(QuantitySpec),
42    /// A projected field (potential or a derived component) over the mesh.
43    Field(Projection),
44    /// The full solved model solution as an opaque value.
45    Solution,
46}
47
48/// The result of evaluating a model: the output value plus optional gradient.
49#[derive(Clone, Debug)]
50pub struct FemmEval {
51    /// The computed output value (scalar, field, or solution).
52    pub value: Value,
53    /// Per-parameter sensitivities, when a gradient was requested.
54    pub gradient: Option<Vec<(Symbol, f64)>>,
55    /// Diagnostics emitted while solving and reducing the output.
56    pub diagnostics: Vec<sim_kernel::Diagnostic>,
57}
58
59/// Quantity value, certificate, and optional total gradient for a completed solve.
60#[derive(Clone, Debug)]
61pub struct QualityAnswer {
62    /// Scalar value of the requested quantity.
63    pub value: f64,
64    /// Certificate describing residual, convergence, and gradient trust.
65    pub certificate: SolveCertificate,
66    /// Gradient values and trust tag when a parameter list is supplied.
67    pub gradient: Option<(Vec<f64>, GradientTrust)>,
68}
69
70/// The opaque payload carried by a model-derived runtime function.
71///
72/// Recorded in a [`Func`]'s metadata so a differentiator can recover the model,
73/// its free variables, and the queried output to build an adjoint pass.
74#[derive(Clone)]
75pub struct FemmFuncPayload {
76    /// The model the function evaluates.
77    pub model: FemmModel,
78    /// The model inputs treated as the function's free variables.
79    pub vars: Vec<Symbol>,
80    /// The output the function returns.
81    pub query: OutputQuery,
82}
83
84impl Object for FemmFuncPayload {
85    fn display(&self, _cx: &mut Cx) -> KernelResult<String> {
86        Ok(format!(
87            "#<femm-payload model={} query={}>",
88            self.model.id.0,
89            describe_query(&self.query)
90        ))
91    }
92
93    fn as_any(&self) -> &dyn Any {
94        self
95    }
96}
97
98impl sim_kernel::ObjectCompat for FemmFuncPayload {
99    fn class(&self, cx: &mut Cx) -> KernelResult<ClassRef> {
100        if let Some(class) = cx
101            .registry()
102            .class_by_symbol(&Symbol::qualified("femm", "FuncPayload"))
103        {
104            return Ok(class.clone());
105        }
106        DefaultFactory.class_stub(ClassId(33), Symbol::qualified("femm", "FuncPayload"))
107    }
108    fn as_expr(&self, cx: &mut Cx) -> KernelResult<Expr> {
109        sim_citizen::constructor_expr(cx, self)
110    }
111    fn as_object_encoder(&self) -> Option<&dyn ObjectEncode> {
112        Some(self)
113    }
114}
115
116impl ObjectEncode for FemmFuncPayload {
117    fn object_encoding(&self, _cx: &mut Cx) -> KernelResult<ObjectEncoding> {
118        Ok(ObjectEncoding::Constructor {
119            class: func_payload_class_symbol(),
120            args: payload_constructor_args(self),
121        })
122    }
123}
124
125impl sim_citizen::Citizen for FemmFuncPayload {
126    fn citizen_symbol() -> Symbol {
127        func_payload_class_symbol()
128    }
129
130    fn citizen_version() -> u32 {
131        1
132    }
133
134    fn citizen_arity() -> usize {
135        3
136    }
137
138    fn citizen_fields() -> &'static [&'static str] {
139        &["model_id", "query", "vars"]
140    }
141}
142
143fn func_payload_class_symbol() -> Symbol {
144    Symbol::qualified("femm", "FuncPayload")
145}
146
147fn payload_constructor_args(payload: &FemmFuncPayload) -> Vec<Expr> {
148    vec![
149        Expr::Symbol(Symbol::new("v1")),
150        int_expr(payload.model.id.0),
151        Expr::String(describe_query(&payload.query)),
152        Expr::List(
153            payload
154                .vars
155                .iter()
156                .map(|name| Expr::String(name.to_string()))
157                .collect(),
158        ),
159    ]
160}
161
162fn int_expr(value: impl ToString) -> Expr {
163    Expr::Number(sim_kernel::NumberLiteral {
164        domain: Symbol::qualified("citizen", "int"),
165        canonical: value.to_string(),
166    })
167}
168
169/// Something that can be evaluated as a FEMM function of its parameters.
170pub trait FemmCallable {
171    /// Evaluates the callable for one [`FemmCall`], returning its output.
172    fn eval(&self, cx: &mut Cx, call: FemmCall) -> FemmResult<FemmEval>;
173}
174
175/// A [`FemmCallable`] that solves a concrete model on each evaluation.
176///
177/// Resolves defaults for any unbound inputs, runs the steady solve, and reduces
178/// the solution to the requested [`OutputQuery`].
179#[derive(Clone)]
180pub struct ModelCallable {
181    /// The model solved on each call.
182    pub model: FemmModel,
183}
184
185impl ModelCallable {
186    fn resolve_params(&self, params: &ParamSet) -> FemmResult<ParamSet> {
187        let mut entries = params.entries.clone();
188        for input in &self.model.inputs {
189            if entries.iter().all(|(name, _)| name != &input.name) {
190                if let Some(default) = &input.default {
191                    entries.push((input.name.clone(), default.clone()));
192                } else {
193                    return Err(FemmError::UnknownFemmParameter(input.name.to_string()));
194                }
195            }
196        }
197        Ok(ParamSet::new(entries))
198    }
199
200    fn solve_solution(
201        &self,
202        cx: &mut Cx,
203        params: &ParamSet,
204        limits: &FemmLimits,
205    ) -> FemmResult<Arc<FemmSolution>> {
206        let resolved = self.resolve_params(params)?;
207        solve_steady(cx, &self.model, &resolved, limits, None).map(|out| out.solution)
208    }
209}
210
211impl FemmCallable for ModelCallable {
212    fn eval(&self, cx: &mut Cx, call: FemmCall) -> FemmResult<FemmEval> {
213        let resolved = self.resolve_params(&call.params)?;
214        match call.query {
215            OutputQuery::Quantity(QuantitySpec::Custom { expr, .. }) => {
216                let value = sim_lib_femm_geometry::eval_expr_f64(cx, &expr, &resolved, &[])?;
217                Ok(FemmEval {
218                    value: cx
219                        .factory()
220                        .number_literal(Symbol::qualified("numbers", "f64"), value.to_string())
221                        .map_err(|err| FemmError::SensitivityUnavailable(err.to_string()))?,
222                    gradient: None,
223                    diagnostics: Vec::new(),
224                })
225            }
226            OutputQuery::Quantity(spec) => {
227                let solution = self.solve_solution(cx, &resolved, &call.limits)?;
228                let scalar = quantity(&solution, &spec)?;
229                Ok(FemmEval {
230                    value: cx
231                        .factory()
232                        .number_literal(Symbol::qualified("numbers", "f64"), scalar.to_string())
233                        .map_err(|err| FemmError::SensitivityUnavailable(err.to_string()))?,
234                    gradient: None,
235                    diagnostics: Vec::new(),
236                })
237            }
238            OutputQuery::Field(projection) => {
239                let solution = self.solve_solution(cx, &resolved, &call.limits)?;
240                let field = Field::new(solution, projection);
241                Ok(FemmEval {
242                    value: cx
243                        .factory()
244                        .opaque(Arc::new(field))
245                        .map_err(|err| FemmError::SensitivityUnavailable(err.to_string()))?,
246                    gradient: None,
247                    diagnostics: Vec::new(),
248                })
249            }
250            OutputQuery::Solution => {
251                let solution = self.solve_solution(cx, &resolved, &call.limits)?;
252                Ok(FemmEval {
253                    value: cx
254                        .factory()
255                        .opaque(solution)
256                        .map_err(|err| FemmError::SensitivityUnavailable(err.to_string()))?,
257                    gradient: None,
258                    diagnostics: Vec::new(),
259                })
260            }
261        }
262    }
263}
264
265/// Returns the requested quantity and the certificate for a completed solve.
266///
267/// Passing `Some(params)` for `wrt` also computes a total finite-difference
268/// gradient and annotates the returned certificate with its trust level.
269/// Passing `None` skips gradient work.
270pub fn quality(
271    cx: &mut Cx,
272    solve: &SteadySolve,
273    quantity_spec: &QuantitySpec,
274    wrt: Option<&[Symbol]>,
275) -> FemmResult<QualityAnswer> {
276    let value = quantity(&solve.solution, quantity_spec)?;
277    let mut certificate = solve.certificate.clone();
278    let gradient = match wrt {
279        None => None,
280        Some(params) => {
281            let (values, trust) =
282                finite_difference_quality_gradient(cx, solve, quantity_spec, params)?;
283            certificate.set_gradient_trust(trust.clone());
284            Some((values, trust))
285        }
286    };
287    Ok(QualityAnswer {
288        value,
289        certificate,
290        gradient,
291    })
292}
293
294fn finite_difference_quality_gradient(
295    cx: &mut Cx,
296    solve: &SteadySolve,
297    quantity_spec: &QuantitySpec,
298    wrt: &[Symbol],
299) -> FemmResult<(Vec<f64>, GradientTrust)> {
300    let callable = ModelCallable {
301        model: solve.model.clone(),
302    };
303    let base_params = callable.resolve_params(&solve.solution.params)?;
304    let mut out = Vec::with_capacity(wrt.len());
305    for symbol in wrt {
306        let base_value = base_params
307            .get(symbol)
308            .ok_or_else(|| FemmError::UnknownFemmParameter(symbol.to_string()))?;
309        let x = value_as_f64(cx, base_value)?;
310        if !x.is_finite() {
311            return Err(FemmError::SensitivityUnavailable(format!(
312                "non-finite FEMM parameter {symbol}"
313            )));
314        }
315        let h = fd_step(x);
316        let plus = replace_param_value(cx, &base_params, symbol, x + h)?;
317        let minus = replace_param_value(cx, &base_params, symbol, x - h)?;
318        let plus_value = quality_at_params(cx, &solve.model, plus, quantity_spec)?;
319        let minus_value = quality_at_params(cx, &solve.model, minus, quantity_spec)?;
320        out.push((plus_value - minus_value) / (2.0 * h));
321    }
322    Ok((out, GradientTrust::FiniteDifferenceOnly))
323}
324
325fn quality_at_params(
326    cx: &mut Cx,
327    model: &FemmModel,
328    params: ParamSet,
329    quantity_spec: &QuantitySpec,
330) -> FemmResult<f64> {
331    let solved = solve_steady(cx, model, &params, &FemmLimits::default(), None)?;
332    quantity(&solved.solution, quantity_spec)
333}
334
335fn replace_param_value(
336    cx: &mut Cx,
337    params: &ParamSet,
338    name: &Symbol,
339    value: f64,
340) -> FemmResult<ParamSet> {
341    let mut found = false;
342    let mut entries = params.entries.clone();
343    for (symbol, slot) in &mut entries {
344        if symbol == name {
345            *slot = cx
346                .factory()
347                .number_literal(Symbol::qualified("numbers", "f64"), value.to_string())
348                .map_err(|err| FemmError::SensitivityUnavailable(err.to_string()))?;
349            found = true;
350        }
351    }
352    if found {
353        Ok(ParamSet::new(entries))
354    } else {
355        Err(FemmError::UnknownFemmParameter(name.to_string()))
356    }
357}
358
359fn fd_step(value: f64) -> f64 {
360    1.0e-6 * value.abs().max(1.0)
361}
362
363/// Wraps a model as a sim-numbers [`Func`] of the named variables.
364///
365/// The returned function solves the model on call and reduces it to `query`;
366/// its metadata carries a [`FemmFuncPayload`] and an adjoint differentiator
367/// hint so sensitivity analysis can recover the model.
368///
369/// # Examples
370///
371/// ```
372/// use sim_kernel::Symbol;
373/// use sim_lib_femm_fixtures::parallel_plate_capacitor;
374/// use sim_lib_femm_function::{femm_as_func, OutputQuery};
375/// use sim_lib_femm_post::QuantitySpec;
376///
377/// let vars = vec![Symbol::new("gap-mm")];
378/// let func = femm_as_func(
379///     parallel_plate_capacitor(),
380///     vars.clone(),
381///     OutputQuery::Quantity(QuantitySpec::Energy { region: None }),
382/// );
383/// assert_eq!(func.vars, vars);
384/// ```
385pub fn femm_as_func(model: FemmModel, vars: Vec<Symbol>, query: OutputQuery) -> Func {
386    let callable = ModelCallable {
387        model: model.clone(),
388    };
389    let closure_vars = vars.clone();
390    let payload_vars = closure_vars.clone();
391    let closure_query = query.clone();
392    Func {
393        vars,
394        body_cas: None,
395        body_native: Some(Arc::new(move |cx, args| {
396            let params = ParamSet::new(
397                closure_vars
398                    .iter()
399                    .cloned()
400                    .zip(args.iter().cloned())
401                    .collect::<Vec<_>>(),
402            );
403            callable
404                .eval(
405                    cx,
406                    FemmCall {
407                        params,
408                        query: closure_query.clone(),
409                        want_grad: None,
410                        limits: FemmLimits::default(),
411                    },
412                )
413                .map(|out| out.value)
414                .map_err(sim_kernel::Error::from)
415        })),
416        metadata: FuncMetadata {
417            source: Some(Symbol::qualified("femm", "model")),
418            differentiator_hint: Some(Symbol::new("femm-adjoint")),
419            payload: DefaultFactory
420                .opaque(Arc::new(FemmFuncPayload {
421                    model: model.clone(),
422                    vars: payload_vars,
423                    query: query.clone(),
424                }))
425                .ok(),
426        },
427    }
428}
429
430/// Wraps a model's potential field as a sim-numbers [`Func`] over position.
431///
432/// Builds a trivial single-element solution for `model` and exposes its
433/// potential projection as a spatial function, used where a model is consumed
434/// as a field-valued function rather than a parameter-to-scalar map.
435pub fn femm_field_func(model: FemmModel) -> Func {
436    let field = Arc::new(FemmSolution {
437        id: StableId(model.id.0 + 1),
438        model_id: model.id,
439        physics: model.physics.clone(),
440        formulation: model.formulation.clone(),
441        params: ParamSet::default(),
442        mesh: sim_lib_femm_mesh::FemMesh2 {
443            xy: vec![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
444            tri: vec![[0, 1, 2]],
445            elem_region: vec![Symbol::new("air")],
446            edge_boundary: Vec::new(),
447        },
448        u: vec![0.0, 1.0, 1.0],
449        diagnostics: sim_lib_femm_flow::SolveDiagnostics {
450            method: Symbol::new("femm-ptc"),
451            converged: true,
452            iterations: 1,
453            final_residual: 0.0,
454            events: Vec::new(),
455            diagnostics: Vec::new(),
456        },
457    });
458    field_as_func(Field::new(field, Projection::Potential))
459}
460
461pub(crate) fn describe_query(query: &OutputQuery) -> String {
462    match query {
463        OutputQuery::Quantity(QuantitySpec::Custom { name, .. }) => format!("quantity:{name}"),
464        OutputQuery::Quantity(_) => "quantity".to_owned(),
465        OutputQuery::Field(projection) => format!("field:{projection:?}"),
466        OutputQuery::Solution => "solution".to_owned(),
467    }
468}