1use std::fmt;
21use std::sync::atomic::{AtomicI64, Ordering};
22use std::sync::{Arc, Mutex as StdMutex};
23
24use hashbrown::HashMap;
25use std::time::Duration;
26
27use serde_json::Value;
28use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
29use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout};
30use tokio::sync::{mpsc, oneshot};
31use tokio::time::timeout;
32
33use crate::error::{AcpError, AcpResult};
34
35type NotificationHandler = Arc<dyn Fn(Value) -> anyhow::Result<()> + Send + Sync>;
40
41#[derive(Debug, Clone, Copy)]
42pub struct StdioTransportOptions {
43 pub include_jsonrpc_version: bool,
44}
45
46impl Default for StdioTransportOptions {
47 fn default() -> Self {
48 Self {
49 include_jsonrpc_version: true,
50 }
51 }
52}
53
54pub struct StdioTransport {
66 write_tx: mpsc::UnboundedSender<String>,
67 pending: Arc<StdMutex<HashMap<String, oneshot::Sender<AcpResult<Value>>>>>,
68 request_counter: AtomicI64,
69 notification_handler: Arc<StdMutex<Option<NotificationHandler>>>,
70 child: StdMutex<Option<Child>>,
71 rpc_timeout: Duration,
72 options: StdioTransportOptions,
73}
74
75impl StdioTransport {
76 pub fn from_child(
81 child: Child,
82 stdin: ChildStdin,
83 stdout: ChildStdout,
84 stderr: ChildStderr,
85 rpc_timeout: Duration,
86 ) -> Self {
87 Self::from_child_with_options(
88 child,
89 stdin,
90 stdout,
91 stderr,
92 rpc_timeout,
93 StdioTransportOptions::default(),
94 )
95 }
96
97 pub fn from_child_with_options(
98 child: Child,
99 stdin: ChildStdin,
100 stdout: ChildStdout,
101 stderr: ChildStderr,
102 rpc_timeout: Duration,
103 options: StdioTransportOptions,
104 ) -> Self {
105 let (write_tx, write_rx) = mpsc::unbounded_channel();
106 let pending = Arc::new(StdMutex::new(HashMap::new()));
107 let notification_handler = Arc::new(StdMutex::new(None));
108
109 spawn_writer(write_rx, stdin);
110 spawn_stderr_logger(stderr);
111 spawn_reader(
112 stdout,
113 Arc::clone(&pending),
114 Arc::clone(¬ification_handler),
115 );
116
117 Self {
118 write_tx,
119 pending,
120 request_counter: AtomicI64::new(1),
121 notification_handler,
122 child: StdMutex::new(Some(child)),
123 rpc_timeout,
124 options,
125 }
126 }
127
128 #[cfg(test)]
133 pub fn new_for_testing(write_tx: mpsc::UnboundedSender<String>, rpc_timeout: Duration) -> Self {
134 Self::new_for_testing_with_options(write_tx, rpc_timeout, StdioTransportOptions::default())
135 }
136
137 #[cfg(test)]
138 pub fn new_for_testing_with_options(
139 write_tx: mpsc::UnboundedSender<String>,
140 rpc_timeout: Duration,
141 options: StdioTransportOptions,
142 ) -> Self {
143 Self {
144 write_tx,
145 pending: Arc::new(StdMutex::new(HashMap::new())),
146 request_counter: AtomicI64::new(1),
147 notification_handler: Arc::new(StdMutex::new(None)),
148 child: StdMutex::new(None),
149 rpc_timeout,
150 options,
151 }
152 }
153
154 pub fn set_notification_handler(&self, handler: NotificationHandler) {
160 if let Ok(mut guard) = self.notification_handler.lock() {
161 *guard = Some(handler);
162 }
163 }
164
165 pub async fn call(&self, method: &str, params: Value) -> AcpResult<Value> {
175 let id = self.request_counter.fetch_add(1, Ordering::Relaxed);
176 let id_value = Value::from(id);
177 let pending_key = response_id_key(&id_value);
178 let (tx, rx) = oneshot::channel();
179 self.pending
180 .lock()
181 .map_err(|_err| AcpError::Internal("stdio transport pending mutex poisoned".into()))?
182 .insert(pending_key.clone(), tx);
183
184 let mut payload = serde_json::json!({
185 "jsonrpc": "2.0",
186 "id": id,
187 "method": method,
188 "params": params,
189 });
190 maybe_strip_jsonrpc_field(&mut payload, self.options);
191 if let Err(e) = self.send_raw(payload) {
192 self.pending.lock().ok().map(|mut g| g.remove(&pending_key));
194 return Err(e);
195 }
196
197 timeout(self.rpc_timeout, rx)
198 .await
199 .map_err(|_err| AcpError::Timeout(format!("{method} timed out")))?
200 .map_err(|_err| AcpError::Internal(format!("{method} response channel closed")))
201 .and_then(|r| r)
202 }
203
204 pub fn notify(&self, method: &str, params: Value) -> AcpResult<()> {
210 let mut payload = serde_json::json!({
211 "jsonrpc": "2.0",
212 "method": method,
213 "params": params,
214 });
215 maybe_strip_jsonrpc_field(&mut payload, self.options);
216 self.send_raw(payload)
217 }
218
219 pub fn respond(&self, id: i64, result: Value) -> AcpResult<()> {
228 self.respond_value(Value::from(id), result)
229 }
230
231 pub fn respond_value(&self, id: Value, result: Value) -> AcpResult<()> {
232 let mut payload = serde_json::json!({
233 "jsonrpc": "2.0",
234 "id": id,
235 "result": result,
236 });
237 maybe_strip_jsonrpc_field(&mut payload, self.options);
238 self.send_raw(payload)
239 }
240
241 pub fn respond_error(&self, id: i64, code: i32, message: impl Into<String>) -> AcpResult<()> {
247 self.respond_error_value(Value::from(id), code, message)
248 }
249
250 pub fn respond_error_value(
251 &self,
252 id: Value,
253 code: i32,
254 message: impl Into<String>,
255 ) -> AcpResult<()> {
256 let mut payload = serde_json::json!({
257 "jsonrpc": "2.0",
258 "id": id,
259 "error": {
260 "code": code,
261 "message": message.into(),
262 },
263 });
264 maybe_strip_jsonrpc_field(&mut payload, self.options);
265 self.send_raw(payload)
266 }
267
268 fn send_raw(&self, payload: Value) -> AcpResult<()> {
269 let text = serde_json::to_string(&payload)?;
270 self.write_tx
271 .send(text)
272 .map_err(|_err| AcpError::Internal("stdio transport writer channel closed".into()))
273 }
274}
275
276impl fmt::Debug for StdioTransport {
277 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278 f.debug_struct("StdioTransport")
279 .field(
280 "request_counter",
281 &self.request_counter.load(Ordering::Relaxed),
282 )
283 .field("rpc_timeout", &self.rpc_timeout)
284 .finish_non_exhaustive()
285 }
286}
287
288impl Drop for StdioTransport {
289 fn drop(&mut self) {
290 if let Ok(mut child) = self.child.lock()
291 && let Some(child) = child.as_mut()
292 {
293 let _ = child.start_kill();
294 }
295 }
296}
297
298fn spawn_writer(mut write_rx: mpsc::UnboundedReceiver<String>, mut stdin: ChildStdin) {
303 tokio::spawn(async move {
304 while let Some(payload) = write_rx.recv().await {
305 if stdin.write_all(payload.as_bytes()).await.is_err()
306 || stdin.write_all(b"\n").await.is_err()
307 || stdin.flush().await.is_err()
308 {
309 tracing::warn!(
310 target: "vtcode.stdio_transport",
311 "stdin write failed; writer task exiting"
312 );
313 break;
314 }
315 }
316 });
317}
318
319fn spawn_stderr_logger(stderr: ChildStderr) {
320 tokio::spawn(async move {
321 let mut reader = BufReader::new(stderr);
322 let mut line = String::new();
323 loop {
324 line.clear();
325 match reader.read_line(&mut line).await {
326 Ok(0) | Err(_) => break,
327 Ok(_) => {
328 tracing::debug!(target: "vtcode.stdio_transport.stderr", "{}", line.trim_end())
329 }
330 }
331 }
332 });
333}
334
335fn spawn_reader(
336 stdout: ChildStdout,
337 pending: Arc<StdMutex<HashMap<String, oneshot::Sender<AcpResult<Value>>>>>,
338 notification_handler: Arc<StdMutex<Option<NotificationHandler>>>,
339) {
340 tokio::spawn(async move {
341 let mut reader = BufReader::new(stdout).lines();
342 while let Ok(Some(line)) = reader.next_line().await {
343 if line.trim().is_empty() {
344 continue;
345 }
346 let message: Value = match serde_json::from_str(&line) {
347 Ok(v) => v,
348 Err(e) => {
349 tracing::warn!("stdio transport: JSON decode failed: {e}");
350 continue;
351 }
352 };
353
354 if let Some(id) = response_id(&message) {
357 let result = extract_rpc_result(&message);
358 let tx = pending
359 .lock()
360 .ok()
361 .and_then(|mut g| g.remove(&response_id_key(&id)));
362 if let Some(tx) = tx {
363 let _ = tx.send(result);
364 }
365 continue;
366 }
367
368 if let Some(handler) = notification_handler
371 .lock()
372 .ok()
373 .and_then(|g| g.as_ref().cloned())
374 && let Err(e) = handler(message)
375 {
376 tracing::warn!("stdio transport: notification handler error: {e}");
377 }
378 }
379 });
380}
381
382fn response_id(message: &Value) -> Option<Value> {
388 if message.get("result").is_some() || message.get("error").is_some() {
389 message.get("id").cloned()
390 } else {
391 None
392 }
393}
394
395fn response_id_key(id: &Value) -> String {
396 serde_json::to_string(id).unwrap_or_else(|_| "null".to_string())
397}
398
399fn maybe_strip_jsonrpc_field(payload: &mut Value, options: StdioTransportOptions) {
400 if options.include_jsonrpc_version {
401 return;
402 }
403
404 if let Some(object) = payload.as_object_mut() {
405 object.remove("jsonrpc");
406 }
407}
408
409fn extract_rpc_result(message: &Value) -> AcpResult<Value> {
410 if let Some(error) = message.get("error") {
411 let code = error
412 .get("code")
413 .and_then(Value::as_i64)
414 .unwrap_or_default();
415 let detail = error
416 .get("message")
417 .and_then(Value::as_str)
418 .unwrap_or("unknown error");
419 Err(AcpError::RemoteError {
420 agent_id: "stdio".into(),
421 message: format!("rpc error {code}: {detail}"),
422 code: Some(code as i32),
423 })
424 } else {
425 Ok(message.get("result").cloned().unwrap_or(Value::Null))
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn response_id_requires_result_or_error() {
435 assert!(
437 response_id(&serde_json::json!({
438 "jsonrpc": "2.0",
439 "method": "some/notification",
440 "params": {}
441 }))
442 .is_none()
443 );
444
445 assert!(
447 response_id(&serde_json::json!({
448 "jsonrpc": "2.0",
449 "id": 7,
450 "method": "permission.request",
451 "params": {}
452 }))
453 .is_none()
454 );
455
456 assert_eq!(
458 response_id(&serde_json::json!({
459 "jsonrpc": "2.0",
460 "id": 3,
461 "result": { "ok": true }
462 })),
463 Some(Value::from(3))
464 );
465
466 assert_eq!(
468 response_id(&serde_json::json!({
469 "jsonrpc": "2.0",
470 "id": 5,
471 "error": { "code": -32601, "message": "method not found" }
472 })),
473 Some(Value::from(5))
474 );
475 }
476
477 #[test]
478 fn extract_rpc_result_propagates_error() {
479 let result = extract_rpc_result(&serde_json::json!({
480 "jsonrpc": "2.0",
481 "id": 1,
482 "error": { "code": -32600, "message": "invalid request" }
483 }));
484 assert!(result.is_err());
485 let err = result.unwrap_err().to_string();
486 assert!(err.contains("invalid request"));
487 }
488
489 #[test]
490 fn extract_rpc_result_returns_result_value() {
491 let result = extract_rpc_result(&serde_json::json!({
492 "jsonrpc": "2.0",
493 "id": 1,
494 "result": { "sessionId": "abc" }
495 }))
496 .unwrap();
497 assert_eq!(result["sessionId"], "abc");
498 }
499
500 #[test]
501 fn notify_serialises_payload_to_write_channel() {
502 let (tx, mut rx) = mpsc::unbounded_channel();
503 let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
504
505 transport
506 .notify("session/cancel", serde_json::json!({ "sessionId": "s1" }))
507 .unwrap();
508
509 let raw = rx.try_recv().expect("notification payload");
510 let payload: Value = serde_json::from_str(&raw).unwrap();
511 assert_eq!(payload["method"], "session/cancel");
512 assert_eq!(payload["params"]["sessionId"], "s1");
513 assert!(
514 payload.get("id").is_none(),
515 "notifications must not have id"
516 );
517 }
518
519 #[test]
520 fn respond_writes_jsonrpc_result() {
521 let (tx, mut rx) = mpsc::unbounded_channel();
522 let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
523
524 transport
525 .respond(42, serde_json::json!({ "ok": true }))
526 .unwrap();
527
528 let raw = rx.try_recv().unwrap();
529 let payload: Value = serde_json::from_str(&raw).unwrap();
530 assert_eq!(payload["jsonrpc"], "2.0");
531 assert_eq!(payload["id"], 42);
532 assert_eq!(payload["result"]["ok"], true);
533 }
534
535 #[test]
536 fn respond_error_writes_jsonrpc_error() {
537 let (tx, mut rx) = mpsc::unbounded_channel();
538 let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
539
540 transport
541 .respond_error(9, -32601, "method not found")
542 .unwrap();
543
544 let raw = rx.try_recv().unwrap();
545 let payload: Value = serde_json::from_str(&raw).unwrap();
546 assert_eq!(payload["id"], 9);
547 assert_eq!(payload["error"]["code"], -32601);
548 assert_eq!(payload["error"]["message"], "method not found");
549 }
550
551 #[test]
552 fn respond_value_supports_string_ids() {
553 let (tx, mut rx) = mpsc::unbounded_channel();
554 let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
555
556 transport
557 .respond_value(
558 Value::String("request-1".to_string()),
559 serde_json::json!({ "ok": true }),
560 )
561 .unwrap();
562
563 let raw = rx.try_recv().unwrap();
564 let payload: Value = serde_json::from_str(&raw).unwrap();
565 assert_eq!(payload["id"], "request-1");
566 assert_eq!(payload["result"]["ok"], true);
567 }
568
569 #[test]
570 fn can_omit_jsonrpc_field_for_codex_mode() {
571 let (tx, mut rx) = mpsc::unbounded_channel();
572 let transport = StdioTransport::new_for_testing_with_options(
573 tx,
574 Duration::from_secs(5),
575 StdioTransportOptions {
576 include_jsonrpc_version: false,
577 },
578 );
579
580 transport
581 .notify("initialized", serde_json::json!({}))
582 .unwrap();
583
584 let raw = rx.try_recv().unwrap();
585 let payload: Value = serde_json::from_str(&raw).unwrap();
586 assert!(payload.get("jsonrpc").is_none());
587 assert_eq!(payload["method"], "initialized");
588 }
589}