Skip to main content

tau_agent_base/
plugin_service.rs

1//! Plugin RPC services for the myelin-based plugin transport.
2//!
3//! This module is the *successor* to `plugin_protocol.rs`: where that module
4//! describes a hand-rolled JSON-lines protocol with two enums
5//! (`PluginRequest` / `PluginMessage`) routed by string-keyed tunnels, this
6//! module defines the same wire as **two myelin services** running over a
7//! single [`DuplexStreamTransport`](myelin::stream::DuplexStreamTransport)
8//! per plugin subprocess.
9//!
10//! Both peers — server (`tau-agent-lib`) and plugin subprocess
11//! (`tau-agent-plugin-worker` / `tau-agent-plugin-tasks`) — *call* and
12//! *serve* methods. Direction is encoded in the API id.
13//!
14//! ## Wire layout
15//!
16//! Each framed payload is `[u8 kind][u16 api_id LE][u8 slot_id][CBOR bytes]`,
17//! length-prefixed. The two API ids in use are [`PLUGIN_API_ID`] (server
18//! calls plugin) and [`PLUGIN_CALLBACK_API_ID`] (plugin calls server) —
19//! both emitted by the `#[myelin::service]` attribute on the trait
20//! definitions below.
21//!
22//! ## Codec
23//!
24//! [`CborCodec`](myelin::stream::CborCodec). Self-describing, debuggable
25//! with `xxd`/CBOR pretty-printers. Postcard would be smaller but loses
26//! debuggability.
27//!
28//! ## Status
29//!
30//! The trait definitions and value types live here; the actual transport
31//! plumbing in `tau-agent-lib::plugin` and the per-plugin executor crates
32//! still uses the JSON-lines protocol from `plugin_protocol.rs`. The
33//! migration replaces one binary at a time. See task #759.
34
35// Myelin services use plain `async fn` in trait definitions — not boxed
36// futures — so we tolerate the auto-trait-bound lint the same way myelin
37// itself does.
38#![allow(async_fn_in_trait)]
39
40use serde::{Deserialize, Serialize};
41
42use crate::plugin_protocol::{HookResult, PluginRegistration, PluginToolResult};
43use crate::protocol::{Request, Response};
44
45// ---------------------------------------------------------------------------
46// API ids
47// ---------------------------------------------------------------------------
48//
49// Both `#[myelin::service]` attributes below pin their `api_id` explicitly
50// (rather than letting the macro derive an FNV hash of the trait name). This
51// makes the wire stable across trait renames. The attribute emits public
52// `PLUGIN_API_ID` and `PLUGIN_CALLBACK_API_ID` constants — use those.
53
54// ---------------------------------------------------------------------------
55// Value types — request payloads grouped into structs so the macro-generated
56// enum variants stay tidy. (`#[myelin::service]` would otherwise turn each
57// trait method's parameter list into the variant fields, which is fine for
58// trivial methods but noisy for the bigger ones.)
59// ---------------------------------------------------------------------------
60
61/// Context passed at plugin initialization (`init`) and on every
62/// `session_start`.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SessionCtx {
65    /// Working directory for the session.
66    pub cwd: String,
67    /// Session id.
68    pub session_id: String,
69    /// Project name for this session, if any.
70    #[serde(skip_serializing_if = "Option::is_none", default)]
71    pub project_name: Option<String>,
72}
73
74/// A tool call dispatch from the server to the plugin.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ToolCallReq {
77    /// Unique id for this tool call (used to correlate cancel/output_delta).
78    pub tool_call_id: String,
79    /// Tool name (must match a name from the plugin's registration).
80    pub name: String,
81    /// JSON-encoded arguments.
82    pub arguments: serde_json::Value,
83    /// Working directory for tool execution.
84    #[serde(skip_serializing_if = "Option::is_none", default)]
85    pub cwd: Option<String>,
86    /// Session id this tool call belongs to.
87    #[serde(skip_serializing_if = "Option::is_none", default)]
88    pub session_id: Option<String>,
89    /// Project name for this session.
90    #[serde(skip_serializing_if = "Option::is_none", default)]
91    pub project_name: Option<String>,
92}
93
94/// A hook invocation.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct HookReq {
97    /// Hook name (e.g. `"before_llm_turn"`, `"after_tool_result"`).
98    pub name: String,
99    /// Hook-specific payload.
100    pub data: serde_json::Value,
101}
102
103// ---------------------------------------------------------------------------
104// Service traits
105// ---------------------------------------------------------------------------
106
107/// Methods the **server** calls on the **plugin**.
108///
109/// The plugin process serves this trait; the server holds a generated
110/// `PluginClient` over a `DuplexClientHalf<_, _, …, PluginRequest,
111/// PluginResponse>` bound to [`PLUGIN_API_ID`].
112///
113/// Concurrent in-flight requests are first-class: `tool_call` and
114/// `cancel_tool_call` may overlap (in fact `cancel_tool_call` is *only*
115/// useful while a `tool_call` is pending — its purpose is to abort it).
116/// Myelin's `MuxedSlots` handles the slot routing for free; the plugin
117/// must serve them concurrently (e.g. spawn tool_call as a background task
118/// so the dispatch loop can pick up the cancel).
119#[myelin::service(api_id = 0x0001)]
120pub trait PluginService {
121    /// Initialise the plugin with session context. Sent once after the
122    /// plugin has called `register` on the server side.
123    async fn init(&self, ctx: SessionCtx);
124
125    /// Execute a hook (e.g. `before_llm_turn`).
126    async fn hook(&self, req: HookReq) -> HookResult;
127
128    /// Execute a tool call. Returns the final result; intermediate output
129    /// is reported via `output_delta` on the [`PluginCallbackService`]
130    /// service.
131    async fn tool_call(&self, call: ToolCallReq) -> PluginToolResult;
132
133    /// Abort a tool call by id. The plugin should kill any associated
134    /// subprocess and return a normal `tool_call` response with
135    /// cancellation noted. No-op if the call already completed.
136    async fn cancel_tool_call(&self, tool_call_id: String);
137
138    /// Notify the plugin that a (sub-)session is starting.
139    async fn session_start(&self, ctx: SessionCtx);
140
141    /// Notify the plugin it has been idle long enough to consider exiting.
142    async fn idle(&self);
143}
144
145/// Methods the **plugin** calls on the **server**.
146///
147/// The server serves this trait; the plugin holds a generated
148/// `PluginCallbackClient` over a `DuplexClientHalf<_, _, …,
149/// PluginCallbackRequest, PluginCallbackResponse>` bound to
150/// [`PLUGIN_CALLBACK_API_ID`].
151///
152/// `output_delta` deserves a note: the existing protocol fires it
153/// one-way (no reply expected) at up to ~200 events/sec during a long
154/// bash command. Myelin's RPC model is request/response, so we keep it
155/// as a `()`-returning RPC and *do* await it. This costs an extra
156/// response frame per delta but bounds in-flight deltas at the slot
157/// pool size, preserving the protocol's flow-control properties. A
158/// future myelin streaming feature could collapse this to a true
159/// fire-and-forget notification.
160#[myelin::service(api_id = 0x0002)]
161pub trait PluginCallbackService {
162    /// Plugin registration. Sent once at startup before serving any
163    /// `PluginCalledByServer` methods. The server waits for this call
164    /// before considering the plugin ready.
165    async fn register(&self, reg: PluginRegistration);
166
167    /// Forward a [`Request`] from the plugin to the server's main
168    /// request handler (the same one the TUI/CLI talks to).
169    async fn server_request(&self, req: Request) -> Response;
170
171    /// Streaming tool output delta. Plugin → server, fire-then-await.
172    /// See trait-level note on the `()` return.
173    async fn output_delta(&self, tool_call_id: String, text: String);
174}
175
176// ---------------------------------------------------------------------------
177// Default duplex transport configuration
178// ---------------------------------------------------------------------------
179
180/// Number of concurrent in-flight outgoing RPCs per direction.
181///
182/// 32 is comfortably above the worst-case fan-out we expect:
183/// - 1 in-flight `tool_call` per plugin (serialised by
184///   `PluginHandle::take_tool_plugin` today).
185/// - bursts of `output_delta` from a long-running tool.
186/// - a `cancel_tool_call` racing the in-flight `tool_call`.
187/// - the merge-worker thread inside the tasks plugin issuing a
188///   `server_request` concurrently with the main dispatch loop.
189pub const DUPLEX_SLOTS: usize = 32;
190
191/// Per-slot reply buffer, in bytes.
192///
193/// `PluginToolResult` payloads can be large (a long bash transcript may
194/// reach tens of KiB). 128 KiB leaves headroom while staying small enough
195/// that `MuxedSlots::new_boxed` keeps the entire slot table on the heap.
196pub const DUPLEX_BUF: usize = 131_072;
197
198/// Type alias for the duplex transport tau uses on every plugin pipe.
199///
200/// `R` / `W` are the plugin's reader/writer halves of the byte stream
201/// (typically `FuturesIoReader<Async<File>>` / `FuturesIoWriter<…>` for
202/// stdin/stdout pipes).
203pub type PluginDuplex<R, W> = myelin::stream::DuplexStreamTransport<
204    R,
205    W,
206    myelin::stream::LengthPrefixed,
207    myelin::stream::CborCodec,
208    DUPLEX_SLOTS,
209    DUPLEX_BUF,
210>;
211
212// ---------------------------------------------------------------------------
213// Tests — round-trip a single RPC over a UnixStream pair, with both halves
214// of both services running concurrently. This is the M2 smoke test from
215// task #759: it proves the macro expansion, codec, framer, transport, and
216// pump all wire together correctly on the runtime tau actually uses (smol).
217// ---------------------------------------------------------------------------
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use myelin::io::futures_io::{FuturesIoReader, FuturesIoWriter};
223    use myelin::transport::ServerTransport;
224    use smol::Async;
225    use std::os::unix::net::UnixStream;
226    use std::sync::Arc;
227    use std::sync::atomic::{AtomicU32, Ordering};
228
229    fn make_pair() -> (
230        FuturesIoReader<Async<UnixStream>>,
231        FuturesIoWriter<Async<UnixStream>>,
232        FuturesIoReader<Async<UnixStream>>,
233        FuturesIoWriter<Async<UnixStream>>,
234    ) {
235        let (sa, sb) = UnixStream::pair().expect("UnixStream::pair");
236        sa.set_nonblocking(true).expect("nonblocking sa");
237        sb.set_nonblocking(true).expect("nonblocking sb");
238
239        // Each side: clone the FD so we have independent reader+writer
240        // halves (the futures_io adapter needs distinct objects for read
241        // and write).
242        let sa_w = sa.try_clone().expect("clone sa");
243        sa_w.set_nonblocking(true).expect("nonblocking sa_w");
244        let sb_w = sb.try_clone().expect("clone sb");
245        sb_w.set_nonblocking(true).expect("nonblocking sb_w");
246
247        let sa_r = Async::new(sa).expect("Async sa");
248        let sa_w = Async::new(sa_w).expect("Async sa_w");
249        let sb_r = Async::new(sb).expect("Async sb");
250        let sb_w = Async::new(sb_w).expect("Async sb_w");
251
252        (
253            FuturesIoReader::new(sa_r),
254            FuturesIoWriter::new(sa_w),
255            FuturesIoReader::new(sb_r),
256            FuturesIoWriter::new(sb_w),
257        )
258    }
259
260    /// Trivial server impls that count calls; we only need to prove the
261    /// generated dispatch + transport plumbing wires up correctly.
262    struct PluginSide {
263        init_calls: Arc<AtomicU32>,
264        idle_calls: Arc<AtomicU32>,
265    }
266
267    impl PluginService for PluginSide {
268        async fn init(&self, _ctx: SessionCtx) {
269            self.init_calls.fetch_add(1, Ordering::SeqCst);
270        }
271        async fn hook(&self, _req: HookReq) -> HookResult {
272            HookResult::default()
273        }
274        async fn tool_call(&self, call: ToolCallReq) -> PluginToolResult {
275            PluginToolResult {
276                tool_call_id: call.tool_call_id,
277                content: vec![],
278                is_error: false,
279                summary: None,
280                post_persist_actions: vec![],
281            }
282        }
283        async fn cancel_tool_call(&self, _id: String) {}
284        async fn session_start(&self, _ctx: SessionCtx) {}
285        async fn idle(&self) {
286            self.idle_calls.fetch_add(1, Ordering::SeqCst);
287        }
288    }
289
290    struct ServerSide {
291        register_calls: Arc<AtomicU32>,
292    }
293
294    impl PluginCallbackService for ServerSide {
295        async fn register(&self, _reg: PluginRegistration) {
296            self.register_calls.fetch_add(1, Ordering::SeqCst);
297        }
298        async fn server_request(&self, _req: Request) -> Response {
299            // Any Response variant works for the round-trip; pick a small one.
300            Response::Ok
301        }
302        async fn output_delta(&self, _id: String, _text: String) {}
303    }
304
305    #[test]
306    fn duplex_round_trips_both_directions() {
307        // Build the duplex transport on each side with the production
308        // `PluginDuplex` type alias — this is the exact stack tau will
309        // use in M3+.
310        let (r_srv, w_srv, r_plg, w_plg) = make_pair();
311        let dx_srv: PluginDuplex<_, _> = PluginDuplex::new(r_srv, w_srv);
312        let dx_plg: PluginDuplex<_, _> = PluginDuplex::new(r_plg, w_plg);
313
314        // Server: serves PLUGIN_CALLBACK_API_ID (so the plugin can call into us),
315        // calls into PLUGIN_API_ID.
316        let srv_server = dx_srv
317            .server_half::<PluginCallbackRequest, PluginCallbackResponse>(PLUGIN_CALLBACK_API_ID);
318        let srv_client = dx_srv.client_half::<PluginRequest, PluginResponse>(PLUGIN_API_ID);
319
320        // Plugin: dual.
321        let plg_server = dx_plg.server_half::<PluginRequest, PluginResponse>(PLUGIN_API_ID);
322        let plg_client = dx_plg
323            .client_half::<PluginCallbackRequest, PluginCallbackResponse>(PLUGIN_CALLBACK_API_ID);
324
325        let (pump_srv, _h_srv) = dx_srv.split();
326        let (pump_plg, _h_plg) = dx_plg.split();
327
328        let plugin_init_count = Arc::new(AtomicU32::new(0));
329        let plugin_idle_count = Arc::new(AtomicU32::new(0));
330        let server_register_count = Arc::new(AtomicU32::new(0));
331
332        let plg_impl = PluginSide {
333            init_calls: plugin_init_count.clone(),
334            idle_calls: plugin_idle_count.clone(),
335        };
336        let srv_impl = ServerSide {
337            register_calls: server_register_count.clone(),
338        };
339
340        smol::block_on(async {
341            // Drive a single round trip in each direction.
342            let mut plg_server = plg_server;
343            let mut srv_server = srv_server;
344
345            let plugin_dispatch = async move {
346                // Serve two `init` calls and two `idle` calls.
347                for _ in 0..4 {
348                    let (req, token) = plg_server.recv().await.expect("plugin recv");
349                    let resp = plugin_dispatch(&plg_impl, req).await;
350                    plg_server.reply(token, resp).await.expect("plugin reply");
351                }
352            };
353            let server_dispatch = async move {
354                // Serve one `register` call.
355                let (req, token) = srv_server.recv().await.expect("server recv");
356                let resp = plugin_callback_dispatch(&srv_impl, req).await;
357                srv_server.reply(token, resp).await.expect("server reply");
358            };
359
360            let work = async {
361                // Plugin → server: register. The generated client method
362                // returns Result<(), TransportError>; in this loopback
363                // scenario the transport never fails so we just await it.
364                let plg_client = PluginCallbackClient::new(plg_client);
365                let _ = plg_client
366                    .register(PluginRegistration {
367                        name: "smoke".into(),
368                        tools: vec![],
369                        hooks: vec![],
370                        commands: vec![],
371                    })
372                    .await;
373
374                // Server → plugin: two inits and two idles, concurrently.
375                let srv_client = PluginClient::new(srv_client);
376                let i1 = srv_client.init(SessionCtx {
377                    cwd: "/tmp".into(),
378                    session_id: "s1".into(),
379                    project_name: None,
380                });
381                let i2 = srv_client.init(SessionCtx {
382                    cwd: "/tmp".into(),
383                    session_id: "s2".into(),
384                    project_name: None,
385                });
386                let id1 = srv_client.idle();
387                let id2 = srv_client.idle();
388                let ((_a, _b), (_c, _d)) = futures_lite::future::zip(
389                    futures_lite::future::zip(i1, i2),
390                    futures_lite::future::zip(id1, id2),
391                )
392                .await;
393            };
394
395            // Pumps never complete normally; `or` returns when work does.
396            futures_lite::future::or(
397                async {
398                    futures_lite::future::zip(
399                        work,
400                        futures_lite::future::zip(plugin_dispatch, server_dispatch),
401                    )
402                    .await;
403                },
404                async {
405                    let _ = futures_lite::future::zip(pump_srv.run(), pump_plg.run()).await;
406                },
407            )
408            .await;
409        });
410
411        assert_eq!(plugin_init_count.load(Ordering::SeqCst), 2);
412        assert_eq!(plugin_idle_count.load(Ordering::SeqCst), 2);
413        assert_eq!(server_register_count.load(Ordering::SeqCst), 1);
414    }
415}