Skip to main content

vex_cli/
proto.rs

1use std::path::PathBuf;
2
3use anyhow::{Result, bail};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use uuid::Uuid;
8
9const TAG_CONTROL: u8 = 0x01;
10const TAG_DATA: u8 = 0x02;
11const MAX_FRAME_SIZE: usize = 1_048_576; // 1 MiB
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(tag = "type")]
15pub enum ClientMessage {
16    CreateSession {
17        shell: Option<String>,
18        repo: Option<String>,
19    },
20    ListSessions,
21    AttachSession {
22        id: Uuid,
23        cols: u16,
24        rows: u16,
25    },
26    DetachSession,
27    ResizeSession {
28        id: Uuid,
29        cols: u16,
30        rows: u16,
31    },
32    KillSession {
33        id: Uuid,
34    },
35    AgentList,
36    AgentNotifications,
37    AgentWatch {
38        session_id: Uuid,
39    },
40    AgentPrompt {
41        session_id: Uuid,
42        text: String,
43    },
44    AgentSpawn {
45        repo: String,
46    },
47    RepoAdd {
48        name: String,
49        path: PathBuf,
50    },
51    RepoRemove {
52        name: String,
53    },
54    RepoList,
55    RepoIntrospectPath {
56        path: PathBuf,
57    },
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
61#[serde(tag = "type")]
62pub enum ServerMessage {
63    SessionCreated {
64        id: Uuid,
65    },
66    Sessions {
67        sessions: Vec<SessionInfo>,
68    },
69    Attached {
70        id: Uuid,
71    },
72    Detached,
73    SessionEnded {
74        id: Uuid,
75        exit_code: Option<i32>,
76    },
77    ClientJoined {
78        session_id: Uuid,
79        client_id: Uuid,
80    },
81    ClientLeft {
82        session_id: Uuid,
83        client_id: Uuid,
84    },
85    Error {
86        message: String,
87    },
88    AgentListResponse {
89        agents: Vec<AgentEntry>,
90    },
91    AgentPromptSent {
92        session_id: Uuid,
93    },
94    AgentConversationLine {
95        session_id: Uuid,
96        line: String,
97    },
98    AgentWatchEnd {
99        session_id: Uuid,
100    },
101    RepoAdded {
102        name: String,
103        path: PathBuf,
104    },
105    RepoRemoved {
106        name: String,
107    },
108    Repos {
109        repos: Vec<RepoEntry>,
110    },
111    RepoIntrospected {
112        suggested_name: String,
113        path: PathBuf,
114        git_remote: Option<String>,
115        git_branch: Option<String>,
116    },
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
120pub struct SessionInfo {
121    pub id: Uuid,
122    pub cols: u16,
123    pub rows: u16,
124    pub created_at: DateTime<Utc>,
125    pub client_count: usize,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
129pub struct AgentEntry {
130    pub vex_session_id: Uuid,
131    pub claude_session_id: String,
132    pub claude_pid: u32,
133    pub cwd: PathBuf,
134    pub detected_at: DateTime<Utc>,
135    pub needs_intervention: bool,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
139pub struct RepoEntry {
140    pub name: String,
141    pub path: PathBuf,
142}
143
144#[derive(Debug)]
145pub enum Frame {
146    Control(Vec<u8>),
147    Data(Vec<u8>),
148}
149
150pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
151    let len = (1 + payload.len()) as u32;
152    w.write_all(&len.to_be_bytes()).await?;
153    w.write_u8(TAG_CONTROL).await?;
154    w.write_all(payload).await?;
155    w.flush().await?;
156    Ok(())
157}
158
159pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
160    let len = (1 + payload.len()) as u32;
161    w.write_all(&len.to_be_bytes()).await?;
162    w.write_u8(TAG_DATA).await?;
163    w.write_all(payload).await?;
164    w.flush().await?;
165    Ok(())
166}
167
168pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
169    let mut len_buf = [0u8; 4];
170    match r.read_exact(&mut len_buf).await {
171        Ok(_) => {}
172        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
173        Err(e) => return Err(e.into()),
174    }
175    let len = u32::from_be_bytes(len_buf) as usize;
176    if len == 0 {
177        bail!("invalid frame: zero length");
178    }
179    if len > MAX_FRAME_SIZE {
180        bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
181    }
182    let tag = {
183        let mut tag_buf = [0u8; 1];
184        r.read_exact(&mut tag_buf).await?;
185        tag_buf[0]
186    };
187    let payload_len = len - 1;
188    let mut payload = vec![0u8; payload_len];
189    if payload_len > 0 {
190        r.read_exact(&mut payload).await?;
191    }
192    match tag {
193        TAG_CONTROL => Ok(Some(Frame::Control(payload))),
194        TAG_DATA => Ok(Some(Frame::Data(payload))),
195        other => bail!("unknown frame tag: 0x{:02x}", other),
196    }
197}
198
199/// Convenience: serialize a ClientMessage and write as a control frame.
200pub async fn send_client_message<W: AsyncWrite + Unpin>(
201    w: &mut W,
202    msg: &ClientMessage,
203) -> Result<()> {
204    let json = serde_json::to_vec(msg)?;
205    write_control(w, &json).await
206}
207
208/// Convenience: serialize a ServerMessage and write as a control frame.
209pub async fn send_server_message<W: AsyncWrite + Unpin>(
210    w: &mut W,
211    msg: &ServerMessage,
212) -> Result<()> {
213    let json = serde_json::to_vec(msg)?;
214    write_control(w, &json).await
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn serde_round_trip_client() {
223        let msgs = vec![
224            ClientMessage::CreateSession {
225                shell: Some("bash".into()),
226                repo: None,
227            },
228            ClientMessage::ListSessions,
229            ClientMessage::AttachSession {
230                id: Uuid::nil(),
231                cols: 120,
232                rows: 40,
233            },
234            ClientMessage::DetachSession,
235            ClientMessage::ResizeSession {
236                id: Uuid::nil(),
237                cols: 80,
238                rows: 24,
239            },
240            ClientMessage::KillSession { id: Uuid::nil() },
241            ClientMessage::AgentList,
242            ClientMessage::AgentNotifications,
243            ClientMessage::AgentWatch {
244                session_id: Uuid::nil(),
245            },
246            ClientMessage::AgentPrompt {
247                session_id: Uuid::nil(),
248                text: "hello".into(),
249            },
250            ClientMessage::AgentSpawn { repo: "vex".into() },
251            ClientMessage::RepoAdd {
252                name: "vex".into(),
253                path: PathBuf::from("/tmp/vex"),
254            },
255            ClientMessage::RepoRemove { name: "vex".into() },
256            ClientMessage::RepoList,
257            ClientMessage::RepoIntrospectPath {
258                path: PathBuf::from("/tmp"),
259            },
260        ];
261        for msg in msgs {
262            let json = serde_json::to_string(&msg).unwrap();
263            let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
264            assert_eq!(msg, decoded);
265        }
266    }
267
268    #[test]
269    fn serde_round_trip_server() {
270        let msgs = vec![
271            ServerMessage::SessionCreated { id: Uuid::nil() },
272            ServerMessage::Sessions {
273                sessions: vec![SessionInfo {
274                    id: Uuid::nil(),
275                    cols: 80,
276                    rows: 24,
277                    created_at: Utc::now(),
278                    client_count: 2,
279                }],
280            },
281            ServerMessage::Attached { id: Uuid::nil() },
282            ServerMessage::Detached,
283            ServerMessage::SessionEnded {
284                id: Uuid::nil(),
285                exit_code: Some(0),
286            },
287            ServerMessage::ClientJoined {
288                session_id: Uuid::nil(),
289                client_id: Uuid::nil(),
290            },
291            ServerMessage::ClientLeft {
292                session_id: Uuid::nil(),
293                client_id: Uuid::nil(),
294            },
295            ServerMessage::Error {
296                message: "fail".into(),
297            },
298            ServerMessage::AgentListResponse {
299                agents: vec![AgentEntry {
300                    vex_session_id: Uuid::nil(),
301                    claude_session_id: "abc123".into(),
302                    claude_pid: 1234,
303                    cwd: PathBuf::from("/tmp"),
304                    detected_at: Utc::now(),
305                    needs_intervention: true,
306                }],
307            },
308            ServerMessage::AgentPromptSent {
309                session_id: Uuid::nil(),
310            },
311            ServerMessage::AgentConversationLine {
312                session_id: Uuid::nil(),
313                line: "test line".into(),
314            },
315            ServerMessage::AgentWatchEnd {
316                session_id: Uuid::nil(),
317            },
318            ServerMessage::RepoAdded {
319                name: "vex".into(),
320                path: PathBuf::from("/tmp/vex"),
321            },
322            ServerMessage::RepoRemoved { name: "vex".into() },
323            ServerMessage::Repos {
324                repos: vec![RepoEntry {
325                    name: "vex".into(),
326                    path: PathBuf::from("/tmp/vex"),
327                }],
328            },
329            ServerMessage::RepoIntrospected {
330                suggested_name: "vex".into(),
331                path: PathBuf::from("/tmp/vex"),
332                git_remote: Some("git@github.com:user/vex.git".into()),
333                git_branch: Some("main".into()),
334            },
335        ];
336        for msg in msgs {
337            let json = serde_json::to_string(&msg).unwrap();
338            let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
339            assert_eq!(msg, decoded);
340        }
341    }
342
343    #[tokio::test]
344    async fn frame_round_trip_control() {
345        let (mut client, mut server) = tokio::io::duplex(1024);
346        let payload = b"hello control";
347        write_control(&mut client, payload).await.unwrap();
348        drop(client);
349        let frame = read_frame(&mut server).await.unwrap().unwrap();
350        match frame {
351            Frame::Control(data) => assert_eq!(data, payload),
352            Frame::Data(_) => panic!("expected control frame"),
353        }
354    }
355
356    #[tokio::test]
357    async fn frame_round_trip_data() {
358        let (mut client, mut server) = tokio::io::duplex(1024);
359        let payload = b"hello data";
360        write_data(&mut client, payload).await.unwrap();
361        drop(client);
362        let frame = read_frame(&mut server).await.unwrap().unwrap();
363        match frame {
364            Frame::Data(data) => assert_eq!(data, payload),
365            Frame::Control(_) => panic!("expected data frame"),
366        }
367    }
368
369    #[tokio::test]
370    async fn frame_eof_returns_none() {
371        let (client, mut server) = tokio::io::duplex(1024);
372        drop(client);
373        let frame = read_frame(&mut server).await.unwrap();
374        assert!(frame.is_none());
375    }
376
377    #[tokio::test]
378    async fn frame_bad_tag() {
379        let (mut client, mut server) = tokio::io::duplex(1024);
380        // Write a frame with tag 0xFF
381        let len: u32 = 2; // 1 byte tag + 1 byte payload
382        client.write_all(&len.to_be_bytes()).await.unwrap();
383        client.write_u8(0xFF).await.unwrap();
384        client.write_u8(0x00).await.unwrap();
385        drop(client);
386        let result = read_frame(&mut server).await;
387        assert!(result.is_err());
388        assert!(
389            result
390                .unwrap_err()
391                .to_string()
392                .contains("unknown frame tag")
393        );
394    }
395
396    #[tokio::test]
397    async fn frame_too_large() {
398        let (mut client, mut server) = tokio::io::duplex(1024);
399        // Write a frame header claiming 2 MiB payload
400        let len: u32 = 2 * 1024 * 1024;
401        client.write_all(&len.to_be_bytes()).await.unwrap();
402        drop(client);
403        let result = read_frame(&mut server).await;
404        assert!(result.is_err());
405        assert!(result.unwrap_err().to_string().contains("frame too large"));
406    }
407
408    #[tokio::test]
409    async fn send_client_message_round_trip() {
410        let (mut client, mut server) = tokio::io::duplex(4096);
411        let msg = ClientMessage::CreateSession {
412            shell: Some("zsh".into()),
413            repo: None,
414        };
415        send_client_message(&mut client, &msg).await.unwrap();
416        drop(client);
417        let frame = read_frame(&mut server).await.unwrap().unwrap();
418        match frame {
419            Frame::Control(data) => {
420                let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
421                assert_eq!(decoded, msg);
422            }
423            Frame::Data(_) => panic!("expected control frame"),
424        }
425    }
426}