wick_component/adapters/binary/
interleaved_pairs.rs

1use tokio_stream::StreamExt;
2use wick_packet::{BoxStream, PacketExt, VPacket};
3
4use crate::adapters::encode;
5use crate::runtime::BoxFuture;
6use crate::{if_done_close_then, make_substream_window, propagate_if_error, runtime as wasmrs_runtime, SingleOutput};
7
8#[macro_export]
9/// This macro will generate the implementations for simple binary operations, operations that take two inputs, produce one output, and are largely want to remain ignorant of stream state.
10macro_rules! binary_interleaved_pairs {
11  ($name:ident => $handler:ident) => {
12    #[cfg_attr(target_family = "wasm",async_trait::async_trait(?Send))]
13    #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
14    impl $name::Operation for Component {
15      type Error = wick_component::AnyError;
16      type Inputs = $name::Inputs;
17      type Outputs = $name::Outputs;
18      type Config = $name::Config;
19
20      async fn $name(
21        inputs: Self::Inputs,
22        mut outputs: Self::Outputs,
23        ctx: Context<Self::Config>,
24      ) -> Result<(), Self::Error> {
25        use wick_packet::BinaryInputs;
26        let (left, right) = inputs.both();
27        wick_component::binary::interleaved_pairs(left, right, &mut outputs, &ctx, &$handler).await?;
28
29        Ok(())
30      }
31    }
32  };
33}
34
35/// Operation helper for common binary operations that have one output.
36pub async fn interleaved_pairs<'c, LEFT, RIGHT, OUTPUT, CONTEXT, OUTPORT, F, E>(
37  left: BoxStream<VPacket<LEFT>>,
38  right: BoxStream<VPacket<RIGHT>>,
39  outputs: &mut OUTPORT,
40  ctx: &'c CONTEXT,
41  func: &'static F,
42) -> Result<(), E>
43where
44  CONTEXT: Clone + wasmrs_runtime::ConditionallySendSync,
45  F: Fn(LEFT, RIGHT, CONTEXT) -> BoxFuture<Result<OUTPUT, E>> + wasmrs_runtime::ConditionallySendSync,
46  OUTPORT: SingleOutput + wasmrs_runtime::ConditionallySendSync,
47  LEFT: serde::de::DeserializeOwned + Clone + wasmrs_runtime::ConditionallySendSync,
48  RIGHT: serde::de::DeserializeOwned + Clone + wasmrs_runtime::ConditionallySendSync,
49  OUTPUT: serde::Serialize + wasmrs_runtime::ConditionallySendSync,
50  E: std::fmt::Display + wasmrs_runtime::ConditionallySendSync,
51{
52  let (_, _) = inner::<LEFT, RIGHT, OUTPUT, CONTEXT, OUTPORT, F, E>(None, None, left, right, outputs, ctx, func).await;
53  outputs.single_output().done();
54
55  Ok(())
56}
57
58#[cfg_attr(not(target_family = "wasm"), async_recursion::async_recursion)]
59#[cfg_attr(target_family = "wasm", async_recursion::async_recursion(?Send))]
60async fn inner<'out, 'c, LEFT, RIGHT, OUTPUT, CONTEXT, OUTPORT, F, E>(
61  last_left: Option<LEFT>,
62  last_right: Option<RIGHT>,
63  mut l_stream: BoxStream<VPacket<LEFT>>,
64  mut r_stream: BoxStream<VPacket<RIGHT>>,
65  outputs: &'out mut OUTPORT,
66  ctx: &'c CONTEXT,
67  func: &'static F,
68) -> (BoxStream<VPacket<LEFT>>, BoxStream<VPacket<RIGHT>>)
69where
70  CONTEXT: Clone + wasmrs_runtime::ConditionallySendSync,
71  F: Fn(LEFT, RIGHT, CONTEXT) -> BoxFuture<Result<OUTPUT, E>> + wasmrs_runtime::ConditionallySendSync,
72  OUTPORT: SingleOutput + wasmrs_runtime::ConditionallySendSync,
73  LEFT: serde::de::DeserializeOwned + Clone + wasmrs_runtime::ConditionallySendSync,
74  RIGHT: serde::de::DeserializeOwned + Clone + wasmrs_runtime::ConditionallySendSync,
75  OUTPUT: serde::Serialize + wasmrs_runtime::ConditionallySendSync,
76  E: std::fmt::Display + wasmrs_runtime::ConditionallySendSync,
77{
78  loop {
79    match (&last_left, &last_right) {
80      (Some(left), None) => {
81        let Some(right) = r_stream.next().await else { break };
82
83        if_done_close_then!([right], break);
84
85        if right.is_open_bracket() {
86          make_substream_window!(outputs, {
87            (l_stream, r_stream) = inner(Some(left.clone()), None, l_stream, r_stream, outputs, ctx, func).await;
88          });
89        } else {
90          let right: RIGHT = propagate_if_error!(right.decode(), outputs, continue);
91          outputs
92            .single_output()
93            .send_raw_payload(encode(func(left.clone(), right, ctx.clone()).await));
94        }
95      }
96      (None, Some(right)) => {
97        let Some(left) = l_stream.next().await else { break };
98
99        if_done_close_then!([left], break);
100
101        if left.is_open_bracket() {
102          make_substream_window!(outputs, {
103            (l_stream, r_stream) = inner(None, Some(right.clone()), l_stream, r_stream, outputs, ctx, func).await;
104          });
105        } else {
106          let left: LEFT = propagate_if_error!(left.decode(), outputs, continue);
107          outputs
108            .single_output()
109            .send_raw_payload(encode(func(left, right.clone(), ctx.clone()).await));
110        }
111      }
112      (None, None) => {
113        let Some(left) = l_stream.next().await else { break };
114        let Some(right) = r_stream.next().await else { break };
115
116        match (left.is_open_bracket(), right.is_open_bracket()) {
117          (true, true) => {
118            make_substream_window!(outputs, {
119              (l_stream, r_stream) = inner(None, None, l_stream, r_stream, outputs, ctx, func).await;
120            });
121          }
122          (true, false) => {
123            if_done_close_then!([right], break);
124
125            let right: RIGHT = propagate_if_error!(right.decode(), outputs, continue);
126            make_substream_window!(outputs, {
127              (l_stream, r_stream) = inner(None, Some(right), l_stream, r_stream, outputs, ctx, func).await;
128            });
129          }
130          (false, true) => {
131            if_done_close_then!([left], break);
132
133            let left: LEFT = propagate_if_error!(left.decode(), outputs, continue);
134            make_substream_window!(outputs, {
135              (l_stream, r_stream) = inner(Some(left), None, l_stream, r_stream, outputs, ctx, func).await;
136            });
137          }
138          (false, false) => {
139            if_done_close_then!([left, right], break);
140            let left: LEFT = propagate_if_error!(left.decode(), outputs, continue);
141            let right: RIGHT = propagate_if_error!(right.decode(), outputs, continue);
142            outputs
143              .single_output()
144              .send_raw_payload(encode(func(left, right, ctx.clone()).await));
145          }
146        }
147      }
148      (Some(_), Some(_)) => {
149        unreachable!()
150      }
151    }
152  }
153
154  (l_stream, r_stream)
155}