wick_component/adapters/unary/
outputs.rs1use tokio_stream::StreamExt;
2use wasmrs_runtime::BoxFuture;
3use wasmrs_rx::Observer;
4use wick_packet::{BoxStream, PacketExt, VPacket, WasmRsChannel};
5
6use crate::{propagate_if_error, Broadcast};
7
8#[macro_export]
9macro_rules! unary_with_outputs {
11 ($name:ident => $handler:ident) => {
12 #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
13 #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
14 impl $name::Operation for Component {
15 type Error = wick_component::AnyError;
16 type Inputs = $name::Inputs;
17 type Outputs = $name::Outputs;
18 type Config = $name::Config;
19
20 async fn $name(
21 mut inputs: Self::Inputs,
22 mut outputs: Self::Outputs,
23 ctx: Context<Self::Config>,
24 ) -> Result<(), Self::Error> {
25 let output_factory = || $name::Outputs::new_parts();
26 use wick_packet::UnaryInputs;
27 let input = inputs.take_input();
28 wick_component::unary::with_outputs(input, outputs, output_factory, &ctx, &$handler).await?;
29
30 Ok(())
31 }
32 }
33 };
34}
35
36pub async fn with_outputs<'out, 'c, INPUT, OUTPUTS, FACTORY, CONTEXT, F, E>(
38 input: BoxStream<VPacket<INPUT>>,
39 outputs: OUTPUTS,
40 output_factory: FACTORY,
41 ctx: &'c CONTEXT,
42 func: &'static F,
43) -> Result<(), E>
44where
45 CONTEXT: Clone + wasmrs_runtime::ConditionallySendSync,
46 INPUT: serde::de::DeserializeOwned + Clone + wasmrs_runtime::ConditionallySendSync,
47 OUTPUTS: WasmRsChannel + Broadcast + wasmrs_runtime::ConditionallySendSync,
48 FACTORY: Fn() -> (
49 OUTPUTS,
50 wasmrs_rx::FluxReceiver<wasmrs::RawPayload, wasmrs::PayloadError>,
51 ) + wasmrs_runtime::ConditionallySendSync,
52 F: Fn(INPUT, OUTPUTS, CONTEXT) -> BoxFuture<Result<(), E>> + wasmrs_runtime::ConditionallySendSync,
53 E: std::fmt::Display + wasmrs_runtime::ConditionallySendSync,
54{
55 let _ =
56 inner::<INPUT, OUTPUTS, FACTORY, CONTEXT, F, E>(input, outputs, std::sync::Arc::new(output_factory), ctx, func)
57 .await;
58
59 Ok(())
60}
61
62#[cfg_attr(not(target_family = "wasm"), async_recursion::async_recursion)]
63#[cfg_attr(target_family = "wasm", async_recursion::async_recursion(?Send))]
64async fn inner<'out, 'c, INPUT, OUTPUTS, FACTORY, CONTEXT, F, E>(
65 mut input_stream: BoxStream<VPacket<INPUT>>,
66 mut outputs: OUTPUTS,
67 output_factory: std::sync::Arc<FACTORY>,
68 ctx: &'c CONTEXT,
69 func: &'static F,
70) -> (BoxStream<VPacket<INPUT>>, OUTPUTS)
71where
72 INPUT: serde::de::DeserializeOwned + Clone + wasmrs_runtime::ConditionallySendSync,
73 OUTPUTS: WasmRsChannel + Broadcast + wasmrs_runtime::ConditionallySendSync,
74 FACTORY: Fn() -> (
75 OUTPUTS,
76 wasmrs_rx::FluxReceiver<wasmrs::RawPayload, wasmrs::PayloadError>,
77 ) + wasmrs_runtime::ConditionallySendSync,
78 CONTEXT: Clone + wasmrs_runtime::ConditionallySendSync,
79 F: Fn(INPUT, OUTPUTS, CONTEXT) -> BoxFuture<Result<(), E>> + wasmrs_runtime::ConditionallySendSync,
80 E: std::fmt::Display + wasmrs_runtime::ConditionallySendSync,
81{
82 loop {
83 let Some(input) = input_stream.next().await else { break };
84 if input.is_open_bracket() {
85 outputs.broadcast_open();
86 let (inner_outputs, mut inner_output_rx) = output_factory();
87 (input_stream, outputs) = inner(input_stream, inner_outputs, output_factory.clone(), ctx, func).await;
88 while let Some(payload) = inner_output_rx.next().await {
89 let _ = outputs.channel().send_result(payload);
90 }
91 outputs.broadcast_close();
92 } else if input.is_close_bracket() || input.is_done() {
93 break;
94 } else {
95 let input: INPUT = propagate_if_error!(input.decode(), outputs, continue);
96 let (inner_outputs, mut inner_output_rx) = output_factory();
97 let result = func(input.clone(), inner_outputs, ctx.clone()).await;
98 while let Some(payload) = inner_output_rx.next().await {
99 let _ = outputs.channel().send_result(payload);
100 }
101 propagate_if_error!(result, outputs, continue);
102 }
103 }
104
105 (input_stream, outputs)
106}