rust_pipe/transport/
stdio.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::process::Stdio;
4use std::sync::Arc;
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::process::Command;
7use tokio::sync::{mpsc, RwLock};
8use tracing;
9
10use super::{Message, Transport, TransportError};
11
12pub struct StdioProcess {
14 pub command: String,
15 pub args: Vec<String>,
16 pub worker_id: String,
17 pub supported_tasks: Vec<String>,
18}
19
20pub struct StdioTransport {
22 processes: Arc<RwLock<HashMap<String, StdioWorker>>>,
23 configs: Vec<StdioProcess>,
24 on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
25}
26
27struct StdioWorker {
28 stdin_tx: mpsc::UnboundedSender<String>,
29}
30
31impl StdioTransport {
32 pub fn new(
33 configs: Vec<StdioProcess>,
34 on_message: impl Fn(String, Message) + Send + Sync + 'static,
35 ) -> Self {
36 Self {
37 processes: Arc::new(RwLock::new(HashMap::new())),
38 configs,
39 on_message: Arc::new(on_message),
40 }
41 }
42}
43
44#[async_trait]
45impl Transport for StdioTransport {
46 async fn start(&self) -> Result<(), TransportError> {
47 for config in &self.configs {
48 let worker_id = config.worker_id.clone();
49 let on_message = self.on_message.clone();
50 let processes = self.processes.clone();
51
52 let mut child = Command::new(&config.command)
53 .args(&config.args)
54 .stdin(Stdio::piped())
55 .stdout(Stdio::piped())
56 .stderr(Stdio::piped())
57 .spawn()
58 .map_err(|e| {
59 TransportError::ConnectionFailed(format!(
60 "Failed to spawn '{}': {}",
61 config.command, e
62 ))
63 })?;
64
65 let stdin = child.stdin.take().expect("stdin piped");
66 let stdout = child.stdout.take().expect("stdout piped");
67
68 let (stdin_tx, mut stdin_rx) = mpsc::unbounded_channel::<String>();
69
70 let reg_msg = Message::WorkerRegister {
72 registration: super::WorkerRegistration {
73 worker_id: worker_id.clone(),
74 supported_tasks: config.supported_tasks.clone(),
75 max_concurrency: 1,
76 language: super::WorkerLanguage::Other("stdio".to_string()),
77 tags: None,
78 },
79 };
80 on_message(worker_id.clone(), reg_msg);
81
82 let wid = worker_id.clone();
84 tokio::spawn(async move {
85 let mut stdin = stdin;
86 while let Some(line) = stdin_rx.recv().await {
87 if stdin.write_all(line.as_bytes()).await.is_err() {
88 tracing::error!(worker_id = %wid, "Failed to write to stdin");
89 break;
90 }
91 if stdin.write_all(b"\n").await.is_err() {
92 break;
93 }
94 let _ = stdin.flush().await;
95 }
96 });
97
98 let wid = worker_id.clone();
100 tokio::spawn(async move {
101 let reader = BufReader::new(stdout);
102 let mut lines = reader.lines();
103
104 while let Ok(Some(line)) = lines.next_line().await {
105 if line.trim().is_empty() {
106 continue;
107 }
108 match serde_json::from_str::<Message>(&line) {
109 Ok(msg) => on_message(wid.clone(), msg),
110 Err(e) => {
111 tracing::debug!(
112 worker_id = %wid,
113 line = %line,
114 error = %e,
115 "Non-JSON line from worker, ignoring"
116 );
117 }
118 }
119 }
120 tracing::info!(worker_id = %wid, "Stdio worker stdout closed");
121 });
122
123 processes
124 .write()
125 .await
126 .insert(worker_id.clone(), StdioWorker { stdin_tx });
127
128 tracing::info!(
129 worker_id = %worker_id,
130 command = %config.command,
131 "Stdio worker spawned"
132 );
133 }
134
135 Ok(())
136 }
137
138 async fn stop(&self) -> Result<(), TransportError> {
139 let processes = self.processes.read().await;
140 for (worker_id, worker) in processes.iter() {
141 let shutdown = Message::Shutdown { graceful: true };
142 let json = serde_json::to_string(&shutdown).unwrap_or_default();
143 let _ = worker.stdin_tx.send(json);
144 tracing::info!(worker_id = %worker_id, "Sent shutdown to stdio worker");
145 }
146 Ok(())
147 }
148
149 async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
150 let processes = self.processes.read().await;
151 let worker = processes
152 .get(worker_id)
153 .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
154
155 let json = serde_json::to_string(&message)
156 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
157
158 worker
159 .stdin_tx
160 .send(json)
161 .map_err(|e| TransportError::SendFailed(e.to_string()))
162 }
163
164 async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
165 let processes = self.processes.read().await;
166 let json = serde_json::to_string(&message)
167 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
168
169 for (_, worker) in processes.iter() {
170 let _ = worker.stdin_tx.send(json.clone());
171 }
172 Ok(())
173 }
174}