Skip to main content

vex_cli/
proto.rs

1use std::path::PathBuf;
2
3use anyhow::{Result, bail};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use uuid::Uuid;
8
9const TAG_CONTROL: u8 = 0x01;
10const TAG_DATA: u8 = 0x02;
11const MAX_FRAME_SIZE: usize = 1_048_576; // 1 MiB
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(tag = "type")]
15pub enum ClientMessage {
16    CreateSession {
17        shell: Option<String>,
18        repo: Option<String>,
19    },
20    ListSessions,
21    AttachSession {
22        id: Uuid,
23        cols: u16,
24        rows: u16,
25    },
26    DetachSession,
27    ResizeSession {
28        id: Uuid,
29        cols: u16,
30        rows: u16,
31    },
32    KillSession {
33        id: Uuid,
34    },
35    AgentList,
36    AgentNotifications,
37    AgentWatch {
38        session_id: Uuid,
39    },
40    AgentPrompt {
41        session_id: Uuid,
42        text: String,
43    },
44    AgentSpawn {
45        repo: String,
46        workstream: Option<String>,
47    },
48    WorkstreamCreate {
49        repo: String,
50        name: String,
51    },
52    WorkstreamList {
53        repo: Option<String>,
54    },
55    WorkstreamRemove {
56        repo: String,
57        name: String,
58    },
59    RepoAdd {
60        name: String,
61        path: PathBuf,
62    },
63    RepoRemove {
64        name: String,
65    },
66    RepoList,
67    RepoIntrospectPath {
68        path: PathBuf,
69    },
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
73#[serde(tag = "type")]
74pub enum ServerMessage {
75    SessionCreated {
76        id: Uuid,
77    },
78    Sessions {
79        sessions: Vec<SessionInfo>,
80    },
81    Attached {
82        id: Uuid,
83    },
84    Detached,
85    SessionEnded {
86        id: Uuid,
87        exit_code: Option<i32>,
88    },
89    ClientJoined {
90        session_id: Uuid,
91        client_id: Uuid,
92    },
93    ClientLeft {
94        session_id: Uuid,
95        client_id: Uuid,
96    },
97    Error {
98        message: String,
99    },
100    AgentListResponse {
101        agents: Vec<AgentEntry>,
102    },
103    AgentPromptSent {
104        session_id: Uuid,
105    },
106    AgentConversationLine {
107        session_id: Uuid,
108        line: String,
109    },
110    AgentWatchEnd {
111        session_id: Uuid,
112    },
113    RepoAdded {
114        name: String,
115        path: PathBuf,
116    },
117    RepoRemoved {
118        name: String,
119    },
120    Repos {
121        repos: Vec<RepoEntry>,
122    },
123    RepoIntrospected {
124        suggested_name: String,
125        path: PathBuf,
126        git_remote: Option<String>,
127        git_branch: Option<String>,
128        children: Vec<String>,
129    },
130    WorkstreamCreated {
131        repo: String,
132        name: String,
133        worktree_path: PathBuf,
134    },
135    WorkstreamRemoved {
136        repo: String,
137        name: String,
138    },
139    Workstreams {
140        workstreams: Vec<WorkstreamInfo>,
141    },
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
145pub struct SessionInfo {
146    pub id: Uuid,
147    pub cols: u16,
148    pub rows: u16,
149    pub created_at: DateTime<Utc>,
150    pub client_count: usize,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
154pub struct AgentEntry {
155    pub vex_session_id: Uuid,
156    pub claude_session_id: String,
157    pub claude_pid: u32,
158    pub cwd: PathBuf,
159    pub detected_at: DateTime<Utc>,
160    pub needs_intervention: bool,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
164pub struct RepoEntry {
165    pub name: String,
166    pub path: PathBuf,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
170pub struct WorkstreamInfo {
171    pub repo: String,
172    pub name: String,
173    pub worktree_path: PathBuf,
174    pub branch: String,
175    pub created_at: DateTime<Utc>,
176}
177
178#[derive(Debug)]
179pub enum Frame {
180    Control(Vec<u8>),
181    Data(Vec<u8>),
182}
183
184pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
185    let len = (1 + payload.len()) as u32;
186    w.write_all(&len.to_be_bytes()).await?;
187    w.write_u8(TAG_CONTROL).await?;
188    w.write_all(payload).await?;
189    w.flush().await?;
190    Ok(())
191}
192
193pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
194    let len = (1 + payload.len()) as u32;
195    w.write_all(&len.to_be_bytes()).await?;
196    w.write_u8(TAG_DATA).await?;
197    w.write_all(payload).await?;
198    w.flush().await?;
199    Ok(())
200}
201
202pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
203    let mut len_buf = [0u8; 4];
204    match r.read_exact(&mut len_buf).await {
205        Ok(_) => {}
206        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
207        Err(e) => return Err(e.into()),
208    }
209    let len = u32::from_be_bytes(len_buf) as usize;
210    if len == 0 {
211        bail!("invalid frame: zero length");
212    }
213    if len > MAX_FRAME_SIZE {
214        bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
215    }
216    let tag = {
217        let mut tag_buf = [0u8; 1];
218        r.read_exact(&mut tag_buf).await?;
219        tag_buf[0]
220    };
221    let payload_len = len - 1;
222    let mut payload = vec![0u8; payload_len];
223    if payload_len > 0 {
224        r.read_exact(&mut payload).await?;
225    }
226    match tag {
227        TAG_CONTROL => Ok(Some(Frame::Control(payload))),
228        TAG_DATA => Ok(Some(Frame::Data(payload))),
229        other => bail!("unknown frame tag: 0x{:02x}", other),
230    }
231}
232
233/// Convenience: serialize a ClientMessage and write as a control frame.
234pub async fn send_client_message<W: AsyncWrite + Unpin>(
235    w: &mut W,
236    msg: &ClientMessage,
237) -> Result<()> {
238    let json = serde_json::to_vec(msg)?;
239    write_control(w, &json).await
240}
241
242/// Convenience: serialize a ServerMessage and write as a control frame.
243pub async fn send_server_message<W: AsyncWrite + Unpin>(
244    w: &mut W,
245    msg: &ServerMessage,
246) -> Result<()> {
247    let json = serde_json::to_vec(msg)?;
248    write_control(w, &json).await
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn serde_round_trip_client() {
257        let msgs = vec![
258            ClientMessage::CreateSession {
259                shell: Some("bash".into()),
260                repo: None,
261            },
262            ClientMessage::ListSessions,
263            ClientMessage::AttachSession {
264                id: Uuid::nil(),
265                cols: 120,
266                rows: 40,
267            },
268            ClientMessage::DetachSession,
269            ClientMessage::ResizeSession {
270                id: Uuid::nil(),
271                cols: 80,
272                rows: 24,
273            },
274            ClientMessage::KillSession { id: Uuid::nil() },
275            ClientMessage::AgentList,
276            ClientMessage::AgentNotifications,
277            ClientMessage::AgentWatch {
278                session_id: Uuid::nil(),
279            },
280            ClientMessage::AgentPrompt {
281                session_id: Uuid::nil(),
282                text: "hello".into(),
283            },
284            ClientMessage::AgentSpawn {
285                repo: "vex".into(),
286                workstream: None,
287            },
288            ClientMessage::AgentSpawn {
289                repo: "vex".into(),
290                workstream: Some("feature-x".into()),
291            },
292            ClientMessage::WorkstreamCreate {
293                repo: "vex".into(),
294                name: "feature-x".into(),
295            },
296            ClientMessage::WorkstreamList { repo: None },
297            ClientMessage::WorkstreamList {
298                repo: Some("vex".into()),
299            },
300            ClientMessage::WorkstreamRemove {
301                repo: "vex".into(),
302                name: "feature-x".into(),
303            },
304            ClientMessage::RepoAdd {
305                name: "vex".into(),
306                path: PathBuf::from("/tmp/vex"),
307            },
308            ClientMessage::RepoRemove { name: "vex".into() },
309            ClientMessage::RepoList,
310            ClientMessage::RepoIntrospectPath {
311                path: PathBuf::from("/tmp"),
312            },
313        ];
314        for msg in msgs {
315            let json = serde_json::to_string(&msg).unwrap();
316            let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
317            assert_eq!(msg, decoded);
318        }
319    }
320
321    #[test]
322    fn serde_round_trip_server() {
323        let msgs = vec![
324            ServerMessage::SessionCreated { id: Uuid::nil() },
325            ServerMessage::Sessions {
326                sessions: vec![SessionInfo {
327                    id: Uuid::nil(),
328                    cols: 80,
329                    rows: 24,
330                    created_at: Utc::now(),
331                    client_count: 2,
332                }],
333            },
334            ServerMessage::Attached { id: Uuid::nil() },
335            ServerMessage::Detached,
336            ServerMessage::SessionEnded {
337                id: Uuid::nil(),
338                exit_code: Some(0),
339            },
340            ServerMessage::ClientJoined {
341                session_id: Uuid::nil(),
342                client_id: Uuid::nil(),
343            },
344            ServerMessage::ClientLeft {
345                session_id: Uuid::nil(),
346                client_id: Uuid::nil(),
347            },
348            ServerMessage::Error {
349                message: "fail".into(),
350            },
351            ServerMessage::AgentListResponse {
352                agents: vec![AgentEntry {
353                    vex_session_id: Uuid::nil(),
354                    claude_session_id: "abc123".into(),
355                    claude_pid: 1234,
356                    cwd: PathBuf::from("/tmp"),
357                    detected_at: Utc::now(),
358                    needs_intervention: true,
359                }],
360            },
361            ServerMessage::AgentPromptSent {
362                session_id: Uuid::nil(),
363            },
364            ServerMessage::AgentConversationLine {
365                session_id: Uuid::nil(),
366                line: "test line".into(),
367            },
368            ServerMessage::AgentWatchEnd {
369                session_id: Uuid::nil(),
370            },
371            ServerMessage::RepoAdded {
372                name: "vex".into(),
373                path: PathBuf::from("/tmp/vex"),
374            },
375            ServerMessage::RepoRemoved { name: "vex".into() },
376            ServerMessage::Repos {
377                repos: vec![RepoEntry {
378                    name: "vex".into(),
379                    path: PathBuf::from("/tmp/vex"),
380                }],
381            },
382            ServerMessage::RepoIntrospected {
383                suggested_name: "vex".into(),
384                path: PathBuf::from("/tmp/vex"),
385                git_remote: Some("git@github.com:user/vex.git".into()),
386                git_branch: Some("main".into()),
387                children: vec!["src".into(), "tests".into()],
388            },
389            ServerMessage::WorkstreamCreated {
390                repo: "vex".into(),
391                name: "feature-x".into(),
392                worktree_path: PathBuf::from("/tmp/workstreams/vex/feature-x"),
393            },
394            ServerMessage::WorkstreamRemoved {
395                repo: "vex".into(),
396                name: "feature-x".into(),
397            },
398            ServerMessage::Workstreams {
399                workstreams: vec![WorkstreamInfo {
400                    repo: "vex".into(),
401                    name: "feature-x".into(),
402                    worktree_path: PathBuf::from("/tmp/workstreams/vex/feature-x"),
403                    branch: "feature-x".into(),
404                    created_at: Utc::now(),
405                }],
406            },
407        ];
408        for msg in msgs {
409            let json = serde_json::to_string(&msg).unwrap();
410            let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
411            assert_eq!(msg, decoded);
412        }
413    }
414
415    #[tokio::test]
416    async fn frame_round_trip_control() {
417        let (mut client, mut server) = tokio::io::duplex(1024);
418        let payload = b"hello control";
419        write_control(&mut client, payload).await.unwrap();
420        drop(client);
421        let frame = read_frame(&mut server).await.unwrap().unwrap();
422        match frame {
423            Frame::Control(data) => assert_eq!(data, payload),
424            Frame::Data(_) => panic!("expected control frame"),
425        }
426    }
427
428    #[tokio::test]
429    async fn frame_round_trip_data() {
430        let (mut client, mut server) = tokio::io::duplex(1024);
431        let payload = b"hello data";
432        write_data(&mut client, payload).await.unwrap();
433        drop(client);
434        let frame = read_frame(&mut server).await.unwrap().unwrap();
435        match frame {
436            Frame::Data(data) => assert_eq!(data, payload),
437            Frame::Control(_) => panic!("expected data frame"),
438        }
439    }
440
441    #[tokio::test]
442    async fn frame_eof_returns_none() {
443        let (client, mut server) = tokio::io::duplex(1024);
444        drop(client);
445        let frame = read_frame(&mut server).await.unwrap();
446        assert!(frame.is_none());
447    }
448
449    #[tokio::test]
450    async fn frame_bad_tag() {
451        let (mut client, mut server) = tokio::io::duplex(1024);
452        // Write a frame with tag 0xFF
453        let len: u32 = 2; // 1 byte tag + 1 byte payload
454        client.write_all(&len.to_be_bytes()).await.unwrap();
455        client.write_u8(0xFF).await.unwrap();
456        client.write_u8(0x00).await.unwrap();
457        drop(client);
458        let result = read_frame(&mut server).await;
459        assert!(result.is_err());
460        assert!(
461            result
462                .unwrap_err()
463                .to_string()
464                .contains("unknown frame tag")
465        );
466    }
467
468    #[tokio::test]
469    async fn frame_too_large() {
470        let (mut client, mut server) = tokio::io::duplex(1024);
471        // Write a frame header claiming 2 MiB payload
472        let len: u32 = 2 * 1024 * 1024;
473        client.write_all(&len.to_be_bytes()).await.unwrap();
474        drop(client);
475        let result = read_frame(&mut server).await;
476        assert!(result.is_err());
477        assert!(result.unwrap_err().to_string().contains("frame too large"));
478    }
479
480    #[tokio::test]
481    async fn send_client_message_round_trip() {
482        let (mut client, mut server) = tokio::io::duplex(4096);
483        let msg = ClientMessage::CreateSession {
484            shell: Some("zsh".into()),
485            repo: None,
486        };
487        send_client_message(&mut client, &msg).await.unwrap();
488        drop(client);
489        let frame = read_frame(&mut server).await.unwrap().unwrap();
490        match frame {
491            Frame::Control(data) => {
492                let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
493                assert_eq!(decoded, msg);
494            }
495            Frame::Data(_) => panic!("expected control frame"),
496        }
497    }
498}