tau_agent_base/plugin_service.rs
1//! Plugin RPC services for the myelin-based plugin transport.
2//!
3//! This module is the *successor* to `plugin_protocol.rs`: where that module
4//! describes a hand-rolled JSON-lines protocol with two enums
5//! (`PluginRequest` / `PluginMessage`) routed by string-keyed tunnels, this
6//! module defines the same wire as **two myelin services** running over a
7//! single [`DuplexStreamTransport`](myelin::stream::DuplexStreamTransport)
8//! per plugin subprocess.
9//!
10//! Both peers — server (`tau-agent-lib`) and plugin subprocess
11//! (`tau-agent-plugin-worker` / `tau-agent-plugin-tasks`) — *call* and
12//! *serve* methods. Direction is encoded in the API id.
13//!
14//! ## Wire layout
15//!
16//! Each framed payload is `[u8 kind][u16 api_id LE][u8 slot_id][CBOR bytes]`,
17//! length-prefixed. The two API ids in use are [`PLUGIN_API_ID`] (server
18//! calls plugin) and [`PLUGIN_CALLBACK_API_ID`] (plugin calls server) —
19//! both emitted by the `#[myelin::service]` attribute on the trait
20//! definitions below.
21//!
22//! ## Codec
23//!
24//! [`CborCodec`](myelin::stream::CborCodec). Self-describing, debuggable
25//! with `xxd`/CBOR pretty-printers. Postcard would be smaller but loses
26//! debuggability.
27//!
28//! ## Status
29//!
30//! The trait definitions and value types live here; the actual transport
31//! plumbing in `tau-agent-lib::plugin` and the per-plugin executor crates
32//! still uses the JSON-lines protocol from `plugin_protocol.rs`. The
33//! migration replaces one binary at a time. See task #759.
34
35// Myelin services use plain `async fn` in trait definitions — not boxed
36// futures — so we tolerate the auto-trait-bound lint the same way myelin
37// itself does.
38#![allow(async_fn_in_trait)]
39
40use serde::{Deserialize, Serialize};
41
42use crate::plugin_protocol::{HookResult, PluginRegistration, PluginToolResult};
43use crate::protocol::{Request, Response};
44
45// ---------------------------------------------------------------------------
46// API ids
47// ---------------------------------------------------------------------------
48//
49// Both `#[myelin::service]` attributes below pin their `api_id` explicitly
50// (rather than letting the macro derive an FNV hash of the trait name). This
51// makes the wire stable across trait renames. The attribute emits public
52// `PLUGIN_API_ID` and `PLUGIN_CALLBACK_API_ID` constants — use those.
53
54// ---------------------------------------------------------------------------
55// Value types — request payloads grouped into structs so the macro-generated
56// enum variants stay tidy. (`#[myelin::service]` would otherwise turn each
57// trait method's parameter list into the variant fields, which is fine for
58// trivial methods but noisy for the bigger ones.)
59// ---------------------------------------------------------------------------
60
61/// Context passed at plugin initialization (`init`) and on every
62/// `session_start`.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SessionCtx {
65 /// Working directory for the session.
66 pub cwd: String,
67 /// Session id.
68 pub session_id: String,
69 /// Project name for this session, if any.
70 #[serde(skip_serializing_if = "Option::is_none", default)]
71 pub project_name: Option<String>,
72}
73
74/// A tool call dispatch from the server to the plugin.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ToolCallReq {
77 /// Unique id for this tool call (used to correlate cancel/output_delta).
78 pub tool_call_id: String,
79 /// Tool name (must match a name from the plugin's registration).
80 pub name: String,
81 /// JSON-encoded arguments.
82 pub arguments: serde_json::Value,
83 /// Working directory for tool execution.
84 #[serde(skip_serializing_if = "Option::is_none", default)]
85 pub cwd: Option<String>,
86 /// Session id this tool call belongs to.
87 #[serde(skip_serializing_if = "Option::is_none", default)]
88 pub session_id: Option<String>,
89 /// Project name for this session.
90 #[serde(skip_serializing_if = "Option::is_none", default)]
91 pub project_name: Option<String>,
92}
93
94/// A hook invocation.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct HookReq {
97 /// Hook name (e.g. `"before_llm_turn"`, `"after_tool_result"`).
98 pub name: String,
99 /// Hook-specific payload.
100 pub data: serde_json::Value,
101}
102
103// ---------------------------------------------------------------------------
104// Service traits
105// ---------------------------------------------------------------------------
106
107/// Methods the **server** calls on the **plugin**.
108///
109/// The plugin process serves this trait; the server holds a generated
110/// `PluginClient` over a `DuplexClientHalf<_, _, …, PluginRequest,
111/// PluginResponse>` bound to [`PLUGIN_API_ID`].
112///
113/// Concurrent in-flight requests are first-class: `tool_call` and
114/// `cancel_tool_call` may overlap (in fact `cancel_tool_call` is *only*
115/// useful while a `tool_call` is pending — its purpose is to abort it).
116/// Myelin's `MuxedSlots` handles the slot routing for free; the plugin
117/// must serve them concurrently (e.g. spawn tool_call as a background task
118/// so the dispatch loop can pick up the cancel).
119#[myelin::service(api_id = 0x0001)]
120pub trait PluginService {
121 /// Initialise the plugin with session context. Sent once after the
122 /// plugin has called `register` on the server side.
123 async fn init(&self, ctx: SessionCtx);
124
125 /// Execute a hook (e.g. `before_llm_turn`).
126 async fn hook(&self, req: HookReq) -> HookResult;
127
128 /// Execute a tool call. Returns the final result; intermediate output
129 /// is reported via `output_delta` on the [`PluginCallbackService`]
130 /// service.
131 async fn tool_call(&self, call: ToolCallReq) -> PluginToolResult;
132
133 /// Abort a tool call by id. The plugin should kill any associated
134 /// subprocess and return a normal `tool_call` response with
135 /// cancellation noted. No-op if the call already completed.
136 async fn cancel_tool_call(&self, tool_call_id: String);
137
138 /// Notify the plugin that a (sub-)session is starting.
139 async fn session_start(&self, ctx: SessionCtx);
140
141 /// Notify the plugin it has been idle long enough to consider exiting.
142 async fn idle(&self);
143}
144
145/// Methods the **plugin** calls on the **server**.
146///
147/// The server serves this trait; the plugin holds a generated
148/// `PluginCallbackClient` over a `DuplexClientHalf<_, _, …,
149/// PluginCallbackRequest, PluginCallbackResponse>` bound to
150/// [`PLUGIN_CALLBACK_API_ID`].
151///
152/// `output_delta` deserves a note: the existing protocol fires it
153/// one-way (no reply expected) at up to ~200 events/sec during a long
154/// bash command. Myelin's RPC model is request/response, so we keep it
155/// as a `()`-returning RPC and *do* await it. This costs an extra
156/// response frame per delta but bounds in-flight deltas at the slot
157/// pool size, preserving the protocol's flow-control properties. A
158/// future myelin streaming feature could collapse this to a true
159/// fire-and-forget notification.
160#[myelin::service(api_id = 0x0002)]
161pub trait PluginCallbackService {
162 /// Plugin registration. Sent once at startup before serving any
163 /// `PluginCalledByServer` methods. The server waits for this call
164 /// before considering the plugin ready.
165 async fn register(&self, reg: PluginRegistration);
166
167 /// Forward a [`Request`] from the plugin to the server's main
168 /// request handler (the same one the TUI/CLI talks to).
169 async fn server_request(&self, req: Request) -> Response;
170
171 /// Streaming tool output delta. Plugin → server, fire-then-await.
172 /// See trait-level note on the `()` return.
173 async fn output_delta(&self, tool_call_id: String, text: String);
174}
175
176// ---------------------------------------------------------------------------
177// Default duplex transport configuration
178// ---------------------------------------------------------------------------
179
180/// Number of concurrent in-flight outgoing RPCs per direction.
181///
182/// 32 is comfortably above the worst-case fan-out we expect:
183/// - 1 in-flight `tool_call` per plugin (serialised by
184/// `PluginHandle::take_tool_plugin` today).
185/// - bursts of `output_delta` from a long-running tool.
186/// - a `cancel_tool_call` racing the in-flight `tool_call`.
187/// - the merge-worker thread inside the tasks plugin issuing a
188/// `server_request` concurrently with the main dispatch loop.
189pub const DUPLEX_SLOTS: usize = 32;
190
191/// Per-slot reply buffer, in bytes.
192///
193/// `PluginToolResult` payloads can be large (a long bash transcript may
194/// reach tens of KiB). 128 KiB leaves headroom while staying small enough
195/// that `MuxedSlots::new_boxed` keeps the entire slot table on the heap.
196pub const DUPLEX_BUF: usize = 131_072;
197
198/// Type alias for the duplex transport tau uses on every plugin pipe.
199///
200/// `R` / `W` are the plugin's reader/writer halves of the byte stream
201/// (typically `FuturesIoReader<Async<File>>` / `FuturesIoWriter<…>` for
202/// stdin/stdout pipes).
203pub type PluginDuplex<R, W> = myelin::stream::DuplexStreamTransport<
204 R,
205 W,
206 myelin::stream::LengthPrefixed,
207 myelin::stream::CborCodec,
208 DUPLEX_SLOTS,
209 DUPLEX_BUF,
210>;
211
212// ---------------------------------------------------------------------------
213// Tests — round-trip a single RPC over a UnixStream pair, with both halves
214// of both services running concurrently. This is the M2 smoke test from
215// task #759: it proves the macro expansion, codec, framer, transport, and
216// pump all wire together correctly on the runtime tau actually uses (smol).
217// ---------------------------------------------------------------------------
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use myelin::io::futures_io::{FuturesIoReader, FuturesIoWriter};
223 use myelin::transport::ServerTransport;
224 use smol::Async;
225 use std::os::unix::net::UnixStream;
226 use std::sync::Arc;
227 use std::sync::atomic::{AtomicU32, Ordering};
228
229 fn make_pair() -> (
230 FuturesIoReader<Async<UnixStream>>,
231 FuturesIoWriter<Async<UnixStream>>,
232 FuturesIoReader<Async<UnixStream>>,
233 FuturesIoWriter<Async<UnixStream>>,
234 ) {
235 let (sa, sb) = UnixStream::pair().expect("UnixStream::pair");
236 sa.set_nonblocking(true).expect("nonblocking sa");
237 sb.set_nonblocking(true).expect("nonblocking sb");
238
239 // Each side: clone the FD so we have independent reader+writer
240 // halves (the futures_io adapter needs distinct objects for read
241 // and write).
242 let sa_w = sa.try_clone().expect("clone sa");
243 sa_w.set_nonblocking(true).expect("nonblocking sa_w");
244 let sb_w = sb.try_clone().expect("clone sb");
245 sb_w.set_nonblocking(true).expect("nonblocking sb_w");
246
247 let sa_r = Async::new(sa).expect("Async sa");
248 let sa_w = Async::new(sa_w).expect("Async sa_w");
249 let sb_r = Async::new(sb).expect("Async sb");
250 let sb_w = Async::new(sb_w).expect("Async sb_w");
251
252 (
253 FuturesIoReader::new(sa_r),
254 FuturesIoWriter::new(sa_w),
255 FuturesIoReader::new(sb_r),
256 FuturesIoWriter::new(sb_w),
257 )
258 }
259
260 /// Trivial server impls that count calls; we only need to prove the
261 /// generated dispatch + transport plumbing wires up correctly.
262 struct PluginSide {
263 init_calls: Arc<AtomicU32>,
264 idle_calls: Arc<AtomicU32>,
265 }
266
267 impl PluginService for PluginSide {
268 async fn init(&self, _ctx: SessionCtx) {
269 self.init_calls.fetch_add(1, Ordering::SeqCst);
270 }
271 async fn hook(&self, _req: HookReq) -> HookResult {
272 HookResult::default()
273 }
274 async fn tool_call(&self, call: ToolCallReq) -> PluginToolResult {
275 PluginToolResult {
276 tool_call_id: call.tool_call_id,
277 content: vec![],
278 is_error: false,
279 summary: None,
280 post_persist_actions: vec![],
281 }
282 }
283 async fn cancel_tool_call(&self, _id: String) {}
284 async fn session_start(&self, _ctx: SessionCtx) {}
285 async fn idle(&self) {
286 self.idle_calls.fetch_add(1, Ordering::SeqCst);
287 }
288 }
289
290 struct ServerSide {
291 register_calls: Arc<AtomicU32>,
292 }
293
294 impl PluginCallbackService for ServerSide {
295 async fn register(&self, _reg: PluginRegistration) {
296 self.register_calls.fetch_add(1, Ordering::SeqCst);
297 }
298 async fn server_request(&self, _req: Request) -> Response {
299 // Any Response variant works for the round-trip; pick a small one.
300 Response::Ok
301 }
302 async fn output_delta(&self, _id: String, _text: String) {}
303 }
304
305 #[test]
306 fn duplex_round_trips_both_directions() {
307 // Build the duplex transport on each side with the production
308 // `PluginDuplex` type alias — this is the exact stack tau will
309 // use in M3+.
310 let (r_srv, w_srv, r_plg, w_plg) = make_pair();
311 let dx_srv: PluginDuplex<_, _> = PluginDuplex::new(r_srv, w_srv);
312 let dx_plg: PluginDuplex<_, _> = PluginDuplex::new(r_plg, w_plg);
313
314 // Server: serves PLUGIN_CALLBACK_API_ID (so the plugin can call into us),
315 // calls into PLUGIN_API_ID.
316 let srv_server = dx_srv
317 .server_half::<PluginCallbackRequest, PluginCallbackResponse>(PLUGIN_CALLBACK_API_ID);
318 let srv_client = dx_srv.client_half::<PluginRequest, PluginResponse>(PLUGIN_API_ID);
319
320 // Plugin: dual.
321 let plg_server = dx_plg.server_half::<PluginRequest, PluginResponse>(PLUGIN_API_ID);
322 let plg_client = dx_plg
323 .client_half::<PluginCallbackRequest, PluginCallbackResponse>(PLUGIN_CALLBACK_API_ID);
324
325 let (pump_srv, _h_srv) = dx_srv.split();
326 let (pump_plg, _h_plg) = dx_plg.split();
327
328 let plugin_init_count = Arc::new(AtomicU32::new(0));
329 let plugin_idle_count = Arc::new(AtomicU32::new(0));
330 let server_register_count = Arc::new(AtomicU32::new(0));
331
332 let plg_impl = PluginSide {
333 init_calls: plugin_init_count.clone(),
334 idle_calls: plugin_idle_count.clone(),
335 };
336 let srv_impl = ServerSide {
337 register_calls: server_register_count.clone(),
338 };
339
340 smol::block_on(async {
341 // Drive a single round trip in each direction.
342 let mut plg_server = plg_server;
343 let mut srv_server = srv_server;
344
345 let plugin_dispatch = async move {
346 // Serve two `init` calls and two `idle` calls.
347 for _ in 0..4 {
348 let (req, token) = plg_server.recv().await.expect("plugin recv");
349 let resp = plugin_dispatch(&plg_impl, req).await;
350 plg_server.reply(token, resp).await.expect("plugin reply");
351 }
352 };
353 let server_dispatch = async move {
354 // Serve one `register` call.
355 let (req, token) = srv_server.recv().await.expect("server recv");
356 let resp = plugin_callback_dispatch(&srv_impl, req).await;
357 srv_server.reply(token, resp).await.expect("server reply");
358 };
359
360 let work = async {
361 // Plugin → server: register. The generated client method
362 // returns Result<(), TransportError>; in this loopback
363 // scenario the transport never fails so we just await it.
364 let plg_client = PluginCallbackClient::new(plg_client);
365 let _ = plg_client
366 .register(PluginRegistration {
367 name: "smoke".into(),
368 tools: vec![],
369 hooks: vec![],
370 commands: vec![],
371 })
372 .await;
373
374 // Server → plugin: two inits and two idles, concurrently.
375 let srv_client = PluginClient::new(srv_client);
376 let i1 = srv_client.init(SessionCtx {
377 cwd: "/tmp".into(),
378 session_id: "s1".into(),
379 project_name: None,
380 });
381 let i2 = srv_client.init(SessionCtx {
382 cwd: "/tmp".into(),
383 session_id: "s2".into(),
384 project_name: None,
385 });
386 let id1 = srv_client.idle();
387 let id2 = srv_client.idle();
388 let ((_a, _b), (_c, _d)) = futures_lite::future::zip(
389 futures_lite::future::zip(i1, i2),
390 futures_lite::future::zip(id1, id2),
391 )
392 .await;
393 };
394
395 // Pumps never complete normally; `or` returns when work does.
396 futures_lite::future::or(
397 async {
398 futures_lite::future::zip(
399 work,
400 futures_lite::future::zip(plugin_dispatch, server_dispatch),
401 )
402 .await;
403 },
404 async {
405 let _ = futures_lite::future::zip(pump_srv.run(), pump_plg.run()).await;
406 },
407 )
408 .await;
409 });
410
411 assert_eq!(plugin_init_count.load(Ordering::SeqCst), 2);
412 assert_eq!(plugin_idle_count.load(Ordering::SeqCst), 2);
413 assert_eq!(server_register_count.load(Ordering::SeqCst), 1);
414 }
415}