Skip to main content

vex_cli/
proto.rs

1use std::path::PathBuf;
2
3use anyhow::{Result, bail};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use uuid::Uuid;
8
9const TAG_CONTROL: u8 = 0x01;
10const TAG_DATA: u8 = 0x02;
11const MAX_FRAME_SIZE: usize = 1_048_576; // 1 MiB
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(tag = "type")]
15pub enum ClientMessage {
16    CreateSession {
17        shell: Option<String>,
18        repo: Option<String>,
19    },
20    ListSessions,
21    AttachSession {
22        id: Uuid,
23        cols: u16,
24        rows: u16,
25    },
26    DetachSession,
27    ResizeSession {
28        id: Uuid,
29        cols: u16,
30        rows: u16,
31    },
32    KillSession {
33        id: Uuid,
34    },
35    AgentList,
36    AgentNotifications,
37    AgentWatch {
38        session_id: Uuid,
39    },
40    AgentPrompt {
41        session_id: Uuid,
42        text: String,
43    },
44    RepoAdd {
45        name: String,
46        path: PathBuf,
47    },
48    RepoRemove {
49        name: String,
50    },
51    RepoList,
52    RepoIntrospectPath {
53        path: PathBuf,
54    },
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
58#[serde(tag = "type")]
59pub enum ServerMessage {
60    SessionCreated {
61        id: Uuid,
62    },
63    Sessions {
64        sessions: Vec<SessionInfo>,
65    },
66    Attached {
67        id: Uuid,
68    },
69    Detached,
70    SessionEnded {
71        id: Uuid,
72        exit_code: Option<i32>,
73    },
74    ClientJoined {
75        session_id: Uuid,
76        client_id: Uuid,
77    },
78    ClientLeft {
79        session_id: Uuid,
80        client_id: Uuid,
81    },
82    Error {
83        message: String,
84    },
85    AgentListResponse {
86        agents: Vec<AgentEntry>,
87    },
88    AgentPromptSent {
89        session_id: Uuid,
90    },
91    AgentConversationLine {
92        session_id: Uuid,
93        line: String,
94    },
95    AgentWatchEnd {
96        session_id: Uuid,
97    },
98    RepoAdded {
99        name: String,
100        path: PathBuf,
101    },
102    RepoRemoved {
103        name: String,
104    },
105    Repos {
106        repos: Vec<RepoEntry>,
107    },
108    RepoIntrospected {
109        suggested_name: String,
110        path: PathBuf,
111        git_remote: Option<String>,
112        git_branch: Option<String>,
113    },
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
117pub struct SessionInfo {
118    pub id: Uuid,
119    pub cols: u16,
120    pub rows: u16,
121    pub created_at: DateTime<Utc>,
122    pub client_count: usize,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
126pub struct AgentEntry {
127    pub vex_session_id: Uuid,
128    pub claude_session_id: String,
129    pub claude_pid: u32,
130    pub cwd: PathBuf,
131    pub detected_at: DateTime<Utc>,
132    pub needs_intervention: bool,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
136pub struct RepoEntry {
137    pub name: String,
138    pub path: PathBuf,
139}
140
141#[derive(Debug)]
142pub enum Frame {
143    Control(Vec<u8>),
144    Data(Vec<u8>),
145}
146
147pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
148    let len = (1 + payload.len()) as u32;
149    w.write_all(&len.to_be_bytes()).await?;
150    w.write_u8(TAG_CONTROL).await?;
151    w.write_all(payload).await?;
152    w.flush().await?;
153    Ok(())
154}
155
156pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
157    let len = (1 + payload.len()) as u32;
158    w.write_all(&len.to_be_bytes()).await?;
159    w.write_u8(TAG_DATA).await?;
160    w.write_all(payload).await?;
161    w.flush().await?;
162    Ok(())
163}
164
165pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
166    let mut len_buf = [0u8; 4];
167    match r.read_exact(&mut len_buf).await {
168        Ok(_) => {}
169        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
170        Err(e) => return Err(e.into()),
171    }
172    let len = u32::from_be_bytes(len_buf) as usize;
173    if len == 0 {
174        bail!("invalid frame: zero length");
175    }
176    if len > MAX_FRAME_SIZE {
177        bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
178    }
179    let tag = {
180        let mut tag_buf = [0u8; 1];
181        r.read_exact(&mut tag_buf).await?;
182        tag_buf[0]
183    };
184    let payload_len = len - 1;
185    let mut payload = vec![0u8; payload_len];
186    if payload_len > 0 {
187        r.read_exact(&mut payload).await?;
188    }
189    match tag {
190        TAG_CONTROL => Ok(Some(Frame::Control(payload))),
191        TAG_DATA => Ok(Some(Frame::Data(payload))),
192        other => bail!("unknown frame tag: 0x{:02x}", other),
193    }
194}
195
196/// Convenience: serialize a ClientMessage and write as a control frame.
197pub async fn send_client_message<W: AsyncWrite + Unpin>(
198    w: &mut W,
199    msg: &ClientMessage,
200) -> Result<()> {
201    let json = serde_json::to_vec(msg)?;
202    write_control(w, &json).await
203}
204
205/// Convenience: serialize a ServerMessage and write as a control frame.
206pub async fn send_server_message<W: AsyncWrite + Unpin>(
207    w: &mut W,
208    msg: &ServerMessage,
209) -> Result<()> {
210    let json = serde_json::to_vec(msg)?;
211    write_control(w, &json).await
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn serde_round_trip_client() {
220        let msgs = vec![
221            ClientMessage::CreateSession {
222                shell: Some("bash".into()),
223                repo: None,
224            },
225            ClientMessage::ListSessions,
226            ClientMessage::AttachSession {
227                id: Uuid::nil(),
228                cols: 120,
229                rows: 40,
230            },
231            ClientMessage::DetachSession,
232            ClientMessage::ResizeSession {
233                id: Uuid::nil(),
234                cols: 80,
235                rows: 24,
236            },
237            ClientMessage::KillSession { id: Uuid::nil() },
238            ClientMessage::AgentList,
239            ClientMessage::AgentNotifications,
240            ClientMessage::AgentWatch {
241                session_id: Uuid::nil(),
242            },
243            ClientMessage::AgentPrompt {
244                session_id: Uuid::nil(),
245                text: "hello".into(),
246            },
247            ClientMessage::RepoAdd {
248                name: "vex".into(),
249                path: PathBuf::from("/tmp/vex"),
250            },
251            ClientMessage::RepoRemove { name: "vex".into() },
252            ClientMessage::RepoList,
253            ClientMessage::RepoIntrospectPath {
254                path: PathBuf::from("/tmp"),
255            },
256        ];
257        for msg in msgs {
258            let json = serde_json::to_string(&msg).unwrap();
259            let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
260            assert_eq!(msg, decoded);
261        }
262    }
263
264    #[test]
265    fn serde_round_trip_server() {
266        let msgs = vec![
267            ServerMessage::SessionCreated { id: Uuid::nil() },
268            ServerMessage::Sessions {
269                sessions: vec![SessionInfo {
270                    id: Uuid::nil(),
271                    cols: 80,
272                    rows: 24,
273                    created_at: Utc::now(),
274                    client_count: 2,
275                }],
276            },
277            ServerMessage::Attached { id: Uuid::nil() },
278            ServerMessage::Detached,
279            ServerMessage::SessionEnded {
280                id: Uuid::nil(),
281                exit_code: Some(0),
282            },
283            ServerMessage::ClientJoined {
284                session_id: Uuid::nil(),
285                client_id: Uuid::nil(),
286            },
287            ServerMessage::ClientLeft {
288                session_id: Uuid::nil(),
289                client_id: Uuid::nil(),
290            },
291            ServerMessage::Error {
292                message: "fail".into(),
293            },
294            ServerMessage::AgentListResponse {
295                agents: vec![AgentEntry {
296                    vex_session_id: Uuid::nil(),
297                    claude_session_id: "abc123".into(),
298                    claude_pid: 1234,
299                    cwd: PathBuf::from("/tmp"),
300                    detected_at: Utc::now(),
301                    needs_intervention: true,
302                }],
303            },
304            ServerMessage::AgentPromptSent {
305                session_id: Uuid::nil(),
306            },
307            ServerMessage::AgentConversationLine {
308                session_id: Uuid::nil(),
309                line: "test line".into(),
310            },
311            ServerMessage::AgentWatchEnd {
312                session_id: Uuid::nil(),
313            },
314            ServerMessage::RepoAdded {
315                name: "vex".into(),
316                path: PathBuf::from("/tmp/vex"),
317            },
318            ServerMessage::RepoRemoved { name: "vex".into() },
319            ServerMessage::Repos {
320                repos: vec![RepoEntry {
321                    name: "vex".into(),
322                    path: PathBuf::from("/tmp/vex"),
323                }],
324            },
325            ServerMessage::RepoIntrospected {
326                suggested_name: "vex".into(),
327                path: PathBuf::from("/tmp/vex"),
328                git_remote: Some("git@github.com:user/vex.git".into()),
329                git_branch: Some("main".into()),
330            },
331        ];
332        for msg in msgs {
333            let json = serde_json::to_string(&msg).unwrap();
334            let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
335            assert_eq!(msg, decoded);
336        }
337    }
338
339    #[tokio::test]
340    async fn frame_round_trip_control() {
341        let (mut client, mut server) = tokio::io::duplex(1024);
342        let payload = b"hello control";
343        write_control(&mut client, payload).await.unwrap();
344        drop(client);
345        let frame = read_frame(&mut server).await.unwrap().unwrap();
346        match frame {
347            Frame::Control(data) => assert_eq!(data, payload),
348            Frame::Data(_) => panic!("expected control frame"),
349        }
350    }
351
352    #[tokio::test]
353    async fn frame_round_trip_data() {
354        let (mut client, mut server) = tokio::io::duplex(1024);
355        let payload = b"hello data";
356        write_data(&mut client, payload).await.unwrap();
357        drop(client);
358        let frame = read_frame(&mut server).await.unwrap().unwrap();
359        match frame {
360            Frame::Data(data) => assert_eq!(data, payload),
361            Frame::Control(_) => panic!("expected data frame"),
362        }
363    }
364
365    #[tokio::test]
366    async fn frame_eof_returns_none() {
367        let (client, mut server) = tokio::io::duplex(1024);
368        drop(client);
369        let frame = read_frame(&mut server).await.unwrap();
370        assert!(frame.is_none());
371    }
372
373    #[tokio::test]
374    async fn frame_bad_tag() {
375        let (mut client, mut server) = tokio::io::duplex(1024);
376        // Write a frame with tag 0xFF
377        let len: u32 = 2; // 1 byte tag + 1 byte payload
378        client.write_all(&len.to_be_bytes()).await.unwrap();
379        client.write_u8(0xFF).await.unwrap();
380        client.write_u8(0x00).await.unwrap();
381        drop(client);
382        let result = read_frame(&mut server).await;
383        assert!(result.is_err());
384        assert!(
385            result
386                .unwrap_err()
387                .to_string()
388                .contains("unknown frame tag")
389        );
390    }
391
392    #[tokio::test]
393    async fn frame_too_large() {
394        let (mut client, mut server) = tokio::io::duplex(1024);
395        // Write a frame header claiming 2 MiB payload
396        let len: u32 = 2 * 1024 * 1024;
397        client.write_all(&len.to_be_bytes()).await.unwrap();
398        drop(client);
399        let result = read_frame(&mut server).await;
400        assert!(result.is_err());
401        assert!(result.unwrap_err().to_string().contains("frame too large"));
402    }
403
404    #[tokio::test]
405    async fn send_client_message_round_trip() {
406        let (mut client, mut server) = tokio::io::duplex(4096);
407        let msg = ClientMessage::CreateSession {
408            shell: Some("zsh".into()),
409            repo: None,
410        };
411        send_client_message(&mut client, &msg).await.unwrap();
412        drop(client);
413        let frame = read_frame(&mut server).await.unwrap().unwrap();
414        match frame {
415            Frame::Control(data) => {
416                let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
417                assert_eq!(decoded, msg);
418            }
419            Frame::Data(_) => panic!("expected control frame"),
420        }
421    }
422}