Skip to main content

vex_cli/
proto.rs

1use std::path::PathBuf;
2
3use anyhow::{Result, bail};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use uuid::Uuid;
8
9const TAG_CONTROL: u8 = 0x01;
10const TAG_DATA: u8 = 0x02;
11const MAX_FRAME_SIZE: usize = 1_048_576; // 1 MiB
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(tag = "type")]
15pub enum ClientMessage {
16    CreateSession {
17        shell: Option<String>,
18        repo: Option<String>,
19    },
20    ListSessions,
21    AttachSession {
22        id: Uuid,
23        cols: u16,
24        rows: u16,
25    },
26    DetachSession,
27    ResizeSession {
28        id: Uuid,
29        cols: u16,
30        rows: u16,
31    },
32    KillSession {
33        id: Uuid,
34    },
35    AgentList,
36    AgentNotifications,
37    AgentWatch {
38        session_id: Uuid,
39    },
40    AgentPrompt {
41        session_id: Uuid,
42        text: String,
43    },
44    AgentSpawn {
45        repo: String,
46        workstream: Option<String>,
47    },
48    WorkstreamCreate {
49        repo: String,
50        name: String,
51    },
52    WorkstreamList {
53        repo: Option<String>,
54    },
55    WorkstreamRemove {
56        repo: String,
57        name: String,
58    },
59    RepoAdd {
60        name: String,
61        path: PathBuf,
62    },
63    RepoRemove {
64        name: String,
65    },
66    RepoList,
67    RepoIntrospectPath {
68        path: PathBuf,
69    },
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
73#[serde(tag = "type")]
74pub enum ServerMessage {
75    SessionCreated {
76        id: Uuid,
77    },
78    Sessions {
79        sessions: Vec<SessionInfo>,
80    },
81    Attached {
82        id: Uuid,
83    },
84    Detached,
85    SessionEnded {
86        id: Uuid,
87        exit_code: Option<i32>,
88    },
89    ClientJoined {
90        session_id: Uuid,
91        client_id: Uuid,
92    },
93    ClientLeft {
94        session_id: Uuid,
95        client_id: Uuid,
96    },
97    Error {
98        message: String,
99    },
100    AgentListResponse {
101        agents: Vec<AgentEntry>,
102    },
103    AgentPromptSent {
104        session_id: Uuid,
105    },
106    AgentConversationLine {
107        session_id: Uuid,
108        line: String,
109    },
110    AgentWatchEnd {
111        session_id: Uuid,
112    },
113    RepoAdded {
114        name: String,
115        path: PathBuf,
116    },
117    RepoRemoved {
118        name: String,
119    },
120    Repos {
121        repos: Vec<RepoEntry>,
122    },
123    RepoIntrospected {
124        suggested_name: String,
125        path: PathBuf,
126        git_remote: Option<String>,
127        git_branch: Option<String>,
128    },
129    WorkstreamCreated {
130        repo: String,
131        name: String,
132        worktree_path: PathBuf,
133    },
134    WorkstreamRemoved {
135        repo: String,
136        name: String,
137    },
138    Workstreams {
139        workstreams: Vec<WorkstreamInfo>,
140    },
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
144pub struct SessionInfo {
145    pub id: Uuid,
146    pub cols: u16,
147    pub rows: u16,
148    pub created_at: DateTime<Utc>,
149    pub client_count: usize,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
153pub struct AgentEntry {
154    pub vex_session_id: Uuid,
155    pub claude_session_id: String,
156    pub claude_pid: u32,
157    pub cwd: PathBuf,
158    pub detected_at: DateTime<Utc>,
159    pub needs_intervention: bool,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
163pub struct RepoEntry {
164    pub name: String,
165    pub path: PathBuf,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
169pub struct WorkstreamInfo {
170    pub repo: String,
171    pub name: String,
172    pub worktree_path: PathBuf,
173    pub branch: String,
174    pub created_at: DateTime<Utc>,
175}
176
177#[derive(Debug)]
178pub enum Frame {
179    Control(Vec<u8>),
180    Data(Vec<u8>),
181}
182
183pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
184    let len = (1 + payload.len()) as u32;
185    w.write_all(&len.to_be_bytes()).await?;
186    w.write_u8(TAG_CONTROL).await?;
187    w.write_all(payload).await?;
188    w.flush().await?;
189    Ok(())
190}
191
192pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
193    let len = (1 + payload.len()) as u32;
194    w.write_all(&len.to_be_bytes()).await?;
195    w.write_u8(TAG_DATA).await?;
196    w.write_all(payload).await?;
197    w.flush().await?;
198    Ok(())
199}
200
201pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
202    let mut len_buf = [0u8; 4];
203    match r.read_exact(&mut len_buf).await {
204        Ok(_) => {}
205        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
206        Err(e) => return Err(e.into()),
207    }
208    let len = u32::from_be_bytes(len_buf) as usize;
209    if len == 0 {
210        bail!("invalid frame: zero length");
211    }
212    if len > MAX_FRAME_SIZE {
213        bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
214    }
215    let tag = {
216        let mut tag_buf = [0u8; 1];
217        r.read_exact(&mut tag_buf).await?;
218        tag_buf[0]
219    };
220    let payload_len = len - 1;
221    let mut payload = vec![0u8; payload_len];
222    if payload_len > 0 {
223        r.read_exact(&mut payload).await?;
224    }
225    match tag {
226        TAG_CONTROL => Ok(Some(Frame::Control(payload))),
227        TAG_DATA => Ok(Some(Frame::Data(payload))),
228        other => bail!("unknown frame tag: 0x{:02x}", other),
229    }
230}
231
232/// Convenience: serialize a ClientMessage and write as a control frame.
233pub async fn send_client_message<W: AsyncWrite + Unpin>(
234    w: &mut W,
235    msg: &ClientMessage,
236) -> Result<()> {
237    let json = serde_json::to_vec(msg)?;
238    write_control(w, &json).await
239}
240
241/// Convenience: serialize a ServerMessage and write as a control frame.
242pub async fn send_server_message<W: AsyncWrite + Unpin>(
243    w: &mut W,
244    msg: &ServerMessage,
245) -> Result<()> {
246    let json = serde_json::to_vec(msg)?;
247    write_control(w, &json).await
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn serde_round_trip_client() {
256        let msgs = vec![
257            ClientMessage::CreateSession {
258                shell: Some("bash".into()),
259                repo: None,
260            },
261            ClientMessage::ListSessions,
262            ClientMessage::AttachSession {
263                id: Uuid::nil(),
264                cols: 120,
265                rows: 40,
266            },
267            ClientMessage::DetachSession,
268            ClientMessage::ResizeSession {
269                id: Uuid::nil(),
270                cols: 80,
271                rows: 24,
272            },
273            ClientMessage::KillSession { id: Uuid::nil() },
274            ClientMessage::AgentList,
275            ClientMessage::AgentNotifications,
276            ClientMessage::AgentWatch {
277                session_id: Uuid::nil(),
278            },
279            ClientMessage::AgentPrompt {
280                session_id: Uuid::nil(),
281                text: "hello".into(),
282            },
283            ClientMessage::AgentSpawn {
284                repo: "vex".into(),
285                workstream: None,
286            },
287            ClientMessage::AgentSpawn {
288                repo: "vex".into(),
289                workstream: Some("feature-x".into()),
290            },
291            ClientMessage::WorkstreamCreate {
292                repo: "vex".into(),
293                name: "feature-x".into(),
294            },
295            ClientMessage::WorkstreamList { repo: None },
296            ClientMessage::WorkstreamList {
297                repo: Some("vex".into()),
298            },
299            ClientMessage::WorkstreamRemove {
300                repo: "vex".into(),
301                name: "feature-x".into(),
302            },
303            ClientMessage::RepoAdd {
304                name: "vex".into(),
305                path: PathBuf::from("/tmp/vex"),
306            },
307            ClientMessage::RepoRemove { name: "vex".into() },
308            ClientMessage::RepoList,
309            ClientMessage::RepoIntrospectPath {
310                path: PathBuf::from("/tmp"),
311            },
312        ];
313        for msg in msgs {
314            let json = serde_json::to_string(&msg).unwrap();
315            let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
316            assert_eq!(msg, decoded);
317        }
318    }
319
320    #[test]
321    fn serde_round_trip_server() {
322        let msgs = vec![
323            ServerMessage::SessionCreated { id: Uuid::nil() },
324            ServerMessage::Sessions {
325                sessions: vec![SessionInfo {
326                    id: Uuid::nil(),
327                    cols: 80,
328                    rows: 24,
329                    created_at: Utc::now(),
330                    client_count: 2,
331                }],
332            },
333            ServerMessage::Attached { id: Uuid::nil() },
334            ServerMessage::Detached,
335            ServerMessage::SessionEnded {
336                id: Uuid::nil(),
337                exit_code: Some(0),
338            },
339            ServerMessage::ClientJoined {
340                session_id: Uuid::nil(),
341                client_id: Uuid::nil(),
342            },
343            ServerMessage::ClientLeft {
344                session_id: Uuid::nil(),
345                client_id: Uuid::nil(),
346            },
347            ServerMessage::Error {
348                message: "fail".into(),
349            },
350            ServerMessage::AgentListResponse {
351                agents: vec![AgentEntry {
352                    vex_session_id: Uuid::nil(),
353                    claude_session_id: "abc123".into(),
354                    claude_pid: 1234,
355                    cwd: PathBuf::from("/tmp"),
356                    detected_at: Utc::now(),
357                    needs_intervention: true,
358                }],
359            },
360            ServerMessage::AgentPromptSent {
361                session_id: Uuid::nil(),
362            },
363            ServerMessage::AgentConversationLine {
364                session_id: Uuid::nil(),
365                line: "test line".into(),
366            },
367            ServerMessage::AgentWatchEnd {
368                session_id: Uuid::nil(),
369            },
370            ServerMessage::RepoAdded {
371                name: "vex".into(),
372                path: PathBuf::from("/tmp/vex"),
373            },
374            ServerMessage::RepoRemoved { name: "vex".into() },
375            ServerMessage::Repos {
376                repos: vec![RepoEntry {
377                    name: "vex".into(),
378                    path: PathBuf::from("/tmp/vex"),
379                }],
380            },
381            ServerMessage::RepoIntrospected {
382                suggested_name: "vex".into(),
383                path: PathBuf::from("/tmp/vex"),
384                git_remote: Some("git@github.com:user/vex.git".into()),
385                git_branch: Some("main".into()),
386            },
387            ServerMessage::WorkstreamCreated {
388                repo: "vex".into(),
389                name: "feature-x".into(),
390                worktree_path: PathBuf::from("/tmp/workstreams/vex/feature-x"),
391            },
392            ServerMessage::WorkstreamRemoved {
393                repo: "vex".into(),
394                name: "feature-x".into(),
395            },
396            ServerMessage::Workstreams {
397                workstreams: vec![WorkstreamInfo {
398                    repo: "vex".into(),
399                    name: "feature-x".into(),
400                    worktree_path: PathBuf::from("/tmp/workstreams/vex/feature-x"),
401                    branch: "feature-x".into(),
402                    created_at: Utc::now(),
403                }],
404            },
405        ];
406        for msg in msgs {
407            let json = serde_json::to_string(&msg).unwrap();
408            let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
409            assert_eq!(msg, decoded);
410        }
411    }
412
413    #[tokio::test]
414    async fn frame_round_trip_control() {
415        let (mut client, mut server) = tokio::io::duplex(1024);
416        let payload = b"hello control";
417        write_control(&mut client, payload).await.unwrap();
418        drop(client);
419        let frame = read_frame(&mut server).await.unwrap().unwrap();
420        match frame {
421            Frame::Control(data) => assert_eq!(data, payload),
422            Frame::Data(_) => panic!("expected control frame"),
423        }
424    }
425
426    #[tokio::test]
427    async fn frame_round_trip_data() {
428        let (mut client, mut server) = tokio::io::duplex(1024);
429        let payload = b"hello data";
430        write_data(&mut client, payload).await.unwrap();
431        drop(client);
432        let frame = read_frame(&mut server).await.unwrap().unwrap();
433        match frame {
434            Frame::Data(data) => assert_eq!(data, payload),
435            Frame::Control(_) => panic!("expected data frame"),
436        }
437    }
438
439    #[tokio::test]
440    async fn frame_eof_returns_none() {
441        let (client, mut server) = tokio::io::duplex(1024);
442        drop(client);
443        let frame = read_frame(&mut server).await.unwrap();
444        assert!(frame.is_none());
445    }
446
447    #[tokio::test]
448    async fn frame_bad_tag() {
449        let (mut client, mut server) = tokio::io::duplex(1024);
450        // Write a frame with tag 0xFF
451        let len: u32 = 2; // 1 byte tag + 1 byte payload
452        client.write_all(&len.to_be_bytes()).await.unwrap();
453        client.write_u8(0xFF).await.unwrap();
454        client.write_u8(0x00).await.unwrap();
455        drop(client);
456        let result = read_frame(&mut server).await;
457        assert!(result.is_err());
458        assert!(
459            result
460                .unwrap_err()
461                .to_string()
462                .contains("unknown frame tag")
463        );
464    }
465
466    #[tokio::test]
467    async fn frame_too_large() {
468        let (mut client, mut server) = tokio::io::duplex(1024);
469        // Write a frame header claiming 2 MiB payload
470        let len: u32 = 2 * 1024 * 1024;
471        client.write_all(&len.to_be_bytes()).await.unwrap();
472        drop(client);
473        let result = read_frame(&mut server).await;
474        assert!(result.is_err());
475        assert!(result.unwrap_err().to_string().contains("frame too large"));
476    }
477
478    #[tokio::test]
479    async fn send_client_message_round_trip() {
480        let (mut client, mut server) = tokio::io::duplex(4096);
481        let msg = ClientMessage::CreateSession {
482            shell: Some("zsh".into()),
483            repo: None,
484        };
485        send_client_message(&mut client, &msg).await.unwrap();
486        drop(client);
487        let frame = read_frame(&mut server).await.unwrap().unwrap();
488        match frame {
489            Frame::Control(data) => {
490                let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
491                assert_eq!(decoded, msg);
492            }
493            Frame::Data(_) => panic!("expected control frame"),
494        }
495    }
496}