1use self::eval::run_eval;
4use self::function::run_fn;
5use self::graph::GraphOperation;
6use self::r#box::run_box;
7use crate::util::JoinHandleWithDrop;
8use crate::workers::EscapeHatch;
9use crate::Runtime;
10use anyhow::{anyhow, bail};
11use futures::future::{self, AbortHandle, BoxFuture};
12use futures::stream::BoxStream;
13use futures::{Future, FutureExt, Stream, StreamExt};
14use std::collections::HashSet;
15use std::pin::Pin;
16use std::{collections::HashMap, sync::Arc};
17use thiserror::Error;
18use tierkreis_core::graph::{Edge, Graph, GraphBuilder, Node, Value};
19use tierkreis_core::prelude::TryInto;
20use tierkreis_core::symbol::{Label, Location, SymbolError};
21use tierkreis_proto::messages::{Callback, Completed, GraphTrace, Status};
22use tokio::sync::mpsc;
23use tokio::sync::watch;
24use tokio_stream::wrappers::UnboundedReceiverStream;
25use tracing::Instrument;
26
27pub(crate) mod r#box;
28pub(crate) mod eval;
29pub(crate) mod function;
30pub(crate) mod graph;
31pub(crate) mod variant;
32
33pub use graph::checkpoint_client::CheckpointClient;
34
35pub struct OperationInputs(BoxStream<'static, Input>);
38
39impl Stream for OperationInputs {
40 type Item = Input;
41
42 fn poll_next(
43 mut self: std::pin::Pin<&mut Self>,
44 cx: &mut std::task::Context<'_>,
45 ) -> std::task::Poll<Option<Self::Item>> {
46 Pin::new(&mut self.as_mut().0).poll_next(cx)
47 }
48}
49
50pub struct OperationOutputs {
54 stream: BoxStream<'static, Output>,
55 abort_input: AbortHandle,
56}
57
58impl OperationOutputs {
59 pub fn into_task(self) -> TaskHandle {
61 TaskHandle::new(self)
62 }
63}
64
65impl Drop for OperationOutputs {
66 fn drop(&mut self) {
67 self.abort_input.abort()
68 }
69}
70
71impl Stream for OperationOutputs {
72 type Item = Output;
73
74 fn poll_next(
75 mut self: Pin<&mut Self>,
76 cx: &mut std::task::Context<'_>,
77 ) -> std::task::Poll<Option<Self::Item>> {
78 Pin::new(&mut self.as_mut().stream).poll_next(cx)
79 }
80}
81
82pub struct RuntimeOperation {
87 start: Box<
88 dyn FnOnce(OperationContext, OperationInputs) -> BoxFuture<'static, anyhow::Result<()>>
89 + Send,
90 >,
91}
92
93impl RuntimeOperation {
94 fn new<F, FF>(f: F) -> Self
95 where
96 F: FnOnce(OperationContext, OperationInputs) -> FF + Send + 'static,
97 FF: Future<Output = anyhow::Result<()>> + Send + 'static,
98 {
99 Self {
100 start: Box::new(|ctx, inputs| f(ctx, inputs).boxed()),
101 }
102 }
103
104 pub fn new_const(value: Value) -> RuntimeOperation {
106 operation_const(value)
107 }
108
109 pub fn new_box(loc: Location, graph: Graph) -> RuntimeOperation {
113 RuntimeOperation::new(|ctx, inputs| run_box(loc, graph, ctx, inputs))
114 }
115
116 pub fn new_graph(graph: Graph) -> RuntimeOperation {
120 RuntimeOperation::new(move |ctx, inputs| GraphOperation::new(graph, ctx, inputs).run())
121 }
122
123 pub(crate) fn new_match() -> RuntimeOperation {
127 RuntimeOperation::new(variant::run_match)
128 }
129
130 pub(crate) fn new_tag(tag: Label) -> RuntimeOperation {
132 RuntimeOperation::new(move |ctx, inputs| variant::run_tag(tag, ctx, inputs))
133 }
134
135 pub fn new_fn_simple<F>(f: F) -> Self
139 where
140 F: FnOnce(HashMap<Label, Value>, OperationContext) -> anyhow::Result<HashMap<Label, Value>>
141 + Send
142 + 'static,
143 {
144 let f = |inputs, ctx| futures::future::ready(f(inputs, ctx));
145 RuntimeOperation::new(move |ctx, inputs| run_fn(f, ctx, inputs))
146 }
147
148 pub fn new_fn_async<F, FF>(f: F) -> Self
152 where
153 F: FnOnce(HashMap<Label, Value>, OperationContext) -> FF + Send + 'static,
154 FF: Future<Output = anyhow::Result<HashMap<Label, Value>>> + Send + 'static,
155 {
156 let f = |inputs, ctx| {
157 let span = tracing::Span::current();
158 let handle = tokio::spawn(f(inputs, ctx).instrument(span));
159 JoinHandleWithDrop::from(handle).map(|r| r.unwrap_or_else(|e| Err(e.into())))
160 };
161 RuntimeOperation::new(move |ctx, inputs| run_fn(f, ctx, inputs))
162 }
163
164 pub fn new_fn_blocking<F>(f: F) -> Self
168 where
169 F: FnOnce(HashMap<Label, Value>, OperationContext) -> anyhow::Result<HashMap<Label, Value>>
170 + Send
171 + 'static,
172 {
173 let f = |inputs, ctx| {
174 let span = tracing::Span::current();
175 let handle = tokio::task::spawn_blocking(move || span.in_scope(|| f(inputs, ctx)));
176 JoinHandleWithDrop::from(handle).map(|r| r.unwrap_or_else(|e| Err(e.into())))
177 };
178 RuntimeOperation::new(move |ctx, inputs| run_fn(f, ctx, inputs))
179 }
180
181 pub fn run<S>(
184 self,
185 runtime: Runtime,
186 callback: Callback,
187 escape: EscapeHatch,
188 inputs: S,
189 stack_trace: GraphTrace,
190 checkpoint_client: Option<CheckpointClient>,
191 ) -> OperationOutputs
192 where
193 S: Stream<Item = Input> + Send + 'static,
194 {
195 let inputs = inputs.chain(futures::stream::pending());
196 let (inputs, abort_handle) = futures::stream::abortable(inputs);
197
198 let (output_tx, output_rx) = mpsc::unbounded_channel();
199
200 let context = OperationContext {
201 output: output_tx.clone(),
202 callback,
203 escape,
204 runtime,
205 graph_trace: stack_trace,
206 checkpoint_client,
207 };
208
209 let span = tracing::Span::current();
210
211 tokio::spawn(
213 async move {
214 let result = (self.start)(context, OperationInputs(inputs.boxed())).await;
215
216 let _ = match result {
217 Ok(()) => output_tx.send(Output::Success),
218 Err(error) => output_tx.send(Output::Failure { error }),
219 };
220 }
221 .instrument(span),
222 );
223
224 OperationOutputs {
225 stream: UnboundedReceiverStream::new(output_rx).boxed(),
226 abort_input: abort_handle,
227 }
228 }
229
230 pub fn run_simple<I>(
232 self,
233 runtime: Runtime,
234 callback: Callback,
235 escape: EscapeHatch,
236 inputs: I,
237 stack_trace: GraphTrace,
238 checkpoint_client: Option<CheckpointClient>,
239 ) -> OperationOutputs
240 where
241 I: IntoIterator<Item = (Label, Value)>,
242 {
243 let inputs: HashMap<_, _> = inputs.into_iter().collect();
244
245 let inputs_stream = futures::stream::iter(inputs)
246 .map(|(port, value)| Input::Input { port, value })
247 .chain(tokio_stream::once(Input::Complete));
248
249 self.run(
250 runtime,
251 callback,
252 escape,
253 inputs_stream,
254 stack_trace,
255 checkpoint_client,
256 )
257 }
258}
259
260#[derive(Clone)]
263pub struct TaskHandle {
264 status: watch::Receiver<Status>,
265 abort: mpsc::UnboundedSender<()>,
266}
267
268impl std::fmt::Debug for TaskHandle {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 f.debug_struct("TaskHandle").finish()
271 }
272}
273
274impl TaskHandle {
275 fn new(mut output_stream: OperationOutputs) -> Self {
276 let (status_tx, status_rx) = watch::channel(Status::Running);
277 let (abort_tx, mut abort_rx) = mpsc::unbounded_channel();
278
279 tokio::spawn(async move {
280 let mut outputs = HashMap::new();
281
282 loop {
283 let output = tokio::select! { biased;
284 _ = abort_rx.recv() => break,
285 msg = output_stream.next() => {
286 match msg {
287 Some(msg) => msg,
288 None => break,
289 }
290 }
291 };
292
293 match output {
294 Output::Output { port, value } => {
295 outputs.insert(port, value);
296 }
297 Output::Success => {
298 let _ = status_tx.send(Status::Completed(Ok(Arc::new(outputs))));
299
300 return;
301 }
302 Output::Failure { error } => {
303 let _ = status_tx.send(Status::Completed(Err(Arc::new(error))));
304 return;
305 }
306 }
307 }
308
309 let error = anyhow!("task was cancelled");
310 let _ = status_tx.send(Status::Completed(Err(Arc::new(error))));
311 });
312
313 Self {
314 status: status_rx,
315 abort: abort_tx,
316 }
317 }
318
319 pub fn status(&self) -> Status {
322 self.status.borrow().clone()
323 }
324
325 pub fn cancel(&self) {
327 let _ = self.abort.send(());
328 }
329
330 pub async fn complete(&mut self) -> Completed {
332 loop {
333 if let Status::Completed(result) = self.status() {
334 match result {
335 Ok(_) => tracing::debug!("complete"),
336 Err(_) => tracing::warn!("complete with error"),
337 }
338 return result;
339 }
340 tracing::debug!("still running");
341
342 let event = self.status.changed().await;
343
344 if event.is_err() {
345 unreachable!("watch was closed before setting a completed status");
346 }
347 }
348 }
349}
350
351#[derive(Debug)]
353pub enum Input {
354 #[allow(missing_docs)]
356 Input { port: Label, value: Value },
357 Complete,
359}
360
361#[derive(Debug)]
363pub enum Output {
364 #[allow(missing_docs)]
366 Output { port: Label, value: Value },
367 Success,
369 #[allow(missing_docs)]
371 Failure { error: anyhow::Error },
372}
373
374impl Output {
375 pub fn context<C>(self, context: C) -> Self
377 where
378 C: std::fmt::Display + Send + Sync + 'static,
379 {
380 match self {
381 Self::Output { port, value } => Self::Output { port, value },
382 Self::Success => Self::Success,
383 Self::Failure { error } => {
384 let error = error.context(context);
385 Self::Failure { error }
386 }
387 }
388 }
389}
390
391#[derive(Clone)]
393pub struct OperationContext {
394 output: mpsc::UnboundedSender<Output>,
395 pub runtime: Runtime,
398 pub callback: Callback,
403 pub escape: EscapeHatch,
406 pub graph_trace: GraphTrace,
408 pub checkpoint_client: Option<CheckpointClient>,
410}
411
412impl OperationContext {
413 pub fn set_output(&self, port: impl Into<Label>, value: Value) {
415 let _ = self.output.send(Output::Output {
416 port: port.into(),
417 value,
418 });
419 }
420
421 fn outer_graph_checkpoint(&mut self) -> Option<&mut CheckpointClient> {
423 if self.graph_trace == GraphTrace::Root {
424 self.checkpoint_client.as_mut()
425 } else {
426 None
427 }
428 }
429}
430
431pub(crate) fn operation_eval() -> RuntimeOperation {
432 RuntimeOperation::new(run_eval)
433}
434
435pub(crate) fn operation_const(value: Value) -> RuntimeOperation {
436 RuntimeOperation::new_fn_simple(move |_inputs, _context| {
437 let mut outputs = HashMap::new();
438 outputs.insert(Label::value(), value);
439 Ok(outputs)
440 })
441}
442
443pub(crate) fn operation_id() -> RuntimeOperation {
444 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
445 let mut outputs = HashMap::new();
446 outputs.insert(Label::value(), take_input(&mut inputs, Label::value())?);
447 Ok(outputs)
448 })
449}
450
451pub(crate) fn operation_sleep() -> RuntimeOperation {
452 RuntimeOperation::new_fn_async(|mut inputs, _context| async move {
453 let delay = validate_float_input(&mut inputs, "delay_secs")?;
454 tokio::time::sleep(tokio::time::Duration::from_secs_f64(delay)).await;
455 let mut outputs = HashMap::new();
456 outputs.insert(Label::value(), take_input(&mut inputs, Label::value())?);
457 Ok(outputs)
458 })
459}
460
461pub(crate) fn operation_copy() -> RuntimeOperation {
462 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
463 let value = take_input(&mut inputs, Label::value())?;
464 let mut outputs = HashMap::new();
465 outputs.insert(TryInto::try_into("value_0")?, value.clone());
466 outputs.insert(TryInto::try_into("value_1")?, value);
467 Ok(outputs)
468 })
469}
470
471pub(crate) fn operation_discard() -> RuntimeOperation {
472 RuntimeOperation::new_fn_simple(|_inputs, _context| Ok(HashMap::new()))
473}
474
475pub(crate) fn operation_equality() -> RuntimeOperation {
476 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
477 let val0 = take_input(&mut inputs, "value_0")?;
478 let val1 = take_input(&mut inputs, "value_1")?;
479 let mut outputs = HashMap::new();
480 outputs.insert(TryInto::try_into("result")?, Value::Bool(val0 == val1));
481 Ok(outputs)
482 })
483}
484
485pub(crate) fn operation_not_equality() -> RuntimeOperation {
486 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
487 let val0 = take_input(&mut inputs, "value_0")?;
488 let val1 = take_input(&mut inputs, "value_1")?;
489 let mut outputs = HashMap::new();
490 outputs.insert(TryInto::try_into("result")?, Value::Bool(val0 != val1));
491 Ok(outputs)
492 })
493}
494
495pub(crate) fn operation_switch() -> RuntimeOperation {
496 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
497 let predicate = validate_bool_input(&mut inputs, "pred")?;
498
499 let branch_true = take_input(&mut inputs, "if_true")?;
500 let branch_false = take_input(&mut inputs, "if_false")?;
501
502 let result = if predicate { branch_true } else { branch_false };
503
504 let mut outputs = HashMap::new();
505 outputs.insert(Label::value(), result);
506 Ok(outputs)
507 })
508}
509
510pub(crate) fn operation_make_pair() -> RuntimeOperation {
511 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
512 let first = take_input(&mut inputs, "first")?;
513 let second = take_input(&mut inputs, "second")?;
514
515 let mut outputs = HashMap::new();
516 outputs.insert(
517 TryInto::try_into("pair")?,
518 Value::Pair(Box::new((first, second))),
519 );
520 Ok(outputs)
521 })
522}
523
524pub(crate) fn operation_unpack_pair() -> RuntimeOperation {
525 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
526 let (first, second) = validate_input(&mut inputs, "pair", |x| match x {
527 Value::Pair(pair) => Some((pair.0, pair.1)),
528 _ => None,
529 })?;
530
531 let mut outputs = HashMap::new();
532 outputs.insert(TryInto::try_into("first")?, first);
533 outputs.insert(TryInto::try_into("second")?, second);
534 Ok(outputs)
535 })
536}
537
538pub(crate) fn operation_push() -> RuntimeOperation {
539 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
540 let vec_l: Label = TryInto::try_into("vec")?;
541 let mut vec = match take_input(&mut inputs, vec_l)? {
542 Value::Vec(vec) => vec,
543 _ => bail!("Push function expected vector input."),
544 };
545
546 let item = take_input(&mut inputs, "item")?;
547
548 vec.push(item);
549 let mut outputs = HashMap::new();
550 outputs.insert(vec_l, Value::Vec(vec));
551 Ok(outputs)
552 })
553}
554
555pub(crate) fn operation_pop() -> RuntimeOperation {
556 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
557 let vec_l: Label = TryInto::try_into("vec")?;
558 let mut vec = match take_input(&mut inputs, vec_l)? {
559 Value::Vec(vec) => {
560 if vec.is_empty() {
561 Err(RuntimeError::EmptyVector)
562 } else {
563 Ok(vec)
564 }
565 }
566 _ => Err(RuntimeError::InvalidInput(vec_l)),
567 }?;
568
569 let item = vec.pop().unwrap();
570 let mut outputs = HashMap::new();
571 outputs.insert(TryInto::try_into("item")?, item);
572 outputs.insert(vec_l, Value::Vec(vec));
573
574 Ok(outputs)
575 })
576}
577
578pub(crate) fn operation_loop() -> RuntimeOperation {
579 RuntimeOperation::new_fn_async(|mut inputs, context| async move {
580 let _ = &context;
581 let body = validate_graph_input(&mut inputs, "body")?;
582 let mut value = take_input(&mut inputs, Label::value())?;
583
584 let node_trace = context
585 .graph_trace
586 .as_node_trace()
587 .map_err(|_| anyhow!("loop function expected stack trace to correspond to a node"))?;
588 for iteration in 1.. {
589 let graph_trace = node_trace.clone().loop_iter(iteration);
591
592 let body_output = RuntimeOperation::new_graph(body.clone())
593 .run_simple(
594 context.runtime.clone(),
595 context.callback.clone(),
596 context.escape.clone(),
597 [(Label::value(), value)],
598 graph_trace,
599 context.checkpoint_client.clone(),
600 )
601 .into_task()
602 .complete()
603 .await
604 .map_err(|err| {
605 let e = anyhow!(format!("{:?}", err.as_ref()));
607 e.context(format!("loop body (iteration {})", iteration))
608 })?;
609
610 let body_output = body_output.as_ref();
611 if let Some(Value::Variant(label, b)) = body_output.get(&Label::value()) {
612 value = *b.clone();
613 if label == &Label::continue_() {
614 continue;
615 } else if label == &Label::break_() {
616 break;
617 }
618 };
619 bail!(
621 "loop node expected body to output a Variant (break | continue) on port 'value' (iteration {})",
622 iteration
623 )
624 }
625 Ok(HashMap::from([(Label::value(), value)]))
626 })
627}
628
629pub(crate) fn operation_sequence() -> RuntimeOperation {
630 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
631 let first = validate_graph_input(&mut inputs, "first")?;
632
633 let second = validate_graph_input(&mut inputs, "second")?;
634
635 let g3 = {
636 let mut builder = GraphBuilder::new();
637 let [input, output] = Graph::boundary();
638 let inputs: Vec<Edge> = first.node_outputs(input).cloned().collect();
639 let second_input_ports: HashSet<_> =
640 second.node_outputs(input).map(|e| e.source.port).collect();
641 let middle: Vec<Edge> = first
643 .node_inputs(output)
644 .filter(|e| second_input_ports.contains(&e.target.port))
645 .cloned()
646 .collect();
647 let outputs: Vec<Edge> = second.node_inputs(output).cloned().collect();
648
649 let b1 = builder.add_node(Node::local_box(first))?;
650 let b2 = builder.add_node(Node::local_box(second))?;
651
652 for input_edge in inputs {
653 builder.add_edge(
654 (input, input_edge.source.port),
655 (b1, input_edge.source.port),
656 input_edge.edge_type,
657 )?;
658 }
659 for seq_edge in middle {
660 builder.add_edge(
661 (b1, seq_edge.target.port),
662 (b2, seq_edge.target.port),
663 seq_edge.edge_type,
664 )?;
665 }
666
667 for output_edge in outputs {
668 builder.add_edge(
669 (b2, output_edge.target.port),
670 (output, output_edge.target.port),
671 output_edge.edge_type,
672 )?;
673 }
674 builder.build()?
675 };
676
677 let mut outputs = HashMap::new();
678 outputs.insert(TryInto::try_into("sequenced")?, Value::Graph(g3));
679
680 Ok(outputs)
681 })
682}
683
684pub(crate) fn operation_parallel() -> RuntimeOperation {
685 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
686 let left = validate_graph_input(&mut inputs, "left")?;
687
688 let right = validate_graph_input(&mut inputs, "right")?;
689
690 let g3 = {
691 let mut builder = GraphBuilder::new();
692 let [input, output] = Graph::boundary();
693
694 let inputs_left: Vec<Edge> = left.node_outputs(input).cloned().collect();
695 let inputs_right: Vec<Edge> = right.node_outputs(input).cloned().collect();
696 let outputs_left: Vec<Edge> = left.node_inputs(output).cloned().collect();
697 let outputs_right: Vec<Edge> = right.node_inputs(output).cloned().collect();
698
699 let b_left = builder.add_node(Node::local_box(left))?;
700 let b_right = builder.add_node(Node::local_box(right))?;
701
702 for left_input_edge in inputs_left {
703 builder.add_edge(
704 (input, left_input_edge.source.port),
705 (b_left, left_input_edge.source.port),
706 left_input_edge.edge_type,
707 )?;
708 }
709
710 for right_input_edge in inputs_right {
711 builder.add_edge(
712 (input, right_input_edge.source.port),
713 (b_right, right_input_edge.source.port),
714 right_input_edge.edge_type,
715 )?;
716 }
717
718 for left_output_edge in outputs_left {
719 builder.add_edge(
720 (b_left, left_output_edge.target.port),
721 (output, left_output_edge.target.port),
722 left_output_edge.edge_type,
723 )?;
724 }
725
726 for right_output_edge in outputs_right {
727 builder.add_edge(
728 (b_right, right_output_edge.target.port),
729 (output, right_output_edge.target.port),
730 right_output_edge.edge_type,
731 )?;
732 }
733
734 builder.build()?
735 };
736
737 let outputs = HashMap::from([(Label::value(), Value::Graph(g3))]);
738
739 Ok(outputs)
740 })
741}
742
743pub(crate) fn operation_make_struct() -> RuntimeOperation {
744 RuntimeOperation::new_fn_simple(|inputs, _context| {
745 let struc = Value::Struct(inputs);
746 let mut outputs = HashMap::new();
747 outputs.insert(TryInto::try_into("struct")?, struc);
748 Ok(outputs)
749 })
750}
751
752pub(crate) fn operation_unpack_struct() -> RuntimeOperation {
753 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
754 let outputs = validate_input(&mut inputs, "struct", |x| match x {
755 Value::Struct(fields) => Some(fields),
756 _ => None,
757 })?;
758 Ok(outputs)
759 })
760}
761
762pub(crate) fn operation_insert_key() -> RuntimeOperation {
763 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
764 let map_l: Label = TryInto::try_into("map")?;
765 let mut map = validate_input(&mut inputs, map_l, |x| match x {
766 Value::Map(map) => Some(map),
767 _ => None,
768 })?;
769
770 let key = take_input(&mut inputs, "key")?;
771 let val = take_input(&mut inputs, "val")?;
772
773 map.insert(key, val);
774 let mut outputs = HashMap::new();
775 outputs.insert(map_l, Value::Map(map));
776 Ok(outputs)
777 })
778}
779
780pub(crate) fn operation_remove_key() -> RuntimeOperation {
781 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
782 let map_l: Label = TryInto::try_into("map")?;
783 let mut map = validate_input(&mut inputs, map_l, |x| match x {
784 Value::Map(map) => Some(map),
785 _ => None,
786 })?;
787
788 let key = take_input(&mut inputs, "key")?;
789 let val = map.remove(&key).ok_or(RuntimeError::KeyNotFound(key))?;
790 let mut outputs = HashMap::new();
791 outputs.insert(map_l, Value::Map(map));
792 outputs.insert(TryInto::try_into("val")?, val);
793 Ok(outputs)
794 })
795}
796
797pub(crate) fn binary_int_operation_with_error<F>(f: F) -> RuntimeOperation
800where
801 F: FnOnce(i64, i64) -> anyhow::Result<i64> + Sync + Send + 'static,
802{
803 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
804 let a = validate_int_input(&mut inputs, "a")?;
805 let b = validate_int_input(&mut inputs, "b")?;
806
807 let result = f(a, b)?;
808 let mut outputs = HashMap::new();
809 outputs.insert(Label::value(), Value::Int(result));
810 Ok(outputs)
811 })
812}
813
814pub(crate) fn binary_int_operation<F>(f: F) -> RuntimeOperation
815where
816 F: FnOnce(i64, i64) -> i64 + Sync + Send + 'static,
817{
818 binary_int_operation_with_error(|a, b| Ok(f(a, b)))
819}
820
821pub(crate) fn binary_flt_operation<F>(f: F) -> RuntimeOperation
822where
823 F: FnOnce(f64, f64) -> f64 + Sync + Send + 'static,
824{
825 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
826 let a = validate_float_input(&mut inputs, "a")?;
827 let b = validate_float_input(&mut inputs, "b")?;
828
829 let mut outputs = HashMap::new();
830 outputs.insert(Label::value(), Value::Float(f(a, b)));
831 Ok(outputs)
832 })
833}
834
835pub(crate) fn binary_int_comparison<F>(f: F) -> RuntimeOperation
836where
837 F: FnOnce(i64, i64) -> bool + Sync + Send + 'static,
838{
839 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
840 let a = validate_int_input(&mut inputs, "a")?;
841 let b = validate_int_input(&mut inputs, "b")?;
842
843 let mut outputs = HashMap::new();
844 outputs.insert(Label::value(), Value::Bool(f(a, b)));
845 Ok(outputs)
846 })
847}
848
849pub(crate) fn binary_flt_comparison<F>(f: F) -> RuntimeOperation
850where
851 F: FnOnce(f64, f64) -> bool + Sync + Send + 'static,
852{
853 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
854 let a = validate_float_input(&mut inputs, "a")?;
855 let b = validate_float_input(&mut inputs, "b")?;
856
857 let mut outputs = HashMap::new();
858 outputs.insert(Label::value(), Value::Bool(f(a, b)));
859 Ok(outputs)
860 })
861}
862
863pub(crate) fn binary_bool_operation<F>(f: F) -> RuntimeOperation
864where
865 F: FnOnce(bool, bool) -> bool + Sync + Send + 'static,
866{
867 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
868 let a = validate_bool_input(&mut inputs, "a")?;
869 let b = validate_bool_input(&mut inputs, "b")?;
870
871 let mut outputs = HashMap::new();
872 outputs.insert(Label::value(), Value::Bool(f(a, b)));
873 Ok(outputs)
874 })
875}
876
877pub(crate) fn operation_int_to_float() -> RuntimeOperation {
878 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
879 let int = validate_int_input(&mut inputs, "int")?;
880
881 let mut outputs = HashMap::new();
882 outputs.insert(Label::value(), Value::Float(int as f64));
883 Ok(outputs)
884 })
885}
886
887pub(crate) fn operation_float_to_int() -> RuntimeOperation {
888 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
889 let flt = validate_float_input(&mut inputs, "float")?;
890
891 let mut outputs = HashMap::new();
892 outputs.insert(Label::value(), Value::Int(flt as i64));
893 Ok(outputs)
894 })
895}
896
897pub(crate) fn operation_partial() -> RuntimeOperation {
898 RuntimeOperation::new_fn_simple(|mut inputs, _context| {
899 let thunk = validate_graph_input(&mut inputs, Label::thunk())?;
900
901 let g3 = {
902 let mut builder = GraphBuilder::new();
903 let [input, output] = Graph::boundary();
904
905 let input_edges: Vec<Edge> = thunk.node_outputs(input).cloned().collect();
906 let output_edges: Vec<Edge> = thunk.node_inputs(output).cloned().collect();
907 let b1 = builder.add_node(Node::local_box(thunk))?;
908
909 for input_edge in input_edges {
914 let port = input_edge.source.port;
915 let source = match inputs.remove(&port) {
916 Some(value) => {
917 let new_const = builder.add_node(Node::Const(value))?;
918 (new_const, Label::value())
919 }
920 None => (input, port),
921 };
922 builder.add_edge(source, (b1, port), input_edge.edge_type)?;
923 }
924 for out_edge in output_edges {
928 builder.add_edge(
929 (b1, out_edge.target.port),
930 (output, out_edge.target.port),
931 out_edge.edge_type,
932 )?;
933 }
934
935 builder.build()?
936 };
937
938 let mut outputs = HashMap::new();
939 outputs.insert(Label::value(), Value::Graph(g3));
940
941 Ok(outputs)
942 })
943}
944
945pub(crate) fn operation_map() -> RuntimeOperation {
946 RuntimeOperation::new_fn_async(|mut inputs, context| async move {
947 let _ = &context;
948 let thunk = validate_graph_input(&mut inputs, Label::thunk())?;
949
950 let list = validate_input(&mut inputs, Label::value(), |x| match x {
951 Value::Vec(a) => Some(a),
952 _ => None,
953 })?;
954
955 let mut tasks = Vec::new();
956
957 let node_trace = context
958 .graph_trace
959 .as_node_trace()
960 .map_err(|_| anyhow!("map function expected stack trace to correspond to a node"))?;
961 for (idx, x) in list.into_iter().enumerate() {
962 let thunk_clone = thunk.clone();
963 let runtime = context.runtime.clone();
964 let callback = context.callback.clone();
965 let escape = context.escape.clone();
966 let checkpoint = context.checkpoint_client.clone();
967 let graph_trace = node_trace.clone().list_elem(idx as u32);
968
969 let span = tracing::Span::current();
970
971 let t = tokio::spawn(
972 async move {
973 let value: Value = RuntimeOperation::new_graph(thunk_clone)
974 .run_simple(
975 runtime,
976 callback,
977 escape,
978 [(Label::value(), x)],
979 graph_trace,
980 checkpoint,
981 )
982 .into_task()
983 .complete()
984 .await
985 .map_err(|err| {
986 let e = anyhow!(format!("{:?}", err.as_ref()));
987 e.context("map body".to_string())
988 })?
989 .get(&Label::value())
990 .ok_or_else(|| anyhow!("map thunk should output on value port"))?
991 .clone();
992 Ok::<Value, anyhow::Error>(value)
993 }
994 .instrument(span),
995 );
996 tasks.push(t);
997 }
998 let x = future::join_all(tasks).await;
1000 let y: Result<Vec<_>, _> = x.into_iter().collect();
1001 let z: Result<Vec<Value>, _> = y?.into_iter().collect();
1002
1003 let outputs = HashMap::from([(Label::value(), Value::Vec(z?))]);
1004
1005 Ok(outputs)
1006 })
1007}
1008
1009fn take_input<E>(
1010 inputs: &mut HashMap<Label, Value>,
1011 port: impl TryInto<Label, Error = E>,
1012) -> anyhow::Result<Value>
1013where
1014 E: Into<SymbolError>,
1015{
1016 let port = TryInto::try_into(port).map_err(|e| e.into())?;
1017 inputs
1018 .remove(&port)
1019 .ok_or_else(|| RuntimeError::MissingInput(port).into())
1020}
1021
1022fn validate_input<E, T>(
1023 inputs: &mut HashMap<Label, Value>,
1024 port: impl TryInto<Label, Error = E>,
1025 validation: impl FnOnce(Value) -> Option<T>,
1026) -> anyhow::Result<T>
1027where
1028 E: Into<SymbolError>,
1029{
1030 let port = TryInto::try_into(port).map_err(|e| e.into())?;
1031 let input = inputs
1032 .remove(&port)
1033 .ok_or(RuntimeError::MissingInput(port))?;
1034 match validation(input) {
1035 Some(v) => Ok(v),
1036 None => Err(anyhow!(RuntimeError::InvalidInput(port))),
1037 }
1038}
1039
1040fn validate_int_input<E>(
1041 inputs: &mut HashMap<Label, Value>,
1042 port: impl TryInto<Label, Error = E>,
1043) -> anyhow::Result<i64>
1044where
1045 E: Into<SymbolError>,
1046{
1047 validate_input(inputs, port, |x| match x {
1048 Value::Int(a) => Some(a),
1049 _ => None,
1050 })
1051}
1052
1053fn validate_float_input<E>(
1054 inputs: &mut HashMap<Label, Value>,
1055 port: impl TryInto<Label, Error = E>,
1056) -> anyhow::Result<f64>
1057where
1058 E: Into<SymbolError>,
1059{
1060 validate_input(inputs, port, |x| match x {
1061 Value::Float(a) => Some(a),
1062 _ => None,
1063 })
1064}
1065
1066fn validate_bool_input<E>(
1067 inputs: &mut HashMap<Label, Value>,
1068 port: impl TryInto<Label, Error = E>,
1069) -> anyhow::Result<bool>
1070where
1071 E: Into<SymbolError>,
1072{
1073 validate_input(inputs, port, |x| match x {
1074 Value::Bool(a) => Some(a),
1075 _ => None,
1076 })
1077}
1078
1079fn validate_graph_input<E>(
1080 inputs: &mut HashMap<Label, Value>,
1081 port: impl TryInto<Label, Error = E>,
1082) -> anyhow::Result<Graph>
1083where
1084 E: Into<SymbolError>,
1085{
1086 validate_input(inputs, port, |x| match x {
1087 Value::Graph(a) => Some(a),
1088 _ => None,
1089 })
1090}
1091
1092#[derive(Debug, Error)]
1093enum RuntimeError {
1094 #[error("Missing input on port {0}.")]
1095 MissingInput(Label),
1096 #[error("Invalid input on port {0}.")]
1097 InvalidInput(Label),
1098 #[error("Vector is empty.")]
1099 EmptyVector,
1100 #[error("Key not found in map.")]
1101 KeyNotFound(Value),
1102}