Skip to main content

rust_pipe/transport/
stdio.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::process::Stdio;
4use std::sync::Arc;
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::process::Command;
7use tokio::sync::{mpsc, RwLock};
8use tracing;
9
10use super::{Message, Transport, TransportError};
11
12/// Configuration for a stdio-based worker process.
13pub struct StdioProcess {
14    pub command: String,
15    pub args: Vec<String>,
16    pub worker_id: String,
17    pub supported_tasks: Vec<String>,
18}
19
20/// Transport that communicates with workers via stdin/stdout pipes.
21pub struct StdioTransport {
22    processes: Arc<RwLock<HashMap<String, StdioWorker>>>,
23    configs: Vec<StdioProcess>,
24    on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
25}
26
27struct StdioWorker {
28    stdin_tx: mpsc::UnboundedSender<String>,
29}
30
31impl StdioTransport {
32    pub fn new(
33        configs: Vec<StdioProcess>,
34        on_message: impl Fn(String, Message) + Send + Sync + 'static,
35    ) -> Self {
36        Self {
37            processes: Arc::new(RwLock::new(HashMap::new())),
38            configs,
39            on_message: Arc::new(on_message),
40        }
41    }
42}
43
44#[async_trait]
45impl Transport for StdioTransport {
46    async fn start(&self) -> Result<(), TransportError> {
47        for config in &self.configs {
48            let worker_id = config.worker_id.clone();
49            let on_message = self.on_message.clone();
50            let processes = self.processes.clone();
51
52            let mut child = Command::new(&config.command)
53                .args(&config.args)
54                .stdin(Stdio::piped())
55                .stdout(Stdio::piped())
56                .stderr(Stdio::piped())
57                .spawn()
58                .map_err(|e| {
59                    TransportError::ConnectionFailed(format!(
60                        "Failed to spawn '{}': {}",
61                        config.command, e
62                    ))
63                })?;
64
65            let stdin = child.stdin.take().expect("stdin piped");
66            let stdout = child.stdout.take().expect("stdout piped");
67
68            let (stdin_tx, mut stdin_rx) = mpsc::unbounded_channel::<String>();
69
70            // Register as worker
71            let reg_msg = Message::WorkerRegister {
72                registration: super::WorkerRegistration {
73                    worker_id: worker_id.clone(),
74                    supported_tasks: config.supported_tasks.clone(),
75                    max_concurrency: 1,
76                    language: super::WorkerLanguage::Other("stdio".to_string()),
77                },
78            };
79            on_message(worker_id.clone(), reg_msg);
80
81            // Stdin writer task
82            let wid = worker_id.clone();
83            tokio::spawn(async move {
84                let mut stdin = stdin;
85                while let Some(line) = stdin_rx.recv().await {
86                    if stdin.write_all(line.as_bytes()).await.is_err() {
87                        tracing::error!(worker_id = %wid, "Failed to write to stdin");
88                        break;
89                    }
90                    if stdin.write_all(b"\n").await.is_err() {
91                        break;
92                    }
93                    let _ = stdin.flush().await;
94                }
95            });
96
97            // Stdout reader task
98            let wid = worker_id.clone();
99            tokio::spawn(async move {
100                let reader = BufReader::new(stdout);
101                let mut lines = reader.lines();
102
103                while let Ok(Some(line)) = lines.next_line().await {
104                    if line.trim().is_empty() {
105                        continue;
106                    }
107                    match serde_json::from_str::<Message>(&line) {
108                        Ok(msg) => on_message(wid.clone(), msg),
109                        Err(e) => {
110                            tracing::debug!(
111                                worker_id = %wid,
112                                line = %line,
113                                error = %e,
114                                "Non-JSON line from worker, ignoring"
115                            );
116                        }
117                    }
118                }
119                tracing::info!(worker_id = %wid, "Stdio worker stdout closed");
120            });
121
122            processes
123                .write()
124                .await
125                .insert(worker_id.clone(), StdioWorker { stdin_tx });
126
127            tracing::info!(
128                worker_id = %worker_id,
129                command = %config.command,
130                "Stdio worker spawned"
131            );
132        }
133
134        Ok(())
135    }
136
137    async fn stop(&self) -> Result<(), TransportError> {
138        let processes = self.processes.read().await;
139        for (worker_id, worker) in processes.iter() {
140            let shutdown = Message::Shutdown { graceful: true };
141            let json = serde_json::to_string(&shutdown).unwrap_or_default();
142            let _ = worker.stdin_tx.send(json);
143            tracing::info!(worker_id = %worker_id, "Sent shutdown to stdio worker");
144        }
145        Ok(())
146    }
147
148    async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
149        let processes = self.processes.read().await;
150        let worker = processes
151            .get(worker_id)
152            .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
153
154        let json = serde_json::to_string(&message)
155            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
156
157        worker
158            .stdin_tx
159            .send(json)
160            .map_err(|e| TransportError::SendFailed(e.to_string()))
161    }
162
163    async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
164        let processes = self.processes.read().await;
165        let json = serde_json::to_string(&message)
166            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
167
168        for (_, worker) in processes.iter() {
169            let _ = worker.stdin_tx.send(json.clone());
170        }
171        Ok(())
172    }
173}