triton_distributed/pipeline/network/tcp/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::sync::Arc;
17
18use futures::{SinkExt, StreamExt};
19use tokio::{io::AsyncWriteExt, net::TcpStream};
20use tokio_util::codec::{FramedRead, FramedWrite};
21use tracing as log;
22
23use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
24use crate::engine::AsyncEngineContext;
25use crate::pipeline::network::{
26    codec::{TwoPartCodec, TwoPartMessage},
27    tcp::StreamType,
28    ConnectionInfo, ResponseStreamPrologue, StreamSender,
29}; // Import SinkExt to use the `send` method
30
31#[allow(dead_code)]
32pub struct TcpClient {
33    worker_id: String,
34}
35
36impl Default for TcpClient {
37    fn default() -> Self {
38        TcpClient {
39            worker_id: uuid::Uuid::new_v4().to_string(),
40        }
41    }
42}
43
44impl TcpClient {
45    pub fn new(worker_id: String) -> Self {
46        TcpClient { worker_id }
47    }
48
49    async fn connect(address: &str) -> Result<TcpStream, String> {
50        let socket = TcpStream::connect(address)
51            .await
52            .map_err(|e| format!("failed to connect: {:?}", e))?;
53
54        socket
55            .set_nodelay(true)
56            .map_err(|e| format!("failed to set nodelay: {:?}", e))?;
57
58        Ok(socket)
59    }
60
61    pub async fn create_response_steam(
62        context: Arc<dyn AsyncEngineContext>,
63        info: ConnectionInfo,
64    ) -> Result<StreamSender, String> {
65        let info = TcpStreamConnectionInfo::try_from(info)?;
66        tracing::trace!("Creating response stream for {:?}", info);
67
68        if info.stream_type != StreamType::Response {
69            return Err(format!(
70                "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
71                info.stream_type
72            ));
73        }
74
75        if info.context != context.id() {
76            return Err(format!(
77                "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
78                context.id(),
79                info.context
80            ));
81        }
82
83        let stream = TcpClient::connect(&info.address).await?;
84        let (read_half, write_half) = tokio::io::split(stream);
85
86        let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
87        let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
88
89        // this is a oneshot channel that will be used to signal when the stream is closed
90        // when the stream sender is dropped, the bytes_rx will be closed and the forwarder task will exit
91        // the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
92        // so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
93        // captured by the monitor task
94        let (mut alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
95
96        // monitors the channel for a cancellation signal
97        // this task exits when the alive_rx half of the oneshot channel is closed or a stop/kill signal is received
98        tokio::spawn(async move {
99            loop {
100                tokio::select! {
101                    msg = framed_reader.next() => {
102                        match msg {
103                            Some(Ok(two_part_msg)) => {
104                                match two_part_msg.optional_parts() {
105                                   (Some(bytes), None) => {
106                                        let msg: ControlMessage = serde_json::from_slice(bytes).unwrap();
107                                        match msg {
108                                            ControlMessage::Stop => {
109                                                context.stop();
110                                                break;
111                                            }
112                                            ControlMessage::Kill => {
113                                                context.kill();
114                                                break;
115                                            }
116                                        }
117                                   }
118                                   _ => {
119                                       // we should not receive this
120                                   }
121                                }
122                            }
123                            Some(Err(e)) => {
124                                panic!("failed to decode message from stream: {:?}", e);
125                                // break;
126                            }
127                            None => {
128                                // the stream was closed, we should stop the stream
129                                return;
130                            }
131                        }
132                    }
133                    _ = alive_tx.closed() => {
134                        // the channel was closed, we should stop the stream
135                        break;
136                    }
137                }
138            }
139            // framed_writer.get_mut().shutdown().await.unwrap();
140        });
141
142        // transport specific handshake message
143        let handshake = CallHomeHandshake {
144            subject: info.subject,
145            stream_type: StreamType::Response,
146        };
147
148        let handshake_bytes = serde_json::to_vec(&handshake).unwrap();
149        let msg = TwoPartMessage::from_header(handshake_bytes.into());
150
151        // issue the the first tcp handshake message
152        framed_writer
153            .send(msg)
154            .await
155            .map_err(|e| format!("failed to send handshake: {:?}", e))?;
156
157        // set up the channel to send bytes to the transport layer
158        let (bytes_tx, mut bytes_rx) = tokio::sync::mpsc::channel(16);
159
160        // forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
161        tokio::spawn(async move {
162            while let Some(msg) = bytes_rx.recv().await {
163                if let Err(e) = framed_writer.send(msg).await {
164                    log::trace!(
165                        "failed to send message to stream; possible disconnect: {:?}",
166                        e
167                    );
168
169                    // TODO - possibly propagate the error upstream
170                    break;
171                }
172            }
173            drop(alive_rx);
174            if let Err(e) = framed_writer.get_mut().shutdown().await {
175                log::trace!("failed to shutdown writer: {:?}", e);
176            }
177        });
178
179        // set up the prologue for the stream
180        // this might have transport specific metadata in the future
181        let prologue = Some(ResponseStreamPrologue { error: None });
182
183        // create the stream sender
184        let stream_sender = StreamSender {
185            tx: bytes_tx,
186            prologue,
187        };
188
189        Ok(stream_sender)
190    }
191}