triton_distributed/pipeline/network/tcp/
server.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use anyhow::Result;
17use core::panic;
18use std::{collections::HashMap, sync::Arc};
19use tokio::sync::Mutex;
20
21use bytes::Bytes;
22use derive_builder::Builder;
23use futures::StreamExt;
24use local_ip_address::{list_afinet_netifas, local_ip};
25use serde::{Deserialize, Serialize};
26use tokio::{
27    io::AsyncWriteExt,
28    sync::{mpsc, oneshot},
29};
30use tokio_util::codec::{FramedRead, FramedWrite};
31
32use super::{
33    CallHomeHandshake, PendingConnections, RegisteredStream, StreamOptions, StreamReceiver,
34    StreamSender, TcpStreamConnectionInfo, TwoPartCodec,
35};
36use crate::engine::AsyncEngineContext;
37use crate::pipeline::{
38    network::{
39        codec::{TwoPartMessage, TwoPartMessageType},
40        tcp::StreamType,
41        ResponseService, ResponseStreamPrologue,
42    },
43    PipelineError,
44};
45
46#[allow(dead_code)]
47type ResponseType = TwoPartMessage;
48
49#[derive(Debug, Serialize, Deserialize, Clone, Builder, Default)]
50pub struct ServerOptions {
51    #[builder(default = "0")]
52    pub port: u16,
53
54    #[builder(default)]
55    pub interface: Option<String>,
56}
57
58impl ServerOptions {
59    pub fn builder() -> ServerOptionsBuilder {
60        ServerOptionsBuilder::default()
61    }
62}
63
64// todo - rename TcpResponseServer
65// we may need to disambiguate this and a TcpRequestServer
66
67/// A [`TcpStreamServer`] is a TCP service that listens on a port for incoming response connections.
68/// A Response connection is a connection that is established by a client with the intention of sending
69/// specific data back to the server. The key differentiating factor is that a [`ResponseServer`] is
70/// expecting a connection from a client with an established subject.
71pub struct TcpStreamServer {
72    local_ip: String,
73    local_port: u16,
74    state: Arc<Mutex<State>>,
75}
76
77// pub struct TcpStreamReceiver {
78//     address: TcpStreamConnectionInfo,
79//     state: Arc<Mutex<State>>,
80//     rx: mpsc::Receiver<ResponseType>,
81// }
82
83#[allow(dead_code)]
84struct RequestedSendConnection {
85    context: Arc<dyn AsyncEngineContext>,
86    connection: oneshot::Sender<Result<StreamSender, String>>,
87}
88
89struct RequestedRecvConnection {
90    context: Arc<dyn AsyncEngineContext>,
91    connection: oneshot::Sender<Result<StreamReceiver, String>>,
92}
93
94// /// When registering a new TcpStream on the server, the registration method will return a [`Connections`] object.
95// /// This [`Connections`] object will have two [`oneshot::Receiver`] objects, one for the [`TcpStreamSender`] and one for the [`TcpStreamReceiver`].
96// /// The [`Connections`] object can be awaited to get the [`TcpStreamSender`] and [`TcpStreamReceiver`] objects; these objects will
97// /// be made available when the matching Client has connected to the server.
98// pub struct Connections {
99//     pub address: TcpStreamConnectionInfo,
100
101//     /// The [`oneshot::Receiver`] for the [`TcpStreamSender`]. Awaiting this object will return the [`TcpStreamSender`] object once
102//     /// the client has connected to the server.
103//     pub sender: Option<oneshot::Receiver<StreamSender>>,
104
105//     /// The [`oneshot::Receiver`] for the [`TcpStreamReceiver`]. Awaiting this object will return the [`TcpStreamReceiver`] object once
106//     /// the client has connected to the server.
107//     pub receiver: Option<oneshot::Receiver<StreamReceiver>>,
108// }
109
110#[derive(Default)]
111struct State {
112    tx_subjects: HashMap<String, RequestedSendConnection>,
113    rx_subjects: HashMap<String, RequestedRecvConnection>,
114    handle: Option<tokio::task::JoinHandle<()>>,
115}
116
117impl TcpStreamServer {
118    pub fn options_builder() -> ServerOptionsBuilder {
119        ServerOptionsBuilder::default()
120    }
121
122    pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
123        let local_ip = match options.interface {
124            Some(interface) => {
125                let interfaces: HashMap<String, std::net::IpAddr> =
126                    list_afinet_netifas()?.into_iter().collect();
127
128                interfaces
129                    .get(&interface)
130                    .ok_or(PipelineError::Generic(format!(
131                        "Interface not found: {}",
132                        interface
133                    )))?
134                    .to_string()
135            }
136            None => local_ip().unwrap().to_string(),
137        };
138
139        let state = Arc::new(Mutex::new(State::default()));
140
141        let local_port = Self::start(local_ip.clone(), options.port, state.clone())
142            .await
143            .map_err(|e| {
144                PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
145            })?;
146
147        tracing::info!("TcpStreamServer started on {}:{}", local_ip, local_port);
148
149        Ok(Arc::new(Self {
150            local_ip,
151            local_port,
152            state,
153        }))
154    }
155
156    #[allow(clippy::await_holding_lock)]
157    async fn start(local_ip: String, local_port: u16, state: Arc<Mutex<State>>) -> Result<u16> {
158        let addr = format!("{}:{}", local_ip, local_port);
159        let state_clone = state.clone();
160        let mut guard = state.lock().await;
161        if guard.handle.is_some() {
162            panic!("TcpStreamServer already started");
163        }
164        let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<u16>>();
165        let handle = tokio::spawn(tcp_listener(addr, state_clone, ready_tx));
166        guard.handle = Some(handle);
167        drop(guard);
168        let local_port = ready_rx.await??;
169        Ok(local_port)
170    }
171}
172
173// todo - possible rename ResponseService to ResponseServer
174#[async_trait::async_trait]
175impl ResponseService for TcpStreamServer {
176    /// Register a new subject and sender with the response subscriber
177    /// Produces an RAII object that will deregister the subject when dropped
178    ///
179    /// we need to register both data in and data out entries
180    /// there might be forward pipeline that want to consume the data out stream
181    /// and there might be a response stream that wants to consume the data in stream
182    /// on registration, we need to specific if we want data-in, data-out or both
183    /// this will map to the type of service that is runniing, i.e. Single or Many In //
184    /// Single or Many Out
185    ///
186    /// todo(ryan) - return a connection object that can be awaited. when successfully connected,
187    /// can ask for the sender and receiver
188    ///
189    /// OR
190    ///
191    /// we make it into register sender and register receiver, both would return a connection object
192    /// and when a connection is established, we'd get the respective sender or receiver
193    ///
194    /// the registration probably needs to be done in one-go, so we should use a builder object for
195    /// requesting a receiver and optional sender
196    async fn register(&self, options: StreamOptions) -> PendingConnections {
197        // oneshot channels to pass back the sender and receiver objects
198
199        let address = format!("{}:{}", self.local_ip, self.local_port);
200        tracing::debug!("Registering new TcpStream on {}", address);
201
202        let send_stream = if options.enable_request_stream {
203            let sender_subject = uuid::Uuid::new_v4().to_string();
204
205            let (pending_sender_tx, pending_sender_rx) = oneshot::channel();
206
207            let connection_info = RequestedSendConnection {
208                context: options.context.clone(),
209                connection: pending_sender_tx,
210            };
211
212            let mut state = self.state.lock().await;
213            state
214                .tx_subjects
215                .insert(sender_subject.clone(), connection_info);
216
217            let registered_stream = RegisteredStream {
218                connection_info: TcpStreamConnectionInfo {
219                    address: address.clone(),
220                    subject: sender_subject.clone(),
221                    context: options.context.id().to_string(),
222                    stream_type: StreamType::Request,
223                }
224                .into(),
225                stream_provider: pending_sender_rx,
226            };
227
228            Some(registered_stream)
229        } else {
230            None
231        };
232
233        let recv_stream = if options.enable_response_stream {
234            let (pending_recver_tx, pending_recver_rx) = oneshot::channel();
235            let receiver_subject = uuid::Uuid::new_v4().to_string();
236
237            let connection_info = RequestedRecvConnection {
238                context: options.context.clone(),
239                connection: pending_recver_tx,
240            };
241
242            let mut state = self.state.lock().await;
243            state
244                .rx_subjects
245                .insert(receiver_subject.clone(), connection_info);
246
247            let registered_stream = RegisteredStream {
248                connection_info: TcpStreamConnectionInfo {
249                    address: address.clone(),
250                    subject: receiver_subject.clone(),
251                    context: options.context.id().to_string(),
252                    stream_type: StreamType::Response,
253                }
254                .into(),
255                stream_provider: pending_recver_rx,
256            };
257
258            Some(registered_stream)
259        } else {
260            None
261        };
262
263        PendingConnections {
264            send_stream,
265            recv_stream,
266        }
267    }
268}
269
270// this method listens on a tcp port for incoming connections
271// new connections are expected to send a protocol specific handshake
272// for us to determine the subject they are interested in, in this case,
273// we expect the first message to be [`FirstMessage`] from which we find
274// the sender, then we spawn a task to forward all bytes from the tcp stream
275// to the sender
276async fn tcp_listener(
277    addr: String,
278    state: Arc<Mutex<State>>,
279    read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
280) {
281    let listener = tokio::net::TcpListener::bind(&addr)
282        .await
283        .map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
284
285    let listener = match listener {
286        Ok(listener) => {
287            let addr = listener
288                .local_addr()
289                .map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))
290                .unwrap();
291
292            read_tx
293                .send(Ok(addr.port()))
294                .expect("Failed to send ready signal");
295
296            listener
297        }
298        Err(e) => {
299            read_tx.send(Err(e)).expect("Failed to send ready signal");
300            return;
301        }
302    };
303
304    loop {
305        let (stream, _addr) = listener.accept().await.unwrap();
306        stream.set_nodelay(true).unwrap();
307        tokio::spawn(handle_connection(stream, state.clone()));
308    }
309
310    // #[instrument(level = "trace"), skip(state)]
311    // todo - clone before spawn and trace process_stream
312    async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
313        let result = process_stream(stream, state).await;
314        match result {
315            Ok(_) => tracing::trace!("TcpStream connection closed"),
316            Err(e) => tracing::error!("TcpStream connection failed: {}", e),
317        }
318    }
319
320    /// This method is responsible for the internal tcp stream handshake
321    /// The handshake will specialize the stream as a request/sender or response/receiver stream
322    async fn process_stream(
323        stream: tokio::net::TcpStream,
324        state: Arc<Mutex<State>>,
325    ) -> Result<(), String> {
326        // split the socket in to a reader and writer
327        let (read_half, write_half) = tokio::io::split(stream);
328
329        // attach the codec to the reader and writer to get framed readers and writers
330        let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
331        let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
332
333        // the internal tcp [`CallHomeHandshake`] connects the socket to the requester
334        // here we await this first message as a raw bytes two part message
335        let first_message = framed_reader
336            .next()
337            .await
338            .ok_or("Connection closed without a ControlMessge".to_string())?
339            .map_err(|e| e.to_string())?;
340
341        // we await on the raw bytes which should come in as a header only message
342        // todo - improve error handling - check for no data
343        if first_message.header().is_none() {
344            return Err("Expected ControlMessage, got DataMessage".to_string());
345        }
346
347        // deserialize the [`CallHomeHandshake`] message
348        let handshake: CallHomeHandshake = serde_json::from_slice(first_message.header().unwrap())
349            .map_err(|e| {
350                format!(
351                    "Failed to deserialize the first message as a valid `CallHomeHandshake`: {}",
352                    e
353                )
354            })?;
355
356        // branch here to handle sender stream or receiver stream
357        match handshake.stream_type {
358            StreamType::Request => process_request_stream().await,
359            StreamType::Response => {
360                process_response_stream(handshake.subject, state, framed_reader, framed_writer)
361                    .await
362            }
363        }
364        .map_err(|e| format!("Failed to process stream: {}", e))
365    }
366
367    async fn process_request_stream() -> Result<(), String> {
368        Ok(())
369    }
370
371    async fn process_response_stream(
372        subject: String,
373        state: Arc<Mutex<State>>,
374        mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
375        writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
376    ) -> Result<(), String> {
377        let response_stream = state
378            .lock().await
379            .rx_subjects
380            .remove(&subject)
381            .ok_or(format!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
382
383        // unwrap response_stream
384        let RequestedRecvConnection {
385            context,
386            connection,
387        } = response_stream;
388
389        // the [`Prologue`]
390        // there must be a second control message it indicate the other segment's generate method was successful
391        let prologue = reader
392            .next()
393            .await
394            .ok_or("Connection closed without a ControlMessge".to_string())?
395            .map_err(|e| e.to_string())?;
396
397        // deserialize prologue
398        let prologue = match prologue.into_message_type() {
399            TwoPartMessageType::HeaderOnly(header) => {
400                let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
401                    .map_err(|e| format!("Failed to deserialize ControlMessage: {}", e))?;
402                prologue
403            }
404            _ => {
405                panic!("Expected HeaderOnly ControlMessage; internally logic error")
406            }
407        };
408
409        // await the control message of GTG or Error, if error, then connection.send(Err(String)), which should fail the
410        // generate call chain
411        //
412        // note: this second control message might be delayed, but the expensive part of setting up the connection
413        // is both complete and ready for data flow; awaiting here is not a performance hit or problem and it allows
414        // us to trace the initial setup time vs the time to prologue
415        if let Some(error) = &prologue.error {
416            let _ = connection.send(Err(error.clone()));
417            return Err(format!("Received error prologue: {}", error));
418        }
419
420        // we need to know the buffer size from the registration options; add this to the RequestRecvConnection object
421        let (tx, rx) = mpsc::channel(16);
422
423        if connection
424            .send(Ok(crate::pipeline::network::StreamReceiver { rx }))
425            .is_err()
426        {
427            return Err("The requester of the stream has been dropped before the connection was established".to_string());
428        }
429
430        let (alive_tx, alive_rx) = mpsc::channel::<()>(1);
431        let (control_tx, _control_rx) = mpsc::channel::<Bytes>(8);
432
433        // monitor task
434        // if the context is cancelled, we need to forward the message across the transport layer
435        // we only determine the forwarding task on a kill signal, on a stop signal, we issue the stop signal, then await for the producer
436        // to naturally close the stream
437        let monitor_task = tokio::spawn(monitor(writer, context.clone(), alive_tx));
438
439        // forward task
440        let forward_task = tokio::spawn(handle_response_stream(
441            reader,
442            tx,
443            control_tx,
444            context.clone(),
445            alive_rx,
446        ));
447
448        // check the results of each of the tasks
449        let (monitor_result, forward_result) = tokio::join!(monitor_task, forward_task);
450
451        // if either of the tasks failed, we need to return an error
452        if let Err(e) = monitor_result {
453            return Err(format!("Monitor task failed: {}", e));
454        }
455        if let Err(e) = forward_result {
456            return Err(format!("Forward task failed: {}", e));
457        }
458
459        Ok(())
460    }
461
462    async fn handle_response_stream(
463        mut framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
464        response_tx: mpsc::Sender<Bytes>,
465        control_tx: mpsc::Sender<Bytes>,
466        context: Arc<dyn AsyncEngineContext>,
467        alive_rx: mpsc::Receiver<()>,
468    ) -> Result<(), String> {
469        // loop over reading the tcp stream and checking if the writer is closed
470        loop {
471            tokio::select! {
472                msg = framed_reader.next() => {
473                    match msg {
474                        Some(Ok(msg)) => {
475                            let (header, data) = msg.into_parts();
476
477                            if !header.is_empty() && (control_tx.send(header).await).is_err() {
478                                tracing::trace!("Control channel closed")
479                            }
480
481                            if !data.is_empty() {
482                                response_tx.send(data).await.unwrap();
483                            }
484                        }
485                        Some(Err(e)) => {
486                            return Err(format!("Failed to read TwoPartCodec message from TcpStream: {}", e));
487                        }
488                        None => {
489                            tracing::trace!("TcpStream closed naturally");
490                            break;
491                        }
492                    }
493                }
494                _ = response_tx.closed() => {
495                    break;
496                }
497                _ = context.killed() => { break; }
498            }
499        }
500        drop(alive_rx);
501        Ok(())
502    }
503
504    #[allow(dead_code)]
505    async fn handle_control_message(
506        mut control_rx: mpsc::Receiver<Bytes>,
507        context: Arc<dyn AsyncEngineContext>,
508        alive_tx: mpsc::Sender<()>,
509    ) -> Result<(), String> {
510        loop {
511            tokio::select! {
512                msg = control_rx.recv() => {
513                    match msg {
514                        Some(_msg) => {
515                            // handle control message
516                        }
517                        None => {
518                            tracing::trace!("Control channel closed");
519                            break;
520                        }
521                    }
522                }
523                _ = context.killed() => {
524                    break;
525                }
526            }
527        }
528        drop(alive_tx);
529        Ok(())
530    }
531
532    async fn monitor(
533        _socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
534        ctx: Arc<dyn AsyncEngineContext>,
535        alive_tx: mpsc::Sender<()>,
536    ) {
537        let alive_tx = alive_tx;
538        tokio::select! {
539            _ = ctx.stopped() => {
540                // send cancellation message
541                panic!("impl cancellation signal");
542            }
543            _ = alive_tx.closed() => {
544                tracing::trace!("response stream closed naturally")
545            }
546        }
547        let mut framed_writer = _socket_tx;
548        framed_writer.get_mut().shutdown().await.unwrap();
549    }
550}