Skip to main content

vex_cli/
proto.rs

1use anyhow::{Result, bail};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use uuid::Uuid;
6
7const TAG_CONTROL: u8 = 0x01;
8const TAG_DATA: u8 = 0x02;
9const MAX_FRAME_SIZE: usize = 1_048_576; // 1 MiB
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(tag = "type")]
13pub enum ClientMessage {
14    Authenticate { token: String },
15    CreateSession { shell: Option<String> },
16    ListSessions,
17    AttachSession { id: Uuid },
18    DetachSession,
19    ResizeSession { id: Uuid, cols: u16, rows: u16 },
20    KillSession { id: Uuid },
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
24#[serde(tag = "type")]
25pub enum ServerMessage {
26    Authenticated,
27    SessionCreated { id: Uuid },
28    Sessions { sessions: Vec<SessionInfo> },
29    Attached { id: Uuid },
30    Detached,
31    SessionEnded { id: Uuid, exit_code: Option<i32> },
32    ClientJoined { session_id: Uuid, client_id: Uuid },
33    ClientLeft { session_id: Uuid, client_id: Uuid },
34    Error { message: String },
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
38pub struct SessionInfo {
39    pub id: Uuid,
40    pub cols: u16,
41    pub rows: u16,
42    pub created_at: DateTime<Utc>,
43    pub client_count: usize,
44}
45
46#[derive(Debug)]
47pub enum Frame {
48    Control(Vec<u8>),
49    Data(Vec<u8>),
50}
51
52pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
53    let len = (1 + payload.len()) as u32;
54    w.write_all(&len.to_be_bytes()).await?;
55    w.write_u8(TAG_CONTROL).await?;
56    w.write_all(payload).await?;
57    w.flush().await?;
58    Ok(())
59}
60
61pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
62    let len = (1 + payload.len()) as u32;
63    w.write_all(&len.to_be_bytes()).await?;
64    w.write_u8(TAG_DATA).await?;
65    w.write_all(payload).await?;
66    w.flush().await?;
67    Ok(())
68}
69
70pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
71    let mut len_buf = [0u8; 4];
72    match r.read_exact(&mut len_buf).await {
73        Ok(_) => {}
74        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
75        Err(e) => return Err(e.into()),
76    }
77    let len = u32::from_be_bytes(len_buf) as usize;
78    if len == 0 {
79        bail!("invalid frame: zero length");
80    }
81    if len > MAX_FRAME_SIZE {
82        bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
83    }
84    let tag = {
85        let mut tag_buf = [0u8; 1];
86        r.read_exact(&mut tag_buf).await?;
87        tag_buf[0]
88    };
89    let payload_len = len - 1;
90    let mut payload = vec![0u8; payload_len];
91    if payload_len > 0 {
92        r.read_exact(&mut payload).await?;
93    }
94    match tag {
95        TAG_CONTROL => Ok(Some(Frame::Control(payload))),
96        TAG_DATA => Ok(Some(Frame::Data(payload))),
97        other => bail!("unknown frame tag: 0x{:02x}", other),
98    }
99}
100
101/// Convenience: serialize a ClientMessage and write as a control frame.
102pub async fn send_client_message<W: AsyncWrite + Unpin>(
103    w: &mut W,
104    msg: &ClientMessage,
105) -> Result<()> {
106    let json = serde_json::to_vec(msg)?;
107    write_control(w, &json).await
108}
109
110/// Convenience: serialize a ServerMessage and write as a control frame.
111pub async fn send_server_message<W: AsyncWrite + Unpin>(
112    w: &mut W,
113    msg: &ServerMessage,
114) -> Result<()> {
115    let json = serde_json::to_vec(msg)?;
116    write_control(w, &json).await
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn serde_round_trip_client() {
125        let msgs = vec![
126            ClientMessage::Authenticate {
127                token: "test-token".into(),
128            },
129            ClientMessage::CreateSession {
130                shell: Some("bash".into()),
131            },
132            ClientMessage::ListSessions,
133            ClientMessage::AttachSession { id: Uuid::nil() },
134            ClientMessage::DetachSession,
135            ClientMessage::ResizeSession {
136                id: Uuid::nil(),
137                cols: 80,
138                rows: 24,
139            },
140            ClientMessage::KillSession { id: Uuid::nil() },
141        ];
142        for msg in msgs {
143            let json = serde_json::to_string(&msg).unwrap();
144            let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
145            assert_eq!(msg, decoded);
146        }
147    }
148
149    #[test]
150    fn serde_round_trip_server() {
151        let msgs = vec![
152            ServerMessage::Authenticated,
153            ServerMessage::SessionCreated { id: Uuid::nil() },
154            ServerMessage::Sessions {
155                sessions: vec![SessionInfo {
156                    id: Uuid::nil(),
157                    cols: 80,
158                    rows: 24,
159                    created_at: Utc::now(),
160                    client_count: 2,
161                }],
162            },
163            ServerMessage::Attached { id: Uuid::nil() },
164            ServerMessage::Detached,
165            ServerMessage::SessionEnded {
166                id: Uuid::nil(),
167                exit_code: Some(0),
168            },
169            ServerMessage::ClientJoined {
170                session_id: Uuid::nil(),
171                client_id: Uuid::nil(),
172            },
173            ServerMessage::ClientLeft {
174                session_id: Uuid::nil(),
175                client_id: Uuid::nil(),
176            },
177            ServerMessage::Error {
178                message: "fail".into(),
179            },
180        ];
181        for msg in msgs {
182            let json = serde_json::to_string(&msg).unwrap();
183            let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
184            assert_eq!(msg, decoded);
185        }
186    }
187
188    #[tokio::test]
189    async fn frame_round_trip_control() {
190        let (mut client, mut server) = tokio::io::duplex(1024);
191        let payload = b"hello control";
192        write_control(&mut client, payload).await.unwrap();
193        drop(client);
194        let frame = read_frame(&mut server).await.unwrap().unwrap();
195        match frame {
196            Frame::Control(data) => assert_eq!(data, payload),
197            Frame::Data(_) => panic!("expected control frame"),
198        }
199    }
200
201    #[tokio::test]
202    async fn frame_round_trip_data() {
203        let (mut client, mut server) = tokio::io::duplex(1024);
204        let payload = b"hello data";
205        write_data(&mut client, payload).await.unwrap();
206        drop(client);
207        let frame = read_frame(&mut server).await.unwrap().unwrap();
208        match frame {
209            Frame::Data(data) => assert_eq!(data, payload),
210            Frame::Control(_) => panic!("expected data frame"),
211        }
212    }
213
214    #[tokio::test]
215    async fn frame_eof_returns_none() {
216        let (client, mut server) = tokio::io::duplex(1024);
217        drop(client);
218        let frame = read_frame(&mut server).await.unwrap();
219        assert!(frame.is_none());
220    }
221
222    #[tokio::test]
223    async fn frame_bad_tag() {
224        let (mut client, mut server) = tokio::io::duplex(1024);
225        // Write a frame with tag 0xFF
226        let len: u32 = 2; // 1 byte tag + 1 byte payload
227        client.write_all(&len.to_be_bytes()).await.unwrap();
228        client.write_u8(0xFF).await.unwrap();
229        client.write_u8(0x00).await.unwrap();
230        drop(client);
231        let result = read_frame(&mut server).await;
232        assert!(result.is_err());
233        assert!(
234            result
235                .unwrap_err()
236                .to_string()
237                .contains("unknown frame tag")
238        );
239    }
240
241    #[tokio::test]
242    async fn frame_too_large() {
243        let (mut client, mut server) = tokio::io::duplex(1024);
244        // Write a frame header claiming 2 MiB payload
245        let len: u32 = 2 * 1024 * 1024;
246        client.write_all(&len.to_be_bytes()).await.unwrap();
247        drop(client);
248        let result = read_frame(&mut server).await;
249        assert!(result.is_err());
250        assert!(result.unwrap_err().to_string().contains("frame too large"));
251    }
252
253    #[tokio::test]
254    async fn send_client_message_round_trip() {
255        let (mut client, mut server) = tokio::io::duplex(4096);
256        let msg = ClientMessage::CreateSession {
257            shell: Some("zsh".into()),
258        };
259        send_client_message(&mut client, &msg).await.unwrap();
260        drop(client);
261        let frame = read_frame(&mut server).await.unwrap().unwrap();
262        match frame {
263            Frame::Control(data) => {
264                let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
265                assert_eq!(decoded, msg);
266            }
267            Frame::Data(_) => panic!("expected control frame"),
268        }
269    }
270}