Skip to main content

vtcode_acp_client/
transport.rs

1//! Generic JSON-RPC-over-stdio transport for subprocess agents.
2//!
3//! [`StdioTransport`] handles the low-level framing of newline-delimited JSON
4//! over a child process's stdin/stdout pair. It is intentionally protocol-agnostic:
5//! it knows nothing about Copilot, ACP sessions, or any other higher-level concept.
6//!
7//! ## Message routing
8//!
9//! The internal reader task inspects each incoming line and dispatches it as follows:
10//!
11//! - **Response** (has `result` or `error` field with a numeric `id`): looked up in the
12//!   pending table populated by [`StdioTransport::call`] and delivered to the waiting
13//!   caller via a [`tokio::sync::oneshot`] channel.
14//! - **Request / notification** (anything else): forwarded to the closure registered
15//!   via [`StdioTransport::set_notification_handler`].
16//!
17//! Stderr lines are forwarded to `tracing::debug!` under the
18//! `vtcode.stdio_transport.stderr` target.
19
20use std::collections::HashMap;
21use std::fmt;
22use std::sync::atomic::{AtomicI64, Ordering};
23use std::sync::{Arc, Mutex as StdMutex};
24use std::time::Duration;
25
26use serde_json::Value;
27use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
28use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout};
29use tokio::sync::{mpsc, oneshot};
30use tokio::time::timeout;
31
32use crate::error::{AcpError, AcpResult};
33
34/// Callback type for incoming server→client requests and notifications.
35///
36/// The handler receives the raw JSON-RPC message value. It should return
37/// `Ok(())` on success; errors are logged as warnings by the transport.
38type NotificationHandler = Arc<dyn Fn(Value) -> anyhow::Result<()> + Send + Sync>;
39
40/// Generic JSON-RPC-over-stdio transport for local subprocess agents.
41///
42/// Wraps a child process and provides:
43/// - [`call`](Self::call): send a request and await its response.
44/// - [`notify`](Self::notify): send a fire-and-forget notification.
45/// - [`respond`](Self::respond) / [`respond_error`](Self::respond_error): reply to
46///   incoming server-initiated requests.
47/// - [`set_notification_handler`](Self::set_notification_handler): register the handler
48///   that receives all incoming server→client messages.
49///
50/// The child process is killed when this struct is dropped.
51pub struct StdioTransport {
52    write_tx: mpsc::UnboundedSender<String>,
53    pending: Arc<StdMutex<HashMap<i64, oneshot::Sender<AcpResult<Value>>>>>,
54    request_counter: AtomicI64,
55    notification_handler: Arc<StdMutex<Option<NotificationHandler>>>,
56    child: StdMutex<Option<Child>>,
57    rpc_timeout: Duration,
58}
59
60impl StdioTransport {
61    /// Wire up transport from a spawned subprocess's stdin/stdout/stderr.
62    ///
63    /// Spawns background tasks for the writer (stdin), stderr logger, and the
64    /// reader (stdout) that dispatches JSON-RPC messages.
65    pub fn from_child(
66        child: Child,
67        stdin: ChildStdin,
68        stdout: ChildStdout,
69        stderr: ChildStderr,
70        rpc_timeout: Duration,
71    ) -> Self {
72        let (write_tx, write_rx) = mpsc::unbounded_channel();
73        let pending = Arc::new(StdMutex::new(HashMap::new()));
74        let notification_handler = Arc::new(StdMutex::new(None));
75
76        spawn_writer(write_rx, stdin);
77        spawn_stderr_logger(stderr);
78        spawn_reader(
79            stdout,
80            Arc::clone(&pending),
81            Arc::clone(&notification_handler),
82        );
83
84        Self {
85            write_tx,
86            pending,
87            request_counter: AtomicI64::new(1),
88            notification_handler,
89            child: StdMutex::new(Some(child)),
90            rpc_timeout,
91        }
92    }
93
94    /// Construct a transport with a pre-wired channel for unit tests.
95    ///
96    /// No subprocess is spawned and no background tasks are started. The caller
97    /// can drive the mock by reading from the paired receiver.
98    pub fn new_for_testing(write_tx: mpsc::UnboundedSender<String>, rpc_timeout: Duration) -> Self {
99        Self {
100            write_tx,
101            pending: Arc::new(StdMutex::new(HashMap::new())),
102            request_counter: AtomicI64::new(1),
103            notification_handler: Arc::new(StdMutex::new(None)),
104            child: StdMutex::new(None),
105            rpc_timeout,
106        }
107    }
108
109    /// Register a handler for incoming server→client requests and notifications.
110    ///
111    /// Must be called once after construction. Subsequent calls overwrite the
112    /// previous handler. The handler receives the raw JSON message value for
113    /// every incoming message that is **not** a response to a pending [`call`](Self::call).
114    pub fn set_notification_handler(&self, handler: NotificationHandler) {
115        if let Ok(mut guard) = self.notification_handler.lock() {
116            *guard = Some(handler);
117        }
118    }
119
120    /// Send a JSON-RPC request and wait for its response.
121    ///
122    /// Assigns a monotonically increasing `id`, inserts it into the pending
123    /// table, serialises the message, and awaits the reply up to `rpc_timeout`.
124    ///
125    /// # Errors
126    ///
127    /// Returns [`AcpError::Timeout`] if the peer does not reply in time, or
128    /// [`AcpError::Internal`] if the transport is shut down.
129    pub async fn call(&self, method: &str, params: Value) -> AcpResult<Value> {
130        let id = self.request_counter.fetch_add(1, Ordering::SeqCst);
131        let (tx, rx) = oneshot::channel();
132        self.pending
133            .lock()
134            .map_err(|_| AcpError::Internal("stdio transport pending mutex poisoned".into()))?
135            .insert(id, tx);
136
137        let payload = serde_json::json!({
138            "jsonrpc": "2.0",
139            "id": id,
140            "method": method,
141            "params": params,
142        });
143        if let Err(e) = self.send_raw(payload) {
144            // Clean up the pending entry so it doesn't linger until timeout.
145            self.pending.lock().ok().map(|mut g| g.remove(&id));
146            return Err(e);
147        }
148
149        timeout(self.rpc_timeout, rx)
150            .await
151            .map_err(|_| AcpError::Timeout(format!("{method} timed out")))?
152            .map_err(|_| AcpError::Internal(format!("{method} response channel closed")))
153            .and_then(|r| r)
154    }
155
156    /// Send a JSON-RPC notification (no response expected).
157    ///
158    /// # Errors
159    ///
160    /// Returns an error if serialisation fails or the writer task has shut down.
161    pub fn notify(&self, method: &str, params: Value) -> AcpResult<()> {
162        let payload = serde_json::json!({
163            "jsonrpc": "2.0",
164            "method": method,
165            "params": params,
166        });
167        self.send_raw(payload)
168    }
169
170    /// Send a JSON-RPC success response to an incoming server request.
171    ///
172    /// Use this to reply to messages received by the notification handler when
173    /// they carry an `id` field (i.e. they expect a response).
174    ///
175    /// # Errors
176    ///
177    /// Returns an error if serialisation fails or the writer task has shut down.
178    pub fn respond(&self, id: i64, result: Value) -> AcpResult<()> {
179        let payload = serde_json::json!({
180            "jsonrpc": "2.0",
181            "id": id,
182            "result": result,
183        });
184        self.send_raw(payload)
185    }
186
187    /// Send a JSON-RPC error response to an incoming server request.
188    ///
189    /// # Errors
190    ///
191    /// Returns an error if serialisation fails or the writer task has shut down.
192    pub fn respond_error(&self, id: i64, code: i32, message: impl Into<String>) -> AcpResult<()> {
193        let payload = serde_json::json!({
194            "jsonrpc": "2.0",
195            "id": id,
196            "error": {
197                "code": code,
198                "message": message.into(),
199            },
200        });
201        self.send_raw(payload)
202    }
203
204    fn send_raw(&self, payload: Value) -> AcpResult<()> {
205        let text = serde_json::to_string(&payload)?;
206        self.write_tx
207            .send(text)
208            .map_err(|_| AcpError::Internal("stdio transport writer channel closed".into()))
209    }
210}
211
212impl fmt::Debug for StdioTransport {
213    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214        f.debug_struct("StdioTransport")
215            .field(
216                "request_counter",
217                &self.request_counter.load(Ordering::Relaxed),
218            )
219            .field("rpc_timeout", &self.rpc_timeout)
220            .finish_non_exhaustive()
221    }
222}
223
224impl Drop for StdioTransport {
225    fn drop(&mut self) {
226        if let Ok(mut child) = self.child.lock()
227            && let Some(child) = child.as_mut()
228        {
229            let _ = child.start_kill();
230        }
231    }
232}
233
234// ============================================================================
235// Background tasks
236// ============================================================================
237
238fn spawn_writer(mut write_rx: mpsc::UnboundedReceiver<String>, mut stdin: ChildStdin) {
239    tokio::spawn(async move {
240        while let Some(payload) = write_rx.recv().await {
241            if stdin.write_all(payload.as_bytes()).await.is_err()
242                || stdin.write_all(b"\n").await.is_err()
243                || stdin.flush().await.is_err()
244            {
245                tracing::warn!(
246                    target: "vtcode.stdio_transport",
247                    "stdin write failed; writer task exiting"
248                );
249                break;
250            }
251        }
252    });
253}
254
255fn spawn_stderr_logger(stderr: ChildStderr) {
256    tokio::spawn(async move {
257        let mut reader = BufReader::new(stderr);
258        let mut line = String::new();
259        loop {
260            line.clear();
261            match reader.read_line(&mut line).await {
262                Ok(0) | Err(_) => break,
263                Ok(_) => {
264                    tracing::debug!(target: "vtcode.stdio_transport.stderr", "{}", line.trim_end())
265                }
266            }
267        }
268    });
269}
270
271fn spawn_reader(
272    stdout: ChildStdout,
273    pending: Arc<StdMutex<HashMap<i64, oneshot::Sender<AcpResult<Value>>>>>,
274    notification_handler: Arc<StdMutex<Option<NotificationHandler>>>,
275) {
276    tokio::spawn(async move {
277        let mut reader = BufReader::new(stdout).lines();
278        while let Ok(Some(line)) = reader.next_line().await {
279            if line.trim().is_empty() {
280                continue;
281            }
282            let message: Value = match serde_json::from_str(&line) {
283                Ok(v) => v,
284                Err(e) => {
285                    tracing::warn!("stdio transport: JSON decode failed: {e}");
286                    continue;
287                }
288            };
289
290            // Dispatch JSON-RPC responses to pending callers.
291            // Extract tx before releasing the lock so `tx.send` runs lock-free.
292            if let Some(id) = response_id(&message) {
293                let result = extract_rpc_result(&message);
294                let tx = pending.lock().ok().and_then(|mut g| g.remove(&id));
295                if let Some(tx) = tx {
296                    let _ = tx.send(result);
297                }
298                continue;
299            }
300
301            // Clone the handler Arc out of the lock so the lock is released
302            // before the handler runs (prevents re-entrancy / call-site latency).
303            if let Some(handler) = notification_handler
304                .lock()
305                .ok()
306                .and_then(|g| g.as_ref().cloned())
307                && let Err(e) = handler(message)
308            {
309                tracing::warn!("stdio transport: notification handler error: {e}");
310            }
311        }
312    });
313}
314
315// ============================================================================
316// Helpers
317// ============================================================================
318
319/// Returns the `id` if the message is a JSON-RPC *response* (has `result` or `error`).
320fn response_id(message: &Value) -> Option<i64> {
321    if message.get("result").is_some() || message.get("error").is_some() {
322        message.get("id").and_then(Value::as_i64)
323    } else {
324        None
325    }
326}
327
328fn extract_rpc_result(message: &Value) -> AcpResult<Value> {
329    if let Some(error) = message.get("error") {
330        let code = error
331            .get("code")
332            .and_then(Value::as_i64)
333            .unwrap_or_default();
334        let detail = error
335            .get("message")
336            .and_then(Value::as_str)
337            .unwrap_or("unknown error");
338        Err(AcpError::RemoteError {
339            agent_id: "stdio".into(),
340            message: format!("rpc error {code}: {detail}"),
341            code: Some(code as i32),
342        })
343    } else {
344        Ok(message.get("result").cloned().unwrap_or(Value::Null))
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn response_id_requires_result_or_error() {
354        // Pure notification: no result/error
355        assert!(
356            response_id(&serde_json::json!({
357                "jsonrpc": "2.0",
358                "method": "some/notification",
359                "params": {}
360            }))
361            .is_none()
362        );
363
364        // Server-initiated request with id but no result
365        assert!(
366            response_id(&serde_json::json!({
367                "jsonrpc": "2.0",
368                "id": 7,
369                "method": "permission.request",
370                "params": {}
371            }))
372            .is_none()
373        );
374
375        // Response has result
376        assert_eq!(
377            response_id(&serde_json::json!({
378                "jsonrpc": "2.0",
379                "id": 3,
380                "result": { "ok": true }
381            })),
382            Some(3)
383        );
384
385        // Error response
386        assert_eq!(
387            response_id(&serde_json::json!({
388                "jsonrpc": "2.0",
389                "id": 5,
390                "error": { "code": -32601, "message": "method not found" }
391            })),
392            Some(5)
393        );
394    }
395
396    #[test]
397    fn extract_rpc_result_propagates_error() {
398        let result = extract_rpc_result(&serde_json::json!({
399            "jsonrpc": "2.0",
400            "id": 1,
401            "error": { "code": -32600, "message": "invalid request" }
402        }));
403        assert!(result.is_err());
404        let err = result.unwrap_err().to_string();
405        assert!(err.contains("invalid request"));
406    }
407
408    #[test]
409    fn extract_rpc_result_returns_result_value() {
410        let result = extract_rpc_result(&serde_json::json!({
411            "jsonrpc": "2.0",
412            "id": 1,
413            "result": { "sessionId": "abc" }
414        }))
415        .unwrap();
416        assert_eq!(result["sessionId"], "abc");
417    }
418
419    #[test]
420    fn notify_serialises_payload_to_write_channel() {
421        let (tx, mut rx) = mpsc::unbounded_channel();
422        let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
423
424        transport
425            .notify("session/cancel", serde_json::json!({ "sessionId": "s1" }))
426            .unwrap();
427
428        let raw = rx.try_recv().expect("notification payload");
429        let payload: Value = serde_json::from_str(&raw).unwrap();
430        assert_eq!(payload["method"], "session/cancel");
431        assert_eq!(payload["params"]["sessionId"], "s1");
432        assert!(
433            payload.get("id").is_none(),
434            "notifications must not have id"
435        );
436    }
437
438    #[test]
439    fn respond_writes_jsonrpc_result() {
440        let (tx, mut rx) = mpsc::unbounded_channel();
441        let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
442
443        transport
444            .respond(42, serde_json::json!({ "ok": true }))
445            .unwrap();
446
447        let raw = rx.try_recv().unwrap();
448        let payload: Value = serde_json::from_str(&raw).unwrap();
449        assert_eq!(payload["jsonrpc"], "2.0");
450        assert_eq!(payload["id"], 42);
451        assert_eq!(payload["result"]["ok"], true);
452    }
453
454    #[test]
455    fn respond_error_writes_jsonrpc_error() {
456        let (tx, mut rx) = mpsc::unbounded_channel();
457        let transport = StdioTransport::new_for_testing(tx, Duration::from_secs(5));
458
459        transport
460            .respond_error(9, -32601, "method not found")
461            .unwrap();
462
463        let raw = rx.try_recv().unwrap();
464        let payload: Value = serde_json::from_str(&raw).unwrap();
465        assert_eq!(payload["id"], 9);
466        assert_eq!(payload["error"]["code"], -32601);
467        assert_eq!(payload["error"]["message"], "method not found");
468    }
469}