rust_pipe/transport/
stdio.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::process::Stdio;
4use std::sync::Arc;
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::process::Command;
7use tokio::sync::{mpsc, RwLock};
8use tracing;
9
10use super::{Message, Transport, TransportError};
11
12pub struct StdioProcess {
14 pub command: String,
15 pub args: Vec<String>,
16 pub worker_id: String,
17 pub supported_tasks: Vec<String>,
18}
19
20pub struct StdioTransport {
22 processes: Arc<RwLock<HashMap<String, StdioWorker>>>,
23 configs: Vec<StdioProcess>,
24 on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
25}
26
27struct StdioWorker {
28 stdin_tx: mpsc::UnboundedSender<String>,
29}
30
31impl StdioTransport {
32 pub fn new(
33 configs: Vec<StdioProcess>,
34 on_message: impl Fn(String, Message) + Send + Sync + 'static,
35 ) -> Self {
36 Self {
37 processes: Arc::new(RwLock::new(HashMap::new())),
38 configs,
39 on_message: Arc::new(on_message),
40 }
41 }
42}
43
44#[async_trait]
45impl Transport for StdioTransport {
46 async fn start(&self) -> Result<(), TransportError> {
47 for config in &self.configs {
48 let worker_id = config.worker_id.clone();
49 let on_message = self.on_message.clone();
50 let processes = self.processes.clone();
51
52 let mut child = Command::new(&config.command)
53 .args(&config.args)
54 .stdin(Stdio::piped())
55 .stdout(Stdio::piped())
56 .stderr(Stdio::piped())
57 .spawn()
58 .map_err(|e| {
59 TransportError::ConnectionFailed(format!(
60 "Failed to spawn '{}': {}",
61 config.command, e
62 ))
63 })?;
64
65 let stdin = child.stdin.take().expect("stdin piped");
66 let stdout = child.stdout.take().expect("stdout piped");
67
68 let (stdin_tx, mut stdin_rx) = mpsc::unbounded_channel::<String>();
69
70 let reg_msg = Message::WorkerRegister {
72 registration: super::WorkerRegistration {
73 worker_id: worker_id.clone(),
74 supported_tasks: config.supported_tasks.clone(),
75 max_concurrency: 1,
76 language: super::WorkerLanguage::Other("stdio".to_string()),
77 },
78 };
79 on_message(worker_id.clone(), reg_msg);
80
81 let wid = worker_id.clone();
83 tokio::spawn(async move {
84 let mut stdin = stdin;
85 while let Some(line) = stdin_rx.recv().await {
86 if stdin.write_all(line.as_bytes()).await.is_err() {
87 tracing::error!(worker_id = %wid, "Failed to write to stdin");
88 break;
89 }
90 if stdin.write_all(b"\n").await.is_err() {
91 break;
92 }
93 let _ = stdin.flush().await;
94 }
95 });
96
97 let wid = worker_id.clone();
99 tokio::spawn(async move {
100 let reader = BufReader::new(stdout);
101 let mut lines = reader.lines();
102
103 while let Ok(Some(line)) = lines.next_line().await {
104 if line.trim().is_empty() {
105 continue;
106 }
107 match serde_json::from_str::<Message>(&line) {
108 Ok(msg) => on_message(wid.clone(), msg),
109 Err(e) => {
110 tracing::debug!(
111 worker_id = %wid,
112 line = %line,
113 error = %e,
114 "Non-JSON line from worker, ignoring"
115 );
116 }
117 }
118 }
119 tracing::info!(worker_id = %wid, "Stdio worker stdout closed");
120 });
121
122 processes
123 .write()
124 .await
125 .insert(worker_id.clone(), StdioWorker { stdin_tx });
126
127 tracing::info!(
128 worker_id = %worker_id,
129 command = %config.command,
130 "Stdio worker spawned"
131 );
132 }
133
134 Ok(())
135 }
136
137 async fn stop(&self) -> Result<(), TransportError> {
138 let processes = self.processes.read().await;
139 for (worker_id, worker) in processes.iter() {
140 let shutdown = Message::Shutdown { graceful: true };
141 let json = serde_json::to_string(&shutdown).unwrap_or_default();
142 let _ = worker.stdin_tx.send(json);
143 tracing::info!(worker_id = %worker_id, "Sent shutdown to stdio worker");
144 }
145 Ok(())
146 }
147
148 async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
149 let processes = self.processes.read().await;
150 let worker = processes
151 .get(worker_id)
152 .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
153
154 let json = serde_json::to_string(&message)
155 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
156
157 worker
158 .stdin_tx
159 .send(json)
160 .map_err(|e| TransportError::SendFailed(e.to_string()))
161 }
162
163 async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
164 let processes = self.processes.read().await;
165 let json = serde_json::to_string(&message)
166 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
167
168 for (_, worker) in processes.iter() {
169 let _ = worker.stdin_tx.send(json.clone());
170 }
171 Ok(())
172 }
173}