Skip to main content

vex_cli/
proto.rs

1use anyhow::{Result, bail};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use uuid::Uuid;
6
7const TAG_CONTROL: u8 = 0x01;
8const TAG_DATA: u8 = 0x02;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11#[serde(tag = "type")]
12pub enum ClientMessage {
13    Authenticate { token: String },
14    CreateSession { shell: Option<String> },
15    ListSessions,
16    AttachSession { id: Uuid },
17    DetachSession,
18    ResizeSession { id: Uuid, cols: u16, rows: u16 },
19    KillSession { id: Uuid },
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23#[serde(tag = "type")]
24pub enum ServerMessage {
25    Authenticated,
26    SessionCreated { id: Uuid },
27    Sessions { sessions: Vec<SessionInfo> },
28    Attached { id: Uuid },
29    Detached,
30    SessionEnded { id: Uuid, exit_code: Option<i32> },
31    ClientJoined { session_id: Uuid, client_id: Uuid },
32    ClientLeft { session_id: Uuid, client_id: Uuid },
33    Error { message: String },
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
37pub struct SessionInfo {
38    pub id: Uuid,
39    pub cols: u16,
40    pub rows: u16,
41    pub created_at: DateTime<Utc>,
42    pub client_count: usize,
43}
44
45#[derive(Debug)]
46pub enum Frame {
47    Control(Vec<u8>),
48    Data(Vec<u8>),
49}
50
51pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
52    let len = (1 + payload.len()) as u32;
53    w.write_all(&len.to_be_bytes()).await?;
54    w.write_u8(TAG_CONTROL).await?;
55    w.write_all(payload).await?;
56    w.flush().await?;
57    Ok(())
58}
59
60pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
61    let len = (1 + payload.len()) as u32;
62    w.write_all(&len.to_be_bytes()).await?;
63    w.write_u8(TAG_DATA).await?;
64    w.write_all(payload).await?;
65    w.flush().await?;
66    Ok(())
67}
68
69pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
70    let mut len_buf = [0u8; 4];
71    match r.read_exact(&mut len_buf).await {
72        Ok(_) => {}
73        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
74        Err(e) => return Err(e.into()),
75    }
76    let len = u32::from_be_bytes(len_buf) as usize;
77    if len == 0 {
78        bail!("invalid frame: zero length");
79    }
80    let tag = {
81        let mut tag_buf = [0u8; 1];
82        r.read_exact(&mut tag_buf).await?;
83        tag_buf[0]
84    };
85    let payload_len = len - 1;
86    let mut payload = vec![0u8; payload_len];
87    if payload_len > 0 {
88        r.read_exact(&mut payload).await?;
89    }
90    match tag {
91        TAG_CONTROL => Ok(Some(Frame::Control(payload))),
92        TAG_DATA => Ok(Some(Frame::Data(payload))),
93        other => bail!("unknown frame tag: 0x{:02x}", other),
94    }
95}
96
97/// Convenience: serialize a ClientMessage and write as a control frame.
98pub async fn send_client_message<W: AsyncWrite + Unpin>(
99    w: &mut W,
100    msg: &ClientMessage,
101) -> Result<()> {
102    let json = serde_json::to_vec(msg)?;
103    write_control(w, &json).await
104}
105
106/// Convenience: serialize a ServerMessage and write as a control frame.
107pub async fn send_server_message<W: AsyncWrite + Unpin>(
108    w: &mut W,
109    msg: &ServerMessage,
110) -> Result<()> {
111    let json = serde_json::to_vec(msg)?;
112    write_control(w, &json).await
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn serde_round_trip_client() {
121        let msgs = vec![
122            ClientMessage::Authenticate {
123                token: "test-token".into(),
124            },
125            ClientMessage::CreateSession {
126                shell: Some("bash".into()),
127            },
128            ClientMessage::ListSessions,
129            ClientMessage::AttachSession { id: Uuid::nil() },
130            ClientMessage::DetachSession,
131            ClientMessage::ResizeSession {
132                id: Uuid::nil(),
133                cols: 80,
134                rows: 24,
135            },
136            ClientMessage::KillSession { id: Uuid::nil() },
137        ];
138        for msg in msgs {
139            let json = serde_json::to_string(&msg).unwrap();
140            let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
141            assert_eq!(msg, decoded);
142        }
143    }
144
145    #[test]
146    fn serde_round_trip_server() {
147        let msgs = vec![
148            ServerMessage::Authenticated,
149            ServerMessage::SessionCreated { id: Uuid::nil() },
150            ServerMessage::Sessions {
151                sessions: vec![SessionInfo {
152                    id: Uuid::nil(),
153                    cols: 80,
154                    rows: 24,
155                    created_at: Utc::now(),
156                    client_count: 2,
157                }],
158            },
159            ServerMessage::Attached { id: Uuid::nil() },
160            ServerMessage::Detached,
161            ServerMessage::SessionEnded {
162                id: Uuid::nil(),
163                exit_code: Some(0),
164            },
165            ServerMessage::ClientJoined {
166                session_id: Uuid::nil(),
167                client_id: Uuid::nil(),
168            },
169            ServerMessage::ClientLeft {
170                session_id: Uuid::nil(),
171                client_id: Uuid::nil(),
172            },
173            ServerMessage::Error {
174                message: "fail".into(),
175            },
176        ];
177        for msg in msgs {
178            let json = serde_json::to_string(&msg).unwrap();
179            let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
180            assert_eq!(msg, decoded);
181        }
182    }
183
184    #[tokio::test]
185    async fn frame_round_trip_control() {
186        let (mut client, mut server) = tokio::io::duplex(1024);
187        let payload = b"hello control";
188        write_control(&mut client, payload).await.unwrap();
189        drop(client);
190        let frame = read_frame(&mut server).await.unwrap().unwrap();
191        match frame {
192            Frame::Control(data) => assert_eq!(data, payload),
193            Frame::Data(_) => panic!("expected control frame"),
194        }
195    }
196
197    #[tokio::test]
198    async fn frame_round_trip_data() {
199        let (mut client, mut server) = tokio::io::duplex(1024);
200        let payload = b"hello data";
201        write_data(&mut client, payload).await.unwrap();
202        drop(client);
203        let frame = read_frame(&mut server).await.unwrap().unwrap();
204        match frame {
205            Frame::Data(data) => assert_eq!(data, payload),
206            Frame::Control(_) => panic!("expected data frame"),
207        }
208    }
209
210    #[tokio::test]
211    async fn frame_eof_returns_none() {
212        let (client, mut server) = tokio::io::duplex(1024);
213        drop(client);
214        let frame = read_frame(&mut server).await.unwrap();
215        assert!(frame.is_none());
216    }
217
218    #[tokio::test]
219    async fn frame_bad_tag() {
220        let (mut client, mut server) = tokio::io::duplex(1024);
221        // Write a frame with tag 0xFF
222        let len: u32 = 2; // 1 byte tag + 1 byte payload
223        client.write_all(&len.to_be_bytes()).await.unwrap();
224        client.write_u8(0xFF).await.unwrap();
225        client.write_u8(0x00).await.unwrap();
226        drop(client);
227        let result = read_frame(&mut server).await;
228        assert!(result.is_err());
229        assert!(
230            result
231                .unwrap_err()
232                .to_string()
233                .contains("unknown frame tag")
234        );
235    }
236
237    #[tokio::test]
238    async fn send_client_message_round_trip() {
239        let (mut client, mut server) = tokio::io::duplex(4096);
240        let msg = ClientMessage::CreateSession {
241            shell: Some("zsh".into()),
242        };
243        send_client_message(&mut client, &msg).await.unwrap();
244        drop(client);
245        let frame = read_frame(&mut server).await.unwrap().unwrap();
246        match frame {
247            Frame::Control(data) => {
248                let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
249                assert_eq!(decoded, msg);
250            }
251            Frame::Data(_) => panic!("expected control frame"),
252        }
253    }
254}