Skip to main content

vex_cli/
proto.rs

1use std::path::PathBuf;
2
3use anyhow::{Result, bail};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use uuid::Uuid;
8
9const TAG_CONTROL: u8 = 0x01;
10const TAG_DATA: u8 = 0x02;
11const MAX_FRAME_SIZE: usize = 1_048_576; // 1 MiB
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(tag = "type")]
15pub enum ClientMessage {
16    CreateSession { shell: Option<String> },
17    ListSessions,
18    AttachSession { id: Uuid, cols: u16, rows: u16 },
19    DetachSession,
20    ResizeSession { id: Uuid, cols: u16, rows: u16 },
21    KillSession { id: Uuid },
22    AgentList,
23    AgentNotifications,
24    AgentWatch { session_id: Uuid },
25    AgentPrompt { session_id: Uuid, text: String },
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
29#[serde(tag = "type")]
30pub enum ServerMessage {
31    SessionCreated { id: Uuid },
32    Sessions { sessions: Vec<SessionInfo> },
33    Attached { id: Uuid },
34    Detached,
35    SessionEnded { id: Uuid, exit_code: Option<i32> },
36    ClientJoined { session_id: Uuid, client_id: Uuid },
37    ClientLeft { session_id: Uuid, client_id: Uuid },
38    Error { message: String },
39    AgentListResponse { agents: Vec<AgentEntry> },
40    AgentConversationLine { session_id: Uuid, line: String },
41    AgentWatchEnd { session_id: Uuid },
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45pub struct SessionInfo {
46    pub id: Uuid,
47    pub cols: u16,
48    pub rows: u16,
49    pub created_at: DateTime<Utc>,
50    pub client_count: usize,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
54pub struct AgentEntry {
55    pub vex_session_id: Uuid,
56    pub claude_session_id: String,
57    pub claude_pid: u32,
58    pub cwd: PathBuf,
59    pub detected_at: DateTime<Utc>,
60    pub needs_intervention: bool,
61}
62
63#[derive(Debug)]
64pub enum Frame {
65    Control(Vec<u8>),
66    Data(Vec<u8>),
67}
68
69pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
70    let len = (1 + payload.len()) as u32;
71    w.write_all(&len.to_be_bytes()).await?;
72    w.write_u8(TAG_CONTROL).await?;
73    w.write_all(payload).await?;
74    w.flush().await?;
75    Ok(())
76}
77
78pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
79    let len = (1 + payload.len()) as u32;
80    w.write_all(&len.to_be_bytes()).await?;
81    w.write_u8(TAG_DATA).await?;
82    w.write_all(payload).await?;
83    w.flush().await?;
84    Ok(())
85}
86
87pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
88    let mut len_buf = [0u8; 4];
89    match r.read_exact(&mut len_buf).await {
90        Ok(_) => {}
91        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
92        Err(e) => return Err(e.into()),
93    }
94    let len = u32::from_be_bytes(len_buf) as usize;
95    if len == 0 {
96        bail!("invalid frame: zero length");
97    }
98    if len > MAX_FRAME_SIZE {
99        bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
100    }
101    let tag = {
102        let mut tag_buf = [0u8; 1];
103        r.read_exact(&mut tag_buf).await?;
104        tag_buf[0]
105    };
106    let payload_len = len - 1;
107    let mut payload = vec![0u8; payload_len];
108    if payload_len > 0 {
109        r.read_exact(&mut payload).await?;
110    }
111    match tag {
112        TAG_CONTROL => Ok(Some(Frame::Control(payload))),
113        TAG_DATA => Ok(Some(Frame::Data(payload))),
114        other => bail!("unknown frame tag: 0x{:02x}", other),
115    }
116}
117
118/// Convenience: serialize a ClientMessage and write as a control frame.
119pub async fn send_client_message<W: AsyncWrite + Unpin>(
120    w: &mut W,
121    msg: &ClientMessage,
122) -> Result<()> {
123    let json = serde_json::to_vec(msg)?;
124    write_control(w, &json).await
125}
126
127/// Convenience: serialize a ServerMessage and write as a control frame.
128pub async fn send_server_message<W: AsyncWrite + Unpin>(
129    w: &mut W,
130    msg: &ServerMessage,
131) -> Result<()> {
132    let json = serde_json::to_vec(msg)?;
133    write_control(w, &json).await
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn serde_round_trip_client() {
142        let msgs = vec![
143            ClientMessage::CreateSession {
144                shell: Some("bash".into()),
145            },
146            ClientMessage::ListSessions,
147            ClientMessage::AttachSession {
148                id: Uuid::nil(),
149                cols: 120,
150                rows: 40,
151            },
152            ClientMessage::DetachSession,
153            ClientMessage::ResizeSession {
154                id: Uuid::nil(),
155                cols: 80,
156                rows: 24,
157            },
158            ClientMessage::KillSession { id: Uuid::nil() },
159            ClientMessage::AgentList,
160            ClientMessage::AgentNotifications,
161            ClientMessage::AgentWatch {
162                session_id: Uuid::nil(),
163            },
164            ClientMessage::AgentPrompt {
165                session_id: Uuid::nil(),
166                text: "hello".into(),
167            },
168        ];
169        for msg in msgs {
170            let json = serde_json::to_string(&msg).unwrap();
171            let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
172            assert_eq!(msg, decoded);
173        }
174    }
175
176    #[test]
177    fn serde_round_trip_server() {
178        let msgs = vec![
179            ServerMessage::SessionCreated { id: Uuid::nil() },
180            ServerMessage::Sessions {
181                sessions: vec![SessionInfo {
182                    id: Uuid::nil(),
183                    cols: 80,
184                    rows: 24,
185                    created_at: Utc::now(),
186                    client_count: 2,
187                }],
188            },
189            ServerMessage::Attached { id: Uuid::nil() },
190            ServerMessage::Detached,
191            ServerMessage::SessionEnded {
192                id: Uuid::nil(),
193                exit_code: Some(0),
194            },
195            ServerMessage::ClientJoined {
196                session_id: Uuid::nil(),
197                client_id: Uuid::nil(),
198            },
199            ServerMessage::ClientLeft {
200                session_id: Uuid::nil(),
201                client_id: Uuid::nil(),
202            },
203            ServerMessage::Error {
204                message: "fail".into(),
205            },
206            ServerMessage::AgentListResponse {
207                agents: vec![AgentEntry {
208                    vex_session_id: Uuid::nil(),
209                    claude_session_id: "abc123".into(),
210                    claude_pid: 1234,
211                    cwd: PathBuf::from("/tmp"),
212                    detected_at: Utc::now(),
213                    needs_intervention: true,
214                }],
215            },
216            ServerMessage::AgentConversationLine {
217                session_id: Uuid::nil(),
218                line: "test line".into(),
219            },
220            ServerMessage::AgentWatchEnd {
221                session_id: Uuid::nil(),
222            },
223        ];
224        for msg in msgs {
225            let json = serde_json::to_string(&msg).unwrap();
226            let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
227            assert_eq!(msg, decoded);
228        }
229    }
230
231    #[tokio::test]
232    async fn frame_round_trip_control() {
233        let (mut client, mut server) = tokio::io::duplex(1024);
234        let payload = b"hello control";
235        write_control(&mut client, payload).await.unwrap();
236        drop(client);
237        let frame = read_frame(&mut server).await.unwrap().unwrap();
238        match frame {
239            Frame::Control(data) => assert_eq!(data, payload),
240            Frame::Data(_) => panic!("expected control frame"),
241        }
242    }
243
244    #[tokio::test]
245    async fn frame_round_trip_data() {
246        let (mut client, mut server) = tokio::io::duplex(1024);
247        let payload = b"hello data";
248        write_data(&mut client, payload).await.unwrap();
249        drop(client);
250        let frame = read_frame(&mut server).await.unwrap().unwrap();
251        match frame {
252            Frame::Data(data) => assert_eq!(data, payload),
253            Frame::Control(_) => panic!("expected data frame"),
254        }
255    }
256
257    #[tokio::test]
258    async fn frame_eof_returns_none() {
259        let (client, mut server) = tokio::io::duplex(1024);
260        drop(client);
261        let frame = read_frame(&mut server).await.unwrap();
262        assert!(frame.is_none());
263    }
264
265    #[tokio::test]
266    async fn frame_bad_tag() {
267        let (mut client, mut server) = tokio::io::duplex(1024);
268        // Write a frame with tag 0xFF
269        let len: u32 = 2; // 1 byte tag + 1 byte payload
270        client.write_all(&len.to_be_bytes()).await.unwrap();
271        client.write_u8(0xFF).await.unwrap();
272        client.write_u8(0x00).await.unwrap();
273        drop(client);
274        let result = read_frame(&mut server).await;
275        assert!(result.is_err());
276        assert!(
277            result
278                .unwrap_err()
279                .to_string()
280                .contains("unknown frame tag")
281        );
282    }
283
284    #[tokio::test]
285    async fn frame_too_large() {
286        let (mut client, mut server) = tokio::io::duplex(1024);
287        // Write a frame header claiming 2 MiB payload
288        let len: u32 = 2 * 1024 * 1024;
289        client.write_all(&len.to_be_bytes()).await.unwrap();
290        drop(client);
291        let result = read_frame(&mut server).await;
292        assert!(result.is_err());
293        assert!(result.unwrap_err().to_string().contains("frame too large"));
294    }
295
296    #[tokio::test]
297    async fn send_client_message_round_trip() {
298        let (mut client, mut server) = tokio::io::duplex(4096);
299        let msg = ClientMessage::CreateSession {
300            shell: Some("zsh".into()),
301        };
302        send_client_message(&mut client, &msg).await.unwrap();
303        drop(client);
304        let frame = read_frame(&mut server).await.unwrap().unwrap();
305        match frame {
306            Frame::Control(data) => {
307                let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
308                assert_eq!(decoded, msg);
309            }
310            Frame::Data(_) => panic!("expected control frame"),
311        }
312    }
313}