Skip to main content

sim_lib_control/
ops.rs

1use std::sync::Arc;
2
3use sim_kernel::{
4    Args, Callable, ClassRef, Cx, Error, Expr, Object, ObjectCompat, RawArgs, Ref, Result, Symbol,
5    Value,
6    control::{
7        ControlAbort, ControlCapture, ControlPrompt, ControlResume, abort, capture,
8        default_control_result_shape, prompt, resume,
9    },
10};
11
12use crate::model::{ContinuationValue, ControlResultValue};
13
14/// A callable runtime object exposing one control primitive.
15///
16/// The four [`ControlFunction`] variants (`prompt`, `capture`, `abort`,
17/// `resume`) are installed by the control lib as `control/*` functions, turning
18/// the kernel control-policy operations into callables the runtime can invoke.
19#[derive(Clone)]
20pub struct ControlFunction {
21    kind: ControlFunctionKind,
22}
23
24#[derive(Clone, Copy)]
25enum ControlFunctionKind {
26    Prompt,
27    Capture,
28    Abort,
29    Resume,
30}
31
32impl ControlFunction {
33    /// Builds the `control/prompt` function, which establishes a prompt.
34    pub fn prompt() -> Self {
35        Self {
36            kind: ControlFunctionKind::Prompt,
37        }
38    }
39
40    /// Builds the `control/capture` function, which captures a continuation.
41    pub fn capture() -> Self {
42        Self {
43            kind: ControlFunctionKind::Capture,
44        }
45    }
46
47    /// Builds the `control/abort` function, which aborts to a prompt.
48    pub fn abort() -> Self {
49        Self {
50            kind: ControlFunctionKind::Abort,
51        }
52    }
53
54    /// Builds the `control/resume` function, which resumes a continuation.
55    pub fn resume() -> Self {
56        Self {
57            kind: ControlFunctionKind::Resume,
58        }
59    }
60
61    /// Returns the `control/*` symbol under which this function is exported.
62    pub fn symbol(&self) -> Symbol {
63        self.kind.symbol()
64    }
65}
66
67impl Object for ControlFunction {
68    fn display(&self, _cx: &mut Cx) -> Result<String> {
69        Ok(format!("#<function {}>", self.kind.symbol()))
70    }
71
72    fn as_any(&self) -> &dyn std::any::Any {
73        self
74    }
75}
76
77impl ObjectCompat for ControlFunction {
78    fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
79        cx.resolve_class(&Symbol::qualified("core", "Function"))
80    }
81
82    fn as_callable(&self) -> Option<&dyn Callable> {
83        Some(self)
84    }
85}
86
87impl Callable for ControlFunction {
88    fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
89        self.kind.call(cx, args.into_vec())
90    }
91
92    fn call_exprs(&self, cx: &mut Cx, args: RawArgs) -> Result<Value> {
93        let values = args
94            .into_exprs()
95            .into_iter()
96            .map(|expr| cx.eval_expr(expr))
97            .collect::<Result<Vec<_>>>()?;
98        self.kind.call(cx, values)
99    }
100}
101
102impl ControlFunctionKind {
103    fn symbol(self) -> Symbol {
104        match self {
105            Self::Prompt => prompt_symbol(),
106            Self::Capture => capture_symbol(),
107            Self::Abort => abort_symbol(),
108            Self::Resume => resume_symbol(),
109        }
110    }
111
112    fn call(self, cx: &mut Cx, args: Vec<Value>) -> Result<Value> {
113        match self {
114            Self::Prompt => call_prompt(cx, args),
115            Self::Capture => call_capture(cx, args),
116            Self::Abort => call_abort(cx, args),
117            Self::Resume => call_resume(cx, args),
118        }
119    }
120}
121
122/// Returns the `control/prompt` symbol.
123pub fn prompt_symbol() -> Symbol {
124    Symbol::qualified("control", "prompt")
125}
126
127/// Returns the `control/capture` symbol.
128pub fn capture_symbol() -> Symbol {
129    Symbol::qualified("control", "capture")
130}
131
132/// Returns the `control/abort` symbol.
133pub fn abort_symbol() -> Symbol {
134    Symbol::qualified("control", "abort")
135}
136
137/// Returns the `control/resume` symbol.
138pub fn resume_symbol() -> Symbol {
139    Symbol::qualified("control", "resume")
140}
141
142fn call_prompt(cx: &mut Cx, args: Vec<Value>) -> Result<Value> {
143    let refs = refs_from_args(cx, args, "control/prompt")?;
144    let [prompt_ref, value_ref] = refs.as_slice() else {
145        return Err(arity_error("control/prompt", "prompt value"));
146    };
147    let prompt_ref = prompt_ref.clone();
148    let value_ref = value_ref.clone();
149    let result = prompt(
150        cx,
151        ControlPrompt::new(
152            prompt_ref,
153            value_ref.clone(),
154            default_control_result_shape(),
155        ),
156        |_cx| Ok(value_ref),
157    )?;
158    control_result_value(cx, result)
159}
160
161fn call_capture(cx: &mut Cx, args: Vec<Value>) -> Result<Value> {
162    let multishot = optional_bool_arg(cx, args.get(3))?;
163    let refs = refs_from_args(cx, args.into_iter().take(3).collect(), "control/capture")?;
164    let [prompt_ref, continuation_ref, value_ref] = refs.as_slice() else {
165        return Err(arity_error(
166            "control/capture",
167            "prompt continuation value [multishot]",
168        ));
169    };
170    let mut request = ControlCapture::new(
171        prompt_ref.clone(),
172        continuation_ref.clone(),
173        value_ref.clone(),
174        default_control_result_shape(),
175    );
176    if multishot {
177        request = request.multishot();
178    }
179    let capture_result = capture(cx, request)?;
180    cx.factory().opaque(Arc::new(ContinuationValue::new(
181        continuation_ref.clone(),
182        capture_result,
183        multishot,
184    )))
185}
186
187fn call_abort(cx: &mut Cx, args: Vec<Value>) -> Result<Value> {
188    let refs = refs_from_args(cx, args, "control/abort")?;
189    let [prompt_ref, value_ref] = refs.as_slice() else {
190        return Err(arity_error("control/abort", "prompt value"));
191    };
192    let prompt_ref = prompt_ref.clone();
193    let value_ref = value_ref.clone();
194    let result = abort(
195        cx,
196        ControlAbort::new(prompt_ref, value_ref, default_control_result_shape()),
197    )?;
198    control_result_value(cx, result)
199}
200
201fn call_resume(cx: &mut Cx, args: Vec<Value>) -> Result<Value> {
202    if args.len() != 2 {
203        return Err(arity_error("control/resume", "continuation value"));
204    }
205    let continuation = continuation_ref(cx, &args[0])?;
206    let value = value_ref(cx, &args[1], "control/resume value")?;
207    let result = resume(
208        cx,
209        ControlResume::new(continuation, value, default_control_result_shape()),
210    )?;
211    control_result_value(cx, result)
212}
213
214fn refs_from_args(cx: &mut Cx, args: Vec<Value>, context: &'static str) -> Result<Vec<Ref>> {
215    args.iter()
216        .map(|value| value_ref(cx, value, context))
217        .collect()
218}
219
220fn continuation_ref(cx: &mut Cx, value: &Value) -> Result<Ref> {
221    if let Some(continuation) = value.object().downcast_ref::<ContinuationValue>() {
222        return Ok(continuation.continuation().clone());
223    }
224    value_ref(cx, value, "control continuation")
225}
226
227fn value_ref(cx: &mut Cx, value: &Value, context: &'static str) -> Result<Ref> {
228    if let Some(result) = value.object().downcast_ref::<ControlResultValue>() {
229        return Ok(result.reference().clone());
230    }
231    let expr = value.object().as_expr(cx)?;
232    match expr {
233        Expr::Symbol(symbol) => Ok(Ref::Symbol(symbol)),
234        _ => Err(Error::TypeMismatch {
235            expected: context,
236            found: "non-ref value",
237        }),
238    }
239}
240
241fn optional_bool_arg(cx: &mut Cx, value: Option<&Value>) -> Result<bool> {
242    let Some(value) = value else {
243        return Ok(false);
244    };
245    match value.object().as_expr(cx)? {
246        Expr::Bool(value) => Ok(value),
247        _ => Err(Error::TypeMismatch {
248            expected: "bool",
249            found: "non-bool",
250        }),
251    }
252}
253
254fn control_result_value(cx: &mut Cx, reference: Ref) -> Result<Value> {
255    cx.factory()
256        .opaque(Arc::new(ControlResultValue::new(reference)))
257}
258
259fn arity_error(function: &'static str, expected: &'static str) -> Error {
260    Error::Eval(format!("{function} expects {expected}"))
261}