triton_distributed/pipeline/
network.rs1pub mod codec;
19pub mod egress;
20pub mod ingress;
21pub mod tcp;
22
23use std::sync::{Arc, OnceLock};
24
25use anyhow::Result;
26use async_trait::async_trait;
27use bytes::Bytes;
28use codec::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
29use derive_builder::Builder;
30use futures::StreamExt;
31use super::{AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, ResponseStream};
33use serde::{Deserialize, Serialize};
34
35use super::{
36 context, AsyncTransportEngine, Context, Data, Error, ManyOut, PipelineError, PipelineIO,
37 SegmentSource, ServiceBackend, ServiceEngine, SingleIn, Source,
38};
39
40pub trait Codable: PipelineIO + Serialize + for<'de> Deserialize<'de> {}
41impl<T: PipelineIO + Serialize + for<'de> Deserialize<'de>> Codable for T {}
42
43#[async_trait]
45pub trait WorkQueueConsumer {
46 async fn dequeue(&self) -> Result<Bytes, String>;
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50#[serde(rename_all = "snake_case")]
51pub enum StreamType {
52 Request,
53 Response,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
63pub struct ResponseStreamPrologue {
64 error: Option<String>,
65}
66
67pub type StreamProvider<T> = tokio::sync::oneshot::Receiver<Result<T, String>>;
68
69#[derive(Debug)]
78pub struct RegisteredStream<T> {
79 pub connection_info: ConnectionInfo,
80 pub stream_provider: StreamProvider<T>,
81}
82
83impl<T> RegisteredStream<T> {
84 pub fn into_parts(self) -> (ConnectionInfo, StreamProvider<T>) {
85 (self.connection_info, self.stream_provider)
86 }
87}
88
89pub struct PendingConnections {
92 pub send_stream: Option<RegisteredStream<StreamSender>>,
93 pub recv_stream: Option<RegisteredStream<StreamReceiver>>,
94}
95
96impl PendingConnections {
97 pub fn into_parts(
98 self,
99 ) -> (
100 Option<RegisteredStream<StreamSender>>,
101 Option<RegisteredStream<StreamReceiver>>,
102 ) {
103 (self.send_stream, self.recv_stream)
104 }
105}
106
107#[async_trait::async_trait]
114pub trait ResponseService {
115 async fn register(&self, options: StreamOptions) -> PendingConnections;
116}
117
118pub struct StreamSender {
143 tx: tokio::sync::mpsc::Sender<TwoPartMessage>,
144 prologue: Option<ResponseStreamPrologue>,
145}
146
147impl StreamSender {
148 pub async fn send(&self, data: Bytes) -> Result<(), String> {
149 self.tx
150 .send(TwoPartMessage::from_data(data))
151 .await
152 .map_err(|e| e.to_string())
153 }
154
155 #[allow(clippy::needless_update)]
156 pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
157 if let Some(prologue) = self.prologue.take() {
158 let prologue = ResponseStreamPrologue { error, ..prologue };
159 self.tx
160 .send(TwoPartMessage::from_header(
161 serde_json::to_vec(&prologue).unwrap().into(),
162 ))
163 .await
164 .map_err(|e| e.to_string())?;
165 } else {
166 panic!("Prologue already sent; or not set; logic error");
167 }
168 Ok(())
169 }
170}
171
172pub struct StreamReceiver {
173 rx: tokio::sync::mpsc::Receiver<Bytes>,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct ConnectionInfo {
187 pub transport: String,
188 pub info: String,
189}
190
191#[derive(Clone, Builder)]
198pub struct StreamOptions {
199 pub context: Arc<dyn AsyncEngineContext>,
201
202 pub enable_request_stream: bool,
207
208 pub enable_response_stream: bool,
211
212 #[builder(default = "8")]
214 pub send_buffer_count: usize,
215
216 #[builder(default = "8")]
218 pub recv_buffer_count: usize,
219}
220
221impl StreamOptions {
222 pub fn builder() -> StreamOptionsBuilder {
223 StreamOptionsBuilder::default()
224 }
225}
226
227pub struct Egress<Req: PipelineIO, Resp: PipelineIO> {
228 transport_engine: Arc<dyn AsyncTransportEngine<Req, Resp>>,
229}
230
231#[async_trait]
232impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
233 for Egress<SingleIn<T>, ManyOut<U>>
234where
235 T: Data + Serialize,
236 U: for<'de> Deserialize<'de> + Data,
237{
238 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
239 self.transport_engine.generate(request).await
240 }
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
244#[serde(rename_all = "snake_case")]
245enum RequestType {
246 SingleIn,
247 ManyIn,
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
251#[serde(rename_all = "snake_case")]
252enum ResponseType {
253 SingleOut,
254 ManyOut,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
258struct RequestControlMessage {
259 id: String,
260 request_type: RequestType,
261 response_type: ResponseType,
262 connection_info: ConnectionInfo,
263}
264
265pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
266 segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
267}
268
269impl<Req: PipelineIO, Resp: PipelineIO> Ingress<Req, Resp> {
270 pub fn new() -> Arc<Self> {
271 Arc::new(Self {
272 segment: OnceLock::new(),
273 })
274 }
275
276 pub fn attach(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<()> {
277 self.segment
278 .set(segment)
279 .map_err(|_| anyhow::anyhow!("Segment already set"))
280 }
281
282 pub fn link(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
283 let ingress = Ingress::new();
284 ingress.attach(segment)?;
285 Ok(ingress)
286 }
287
288 pub fn for_pipeline(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
289 let ingress = Ingress::new();
290 ingress.attach(segment)?;
291 Ok(ingress)
292 }
293
294 pub fn for_engine(engine: ServiceEngine<Req, Resp>) -> Result<Arc<Self>> {
295 let frontend = SegmentSource::<Req, Resp>::new();
296 let backend = ServiceBackend::from_engine(engine);
297
298 let pipeline = frontend.link(backend)?.link(frontend)?;
300
301 let ingress = Ingress::new();
302 ingress.attach(pipeline)?;
303
304 Ok(ingress)
305 }
306}
307
308#[async_trait]
309pub trait PushWorkHandler: Send + Sync {
310 async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>;
311}