1use std::collections::HashMap;
21use std::fmt;
22use std::sync::atomic::{AtomicI64, Ordering};
23use std::sync::{Arc, Mutex as StdMutex};
24use std::time::Duration;
25
26use serde_json::Value;
27use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
28use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout};
29use tokio::sync::{mpsc, oneshot};
30use tokio::time::timeout;
31
32use crate::error::{AcpError, AcpResult};
33
34type NotificationHandler = Arc<dyn Fn(Value) -> anyhow::Result<()> + Send + Sync>;
39
40pub struct StdioTransport {
52 write_tx: mpsc::UnboundedSender<String>,
53 pending: Arc<StdMutex<HashMap<i64, oneshot::Sender<AcpResult<Value>>>>>,
54 request_counter: AtomicI64,
55 notification_handler: Arc<StdMutex<Option<NotificationHandler>>>,
56 child: StdMutex<Option<Child>>,
57 rpc_timeout: Duration,
58}
59
60impl StdioTransport {
61 pub fn from_child(
66 child: Child,
67 stdin: ChildStdin,
68 stdout: ChildStdout,
69 stderr: ChildStderr,
70 rpc_timeout: Duration,
71 ) -> Self {
72 let (write_tx, write_rx) = mpsc::unbounded_channel();
73 let pending = Arc::new(StdMutex::new(HashMap::new()));
74 let notification_handler = Arc::new(StdMutex::new(None));
75
76 spawn_writer(write_rx, stdin);
77 spawn_stderr_logger(stderr);
78 spawn_reader(
79 stdout,
80 Arc::clone(&pending),
81 Arc::clone(¬ification_handler),
82 );
83
84 Self {
85 write_tx,
86 pending,
87 request_counter: AtomicI64::new(1),
88 notification_handler,
89 child: StdMutex::new(Some(child)),
90 rpc_timeout,
91 }
92 }
93
94 pub fn new_for_testing(write_tx: mpsc::UnboundedSender<String>, rpc_timeout: Duration) -> Self {
99 Self {
100 write_tx,
101 pending: Arc::new(StdMutex::new(HashMap::new())),
102 request_counter: AtomicI64::new(1),
103 notification_handler: Arc::new(StdMutex::new(None)),
104 child: StdMutex::new(None),
105 rpc_timeout,
106 }
107 }
108
109 pub fn set_notification_handler(&self, handler: NotificationHandler) {
115 if let Ok(mut guard) = self.notification_handler.lock() {
116 *guard = Some(handler);
117 }
118 }
119
120 pub async fn call(&self, method: &str, params: Value) -> AcpResult<Value> {
130 let id = self.request_counter.fetch_add(1, Ordering::SeqCst);
131 let (tx, rx) = oneshot::channel();
132 self.pending
133 .lock()
134 .map_err(|_| AcpError::Internal("stdio transport pending mutex poisoned".into()))?
135 .insert(id, tx);
136
137 let payload = serde_json::json!({
138 "jsonrpc": "2.0",
139 "id": id,
140 "method": method,
141 "params": params,
142 });
143 if let Err(e) = self.send_raw(payload) {
144 self.pending.lock().ok().map(|mut g| g.remove(&id));
146 return Err(e);
147 }
148
149 timeout(self.rpc_timeout, rx)
150 .await
151 .map_err(|_| AcpError::Timeout(format!("{method} timed out")))?
152 .map_err(|_| AcpError::Internal(format!("{method} response channel closed")))
153 .and_then(|r| r)
154 }
155
156 pub fn notify(&self, method: &str, params: Value) -> AcpResult<()> {
162 let payload = serde_json::json!({
163 "jsonrpc": "2.0",
164 "method": method,
165 "params": params,
166 });
167 self.send_raw(payload)
168 }
169
170 pub fn respond(&self, id: i64, result: Value) -> AcpResult<()> {
179 let payload = serde_json::json!({
180 "jsonrpc": "2.0",
181 "id": id,
182 "result": result,
183 });
184 self.send_raw(payload)
185 }
186
187 pub fn respond_error(&self, id: i64, code: i32, message: impl Into<String>) -> AcpResult<()> {
193 let payload = serde_json::json!({
194 "jsonrpc": "2.0",
195 "id": id,
196 "error": {
197 "code": code,
198 "message": message.into(),
199 },
200 });
201 self.send_raw(payload)
202 }
203
204 fn send_raw(&self, payload: Value) -> AcpResult<()> {
205 let text = serde_json::to_string(&payload)?;
206 self.write_tx
207 .send(text)
208 .map_err(|_| AcpError::Internal("stdio transport writer channel closed".into()))
209 }
210}
211
212impl fmt::Debug for StdioTransport {
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 f.debug_struct("StdioTransport")
215 .field(
216 "request_counter",
217 &self.request_counter.load(Ordering::Relaxed),
218 )
219 .field("rpc_timeout", &self.rpc_timeout)
220 .finish_non_exhaustive()
221 }
222}
223
224impl Drop for StdioTransport {
225 fn drop(&mut self) {
226 if let Ok(mut child) = self.child.lock()
227 && let Some(child) = child.as_mut()
228 {
229 let _ = child.start_kill();
230 }
231 }
232}
233
234fn spawn_writer(mut write_rx: mpsc::UnboundedReceiver<String>, mut stdin: ChildStdin) {
239 tokio::spawn(async move {
240 while let Some(payload) = write_rx.recv().await {
241 if stdin.write_all(payload.as_bytes()).await.is_err()
242 || stdin.write_all(b"\n").await.is_err()
243 || stdin.flush().await.is_err()
244 {
245 tracing::warn!(
246 target: "vtcode.stdio_transport",
247 "stdin write failed; writer task exiting"
248 );
249 break;
250 }
251 }
252 });
253}
254
255fn spawn_stderr_logger(stderr: ChildStderr) {
256 tokio::spawn(async move {
257 let mut reader = BufReader::new(stderr);
258 let mut line = String::new();
259 loop {
260 line.clear();
261 match reader.read_line(&mut line).await {
262 Ok(0) | Err(_) => break,
263 Ok(_) => {
264 tracing::debug!(target: "vtcode.stdio_transport.stderr", "{}", line.trim_end())
265 }
266 }
267 }
268 });
269}
270
271fn spawn_reader(
272 stdout: ChildStdout,
273 pending: Arc<StdMutex<HashMap<i64, oneshot::Sender<AcpResult<Value>>>>>,
274 notification_handler: Arc<StdMutex<Option<NotificationHandler>>>,
275) {
276 tokio::spawn(async move {
277 let mut reader = BufReader::new(stdout).lines();
278 while let Ok(Some(line)) = reader.next_line().await {
279 if line.trim().is_empty() {
280 continue;
281 }
282 let message: Value = match serde_json::from_str(&line) {
283 Ok(v) => v,
284 Err(e) => {
285 tracing::warn!("stdio transport: JSON decode failed: {e}");
286 continue;
287 }
288 };
289
290 if let Some(id) = response_id(&message) {
293 let result = extract_rpc_result(&message);
294 let tx = pending.lock().ok().and_then(|mut g| g.remove(&id));
295 if let Some(tx) = tx {
296 let _ = tx.send(result);
297 }
298 continue;
299 }
300
301 if let Some(handler) = notification_handler
304 .lock()
305 .ok()
306 .and_then(|g| g.as_ref().cloned())
307 && let Err(e) = handler(message)
308 {
309 tracing::warn!("stdio transport: notification handler error: {e}");
310 }
311 }
312 });
313}
314
315fn response_id(message: &Value) -> Option<i64> {
321 if message.get("result").is_some() || message.get("error").is_some() {
322 message.get("id").and_then(Value::as_i64)
323 } else {
324 None
325 }
326}
327
328fn extract_rpc_result(message: &Value) -> AcpResult<Value> {
329 if let Some(error) = message.get("error") {
330 let code = error
331 .get("code")
332 .and_then(Value::as_i64)
333 .unwrap_or_default();
334 let detail = error
335 .get("message")
336 .and_then(Value::as_str)
337 .unwrap_or("unknown error");
338 Err(AcpError::RemoteError {
339 agent_id: "stdio".into(),
340 message: format!("rpc error {code}: {detail}"),
341 code: Some(code as i32),
342 })
343 } else {
344 Ok(message.get("result").cloned().unwrap_or(Value::Null))
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn response_id_requires_result_or_error() {
354 assert!(
356 response_id(&serde_json::json!({
357 "jsonrpc": "2.0",
358 "method": "some/notification",
359 "params": {}
360 }))
361 .is_none()
362 );
363
364 assert!(
366 response_id(&serde_json::json!({
367 "jsonrpc": "2.0",
368 "id": 7,
369 "method": "permission.request",
370 "params": {}
371 }))
372 .is_none()
373 );
374
375 assert_eq!(
377 response_id(&serde_json::json!({
378 "jsonrpc": "2.0",
379 "id": 3,
380 "result": { "ok": true }
381 })),
382 Some(3)
383 );
384
385 assert_eq!(
387 response_id(&serde_json::json!({
388 "jsonrpc": "2.0",
389 "id": 5,
390 "error": { "code": -32601, "message": "method not found" }
391 })),
392 Some(5)
393 );
394 }
395
396 #[test]
397 fn extract_rpc_result_propagates_error() {
398 let result = extract_rpc_result(&serde_json::json!({
399 "jsonrpc": "2.0",
400 "id": 1,
401 "error": { "code": -32600, "message": "invalid request" }
402 }));
403 assert!(result.is_err());
404 let err = result.unwrap_err().to_string();
405 assert!(err.contains("invalid request"));
406 }
407
408 #[test]
409 fn extract_rpc_result_returns_result_value() {
410 let result = extract_rpc_result(&serde_json::json!({
411 "jsonrpc": "2.0",
412 "id": 1,
413 "result": { "sessionId": "abc" }
414 }))
415 .unwrap();
416 assert_eq!(result["sessionId"], "abc");
417 }
418
419 #[test]
420 fn notify_serialises_payload_to_write_channel() {
421 let (tx, mut rx) = mpsc::unbounded_channel();
422 let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
423
424 transport
425 .notify("session/cancel", serde_json::json!({ "sessionId": "s1" }))
426 .unwrap();
427
428 let raw = rx.try_recv().expect("notification payload");
429 let payload: Value = serde_json::from_str(&raw).unwrap();
430 assert_eq!(payload["method"], "session/cancel");
431 assert_eq!(payload["params"]["sessionId"], "s1");
432 assert!(
433 payload.get("id").is_none(),
434 "notifications must not have id"
435 );
436 }
437
438 #[test]
439 fn respond_writes_jsonrpc_result() {
440 let (tx, mut rx) = mpsc::unbounded_channel();
441 let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
442
443 transport
444 .respond(42, serde_json::json!({ "ok": true }))
445 .unwrap();
446
447 let raw = rx.try_recv().unwrap();
448 let payload: Value = serde_json::from_str(&raw).unwrap();
449 assert_eq!(payload["jsonrpc"], "2.0");
450 assert_eq!(payload["id"], 42);
451 assert_eq!(payload["result"]["ok"], true);
452 }
453
454 #[test]
455 fn respond_error_writes_jsonrpc_error() {
456 let (tx, mut rx) = mpsc::unbounded_channel();
457 let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
458
459 transport
460 .respond_error(9, -32601, "method not found")
461 .unwrap();
462
463 let raw = rx.try_recv().unwrap();
464 let payload: Value = serde_json::from_str(&raw).unwrap();
465 assert_eq!(payload["id"], 9);
466 assert_eq!(payload["error"]["code"], -32601);
467 assert_eq!(payload["error"]["message"], "method not found");
468 }
469}