triton_distributed/pipeline/
network.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! TODO - we need to reconcile what is in this crate with distributed::transports
17
18pub mod codec;
19pub mod egress;
20pub mod ingress;
21pub mod tcp;
22
23use std::sync::{Arc, OnceLock};
24
25use anyhow::Result;
26use async_trait::async_trait;
27use bytes::Bytes;
28use codec::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
29use derive_builder::Builder;
30use futures::StreamExt;
31// io::Cursor, TryStreamExt
32use super::{AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, ResponseStream};
33use serde::{Deserialize, Serialize};
34
35use super::{
36    context, AsyncTransportEngine, Context, Data, Error, ManyOut, PipelineError, PipelineIO,
37    SegmentSource, ServiceBackend, ServiceEngine, SingleIn, Source,
38};
39
40pub trait Codable: PipelineIO + Serialize + for<'de> Deserialize<'de> {}
41impl<T: PipelineIO + Serialize + for<'de> Deserialize<'de>> Codable for T {}
42
43/// `WorkQueueConsumer` is a generic interface for a work queue that can be used to send and receive
44#[async_trait]
45pub trait WorkQueueConsumer {
46    async fn dequeue(&self) -> Result<Bytes, String>;
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50#[serde(rename_all = "snake_case")]
51pub enum StreamType {
52    Request,
53    Response,
54}
55
56/// This is the first message in a `ResponseStream`. This is not a message that gets process
57/// by the general pipeline, but is a control message that is awaited before the
58/// [`AsyncEngine::generate`] method is allowed to return.
59///
60/// If an error is present, the [`AsyncEngine::generate`] method will return the error instead
61/// of returning the `ResponseStream`.
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
63pub struct ResponseStreamPrologue {
64    error: Option<String>,
65}
66
67pub type StreamProvider<T> = tokio::sync::oneshot::Receiver<Result<T, String>>;
68
69/// The [`RegisteredStream`] object is acquired from a [`StreamProvider`] and is used to provide
70/// an awaitable receiver which will the `T` which is either a stream writer for a request stream
71/// or a stream reader for a response stream.
72///
73/// make this an raii object linked to some stream provider
74/// if the object has not been awaited an the type T unwrapped, the registered stream
75/// on the stream provider will be informed and can clean up a stream that will never
76/// be connected.
77#[derive(Debug)]
78pub struct RegisteredStream<T> {
79    pub connection_info: ConnectionInfo,
80    pub stream_provider: StreamProvider<T>,
81}
82
83impl<T> RegisteredStream<T> {
84    pub fn into_parts(self) -> (ConnectionInfo, StreamProvider<T>) {
85        (self.connection_info, self.stream_provider)
86    }
87}
88
89/// After registering a stream, the [`PendingConnections`] object is returned to the caller. This
90/// object can be used to await the connection to be established.
91pub struct PendingConnections {
92    pub send_stream: Option<RegisteredStream<StreamSender>>,
93    pub recv_stream: Option<RegisteredStream<StreamReceiver>>,
94}
95
96impl PendingConnections {
97    pub fn into_parts(
98        self,
99    ) -> (
100        Option<RegisteredStream<StreamSender>>,
101        Option<RegisteredStream<StreamReceiver>>,
102    ) {
103        (self.send_stream, self.recv_stream)
104    }
105}
106
107/// A [`ResponseService`] implements a services in which a context a specific subject with will
108/// be associated with a stream of responses. The key difference between a [`ResponseService`]
109/// and a [`RequestService`] is that the [`ResponseService`] is the awaits an explicit connection
110/// to be established, where as a [`RequestService`] has no known knowledge about incoming
111/// connections. All [`ResponseService`] connections are expected, all [`RequestService`] connections
112/// are unexpected.
113#[async_trait::async_trait]
114pub trait ResponseService {
115    async fn register(&self, options: StreamOptions) -> PendingConnections;
116}
117
118// #[derive(Debug, Clone, Serialize, Deserialize)]
119// struct Handshake {
120//     request_id: String,
121//     worker_id: Option<String>,
122//     error: Option<String>,
123// }
124
125// impl Handshake {
126//     pub fn validate(&self) -> Result<(), String> {
127//         if let Some(e) = &self.error {
128//             return Err(e.clone());
129//         }
130//         Ok(())
131//     }
132// }
133
134// this probably needs to be come a ResponseStreamSender
135// since the prologue in this scenario sender telling the receiver
136// that all is good and it's ready to send
137//
138// in the RequestStreamSender, the prologue would be coming from the
139// receiver, so the sender would have to await the prologue which if
140// was not an error, would indicate the RequestStreamReceiver is read
141// to receive data.
142pub struct StreamSender {
143    tx: tokio::sync::mpsc::Sender<TwoPartMessage>,
144    prologue: Option<ResponseStreamPrologue>,
145}
146
147impl StreamSender {
148    pub async fn send(&self, data: Bytes) -> Result<(), String> {
149        self.tx
150            .send(TwoPartMessage::from_data(data))
151            .await
152            .map_err(|e| e.to_string())
153    }
154
155    #[allow(clippy::needless_update)]
156    pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
157        if let Some(prologue) = self.prologue.take() {
158            let prologue = ResponseStreamPrologue { error, ..prologue };
159            self.tx
160                .send(TwoPartMessage::from_header(
161                    serde_json::to_vec(&prologue).unwrap().into(),
162                ))
163                .await
164                .map_err(|e| e.to_string())?;
165        } else {
166            panic!("Prologue already sent; or not set; logic error");
167        }
168        Ok(())
169    }
170}
171
172pub struct StreamReceiver {
173    rx: tokio::sync::mpsc::Receiver<Bytes>,
174}
175
176/// Connection Info is encoded as JSON and then again serialized has part of the Transport
177/// Layer. The double serialization is not performance critical as it is only done once per
178/// connection. The primary reason storing the ConnecitonInfo has a JSON string is for type
179/// erasure. The Transport Layer will check the [`ConnectionInfo::transport`] type and then
180/// route it to the appropriate instance of the Transport, which will then deserialize the
181/// [`ConnectionInfo::info`] field to its internal connection info object.
182///
183/// Optionally, this object could become strongly typed for which all possible combinations
184/// of transport and connection info would need to be enumerated.
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct ConnectionInfo {
187    pub transport: String,
188    pub info: String,
189}
190
191/// When registering a new TransportStream on the server, the caller specifies if the
192/// stream is a sender, receiver or both.
193///
194/// Senders and Receivers are with share a Context, but result in separate tcp socket
195/// connections to the server. Internally, we may use bcast channels to coordinate the
196/// internal control messages between the sender and receiver socket connections.
197#[derive(Clone, Builder)]
198pub struct StreamOptions {
199    /// Context
200    pub context: Arc<dyn AsyncEngineContext>,
201
202    /// Register with the server that this connection will have a server-side Sender
203    /// that can be picked up by the Request/Forward pipeline
204    ///
205    /// TODO - note, this option is currently not implemented and will cause a panic
206    pub enable_request_stream: bool,
207
208    /// Register with the server that this connection will have a server-side Receiver
209    /// that can be picked up by the Response/Reverse pipeline
210    pub enable_response_stream: bool,
211
212    /// The number of messages to buffer before blocking
213    #[builder(default = "8")]
214    pub send_buffer_count: usize,
215
216    /// The number of messages to buffer before blocking
217    #[builder(default = "8")]
218    pub recv_buffer_count: usize,
219}
220
221impl StreamOptions {
222    pub fn builder() -> StreamOptionsBuilder {
223        StreamOptionsBuilder::default()
224    }
225}
226
227pub struct Egress<Req: PipelineIO, Resp: PipelineIO> {
228    transport_engine: Arc<dyn AsyncTransportEngine<Req, Resp>>,
229}
230
231#[async_trait]
232impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
233    for Egress<SingleIn<T>, ManyOut<U>>
234where
235    T: Data + Serialize,
236    U: for<'de> Deserialize<'de> + Data,
237{
238    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
239        self.transport_engine.generate(request).await
240    }
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
244#[serde(rename_all = "snake_case")]
245enum RequestType {
246    SingleIn,
247    ManyIn,
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
251#[serde(rename_all = "snake_case")]
252enum ResponseType {
253    SingleOut,
254    ManyOut,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
258struct RequestControlMessage {
259    id: String,
260    request_type: RequestType,
261    response_type: ResponseType,
262    connection_info: ConnectionInfo,
263}
264
265pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
266    segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
267}
268
269impl<Req: PipelineIO, Resp: PipelineIO> Ingress<Req, Resp> {
270    pub fn new() -> Arc<Self> {
271        Arc::new(Self {
272            segment: OnceLock::new(),
273        })
274    }
275
276    pub fn attach(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<()> {
277        self.segment
278            .set(segment)
279            .map_err(|_| anyhow::anyhow!("Segment already set"))
280    }
281
282    pub fn link(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
283        let ingress = Ingress::new();
284        ingress.attach(segment)?;
285        Ok(ingress)
286    }
287
288    pub fn for_pipeline(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
289        let ingress = Ingress::new();
290        ingress.attach(segment)?;
291        Ok(ingress)
292    }
293
294    pub fn for_engine(engine: ServiceEngine<Req, Resp>) -> Result<Arc<Self>> {
295        let frontend = SegmentSource::<Req, Resp>::new();
296        let backend = ServiceBackend::from_engine(engine);
297
298        // create the pipeline
299        let pipeline = frontend.link(backend)?.link(frontend)?;
300
301        let ingress = Ingress::new();
302        ingress.attach(pipeline)?;
303
304        Ok(ingress)
305    }
306}
307
308#[async_trait]
309pub trait PushWorkHandler: Send + Sync {
310    async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>;
311}