1use std::{collections::BTreeMap, sync::Arc};
4
5use sim_kernel::{
6 Args, ClassRef, Cx, DefaultFactory, Error, Expr, Factory, HandleStore, Object, Ref, Result,
7 Symbol, Term, Value,
8};
9use sim_lib_numbers_func::Func;
10
11use super::{options::parse_symbolish_value, registry::global_numeric_registry};
12
13#[derive(Clone, Debug, PartialEq, Eq)]
15pub enum PipelineKind {
16 OdeSolve,
18 Quadrature,
20}
21
22impl PipelineKind {
23 pub fn symbol(&self) -> Symbol {
25 match self {
26 Self::OdeSolve => Symbol::new("ode-solve"),
27 Self::Quadrature => Symbol::new("quadrature"),
28 }
29 }
30}
31
32#[derive(Clone, Debug, PartialEq, Eq)]
34pub enum StateKind {
35 F64,
37 Tensor,
39}
40
41impl StateKind {
42 pub fn symbol(&self) -> Symbol {
44 match self {
45 Self::F64 => Symbol::new("f64"),
46 Self::Tensor => Symbol::new("tensor"),
47 }
48 }
49}
50
51#[derive(Clone, Debug)]
53pub struct ComposedPipeline {
54 pub func_ref: Ref,
56 pub kind: PipelineKind,
58 pub method: Symbol,
60 pub state: StateKind,
62}
63
64impl ComposedPipeline {
65 pub fn new(func_ref: Ref, kind: PipelineKind, method: Symbol, state: StateKind) -> Self {
67 Self {
68 func_ref,
69 kind,
70 method,
71 state,
72 }
73 }
74
75 pub fn table_value(&self, factory: &dyn Factory) -> Result<Value> {
77 factory.table(vec![
78 (
79 Symbol::new("kind"),
80 factory.string("composed-pipeline".to_owned())?,
81 ),
82 (Symbol::new("domain"), factory.symbol(self.kind.symbol())?),
83 (Symbol::new("method"), factory.symbol(self.method.clone())?),
84 (Symbol::new("state"), factory.symbol(self.state.symbol())?),
85 (
86 Symbol::new("func"),
87 factory.expr(Term::Ref(self.func_ref.clone()).into())?,
88 ),
89 ])
90 }
91}
92
93impl Object for ComposedPipeline {
94 fn display(&self, _cx: &mut Cx) -> Result<String> {
95 Ok(format!(
96 "#<composed-pipeline {} {} {}>",
97 self.kind.symbol(),
98 self.method,
99 self.state.symbol()
100 ))
101 }
102
103 fn as_any(&self) -> &dyn std::any::Any {
104 self
105 }
106}
107
108impl sim_kernel::ObjectCompat for ComposedPipeline {
109 fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
110 if let Some(value) = cx
111 .registry()
112 .class_by_symbol(&Symbol::qualified("core", "Table"))
113 {
114 return Ok(value.clone());
115 }
116 DefaultFactory.class_stub(
117 sim_kernel::CORE_TABLE_CLASS_ID,
118 Symbol::qualified("core", "Table"),
119 )
120 }
121
122 fn as_expr(&self, cx: &mut Cx) -> Result<Expr> {
123 self.as_table(cx)?.object().as_expr(cx)
124 }
125
126 fn as_table(&self, cx: &mut Cx) -> Result<Value> {
127 self.table_value(cx.factory())
128 }
129}
130
131pub fn call_numeric_compose(cx: &mut Cx, args: Args) -> Result<Value> {
132 let values = args.into_vec();
133 let pipeline = compose_from_values(cx, &values)?;
134 pipeline_value(cx, pipeline)
135}
136
137pub fn call_numeric_compose_exprs(cx: &mut Cx, args: Vec<Expr>) -> Result<Value> {
138 let pipeline = compose_from_exprs(cx, &args)?;
139 pipeline_value(cx, pipeline)
140}
141
142pub fn call_numeric_run_composed(cx: &mut Cx, args: Args) -> Result<Value> {
143 super::pipeline_run::call_numeric_run_composed(cx, args)
144}
145
146pub fn call_numeric_run_composed_exprs(cx: &mut Cx, args: Vec<Expr>) -> Result<Value> {
147 super::pipeline_run::call_numeric_run_composed_exprs(cx, args)
148}
149
150fn compose_from_values(cx: &mut Cx, values: &[Value]) -> Result<ComposedPipeline> {
151 match values {
152 [func, kind, method, state] if !is_compose_key_value(cx, kind)? => {
153 let func_ref = require_func_ref(cx, "numeric/compose", func)?;
154 let kind = require_pipeline_kind_value(cx, "numeric/compose", kind)?;
155 let method = require_symbol_value(cx, "numeric/compose", method)?;
156 let state = require_state_kind_value(cx, "numeric/compose", state)?;
157 finish_compose(func_ref, kind, method, state)
158 }
159 [func, rest @ ..] if rest.len().is_multiple_of(2) => {
160 let func_ref = require_func_ref(cx, "numeric/compose", func)?;
161 let mut options = BTreeMap::<String, Value>::new();
162 for pair in rest.chunks(2) {
163 let key = require_compose_key_value(cx, &pair[0])?;
164 options.insert(key, pair[1].clone());
165 }
166 let kind = require_compose_kind_value(cx, &options)?;
167 let method = require_compose_symbol_value(cx, &options, "method")?;
168 let state = require_compose_state_value(cx, &options)?;
169 finish_compose(func_ref, kind, method, state)
170 }
171 _ => Err(Error::Eval(
172 "numeric/compose expects func, kind, method, state or keyword pairs".to_owned(),
173 )),
174 }
175}
176
177fn compose_from_exprs(cx: &mut Cx, args: &[Expr]) -> Result<ComposedPipeline> {
178 let Some((func_expr, rest)) = args.split_first() else {
179 return Err(Error::Eval(
180 "numeric/compose expects func, kind, method, state or keyword pairs".to_owned(),
181 ));
182 };
183 let func = cx.eval_expr(func_expr.clone())?;
184 let func_ref = require_func_ref(cx, "numeric/compose", &func)?;
185 if let [kind_expr, method_expr, state_expr] = rest
186 && !is_compose_key_expr(kind_expr)
187 {
188 let kind = require_pipeline_kind_expr("numeric/compose", kind_expr)?;
189 let method = require_symbol_expr("numeric/compose", method_expr)?;
190 let state = require_state_kind_expr("numeric/compose", state_expr)?;
191 return finish_compose(func_ref, kind, method, state);
192 }
193 if !rest.len().is_multiple_of(2) {
194 return Err(Error::Eval(
195 "numeric/compose keyword arguments must be key/value pairs".to_owned(),
196 ));
197 }
198 let mut options = BTreeMap::<String, Symbol>::new();
199 for pair in rest.chunks(2) {
200 options.insert(
201 require_compose_key_expr(&pair[0])?,
202 require_symbol_expr("numeric/compose", &pair[1])?,
203 );
204 }
205 let kind = require_compose_kind_symbol(&options)?;
206 let method = require_compose_symbol(&options, "method")?;
207 let state = parse_state_kind(&require_compose_symbol(&options, "state")?).ok_or_else(|| {
208 Error::Eval("numeric/compose expected state kind f64 or tensor".to_owned())
209 })?;
210 finish_compose(func_ref, kind, method, state)
211}
212
213fn require_func_ref(cx: &mut Cx, name: &str, value: &Value) -> Result<Ref> {
214 value.object().downcast_ref::<Func>().ok_or_else(|| {
215 Error::Eval(format!(
216 "{name} expects its first argument to be a Func value"
217 ))
218 })?;
219 Ok(Ref::Handle(cx.handles_mut().intern(value.clone())))
220}
221
222fn pipeline_value(cx: &mut Cx, pipeline: ComposedPipeline) -> Result<Value> {
223 cx.factory().opaque(Arc::new(pipeline))
224}
225
226fn require_pipeline_kind_value(cx: &mut Cx, name: &str, value: &Value) -> Result<PipelineKind> {
227 let symbol = require_symbol_value(cx, name, value)?;
228 parse_pipeline_kind(&symbol).ok_or_else(|| {
229 Error::Eval(format!(
230 "{name} expected pipeline kind ode-solve or quadrature"
231 ))
232 })
233}
234
235fn require_state_kind_value(cx: &mut Cx, name: &str, value: &Value) -> Result<StateKind> {
236 let symbol = require_symbol_value(cx, name, value)?;
237 parse_state_kind(&symbol)
238 .ok_or_else(|| Error::Eval(format!("{name} expected state kind f64 or tensor")))
239}
240
241fn require_symbol_value(cx: &mut Cx, name: &str, value: &Value) -> Result<Symbol> {
242 parse_symbolish_value(cx, value)?
243 .ok_or_else(|| Error::Eval(format!("{name} expected a symbol argument")))
244}
245
246fn finish_compose(
247 func_ref: Ref,
248 kind: PipelineKind,
249 method: Symbol,
250 state: StateKind,
251) -> Result<ComposedPipeline> {
252 if kind == PipelineKind::Quadrature {
253 validate_quadrature_method(&method)?;
254 }
255 Ok(ComposedPipeline::new(func_ref, kind, method, state))
256}
257
258fn require_compose_kind_value(
259 cx: &mut Cx,
260 options: &BTreeMap<String, Value>,
261) -> Result<PipelineKind> {
262 let symbol = options
263 .get("domain")
264 .or_else(|| options.get("kind"))
265 .ok_or_else(|| Error::Eval("numeric/compose missing :domain".to_owned()))
266 .and_then(|value| require_symbol_value(cx, "numeric/compose", value))?;
267 parse_pipeline_kind(&symbol).ok_or_else(|| {
268 Error::Eval("numeric/compose expected domain ode-solve or quadrature".to_owned())
269 })
270}
271
272fn require_compose_kind_symbol(options: &BTreeMap<String, Symbol>) -> Result<PipelineKind> {
273 let symbol = options
274 .get("domain")
275 .or_else(|| options.get("kind"))
276 .ok_or_else(|| Error::Eval("numeric/compose missing :domain".to_owned()))?;
277 parse_pipeline_kind(symbol).ok_or_else(|| {
278 Error::Eval("numeric/compose expected domain ode-solve or quadrature".to_owned())
279 })
280}
281
282fn require_compose_state_value(
283 cx: &mut Cx,
284 options: &BTreeMap<String, Value>,
285) -> Result<StateKind> {
286 let symbol = require_compose_symbol_value(cx, options, "state")?;
287 parse_state_kind(&symbol)
288 .ok_or_else(|| Error::Eval("numeric/compose expected state kind f64 or tensor".to_owned()))
289}
290
291fn require_compose_symbol_value(
292 cx: &mut Cx,
293 options: &BTreeMap<String, Value>,
294 key: &str,
295) -> Result<Symbol> {
296 let value = options
297 .get(key)
298 .ok_or_else(|| Error::Eval(format!("numeric/compose missing :{key}")))?;
299 require_symbol_value(cx, "numeric/compose", value)
300}
301
302fn require_compose_symbol(options: &BTreeMap<String, Symbol>, key: &str) -> Result<Symbol> {
303 options
304 .get(key)
305 .cloned()
306 .ok_or_else(|| Error::Eval(format!("numeric/compose missing :{key}")))
307}
308
309fn require_pipeline_kind_expr(name: &str, expr: &Expr) -> Result<PipelineKind> {
310 let symbol = require_symbol_expr(name, expr)?;
311 parse_pipeline_kind(&symbol).ok_or_else(|| {
312 Error::Eval(format!(
313 "{name} expected pipeline kind ode-solve or quadrature"
314 ))
315 })
316}
317
318fn require_state_kind_expr(name: &str, expr: &Expr) -> Result<StateKind> {
319 let symbol = require_symbol_expr(name, expr)?;
320 parse_state_kind(&symbol)
321 .ok_or_else(|| Error::Eval(format!("{name} expected state kind f64 or tensor")))
322}
323
324fn require_symbol_expr(name: &str, expr: &Expr) -> Result<Symbol> {
325 match expr {
326 Expr::Symbol(symbol) => Ok(symbol.clone()),
327 Expr::Quote { expr, .. } => match expr.as_ref() {
328 Expr::Symbol(symbol) => Ok(symbol.clone()),
329 _ => Err(Error::Eval(format!("{name} expected a symbol argument"))),
330 },
331 _ => Err(Error::Eval(format!("{name} expected a symbol argument"))),
332 }
333}
334
335fn parse_pipeline_kind(symbol: &Symbol) -> Option<PipelineKind> {
336 match keyword_name(symbol).as_str() {
337 "ode-solve" => Some(PipelineKind::OdeSolve),
338 "quadrature" => Some(PipelineKind::Quadrature),
339 _ => None,
340 }
341}
342
343fn parse_state_kind(symbol: &Symbol) -> Option<StateKind> {
344 match keyword_name(symbol).as_str() {
345 "f64" => Some(StateKind::F64),
346 "tensor" => Some(StateKind::Tensor),
347 _ => None,
348 }
349}
350
351fn keyword_name(symbol: &Symbol) -> String {
352 symbol
353 .name
354 .strip_prefix(':')
355 .unwrap_or(&symbol.name)
356 .to_owned()
357}
358
359fn is_compose_key_value(cx: &mut Cx, value: &Value) -> Result<bool> {
360 Ok(parse_symbolish_value(cx, value)?
361 .as_ref()
362 .is_some_and(|symbol| is_compose_key_name(&keyword_name(symbol))))
363}
364
365fn require_compose_key_value(cx: &mut Cx, value: &Value) -> Result<String> {
366 parse_symbolish_value(cx, value)?
367 .map(|symbol| keyword_name(&symbol))
368 .filter(|key| is_compose_key_name(key))
369 .ok_or_else(|| Error::Eval("numeric/compose expected keyword argument".to_owned()))
370}
371
372fn is_compose_key_expr(expr: &Expr) -> bool {
373 let Expr::Symbol(symbol) = expr else {
374 return false;
375 };
376 is_compose_key_name(&keyword_name(symbol))
377}
378
379fn require_compose_key_expr(expr: &Expr) -> Result<String> {
380 let Expr::Symbol(symbol) = expr else {
381 return Err(Error::Eval(
382 "numeric/compose expected keyword argument".to_owned(),
383 ));
384 };
385 let key = keyword_name(symbol);
386 if is_compose_key_name(&key) {
387 Ok(key)
388 } else {
389 Err(Error::Eval(format!(
390 "numeric/compose: unknown option :{key}"
391 )))
392 }
393}
394
395fn is_compose_key_name(key: &str) -> bool {
396 matches!(key, "domain" | "kind" | "method" | "state")
397}
398
399fn validate_quadrature_method(method: &Symbol) -> Result<()> {
400 let method = resolve_quad_method(method);
401 let registry = global_numeric_registry()
402 .read()
403 .map_err(|_| Error::PoisonedLock("numeric registry"))?;
404 if registry.quadrature_fixed(&method).is_some()
405 || registry.quadrature_adaptive(&method).is_some()
406 {
407 Ok(())
408 } else {
409 Err(unknown_numeric_method("quadrature", &method))
410 }
411}
412
413fn resolve_quad_method(method: &Symbol) -> Symbol {
414 if *method != Symbol::new("auto") {
415 return method.clone();
416 }
417 Symbol::new("simpson")
418}
419
420fn unknown_numeric_method(kind: &str, method: &Symbol) -> Error {
421 Error::Eval(format!("UnknownNumericMethod: {kind} method {method}"))
422}