Skip to main content

vex_cli/
proto.rs

1use anyhow::{Result, bail};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use uuid::Uuid;
6
7const TAG_CONTROL: u8 = 0x01;
8const TAG_DATA: u8 = 0x02;
9const MAX_FRAME_SIZE: usize = 1_048_576; // 1 MiB
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(tag = "type")]
13pub enum ClientMessage {
14    CreateSession {
15        shell: Option<String>,
16    },
17    ListSessions,
18    AttachSession {
19        id: Uuid,
20        cols: u16,
21        rows: u16,
22    },
23    DetachSession,
24    ResizeSession {
25        id: Uuid,
26        cols: u16,
27        rows: u16,
28    },
29    KillSession {
30        id: Uuid,
31    },
32    CreateAgent {
33        model: Option<String>,
34        permission_mode: Option<String>,
35        allowed_tools: Vec<String>,
36        max_turns: Option<u32>,
37        cwd: Option<String>,
38    },
39    AgentPrompt {
40        id: Uuid,
41        prompt: String,
42    },
43    AgentStatus {
44        id: Uuid,
45    },
46    ListAgents,
47    KillAgent {
48        id: Uuid,
49    },
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53#[serde(tag = "type")]
54pub enum ServerMessage {
55    SessionCreated {
56        id: Uuid,
57    },
58    Sessions {
59        sessions: Vec<SessionInfo>,
60    },
61    Attached {
62        id: Uuid,
63    },
64    Detached,
65    SessionEnded {
66        id: Uuid,
67        exit_code: Option<i32>,
68    },
69    ClientJoined {
70        session_id: Uuid,
71        client_id: Uuid,
72    },
73    ClientLeft {
74        session_id: Uuid,
75        client_id: Uuid,
76    },
77    Error {
78        message: String,
79    },
80    AgentCreated {
81        id: Uuid,
82    },
83    AgentOutput {
84        id: Uuid,
85        event: AgentEvent,
86    },
87    AgentPromptDone {
88        id: Uuid,
89        turn_count: u32,
90    },
91    AgentStatusResponse {
92        id: Uuid,
93        status: AgentState,
94        claude_session_id: Option<String>,
95        model: Option<String>,
96        turn_count: u32,
97    },
98    Agents {
99        agents: Vec<AgentInfo>,
100    },
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
104pub struct SessionInfo {
105    pub id: Uuid,
106    pub cols: u16,
107    pub rows: u16,
108    pub created_at: DateTime<Utc>,
109    pub client_count: usize,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
113pub enum AgentState {
114    Idle,
115    Processing,
116    Error(String),
117}
118
119impl std::fmt::Display for AgentState {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        match self {
122            AgentState::Idle => write!(f, "idle"),
123            AgentState::Processing => write!(f, "processing"),
124            AgentState::Error(msg) => write!(f, "error: {}", msg),
125        }
126    }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
130pub struct AgentInfo {
131    pub id: Uuid,
132    pub status: AgentState,
133    pub model: Option<String>,
134    pub turn_count: u32,
135    pub created_at: DateTime<Utc>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
139pub struct AgentEvent {
140    pub event_type: String,
141    pub raw_json: String,
142}
143
144#[derive(Debug)]
145pub enum Frame {
146    Control(Vec<u8>),
147    Data(Vec<u8>),
148}
149
150pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
151    let len = (1 + payload.len()) as u32;
152    w.write_all(&len.to_be_bytes()).await?;
153    w.write_u8(TAG_CONTROL).await?;
154    w.write_all(payload).await?;
155    w.flush().await?;
156    Ok(())
157}
158
159pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
160    let len = (1 + payload.len()) as u32;
161    w.write_all(&len.to_be_bytes()).await?;
162    w.write_u8(TAG_DATA).await?;
163    w.write_all(payload).await?;
164    w.flush().await?;
165    Ok(())
166}
167
168pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
169    let mut len_buf = [0u8; 4];
170    match r.read_exact(&mut len_buf).await {
171        Ok(_) => {}
172        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
173        Err(e) => return Err(e.into()),
174    }
175    let len = u32::from_be_bytes(len_buf) as usize;
176    if len == 0 {
177        bail!("invalid frame: zero length");
178    }
179    if len > MAX_FRAME_SIZE {
180        bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
181    }
182    let tag = {
183        let mut tag_buf = [0u8; 1];
184        r.read_exact(&mut tag_buf).await?;
185        tag_buf[0]
186    };
187    let payload_len = len - 1;
188    let mut payload = vec![0u8; payload_len];
189    if payload_len > 0 {
190        r.read_exact(&mut payload).await?;
191    }
192    match tag {
193        TAG_CONTROL => Ok(Some(Frame::Control(payload))),
194        TAG_DATA => Ok(Some(Frame::Data(payload))),
195        other => bail!("unknown frame tag: 0x{:02x}", other),
196    }
197}
198
199/// Convenience: serialize a ClientMessage and write as a control frame.
200pub async fn send_client_message<W: AsyncWrite + Unpin>(
201    w: &mut W,
202    msg: &ClientMessage,
203) -> Result<()> {
204    let json = serde_json::to_vec(msg)?;
205    write_control(w, &json).await
206}
207
208/// Convenience: serialize a ServerMessage and write as a control frame.
209pub async fn send_server_message<W: AsyncWrite + Unpin>(
210    w: &mut W,
211    msg: &ServerMessage,
212) -> Result<()> {
213    let json = serde_json::to_vec(msg)?;
214    write_control(w, &json).await
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn serde_round_trip_client() {
223        let msgs = vec![
224            ClientMessage::CreateSession {
225                shell: Some("bash".into()),
226            },
227            ClientMessage::ListSessions,
228            ClientMessage::AttachSession {
229                id: Uuid::nil(),
230                cols: 120,
231                rows: 40,
232            },
233            ClientMessage::DetachSession,
234            ClientMessage::ResizeSession {
235                id: Uuid::nil(),
236                cols: 80,
237                rows: 24,
238            },
239            ClientMessage::KillSession { id: Uuid::nil() },
240            ClientMessage::CreateAgent {
241                model: Some("sonnet".into()),
242                permission_mode: Some("plan".into()),
243                allowed_tools: vec!["Read".into(), "Write".into()],
244                max_turns: Some(10),
245                cwd: Some("/tmp".into()),
246            },
247            ClientMessage::CreateAgent {
248                model: None,
249                permission_mode: None,
250                allowed_tools: vec![],
251                max_turns: None,
252                cwd: None,
253            },
254            ClientMessage::AgentPrompt {
255                id: Uuid::nil(),
256                prompt: "do something".into(),
257            },
258            ClientMessage::AgentStatus { id: Uuid::nil() },
259            ClientMessage::ListAgents,
260            ClientMessage::KillAgent { id: Uuid::nil() },
261        ];
262        for msg in msgs {
263            let json = serde_json::to_string(&msg).unwrap();
264            let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
265            assert_eq!(msg, decoded);
266        }
267    }
268
269    #[test]
270    fn serde_round_trip_server() {
271        let msgs = vec![
272            ServerMessage::SessionCreated { id: Uuid::nil() },
273            ServerMessage::Sessions {
274                sessions: vec![SessionInfo {
275                    id: Uuid::nil(),
276                    cols: 80,
277                    rows: 24,
278                    created_at: Utc::now(),
279                    client_count: 2,
280                }],
281            },
282            ServerMessage::Attached { id: Uuid::nil() },
283            ServerMessage::Detached,
284            ServerMessage::SessionEnded {
285                id: Uuid::nil(),
286                exit_code: Some(0),
287            },
288            ServerMessage::ClientJoined {
289                session_id: Uuid::nil(),
290                client_id: Uuid::nil(),
291            },
292            ServerMessage::ClientLeft {
293                session_id: Uuid::nil(),
294                client_id: Uuid::nil(),
295            },
296            ServerMessage::Error {
297                message: "fail".into(),
298            },
299            ServerMessage::AgentCreated { id: Uuid::nil() },
300            ServerMessage::AgentOutput {
301                id: Uuid::nil(),
302                event: AgentEvent {
303                    event_type: "content_block_delta".into(),
304                    raw_json: r#"{"type":"content_block_delta"}"#.into(),
305                },
306            },
307            ServerMessage::AgentPromptDone {
308                id: Uuid::nil(),
309                turn_count: 3,
310            },
311            ServerMessage::AgentStatusResponse {
312                id: Uuid::nil(),
313                status: AgentState::Idle,
314                claude_session_id: Some("sess-123".into()),
315                model: Some("sonnet".into()),
316                turn_count: 5,
317            },
318            ServerMessage::AgentStatusResponse {
319                id: Uuid::nil(),
320                status: AgentState::Processing,
321                claude_session_id: None,
322                model: None,
323                turn_count: 0,
324            },
325            ServerMessage::AgentStatusResponse {
326                id: Uuid::nil(),
327                status: AgentState::Error("something broke".into()),
328                claude_session_id: None,
329                model: None,
330                turn_count: 0,
331            },
332            ServerMessage::Agents {
333                agents: vec![AgentInfo {
334                    id: Uuid::nil(),
335                    status: AgentState::Idle,
336                    model: Some("sonnet".into()),
337                    turn_count: 2,
338                    created_at: Utc::now(),
339                }],
340            },
341        ];
342        for msg in msgs {
343            let json = serde_json::to_string(&msg).unwrap();
344            let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
345            assert_eq!(msg, decoded);
346        }
347    }
348
349    #[tokio::test]
350    async fn frame_round_trip_control() {
351        let (mut client, mut server) = tokio::io::duplex(1024);
352        let payload = b"hello control";
353        write_control(&mut client, payload).await.unwrap();
354        drop(client);
355        let frame = read_frame(&mut server).await.unwrap().unwrap();
356        match frame {
357            Frame::Control(data) => assert_eq!(data, payload),
358            Frame::Data(_) => panic!("expected control frame"),
359        }
360    }
361
362    #[tokio::test]
363    async fn frame_round_trip_data() {
364        let (mut client, mut server) = tokio::io::duplex(1024);
365        let payload = b"hello data";
366        write_data(&mut client, payload).await.unwrap();
367        drop(client);
368        let frame = read_frame(&mut server).await.unwrap().unwrap();
369        match frame {
370            Frame::Data(data) => assert_eq!(data, payload),
371            Frame::Control(_) => panic!("expected data frame"),
372        }
373    }
374
375    #[tokio::test]
376    async fn frame_eof_returns_none() {
377        let (client, mut server) = tokio::io::duplex(1024);
378        drop(client);
379        let frame = read_frame(&mut server).await.unwrap();
380        assert!(frame.is_none());
381    }
382
383    #[tokio::test]
384    async fn frame_bad_tag() {
385        let (mut client, mut server) = tokio::io::duplex(1024);
386        // Write a frame with tag 0xFF
387        let len: u32 = 2; // 1 byte tag + 1 byte payload
388        client.write_all(&len.to_be_bytes()).await.unwrap();
389        client.write_u8(0xFF).await.unwrap();
390        client.write_u8(0x00).await.unwrap();
391        drop(client);
392        let result = read_frame(&mut server).await;
393        assert!(result.is_err());
394        assert!(
395            result
396                .unwrap_err()
397                .to_string()
398                .contains("unknown frame tag")
399        );
400    }
401
402    #[tokio::test]
403    async fn frame_too_large() {
404        let (mut client, mut server) = tokio::io::duplex(1024);
405        // Write a frame header claiming 2 MiB payload
406        let len: u32 = 2 * 1024 * 1024;
407        client.write_all(&len.to_be_bytes()).await.unwrap();
408        drop(client);
409        let result = read_frame(&mut server).await;
410        assert!(result.is_err());
411        assert!(result.unwrap_err().to_string().contains("frame too large"));
412    }
413
414    #[tokio::test]
415    async fn send_client_message_round_trip() {
416        let (mut client, mut server) = tokio::io::duplex(4096);
417        let msg = ClientMessage::CreateSession {
418            shell: Some("zsh".into()),
419        };
420        send_client_message(&mut client, &msg).await.unwrap();
421        drop(client);
422        let frame = read_frame(&mut server).await.unwrap().unwrap();
423        match frame {
424            Frame::Control(data) => {
425                let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
426                assert_eq!(decoded, msg);
427            }
428            Frame::Data(_) => panic!("expected control frame"),
429        }
430    }
431}