Skip to main content

vex_cli/
proto.rs

1use anyhow::{Result, bail};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use uuid::Uuid;
6
7const TAG_CONTROL: u8 = 0x01;
8const TAG_DATA: u8 = 0x02;
9const MAX_FRAME_SIZE: usize = 1_048_576; // 1 MiB
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(tag = "type")]
13pub enum ClientMessage {
14    CreateSession { shell: Option<String> },
15    ListSessions,
16    AttachSession { id: Uuid, cols: u16, rows: u16 },
17    DetachSession,
18    ResizeSession { id: Uuid, cols: u16, rows: u16 },
19    KillSession { id: Uuid },
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23#[serde(tag = "type")]
24pub enum ServerMessage {
25    SessionCreated { id: Uuid },
26    Sessions { sessions: Vec<SessionInfo> },
27    Attached { id: Uuid },
28    Detached,
29    SessionEnded { id: Uuid, exit_code: Option<i32> },
30    ClientJoined { session_id: Uuid, client_id: Uuid },
31    ClientLeft { session_id: Uuid, client_id: Uuid },
32    Error { message: String },
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
36pub struct SessionInfo {
37    pub id: Uuid,
38    pub cols: u16,
39    pub rows: u16,
40    pub created_at: DateTime<Utc>,
41    pub client_count: usize,
42}
43
44#[derive(Debug)]
45pub enum Frame {
46    Control(Vec<u8>),
47    Data(Vec<u8>),
48}
49
50pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
51    let len = (1 + payload.len()) as u32;
52    w.write_all(&len.to_be_bytes()).await?;
53    w.write_u8(TAG_CONTROL).await?;
54    w.write_all(payload).await?;
55    w.flush().await?;
56    Ok(())
57}
58
59pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
60    let len = (1 + payload.len()) as u32;
61    w.write_all(&len.to_be_bytes()).await?;
62    w.write_u8(TAG_DATA).await?;
63    w.write_all(payload).await?;
64    w.flush().await?;
65    Ok(())
66}
67
68pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
69    let mut len_buf = [0u8; 4];
70    match r.read_exact(&mut len_buf).await {
71        Ok(_) => {}
72        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
73        Err(e) => return Err(e.into()),
74    }
75    let len = u32::from_be_bytes(len_buf) as usize;
76    if len == 0 {
77        bail!("invalid frame: zero length");
78    }
79    if len > MAX_FRAME_SIZE {
80        bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
81    }
82    let tag = {
83        let mut tag_buf = [0u8; 1];
84        r.read_exact(&mut tag_buf).await?;
85        tag_buf[0]
86    };
87    let payload_len = len - 1;
88    let mut payload = vec![0u8; payload_len];
89    if payload_len > 0 {
90        r.read_exact(&mut payload).await?;
91    }
92    match tag {
93        TAG_CONTROL => Ok(Some(Frame::Control(payload))),
94        TAG_DATA => Ok(Some(Frame::Data(payload))),
95        other => bail!("unknown frame tag: 0x{:02x}", other),
96    }
97}
98
99/// Convenience: serialize a ClientMessage and write as a control frame.
100pub async fn send_client_message<W: AsyncWrite + Unpin>(
101    w: &mut W,
102    msg: &ClientMessage,
103) -> Result<()> {
104    let json = serde_json::to_vec(msg)?;
105    write_control(w, &json).await
106}
107
108/// Convenience: serialize a ServerMessage and write as a control frame.
109pub async fn send_server_message<W: AsyncWrite + Unpin>(
110    w: &mut W,
111    msg: &ServerMessage,
112) -> Result<()> {
113    let json = serde_json::to_vec(msg)?;
114    write_control(w, &json).await
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn serde_round_trip_client() {
123        let msgs = vec![
124            ClientMessage::CreateSession {
125                shell: Some("bash".into()),
126            },
127            ClientMessage::ListSessions,
128            ClientMessage::AttachSession {
129                id: Uuid::nil(),
130                cols: 120,
131                rows: 40,
132            },
133            ClientMessage::DetachSession,
134            ClientMessage::ResizeSession {
135                id: Uuid::nil(),
136                cols: 80,
137                rows: 24,
138            },
139            ClientMessage::KillSession { id: Uuid::nil() },
140        ];
141        for msg in msgs {
142            let json = serde_json::to_string(&msg).unwrap();
143            let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
144            assert_eq!(msg, decoded);
145        }
146    }
147
148    #[test]
149    fn serde_round_trip_server() {
150        let msgs = vec![
151            ServerMessage::SessionCreated { id: Uuid::nil() },
152            ServerMessage::Sessions {
153                sessions: vec![SessionInfo {
154                    id: Uuid::nil(),
155                    cols: 80,
156                    rows: 24,
157                    created_at: Utc::now(),
158                    client_count: 2,
159                }],
160            },
161            ServerMessage::Attached { id: Uuid::nil() },
162            ServerMessage::Detached,
163            ServerMessage::SessionEnded {
164                id: Uuid::nil(),
165                exit_code: Some(0),
166            },
167            ServerMessage::ClientJoined {
168                session_id: Uuid::nil(),
169                client_id: Uuid::nil(),
170            },
171            ServerMessage::ClientLeft {
172                session_id: Uuid::nil(),
173                client_id: Uuid::nil(),
174            },
175            ServerMessage::Error {
176                message: "fail".into(),
177            },
178        ];
179        for msg in msgs {
180            let json = serde_json::to_string(&msg).unwrap();
181            let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
182            assert_eq!(msg, decoded);
183        }
184    }
185
186    #[tokio::test]
187    async fn frame_round_trip_control() {
188        let (mut client, mut server) = tokio::io::duplex(1024);
189        let payload = b"hello control";
190        write_control(&mut client, payload).await.unwrap();
191        drop(client);
192        let frame = read_frame(&mut server).await.unwrap().unwrap();
193        match frame {
194            Frame::Control(data) => assert_eq!(data, payload),
195            Frame::Data(_) => panic!("expected control frame"),
196        }
197    }
198
199    #[tokio::test]
200    async fn frame_round_trip_data() {
201        let (mut client, mut server) = tokio::io::duplex(1024);
202        let payload = b"hello data";
203        write_data(&mut client, payload).await.unwrap();
204        drop(client);
205        let frame = read_frame(&mut server).await.unwrap().unwrap();
206        match frame {
207            Frame::Data(data) => assert_eq!(data, payload),
208            Frame::Control(_) => panic!("expected data frame"),
209        }
210    }
211
212    #[tokio::test]
213    async fn frame_eof_returns_none() {
214        let (client, mut server) = tokio::io::duplex(1024);
215        drop(client);
216        let frame = read_frame(&mut server).await.unwrap();
217        assert!(frame.is_none());
218    }
219
220    #[tokio::test]
221    async fn frame_bad_tag() {
222        let (mut client, mut server) = tokio::io::duplex(1024);
223        // Write a frame with tag 0xFF
224        let len: u32 = 2; // 1 byte tag + 1 byte payload
225        client.write_all(&len.to_be_bytes()).await.unwrap();
226        client.write_u8(0xFF).await.unwrap();
227        client.write_u8(0x00).await.unwrap();
228        drop(client);
229        let result = read_frame(&mut server).await;
230        assert!(result.is_err());
231        assert!(
232            result
233                .unwrap_err()
234                .to_string()
235                .contains("unknown frame tag")
236        );
237    }
238
239    #[tokio::test]
240    async fn frame_too_large() {
241        let (mut client, mut server) = tokio::io::duplex(1024);
242        // Write a frame header claiming 2 MiB payload
243        let len: u32 = 2 * 1024 * 1024;
244        client.write_all(&len.to_be_bytes()).await.unwrap();
245        drop(client);
246        let result = read_frame(&mut server).await;
247        assert!(result.is_err());
248        assert!(result.unwrap_err().to_string().contains("frame too large"));
249    }
250
251    #[tokio::test]
252    async fn send_client_message_round_trip() {
253        let (mut client, mut server) = tokio::io::duplex(4096);
254        let msg = ClientMessage::CreateSession {
255            shell: Some("zsh".into()),
256        };
257        send_client_message(&mut client, &msg).await.unwrap();
258        drop(client);
259        let frame = read_frame(&mut server).await.unwrap().unwrap();
260        match frame {
261            Frame::Control(data) => {
262                let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
263                assert_eq!(decoded, msg);
264            }
265            Frame::Data(_) => panic!("expected control frame"),
266        }
267    }
268}