triton_distributed/pipeline/network/tcp/
server.rs1use anyhow::Result;
17use core::panic;
18use std::{collections::HashMap, sync::Arc};
19use tokio::sync::Mutex;
20
21use bytes::Bytes;
22use derive_builder::Builder;
23use futures::StreamExt;
24use local_ip_address::{list_afinet_netifas, local_ip};
25use serde::{Deserialize, Serialize};
26use tokio::{
27 io::AsyncWriteExt,
28 sync::{mpsc, oneshot},
29};
30use tokio_util::codec::{FramedRead, FramedWrite};
31
32use super::{
33 CallHomeHandshake, PendingConnections, RegisteredStream, StreamOptions, StreamReceiver,
34 StreamSender, TcpStreamConnectionInfo, TwoPartCodec,
35};
36use crate::engine::AsyncEngineContext;
37use crate::pipeline::{
38 network::{
39 codec::{TwoPartMessage, TwoPartMessageType},
40 tcp::StreamType,
41 ResponseService, ResponseStreamPrologue,
42 },
43 PipelineError,
44};
45
46#[allow(dead_code)]
47type ResponseType = TwoPartMessage;
48
49#[derive(Debug, Serialize, Deserialize, Clone, Builder, Default)]
50pub struct ServerOptions {
51 #[builder(default = "0")]
52 pub port: u16,
53
54 #[builder(default)]
55 pub interface: Option<String>,
56}
57
58impl ServerOptions {
59 pub fn builder() -> ServerOptionsBuilder {
60 ServerOptionsBuilder::default()
61 }
62}
63
64pub struct TcpStreamServer {
72 local_ip: String,
73 local_port: u16,
74 state: Arc<Mutex<State>>,
75}
76
77#[allow(dead_code)]
84struct RequestedSendConnection {
85 context: Arc<dyn AsyncEngineContext>,
86 connection: oneshot::Sender<Result<StreamSender, String>>,
87}
88
89struct RequestedRecvConnection {
90 context: Arc<dyn AsyncEngineContext>,
91 connection: oneshot::Sender<Result<StreamReceiver, String>>,
92}
93
94#[derive(Default)]
111struct State {
112 tx_subjects: HashMap<String, RequestedSendConnection>,
113 rx_subjects: HashMap<String, RequestedRecvConnection>,
114 handle: Option<tokio::task::JoinHandle<()>>,
115}
116
117impl TcpStreamServer {
118 pub fn options_builder() -> ServerOptionsBuilder {
119 ServerOptionsBuilder::default()
120 }
121
122 pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
123 let local_ip = match options.interface {
124 Some(interface) => {
125 let interfaces: HashMap<String, std::net::IpAddr> =
126 list_afinet_netifas()?.into_iter().collect();
127
128 interfaces
129 .get(&interface)
130 .ok_or(PipelineError::Generic(format!(
131 "Interface not found: {}",
132 interface
133 )))?
134 .to_string()
135 }
136 None => local_ip().unwrap().to_string(),
137 };
138
139 let state = Arc::new(Mutex::new(State::default()));
140
141 let local_port = Self::start(local_ip.clone(), options.port, state.clone())
142 .await
143 .map_err(|e| {
144 PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
145 })?;
146
147 tracing::info!("TcpStreamServer started on {}:{}", local_ip, local_port);
148
149 Ok(Arc::new(Self {
150 local_ip,
151 local_port,
152 state,
153 }))
154 }
155
156 #[allow(clippy::await_holding_lock)]
157 async fn start(local_ip: String, local_port: u16, state: Arc<Mutex<State>>) -> Result<u16> {
158 let addr = format!("{}:{}", local_ip, local_port);
159 let state_clone = state.clone();
160 let mut guard = state.lock().await;
161 if guard.handle.is_some() {
162 panic!("TcpStreamServer already started");
163 }
164 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<u16>>();
165 let handle = tokio::spawn(tcp_listener(addr, state_clone, ready_tx));
166 guard.handle = Some(handle);
167 drop(guard);
168 let local_port = ready_rx.await??;
169 Ok(local_port)
170 }
171}
172
173#[async_trait::async_trait]
175impl ResponseService for TcpStreamServer {
176 async fn register(&self, options: StreamOptions) -> PendingConnections {
197 let address = format!("{}:{}", self.local_ip, self.local_port);
200 tracing::debug!("Registering new TcpStream on {}", address);
201
202 let send_stream = if options.enable_request_stream {
203 let sender_subject = uuid::Uuid::new_v4().to_string();
204
205 let (pending_sender_tx, pending_sender_rx) = oneshot::channel();
206
207 let connection_info = RequestedSendConnection {
208 context: options.context.clone(),
209 connection: pending_sender_tx,
210 };
211
212 let mut state = self.state.lock().await;
213 state
214 .tx_subjects
215 .insert(sender_subject.clone(), connection_info);
216
217 let registered_stream = RegisteredStream {
218 connection_info: TcpStreamConnectionInfo {
219 address: address.clone(),
220 subject: sender_subject.clone(),
221 context: options.context.id().to_string(),
222 stream_type: StreamType::Request,
223 }
224 .into(),
225 stream_provider: pending_sender_rx,
226 };
227
228 Some(registered_stream)
229 } else {
230 None
231 };
232
233 let recv_stream = if options.enable_response_stream {
234 let (pending_recver_tx, pending_recver_rx) = oneshot::channel();
235 let receiver_subject = uuid::Uuid::new_v4().to_string();
236
237 let connection_info = RequestedRecvConnection {
238 context: options.context.clone(),
239 connection: pending_recver_tx,
240 };
241
242 let mut state = self.state.lock().await;
243 state
244 .rx_subjects
245 .insert(receiver_subject.clone(), connection_info);
246
247 let registered_stream = RegisteredStream {
248 connection_info: TcpStreamConnectionInfo {
249 address: address.clone(),
250 subject: receiver_subject.clone(),
251 context: options.context.id().to_string(),
252 stream_type: StreamType::Response,
253 }
254 .into(),
255 stream_provider: pending_recver_rx,
256 };
257
258 Some(registered_stream)
259 } else {
260 None
261 };
262
263 PendingConnections {
264 send_stream,
265 recv_stream,
266 }
267 }
268}
269
270async fn tcp_listener(
277 addr: String,
278 state: Arc<Mutex<State>>,
279 read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
280) {
281 let listener = tokio::net::TcpListener::bind(&addr)
282 .await
283 .map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
284
285 let listener = match listener {
286 Ok(listener) => {
287 let addr = listener
288 .local_addr()
289 .map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))
290 .unwrap();
291
292 read_tx
293 .send(Ok(addr.port()))
294 .expect("Failed to send ready signal");
295
296 listener
297 }
298 Err(e) => {
299 read_tx.send(Err(e)).expect("Failed to send ready signal");
300 return;
301 }
302 };
303
304 loop {
305 let (stream, _addr) = listener.accept().await.unwrap();
306 stream.set_nodelay(true).unwrap();
307 tokio::spawn(handle_connection(stream, state.clone()));
308 }
309
310 async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
313 let result = process_stream(stream, state).await;
314 match result {
315 Ok(_) => tracing::trace!("TcpStream connection closed"),
316 Err(e) => tracing::error!("TcpStream connection failed: {}", e),
317 }
318 }
319
320 async fn process_stream(
323 stream: tokio::net::TcpStream,
324 state: Arc<Mutex<State>>,
325 ) -> Result<(), String> {
326 let (read_half, write_half) = tokio::io::split(stream);
328
329 let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
331 let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
332
333 let first_message = framed_reader
336 .next()
337 .await
338 .ok_or("Connection closed without a ControlMessge".to_string())?
339 .map_err(|e| e.to_string())?;
340
341 if first_message.header().is_none() {
344 return Err("Expected ControlMessage, got DataMessage".to_string());
345 }
346
347 let handshake: CallHomeHandshake = serde_json::from_slice(first_message.header().unwrap())
349 .map_err(|e| {
350 format!(
351 "Failed to deserialize the first message as a valid `CallHomeHandshake`: {}",
352 e
353 )
354 })?;
355
356 match handshake.stream_type {
358 StreamType::Request => process_request_stream().await,
359 StreamType::Response => {
360 process_response_stream(handshake.subject, state, framed_reader, framed_writer)
361 .await
362 }
363 }
364 .map_err(|e| format!("Failed to process stream: {}", e))
365 }
366
367 async fn process_request_stream() -> Result<(), String> {
368 Ok(())
369 }
370
371 async fn process_response_stream(
372 subject: String,
373 state: Arc<Mutex<State>>,
374 mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
375 writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
376 ) -> Result<(), String> {
377 let response_stream = state
378 .lock().await
379 .rx_subjects
380 .remove(&subject)
381 .ok_or(format!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
382
383 let RequestedRecvConnection {
385 context,
386 connection,
387 } = response_stream;
388
389 let prologue = reader
392 .next()
393 .await
394 .ok_or("Connection closed without a ControlMessge".to_string())?
395 .map_err(|e| e.to_string())?;
396
397 let prologue = match prologue.into_message_type() {
399 TwoPartMessageType::HeaderOnly(header) => {
400 let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
401 .map_err(|e| format!("Failed to deserialize ControlMessage: {}", e))?;
402 prologue
403 }
404 _ => {
405 panic!("Expected HeaderOnly ControlMessage; internally logic error")
406 }
407 };
408
409 if let Some(error) = &prologue.error {
416 let _ = connection.send(Err(error.clone()));
417 return Err(format!("Received error prologue: {}", error));
418 }
419
420 let (tx, rx) = mpsc::channel(16);
422
423 if connection
424 .send(Ok(crate::pipeline::network::StreamReceiver { rx }))
425 .is_err()
426 {
427 return Err("The requester of the stream has been dropped before the connection was established".to_string());
428 }
429
430 let (alive_tx, alive_rx) = mpsc::channel::<()>(1);
431 let (control_tx, _control_rx) = mpsc::channel::<Bytes>(8);
432
433 let monitor_task = tokio::spawn(monitor(writer, context.clone(), alive_tx));
438
439 let forward_task = tokio::spawn(handle_response_stream(
441 reader,
442 tx,
443 control_tx,
444 context.clone(),
445 alive_rx,
446 ));
447
448 let (monitor_result, forward_result) = tokio::join!(monitor_task, forward_task);
450
451 if let Err(e) = monitor_result {
453 return Err(format!("Monitor task failed: {}", e));
454 }
455 if let Err(e) = forward_result {
456 return Err(format!("Forward task failed: {}", e));
457 }
458
459 Ok(())
460 }
461
462 async fn handle_response_stream(
463 mut framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
464 response_tx: mpsc::Sender<Bytes>,
465 control_tx: mpsc::Sender<Bytes>,
466 context: Arc<dyn AsyncEngineContext>,
467 alive_rx: mpsc::Receiver<()>,
468 ) -> Result<(), String> {
469 loop {
471 tokio::select! {
472 msg = framed_reader.next() => {
473 match msg {
474 Some(Ok(msg)) => {
475 let (header, data) = msg.into_parts();
476
477 if !header.is_empty() && (control_tx.send(header).await).is_err() {
478 tracing::trace!("Control channel closed")
479 }
480
481 if !data.is_empty() {
482 response_tx.send(data).await.unwrap();
483 }
484 }
485 Some(Err(e)) => {
486 return Err(format!("Failed to read TwoPartCodec message from TcpStream: {}", e));
487 }
488 None => {
489 tracing::trace!("TcpStream closed naturally");
490 break;
491 }
492 }
493 }
494 _ = response_tx.closed() => {
495 break;
496 }
497 _ = context.killed() => { break; }
498 }
499 }
500 drop(alive_rx);
501 Ok(())
502 }
503
504 #[allow(dead_code)]
505 async fn handle_control_message(
506 mut control_rx: mpsc::Receiver<Bytes>,
507 context: Arc<dyn AsyncEngineContext>,
508 alive_tx: mpsc::Sender<()>,
509 ) -> Result<(), String> {
510 loop {
511 tokio::select! {
512 msg = control_rx.recv() => {
513 match msg {
514 Some(_msg) => {
515 }
517 None => {
518 tracing::trace!("Control channel closed");
519 break;
520 }
521 }
522 }
523 _ = context.killed() => {
524 break;
525 }
526 }
527 }
528 drop(alive_tx);
529 Ok(())
530 }
531
532 async fn monitor(
533 _socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
534 ctx: Arc<dyn AsyncEngineContext>,
535 alive_tx: mpsc::Sender<()>,
536 ) {
537 let alive_tx = alive_tx;
538 tokio::select! {
539 _ = ctx.stopped() => {
540 panic!("impl cancellation signal");
542 }
543 _ = alive_tx.closed() => {
544 tracing::trace!("response stream closed naturally")
545 }
546 }
547 let mut framed_writer = _socket_tx;
548 framed_writer.get_mut().shutdown().await.unwrap();
549 }
550}