triton_distributed/pipeline/network/tcp/client.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::sync::Arc;
17
18use futures::{SinkExt, StreamExt};
19use tokio::{io::AsyncWriteExt, net::TcpStream};
20use tokio_util::codec::{FramedRead, FramedWrite};
21use tracing as log;
22
23use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
24use crate::engine::AsyncEngineContext;
25use crate::pipeline::network::{
26 codec::{TwoPartCodec, TwoPartMessage},
27 tcp::StreamType,
28 ConnectionInfo, ResponseStreamPrologue, StreamSender,
29}; // Import SinkExt to use the `send` method
30
31#[allow(dead_code)]
32pub struct TcpClient {
33 worker_id: String,
34}
35
36impl Default for TcpClient {
37 fn default() -> Self {
38 TcpClient {
39 worker_id: uuid::Uuid::new_v4().to_string(),
40 }
41 }
42}
43
44impl TcpClient {
45 pub fn new(worker_id: String) -> Self {
46 TcpClient { worker_id }
47 }
48
49 async fn connect(address: &str) -> Result<TcpStream, String> {
50 let socket = TcpStream::connect(address)
51 .await
52 .map_err(|e| format!("failed to connect: {:?}", e))?;
53
54 socket
55 .set_nodelay(true)
56 .map_err(|e| format!("failed to set nodelay: {:?}", e))?;
57
58 Ok(socket)
59 }
60
61 pub async fn create_response_steam(
62 context: Arc<dyn AsyncEngineContext>,
63 info: ConnectionInfo,
64 ) -> Result<StreamSender, String> {
65 let info = TcpStreamConnectionInfo::try_from(info)?;
66 tracing::trace!("Creating response stream for {:?}", info);
67
68 if info.stream_type != StreamType::Response {
69 return Err(format!(
70 "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
71 info.stream_type
72 ));
73 }
74
75 if info.context != context.id() {
76 return Err(format!(
77 "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
78 context.id(),
79 info.context
80 ));
81 }
82
83 let stream = TcpClient::connect(&info.address).await?;
84 let (read_half, write_half) = tokio::io::split(stream);
85
86 let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
87 let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
88
89 // this is a oneshot channel that will be used to signal when the stream is closed
90 // when the stream sender is dropped, the bytes_rx will be closed and the forwarder task will exit
91 // the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
92 // so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
93 // captured by the monitor task
94 let (mut alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
95
96 // monitors the channel for a cancellation signal
97 // this task exits when the alive_rx half of the oneshot channel is closed or a stop/kill signal is received
98 tokio::spawn(async move {
99 loop {
100 tokio::select! {
101 msg = framed_reader.next() => {
102 match msg {
103 Some(Ok(two_part_msg)) => {
104 match two_part_msg.optional_parts() {
105 (Some(bytes), None) => {
106 let msg: ControlMessage = serde_json::from_slice(bytes).unwrap();
107 match msg {
108 ControlMessage::Stop => {
109 context.stop();
110 break;
111 }
112 ControlMessage::Kill => {
113 context.kill();
114 break;
115 }
116 }
117 }
118 _ => {
119 // we should not receive this
120 }
121 }
122 }
123 Some(Err(e)) => {
124 panic!("failed to decode message from stream: {:?}", e);
125 // break;
126 }
127 None => {
128 // the stream was closed, we should stop the stream
129 return;
130 }
131 }
132 }
133 _ = alive_tx.closed() => {
134 // the channel was closed, we should stop the stream
135 break;
136 }
137 }
138 }
139 // framed_writer.get_mut().shutdown().await.unwrap();
140 });
141
142 // transport specific handshake message
143 let handshake = CallHomeHandshake {
144 subject: info.subject,
145 stream_type: StreamType::Response,
146 };
147
148 let handshake_bytes = serde_json::to_vec(&handshake).unwrap();
149 let msg = TwoPartMessage::from_header(handshake_bytes.into());
150
151 // issue the the first tcp handshake message
152 framed_writer
153 .send(msg)
154 .await
155 .map_err(|e| format!("failed to send handshake: {:?}", e))?;
156
157 // set up the channel to send bytes to the transport layer
158 let (bytes_tx, mut bytes_rx) = tokio::sync::mpsc::channel(16);
159
160 // forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
161 tokio::spawn(async move {
162 while let Some(msg) = bytes_rx.recv().await {
163 if let Err(e) = framed_writer.send(msg).await {
164 log::trace!(
165 "failed to send message to stream; possible disconnect: {:?}",
166 e
167 );
168
169 // TODO - possibly propagate the error upstream
170 break;
171 }
172 }
173 drop(alive_rx);
174 if let Err(e) = framed_writer.get_mut().shutdown().await {
175 log::trace!("failed to shutdown writer: {:?}", e);
176 }
177 });
178
179 // set up the prologue for the stream
180 // this might have transport specific metadata in the future
181 let prologue = Some(ResponseStreamPrologue { error: None });
182
183 // create the stream sender
184 let stream_sender = StreamSender {
185 tx: bytes_tx,
186 prologue,
187 };
188
189 Ok(stream_sender)
190 }
191}