Skip to main content

zshrs_daemon/
ipc.rs

1// IPC framing + message types.
2//
3// Wire format (per docs/DAEMON.md "IPC wire format"):
4//   [4 bytes: u32 BE length] [length bytes: UTF-8 JSON]
5//
6// Envelopes:
7//   client → daemon (handshake) : { "hello":   { ... } }
8//   daemon → client (handshake) : { "welcome": { ... } | "err": { ... } }
9//   client → daemon (request)   : { "id": u64, "op": str, "args": {...} }
10//   daemon → client (response)  : { "id": u64, "ok": bool, [payload...] | "err": {...} }
11//   daemon → client (async evt) : { "event": str, [payload...] }
12//
13// JSON-tagged via serde's untagged-enum dispatch (kind detected by which top-level key
14// exists). This keeps every message a flat object on the wire, debuggable with `socat`.
15
16use std::io;
17
18use byteorder::{BigEndian, ByteOrder};
19use serde::{Deserialize, Serialize};
20use tokio::io::{AsyncReadExt, AsyncWriteExt};
21
22use super::{DaemonError, Result};
23
24/// Daemon's protocol version — bumped on incompatible wire changes.
25pub const PROTOCOL_VERSION: u32 = 1;
26
27/// Largest single frame the daemon will accept. Guards against runaway client.
28pub const MAX_FRAME_BYTES: usize = 64 * 1024 * 1024;
29/// `ProtocolVersion` — see fields for layout.
30#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
31pub struct ProtocolVersion(pub u32);
32
33/// First message client sends after connect.
34#[derive(Clone, Debug, Serialize, Deserialize)]
35pub struct Hello {
36    /// `version` field.
37    pub version: u32,
38    /// `client_pid` field.
39    pub client_pid: i32,
40    /// `tty` field.
41    pub tty: Option<String>,
42    /// `cwd` field.
43    pub cwd: Option<String>,
44    /// `argv0` field.
45    pub argv0: Option<String>,
46}
47
48/// Daemon's response to a successful Hello.
49#[derive(Clone, Debug, Serialize, Deserialize)]
50pub struct Welcome {
51    /// `version` field.
52    pub version: u32,
53    /// `client_id` field.
54    pub client_id: u64,
55    /// `session_id` field.
56    pub session_id: String,
57    /// `daemon_pid` field.
58    pub daemon_pid: i32,
59    /// `daemon_uptime_ms` field.
60    pub daemon_uptime_ms: u64,
61}
62
63/// Error payload — paired with `ok: false` on response, or `err: {...}` on welcome failure.
64#[derive(Clone, Debug, Serialize, Deserialize)]
65pub struct ErrPayload {
66    /// `code` field.
67    pub code: String,
68    /// `msg` field.
69    pub msg: String,
70}
71
72impl ErrPayload {
73    /// `new` — see implementation.
74    pub fn new<C: Into<String>, M: Into<String>>(code: C, msg: M) -> Self {
75        Self {
76            code: code.into(),
77            msg: msg.into(),
78        }
79    }
80}
81
82impl From<rusqlite::Error> for ErrPayload {
83    fn from(e: rusqlite::Error) -> Self {
84        Self::new("sqlite", e.to_string())
85    }
86}
87
88impl From<std::io::Error> for ErrPayload {
89    fn from(e: std::io::Error) -> Self {
90        Self::new("io", e.to_string())
91    }
92}
93
94impl From<super::DaemonError> for ErrPayload {
95    fn from(e: super::DaemonError) -> Self {
96        Self::new("daemon", e.to_string())
97    }
98}
99
100/// Top-level frame envelope. One of these JSON objects per length-prefixed frame.
101#[derive(Clone, Debug, Serialize, Deserialize)]
102#[serde(untagged)]
103pub enum Frame {
104    /// `Hello` variant.
105    Hello { hello: Hello },
106    /// `Welcome` variant.
107    Welcome { welcome: Welcome },
108    /// `WelcomeErr` variant.
109    WelcomeErr {
110        welcome: serde_json::Value,
111        err: ErrPayload,
112    },
113    /// `Request` variant.
114    Request {
115        id: u64,
116        op: String,
117        #[serde(default)]
118        args: serde_json::Value,
119    },
120    /// `Response` variant.
121    Response {
122        id: u64,
123        ok: bool,
124        #[serde(flatten)]
125        payload: serde_json::Value,
126    },
127    /// `Event` variant.
128    Event {
129        event: String,
130        #[serde(flatten)]
131        payload: serde_json::Value,
132    },
133}
134
135impl Frame {
136    /// `hello` — see implementation.
137    pub fn hello(h: Hello) -> Self {
138        Frame::Hello { hello: h }
139    }
140    /// `welcome` — see implementation.
141    pub fn welcome(w: Welcome) -> Self {
142        Frame::Welcome { welcome: w }
143    }
144    /// `request` — see implementation.
145    pub fn request(id: u64, op: impl Into<String>, args: serde_json::Value) -> Self {
146        Frame::Request {
147            id,
148            op: op.into(),
149            args,
150        }
151    }
152    /// `ok_response` — see implementation.
153    pub fn ok_response(id: u64, payload: serde_json::Value) -> Self {
154        Frame::Response {
155            id,
156            ok: true,
157            payload,
158        }
159    }
160    /// `err_response` — see implementation.
161    pub fn err_response(id: u64, err: ErrPayload) -> Self {
162        let payload = serde_json::json!({ "err": err });
163        Frame::Response {
164            id,
165            ok: false,
166            payload,
167        }
168    }
169    /// `event` — see implementation.
170    pub fn event(name: impl Into<String>, payload: serde_json::Value) -> Self {
171        Frame::Event {
172            event: name.into(),
173            payload,
174        }
175    }
176}
177
178/// Async event types pushed daemon → client. Names match the doc's event table.
179#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
180#[serde(rename_all = "snake_case")]
181pub enum Event {
182    /// `ShardUpdated` variant.
183    ShardUpdated,
184    /// `RebuildComplete` variant.
185    RebuildComplete,
186    /// `CanonicalChanged` variant.
187    CanonicalChanged,
188    /// `Match` variant.
189    Match,
190    /// `CmdExecute` variant.
191    CmdExecute,
192    /// `Notify` variant.
193    Notify,
194    /// `DaemonShutdown` variant.
195    DaemonShutdown,
196    /// `AskPending` variant.
197    AskPending,
198    /// `AskDismissed` variant.
199    AskDismissed,
200    /// `AskProgress` variant.
201    AskProgress,
202    /// `LongCmdComplete` variant.
203    LongCmdComplete,
204    /// `LongCmdStarted` variant.
205    LongCmdStarted,
206    /// `LongCmdFailed` variant.
207    LongCmdFailed,
208    /// `LongCmdSignaled` variant.
209    LongCmdSignaled,
210}
211
212// -------- Wire framing helpers --------
213
214/// Read one length-prefixed frame from the socket.
215pub async fn read_frame<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<Frame> {
216    let mut len_buf = [0u8; 4];
217    reader.read_exact(&mut len_buf).await?;
218    let len = BigEndian::read_u32(&len_buf) as usize;
219    if len == 0 {
220        return Err(DaemonError::other("zero-length frame"));
221    }
222    if len > MAX_FRAME_BYTES {
223        return Err(DaemonError::FrameTooLarge {
224            size: len,
225            max: MAX_FRAME_BYTES,
226        });
227    }
228
229    let mut buf = vec![0u8; len];
230    reader.read_exact(&mut buf).await?;
231
232    let frame: Frame = serde_json::from_slice(&buf)?;
233    Ok(frame)
234}
235
236/// Write one length-prefixed frame to the socket.
237pub async fn write_frame<W: AsyncWriteExt + Unpin>(writer: &mut W, frame: &Frame) -> Result<()> {
238    let body = serde_json::to_vec(frame)?;
239    if body.len() > MAX_FRAME_BYTES {
240        return Err(DaemonError::FrameTooLarge {
241            size: body.len(),
242            max: MAX_FRAME_BYTES,
243        });
244    }
245    let mut header = [0u8; 4];
246    BigEndian::write_u32(&mut header, body.len() as u32);
247    writer.write_all(&header).await?;
248    writer.write_all(&body).await?;
249    writer.flush().await?;
250    Ok(())
251}
252
253/// Synchronous variant for client-side blocking IPC (used from non-async builtins).
254pub fn read_frame_sync<R: io::Read>(reader: &mut R) -> Result<Frame> {
255    let mut len_buf = [0u8; 4];
256    reader.read_exact(&mut len_buf)?;
257    let len = BigEndian::read_u32(&len_buf) as usize;
258    if len == 0 {
259        return Err(DaemonError::other("zero-length frame"));
260    }
261    if len > MAX_FRAME_BYTES {
262        return Err(DaemonError::FrameTooLarge {
263            size: len,
264            max: MAX_FRAME_BYTES,
265        });
266    }
267
268    let mut buf = vec![0u8; len];
269    reader.read_exact(&mut buf)?;
270
271    let frame: Frame = serde_json::from_slice(&buf)?;
272    Ok(frame)
273}
274
275/// Synchronous frame write (client side).
276pub fn write_frame_sync<W: io::Write>(writer: &mut W, frame: &Frame) -> Result<()> {
277    let body = serde_json::to_vec(frame)?;
278    if body.len() > MAX_FRAME_BYTES {
279        return Err(DaemonError::FrameTooLarge {
280            size: body.len(),
281            max: MAX_FRAME_BYTES,
282        });
283    }
284    let mut header = [0u8; 4];
285    BigEndian::write_u32(&mut header, body.len() as u32);
286    writer.write_all(&header)?;
287    writer.write_all(&body)?;
288    writer.flush()?;
289    Ok(())
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use std::io::Cursor;
296
297    #[test]
298    fn roundtrip_hello_sync() {
299        let h = Hello {
300            version: PROTOCOL_VERSION,
301            client_pid: 12345,
302            tty: Some("/dev/ttys003".into()),
303            cwd: Some("/home/wizard".into()),
304            argv0: Some("zshrs".into()),
305        };
306        let frame = Frame::hello(h);
307
308        let mut buf = Vec::new();
309        write_frame_sync(&mut buf, &frame).unwrap();
310
311        let mut cur = Cursor::new(buf);
312        let read = read_frame_sync(&mut cur).unwrap();
313
314        match read {
315            Frame::Hello { hello } => {
316                assert_eq!(hello.version, PROTOCOL_VERSION);
317                assert_eq!(hello.client_pid, 12345);
318                assert_eq!(hello.tty.as_deref(), Some("/dev/ttys003"));
319            }
320            _ => panic!("expected Hello, got {:?}", read),
321        }
322    }
323
324    #[test]
325    fn roundtrip_request_sync() {
326        let frame = Frame::request(42, "ping", serde_json::json!({}));
327        let mut buf = Vec::new();
328        write_frame_sync(&mut buf, &frame).unwrap();
329
330        let mut cur = Cursor::new(buf);
331        let read = read_frame_sync(&mut cur).unwrap();
332
333        match read {
334            Frame::Request { id, op, args } => {
335                assert_eq!(id, 42);
336                assert_eq!(op, "ping");
337                assert!(args.is_object());
338            }
339            _ => panic!("expected Request, got {:?}", read),
340        }
341    }
342
343    #[test]
344    fn roundtrip_event_sync() {
345        let frame = Frame::event(
346            "shard_updated",
347            serde_json::json!({"shard":"foo","generation":3}),
348        );
349        let mut buf = Vec::new();
350        write_frame_sync(&mut buf, &frame).unwrap();
351
352        let mut cur = Cursor::new(buf);
353        let read = read_frame_sync(&mut cur).unwrap();
354
355        match read {
356            Frame::Event { event, payload } => {
357                assert_eq!(event, "shard_updated");
358                assert_eq!(payload["shard"], "foo");
359                assert_eq!(payload["generation"], 3);
360            }
361            _ => panic!("expected Event, got {:?}", read),
362        }
363    }
364
365    #[test]
366    fn frame_too_large_rejected_on_write() {
367        let big = "x".repeat(MAX_FRAME_BYTES + 1);
368        let frame = Frame::request(1, "ping", serde_json::json!({"big": big}));
369        let mut buf = Vec::new();
370        let err = write_frame_sync(&mut buf, &frame).unwrap_err();
371        matches!(err, DaemonError::FrameTooLarge { .. });
372    }
373
374    #[test]
375    fn frame_too_large_rejected_on_read() {
376        let mut buf = Vec::new();
377        let bogus_len = (MAX_FRAME_BYTES + 1) as u32;
378        let mut hdr = [0u8; 4];
379        BigEndian::write_u32(&mut hdr, bogus_len);
380        buf.extend_from_slice(&hdr);
381        let mut cur = Cursor::new(buf);
382        let err = read_frame_sync(&mut cur).unwrap_err();
383        matches!(err, DaemonError::FrameTooLarge { .. });
384    }
385
386    #[tokio::test]
387    async fn roundtrip_async() {
388        let frame = Frame::request(7, "info", serde_json::json!({}));
389        let (mut a, mut b) = tokio::io::duplex(64 * 1024);
390        let writer_frame = frame.clone();
391        tokio::spawn(async move {
392            write_frame(&mut a, &writer_frame).await.unwrap();
393        });
394        let read = read_frame(&mut b).await.unwrap();
395        match read {
396            Frame::Request { id, op, .. } => {
397                assert_eq!(id, 7);
398                assert_eq!(op, "info");
399            }
400            _ => panic!("expected Request"),
401        }
402    }
403}