Skip to main content

vex_cli/
proto.rs

1use std::path::PathBuf;
2
3use anyhow::{Result, bail};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use uuid::Uuid;
8
9const TAG_CONTROL: u8 = 0x01;
10const TAG_DATA: u8 = 0x02;
11const MAX_FRAME_SIZE: usize = 1_048_576; // 1 MiB
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(tag = "type")]
15pub enum ClientMessage {
16    CreateShell {
17        shell: Option<String>,
18        repo: Option<String>,
19        workstream: Option<String>,
20    },
21    ListShells,
22    AttachShell {
23        id: Uuid,
24        cols: u16,
25        rows: u16,
26    },
27    DetachShell,
28    ResizeShell {
29        id: Uuid,
30        cols: u16,
31        rows: u16,
32    },
33    KillShell {
34        id: Uuid,
35    },
36    AgentList,
37    AgentNotifications,
38    AgentWatch {
39        shell_id: Uuid,
40    },
41    AgentPrompt {
42        shell_id: Uuid,
43        text: String,
44    },
45    AgentSpawn {
46        repo: String,
47        workstream: Option<String>,
48    },
49    WorkstreamCreate {
50        repo: String,
51        name: String,
52    },
53    WorkstreamList {
54        repo: Option<String>,
55    },
56    WorkstreamRemove {
57        repo: String,
58        name: String,
59    },
60    RepoAdd {
61        name: String,
62        path: PathBuf,
63    },
64    RepoRemove {
65        name: String,
66    },
67    RepoList,
68    RepoIntrospectPath {
69        path: PathBuf,
70    },
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
74#[serde(tag = "type")]
75pub enum ServerMessage {
76    ShellCreated {
77        id: Uuid,
78    },
79    Shells {
80        shells: Vec<ShellInfo>,
81    },
82    Attached {
83        id: Uuid,
84    },
85    Detached,
86    ShellEnded {
87        id: Uuid,
88        exit_code: Option<i32>,
89    },
90    ClientJoined {
91        shell_id: Uuid,
92        client_id: Uuid,
93    },
94    ClientLeft {
95        shell_id: Uuid,
96        client_id: Uuid,
97    },
98    Error {
99        message: String,
100    },
101    AgentListResponse {
102        agents: Vec<AgentEntry>,
103    },
104    AgentPromptSent {
105        shell_id: Uuid,
106    },
107    AgentConversationLine {
108        shell_id: Uuid,
109        line: String,
110    },
111    AgentWatchEnd {
112        shell_id: Uuid,
113    },
114    RepoAdded {
115        name: String,
116        path: PathBuf,
117    },
118    RepoRemoved {
119        name: String,
120    },
121    Repos {
122        repos: Vec<RepoEntry>,
123    },
124    RepoIntrospected {
125        suggested_name: String,
126        path: PathBuf,
127        git_remote: Option<String>,
128        git_branch: Option<String>,
129        children: Vec<String>,
130    },
131    WorkstreamCreated {
132        repo: String,
133        name: String,
134        worktree_path: PathBuf,
135    },
136    WorkstreamRemoved {
137        repo: String,
138        name: String,
139    },
140    Workstreams {
141        workstreams: Vec<WorkstreamInfo>,
142    },
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
146pub struct ShellInfo {
147    pub id: Uuid,
148    pub cols: u16,
149    pub rows: u16,
150    pub created_at: DateTime<Utc>,
151    pub client_count: usize,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
155pub struct AgentEntry {
156    pub vex_shell_id: Uuid,
157    pub claude_session_id: String,
158    pub claude_pid: u32,
159    pub cwd: PathBuf,
160    pub detected_at: DateTime<Utc>,
161    pub needs_intervention: bool,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
165pub struct RepoEntry {
166    pub name: String,
167    pub path: PathBuf,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
171pub struct WorkstreamInfo {
172    pub repo: String,
173    pub name: String,
174    pub worktree_path: PathBuf,
175    pub branch: String,
176    pub created_at: DateTime<Utc>,
177}
178
179#[derive(Debug)]
180pub enum Frame {
181    Control(Vec<u8>),
182    Data(Vec<u8>),
183}
184
185pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
186    let len = (1 + payload.len()) as u32;
187    w.write_all(&len.to_be_bytes()).await?;
188    w.write_u8(TAG_CONTROL).await?;
189    w.write_all(payload).await?;
190    w.flush().await?;
191    Ok(())
192}
193
194pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
195    let len = (1 + payload.len()) as u32;
196    w.write_all(&len.to_be_bytes()).await?;
197    w.write_u8(TAG_DATA).await?;
198    w.write_all(payload).await?;
199    w.flush().await?;
200    Ok(())
201}
202
203pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
204    let mut len_buf = [0u8; 4];
205    match r.read_exact(&mut len_buf).await {
206        Ok(_) => {}
207        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
208        Err(e) => return Err(e.into()),
209    }
210    let len = u32::from_be_bytes(len_buf) as usize;
211    if len == 0 {
212        bail!("invalid frame: zero length");
213    }
214    if len > MAX_FRAME_SIZE {
215        bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
216    }
217    let tag = {
218        let mut tag_buf = [0u8; 1];
219        r.read_exact(&mut tag_buf).await?;
220        tag_buf[0]
221    };
222    let payload_len = len - 1;
223    let mut payload = vec![0u8; payload_len];
224    if payload_len > 0 {
225        r.read_exact(&mut payload).await?;
226    }
227    match tag {
228        TAG_CONTROL => Ok(Some(Frame::Control(payload))),
229        TAG_DATA => Ok(Some(Frame::Data(payload))),
230        other => bail!("unknown frame tag: 0x{:02x}", other),
231    }
232}
233
234/// Convenience: serialize a ClientMessage and write as a control frame.
235pub async fn send_client_message<W: AsyncWrite + Unpin>(
236    w: &mut W,
237    msg: &ClientMessage,
238) -> Result<()> {
239    let json = serde_json::to_vec(msg)?;
240    write_control(w, &json).await
241}
242
243/// Convenience: serialize a ServerMessage and write as a control frame.
244pub async fn send_server_message<W: AsyncWrite + Unpin>(
245    w: &mut W,
246    msg: &ServerMessage,
247) -> Result<()> {
248    let json = serde_json::to_vec(msg)?;
249    write_control(w, &json).await
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn serde_round_trip_client() {
258        let msgs = vec![
259            ClientMessage::CreateShell {
260                shell: Some("bash".into()),
261                repo: None,
262                workstream: None,
263            },
264            ClientMessage::ListShells,
265            ClientMessage::AttachShell {
266                id: Uuid::nil(),
267                cols: 120,
268                rows: 40,
269            },
270            ClientMessage::DetachShell,
271            ClientMessage::ResizeShell {
272                id: Uuid::nil(),
273                cols: 80,
274                rows: 24,
275            },
276            ClientMessage::KillShell { id: Uuid::nil() },
277            ClientMessage::AgentList,
278            ClientMessage::AgentNotifications,
279            ClientMessage::AgentWatch {
280                shell_id: Uuid::nil(),
281            },
282            ClientMessage::AgentPrompt {
283                shell_id: Uuid::nil(),
284                text: "hello".into(),
285            },
286            ClientMessage::AgentSpawn {
287                repo: "vex".into(),
288                workstream: None,
289            },
290            ClientMessage::AgentSpawn {
291                repo: "vex".into(),
292                workstream: Some("feature-x".into()),
293            },
294            ClientMessage::WorkstreamCreate {
295                repo: "vex".into(),
296                name: "feature-x".into(),
297            },
298            ClientMessage::WorkstreamList { repo: None },
299            ClientMessage::WorkstreamList {
300                repo: Some("vex".into()),
301            },
302            ClientMessage::WorkstreamRemove {
303                repo: "vex".into(),
304                name: "feature-x".into(),
305            },
306            ClientMessage::RepoAdd {
307                name: "vex".into(),
308                path: PathBuf::from("/tmp/vex"),
309            },
310            ClientMessage::RepoRemove { name: "vex".into() },
311            ClientMessage::RepoList,
312            ClientMessage::RepoIntrospectPath {
313                path: PathBuf::from("/tmp"),
314            },
315        ];
316        for msg in msgs {
317            let json = serde_json::to_string(&msg).unwrap();
318            let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
319            assert_eq!(msg, decoded);
320        }
321    }
322
323    #[test]
324    fn serde_round_trip_server() {
325        let msgs = vec![
326            ServerMessage::ShellCreated { id: Uuid::nil() },
327            ServerMessage::Shells {
328                shells: vec![ShellInfo {
329                    id: Uuid::nil(),
330                    cols: 80,
331                    rows: 24,
332                    created_at: Utc::now(),
333                    client_count: 2,
334                }],
335            },
336            ServerMessage::Attached { id: Uuid::nil() },
337            ServerMessage::Detached,
338            ServerMessage::ShellEnded {
339                id: Uuid::nil(),
340                exit_code: Some(0),
341            },
342            ServerMessage::ClientJoined {
343                shell_id: Uuid::nil(),
344                client_id: Uuid::nil(),
345            },
346            ServerMessage::ClientLeft {
347                shell_id: Uuid::nil(),
348                client_id: Uuid::nil(),
349            },
350            ServerMessage::Error {
351                message: "fail".into(),
352            },
353            ServerMessage::AgentListResponse {
354                agents: vec![AgentEntry {
355                    vex_shell_id: Uuid::nil(),
356                    claude_session_id: "abc123".into(),
357                    claude_pid: 1234,
358                    cwd: PathBuf::from("/tmp"),
359                    detected_at: Utc::now(),
360                    needs_intervention: true,
361                }],
362            },
363            ServerMessage::AgentPromptSent {
364                shell_id: Uuid::nil(),
365            },
366            ServerMessage::AgentConversationLine {
367                shell_id: Uuid::nil(),
368                line: "test line".into(),
369            },
370            ServerMessage::AgentWatchEnd {
371                shell_id: Uuid::nil(),
372            },
373            ServerMessage::RepoAdded {
374                name: "vex".into(),
375                path: PathBuf::from("/tmp/vex"),
376            },
377            ServerMessage::RepoRemoved { name: "vex".into() },
378            ServerMessage::Repos {
379                repos: vec![RepoEntry {
380                    name: "vex".into(),
381                    path: PathBuf::from("/tmp/vex"),
382                }],
383            },
384            ServerMessage::RepoIntrospected {
385                suggested_name: "vex".into(),
386                path: PathBuf::from("/tmp/vex"),
387                git_remote: Some("git@github.com:user/vex.git".into()),
388                git_branch: Some("main".into()),
389                children: vec!["src".into(), "tests".into()],
390            },
391            ServerMessage::WorkstreamCreated {
392                repo: "vex".into(),
393                name: "feature-x".into(),
394                worktree_path: PathBuf::from("/tmp/workstreams/vex/feature-x"),
395            },
396            ServerMessage::WorkstreamRemoved {
397                repo: "vex".into(),
398                name: "feature-x".into(),
399            },
400            ServerMessage::Workstreams {
401                workstreams: vec![WorkstreamInfo {
402                    repo: "vex".into(),
403                    name: "feature-x".into(),
404                    worktree_path: PathBuf::from("/tmp/workstreams/vex/feature-x"),
405                    branch: "feature-x".into(),
406                    created_at: Utc::now(),
407                }],
408            },
409        ];
410        for msg in msgs {
411            let json = serde_json::to_string(&msg).unwrap();
412            let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
413            assert_eq!(msg, decoded);
414        }
415    }
416
417    #[tokio::test]
418    async fn frame_round_trip_control() {
419        let (mut client, mut server) = tokio::io::duplex(1024);
420        let payload = b"hello control";
421        write_control(&mut client, payload).await.unwrap();
422        drop(client);
423        let frame = read_frame(&mut server).await.unwrap().unwrap();
424        match frame {
425            Frame::Control(data) => assert_eq!(data, payload),
426            Frame::Data(_) => panic!("expected control frame"),
427        }
428    }
429
430    #[tokio::test]
431    async fn frame_round_trip_data() {
432        let (mut client, mut server) = tokio::io::duplex(1024);
433        let payload = b"hello data";
434        write_data(&mut client, payload).await.unwrap();
435        drop(client);
436        let frame = read_frame(&mut server).await.unwrap().unwrap();
437        match frame {
438            Frame::Data(data) => assert_eq!(data, payload),
439            Frame::Control(_) => panic!("expected data frame"),
440        }
441    }
442
443    #[tokio::test]
444    async fn frame_eof_returns_none() {
445        let (client, mut server) = tokio::io::duplex(1024);
446        drop(client);
447        let frame = read_frame(&mut server).await.unwrap();
448        assert!(frame.is_none());
449    }
450
451    #[tokio::test]
452    async fn frame_bad_tag() {
453        let (mut client, mut server) = tokio::io::duplex(1024);
454        // Write a frame with tag 0xFF
455        let len: u32 = 2; // 1 byte tag + 1 byte payload
456        client.write_all(&len.to_be_bytes()).await.unwrap();
457        client.write_u8(0xFF).await.unwrap();
458        client.write_u8(0x00).await.unwrap();
459        drop(client);
460        let result = read_frame(&mut server).await;
461        assert!(result.is_err());
462        assert!(
463            result
464                .unwrap_err()
465                .to_string()
466                .contains("unknown frame tag")
467        );
468    }
469
470    #[tokio::test]
471    async fn frame_too_large() {
472        let (mut client, mut server) = tokio::io::duplex(1024);
473        // Write a frame header claiming 2 MiB payload
474        let len: u32 = 2 * 1024 * 1024;
475        client.write_all(&len.to_be_bytes()).await.unwrap();
476        drop(client);
477        let result = read_frame(&mut server).await;
478        assert!(result.is_err());
479        assert!(result.unwrap_err().to_string().contains("frame too large"));
480    }
481
482    #[tokio::test]
483    async fn send_client_message_round_trip() {
484        let (mut client, mut server) = tokio::io::duplex(4096);
485        let msg = ClientMessage::CreateShell {
486            shell: Some("zsh".into()),
487            repo: None,
488            workstream: None,
489        };
490        send_client_message(&mut client, &msg).await.unwrap();
491        drop(client);
492        let frame = read_frame(&mut server).await.unwrap().unwrap();
493        match frame {
494            Frame::Control(data) => {
495                let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
496                assert_eq!(decoded, msg);
497            }
498            Frame::Data(_) => panic!("expected control frame"),
499        }
500    }
501}