triton_distributed/pipeline/network/
tcp.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! TCP Transport Module
17//!
18//! The TCP Transport module consists of two main components: Client and Server. The Client is
19//! the downstream node that is responsible for connecting back to the upstream node (Server).
20//!
21//! Both Client and Server are given a Stream object that they can specialize for their specific
22//! needs, i.e. if they are SingleIn/ManyIn or SingleOut/ManyOut.
23//!
24//! The Request object will carry the Transport Type and Connection details, i.e. how the receiver
25//! of a Request is able to communicate back to the source of the Request.
26//!
27//! There are two types of TcpStream:
28//! - CallHome stream - the address for the listening socket is forward via some mechanism which then
29//!   connects back to the source of the CallHome stream. To match the socket with an awaiting data
30//!   stream, the CallHomeHandshake is used.
31
32pub mod client;
33pub mod server;
34
35use serde::{Deserialize, Serialize};
36
37#[allow(unused_imports)]
38use super::{
39    codec::TwoPartCodec, ConnectionInfo, PendingConnections, RegisteredStream, ResponseService,
40    StreamOptions, StreamReceiver, StreamSender, StreamType,
41};
42
43const TCP_TRANSPORT: &str = "tcp_server";
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct TcpStreamConnectionInfo {
47    pub address: String,
48    pub subject: String,
49    pub context: String,
50    pub stream_type: StreamType,
51}
52
53impl From<TcpStreamConnectionInfo> for ConnectionInfo {
54    fn from(info: TcpStreamConnectionInfo) -> Self {
55        // Need to consider the below. If failure should be fatal, keep the below with .expect()
56        // But if there is a default value, we can use:
57        // unwrap_or_else(|e| {
58        //     eprintln!("Failed to serialize TcpStreamConnectionInfo: {:?}", e);
59        //     "{}".to_string() // Provide a fallback empty JSON string or default value
60        ConnectionInfo {
61            transport: TCP_TRANSPORT.to_string(),
62            info: serde_json::to_string(&info)
63                .expect("Failed to serialize TcpStreamConnectionInfo"),
64        }
65    }
66}
67
68impl TryFrom<ConnectionInfo> for TcpStreamConnectionInfo {
69    type Error = String;
70
71    fn try_from(info: ConnectionInfo) -> Result<Self, Self::Error> {
72        if info.transport != TCP_TRANSPORT {
73            return Err(format!(
74                "Invalid transport; TcpClient requires the transport to be `tcp_server`; however {} was passed",
75                info.transport
76            ));
77        }
78
79        serde_json::from_str(&info.info)
80            .map_err(|e| format!("Failed parse ConnectionInfo: {:?}", e))
81    }
82}
83
84/// First message sent over a CallHome stream which will map the newly created socket to a specific
85/// response data stream which was registered with the same subject.
86///
87/// This is a transport specific message as part of forming/completing a CallHome TcpStream.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89struct CallHomeHandshake {
90    subject: String,
91    stream_type: StreamType,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95#[serde(rename_all = "snake_case")]
96enum ControlMessage {
97    Stop,
98    Kill,
99}
100
101#[cfg(test)]
102mod tests {
103    use crate::engine::AsyncEngineContextProvider;
104
105    use super::*;
106    use crate::pipeline::Context;
107
108    #[derive(Debug, Clone, Serialize, Deserialize)]
109    struct TestMessage {
110        foo: String,
111    }
112
113    #[tokio::test]
114    async fn test_tcp_stream_client_server() {
115        println!("Test Started");
116        let options = server::ServerOptions::builder().port(9124).build().unwrap();
117        println!("Test Started");
118        let server = server::TcpStreamServer::new(options).await.unwrap();
119        println!("Server created");
120
121        let context_rank0 = Context::new(());
122
123        let options = StreamOptions::builder()
124            .context(context_rank0.context())
125            .enable_request_stream(false)
126            .enable_response_stream(true)
127            .build()
128            .unwrap();
129
130        let pending_connection = server.register(options).await;
131
132        let connection_info = pending_connection
133            .recv_stream
134            .as_ref()
135            .unwrap()
136            .connection_info
137            .clone();
138
139        // set up the other rank
140        let context_rank1 = Context::with_id((), context_rank0.id().to_string());
141
142        // connect to the server socket
143        let mut send_stream =
144            client::TcpClient::create_response_steam(context_rank1.context(), connection_info)
145                .await
146                .unwrap();
147        println!("Client connected");
148
149        // the client can now setup it's end of the stream and if it errors, it can send a message
150        // to the server to stop the stream
151        //
152        // this step must be done before the next step on the server can complete, i.e.
153        // the server's stream is now blocked on receiving the prologue message
154        //
155        // let's improve this and use an enum like Ok/Err; currently, None means good-to-go, and
156        // Some(String) means an error happened on this downstream node and we need to alert the
157        // upstream node that an error occurred
158        send_stream.send_prologue(None).await.unwrap();
159
160        // [server] next - now pending connections should be connected
161        let recv_stream = pending_connection
162            .recv_stream
163            .unwrap()
164            .stream_provider
165            .await
166            .unwrap();
167
168        println!("Server paired");
169
170        let msg = TestMessage {
171            foo: "bar".to_string(),
172        };
173
174        let payload = serde_json::to_vec(&msg).unwrap();
175
176        send_stream.send(payload.into()).await.unwrap();
177
178        println!("Client sent message");
179
180        let data = recv_stream.unwrap().rx.recv().await.unwrap();
181
182        println!("Server received message");
183
184        let recv_msg = serde_json::from_slice::<TestMessage>(&data).unwrap();
185
186        assert_eq!(msg.foo, recv_msg.foo);
187        println!("message match");
188
189        drop(send_stream);
190
191        // let data = recv_stream.rx.recv().await;
192
193        // assert!(data.is_none());
194    }
195}