Skip to main content

rust_pipe/transport/
stdio.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::process::Stdio;
4use std::sync::Arc;
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::process::Command;
7use tokio::sync::{mpsc, RwLock};
8use tracing;
9
10use super::{Message, Transport, TransportError};
11
12/// Configuration for a stdio-based worker process.
13pub struct StdioProcess {
14    pub command: String,
15    pub args: Vec<String>,
16    pub worker_id: String,
17    pub supported_tasks: Vec<String>,
18}
19
20/// Transport that communicates with workers via stdin/stdout pipes.
21pub struct StdioTransport {
22    processes: Arc<RwLock<HashMap<String, StdioWorker>>>,
23    configs: Vec<StdioProcess>,
24    on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
25}
26
27struct StdioWorker {
28    stdin_tx: mpsc::UnboundedSender<String>,
29}
30
31impl StdioTransport {
32    pub fn new(
33        configs: Vec<StdioProcess>,
34        on_message: impl Fn(String, Message) + Send + Sync + 'static,
35    ) -> Self {
36        Self {
37            processes: Arc::new(RwLock::new(HashMap::new())),
38            configs,
39            on_message: Arc::new(on_message),
40        }
41    }
42}
43
44#[async_trait]
45impl Transport for StdioTransport {
46    async fn start(&self) -> Result<(), TransportError> {
47        for config in &self.configs {
48            let worker_id = config.worker_id.clone();
49            let on_message = self.on_message.clone();
50            let processes = self.processes.clone();
51
52            let mut child = Command::new(&config.command)
53                .args(&config.args)
54                .stdin(Stdio::piped())
55                .stdout(Stdio::piped())
56                .stderr(Stdio::piped())
57                .spawn()
58                .map_err(|e| {
59                    TransportError::ConnectionFailed(format!(
60                        "Failed to spawn '{}': {}",
61                        config.command, e
62                    ))
63                })?;
64
65            let stdin = child.stdin.take().expect("stdin piped");
66            let stdout = child.stdout.take().expect("stdout piped");
67
68            let (stdin_tx, mut stdin_rx) = mpsc::unbounded_channel::<String>();
69
70            // Register as worker
71            let reg_msg = Message::WorkerRegister {
72                registration: super::WorkerRegistration {
73                    worker_id: worker_id.clone(),
74                    supported_tasks: config.supported_tasks.clone(),
75                    max_concurrency: 1,
76                    language: super::WorkerLanguage::Other("stdio".to_string()),
77                    tags: None,
78                },
79            };
80            on_message(worker_id.clone(), reg_msg);
81
82            // Stdin writer task
83            let wid = worker_id.clone();
84            tokio::spawn(async move {
85                let mut stdin = stdin;
86                while let Some(line) = stdin_rx.recv().await {
87                    if stdin.write_all(line.as_bytes()).await.is_err() {
88                        tracing::error!(worker_id = %wid, "Failed to write to stdin");
89                        break;
90                    }
91                    if stdin.write_all(b"\n").await.is_err() {
92                        break;
93                    }
94                    let _ = stdin.flush().await;
95                }
96            });
97
98            // Stdout reader task
99            let wid = worker_id.clone();
100            tokio::spawn(async move {
101                let reader = BufReader::new(stdout);
102                let mut lines = reader.lines();
103
104                while let Ok(Some(line)) = lines.next_line().await {
105                    if line.trim().is_empty() {
106                        continue;
107                    }
108                    match serde_json::from_str::<Message>(&line) {
109                        Ok(msg) => on_message(wid.clone(), msg),
110                        Err(e) => {
111                            tracing::debug!(
112                                worker_id = %wid,
113                                line = %line,
114                                error = %e,
115                                "Non-JSON line from worker, ignoring"
116                            );
117                        }
118                    }
119                }
120                tracing::info!(worker_id = %wid, "Stdio worker stdout closed");
121            });
122
123            processes
124                .write()
125                .await
126                .insert(worker_id.clone(), StdioWorker { stdin_tx });
127
128            tracing::info!(
129                worker_id = %worker_id,
130                command = %config.command,
131                "Stdio worker spawned"
132            );
133        }
134
135        Ok(())
136    }
137
138    async fn stop(&self) -> Result<(), TransportError> {
139        let processes = self.processes.read().await;
140        for (worker_id, worker) in processes.iter() {
141            let shutdown = Message::Shutdown { graceful: true };
142            let json = serde_json::to_string(&shutdown).unwrap_or_default();
143            let _ = worker.stdin_tx.send(json);
144            tracing::info!(worker_id = %worker_id, "Sent shutdown to stdio worker");
145        }
146        Ok(())
147    }
148
149    async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
150        let processes = self.processes.read().await;
151        let worker = processes
152            .get(worker_id)
153            .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
154
155        let json = serde_json::to_string(&message)
156            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
157
158        worker
159            .stdin_tx
160            .send(json)
161            .map_err(|e| TransportError::SendFailed(e.to_string()))
162    }
163
164    async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
165        let processes = self.processes.read().await;
166        let json = serde_json::to_string(&message)
167            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
168
169        for (_, worker) in processes.iter() {
170            let _ = worker.stdin_tx.send(json.clone());
171        }
172        Ok(())
173    }
174}