Skip to main content

vex_cli/
proto.rs

1use anyhow::{Result, bail};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use uuid::Uuid;
6
7const TAG_CONTROL: u8 = 0x01;
8const TAG_DATA: u8 = 0x02;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11#[serde(tag = "type")]
12pub enum ClientMessage {
13    CreateSession { shell: Option<String> },
14    ListSessions,
15    AttachSession { id: Uuid },
16    DetachSession,
17    ResizeSession { id: Uuid, cols: u16, rows: u16 },
18    KillSession { id: Uuid },
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
22#[serde(tag = "type")]
23pub enum ServerMessage {
24    SessionCreated { id: Uuid },
25    Sessions { sessions: Vec<SessionInfo> },
26    Attached { id: Uuid },
27    Detached,
28    SessionEnded { id: Uuid, exit_code: Option<i32> },
29    Error { message: String },
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
33pub struct SessionInfo {
34    pub id: Uuid,
35    pub cols: u16,
36    pub rows: u16,
37    pub created_at: DateTime<Utc>,
38}
39
40#[derive(Debug)]
41pub enum Frame {
42    Control(Vec<u8>),
43    Data(Vec<u8>),
44}
45
46pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
47    let len = (1 + payload.len()) as u32;
48    w.write_all(&len.to_be_bytes()).await?;
49    w.write_u8(TAG_CONTROL).await?;
50    w.write_all(payload).await?;
51    w.flush().await?;
52    Ok(())
53}
54
55pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
56    let len = (1 + payload.len()) as u32;
57    w.write_all(&len.to_be_bytes()).await?;
58    w.write_u8(TAG_DATA).await?;
59    w.write_all(payload).await?;
60    w.flush().await?;
61    Ok(())
62}
63
64pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
65    let mut len_buf = [0u8; 4];
66    match r.read_exact(&mut len_buf).await {
67        Ok(_) => {}
68        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
69        Err(e) => return Err(e.into()),
70    }
71    let len = u32::from_be_bytes(len_buf) as usize;
72    if len == 0 {
73        bail!("invalid frame: zero length");
74    }
75    let tag = {
76        let mut tag_buf = [0u8; 1];
77        r.read_exact(&mut tag_buf).await?;
78        tag_buf[0]
79    };
80    let payload_len = len - 1;
81    let mut payload = vec![0u8; payload_len];
82    if payload_len > 0 {
83        r.read_exact(&mut payload).await?;
84    }
85    match tag {
86        TAG_CONTROL => Ok(Some(Frame::Control(payload))),
87        TAG_DATA => Ok(Some(Frame::Data(payload))),
88        other => bail!("unknown frame tag: 0x{:02x}", other),
89    }
90}
91
92/// Convenience: serialize a ClientMessage and write as a control frame.
93pub async fn send_client_message<W: AsyncWrite + Unpin>(
94    w: &mut W,
95    msg: &ClientMessage,
96) -> Result<()> {
97    let json = serde_json::to_vec(msg)?;
98    write_control(w, &json).await
99}
100
101/// Convenience: serialize a ServerMessage and write as a control frame.
102pub async fn send_server_message<W: AsyncWrite + Unpin>(
103    w: &mut W,
104    msg: &ServerMessage,
105) -> Result<()> {
106    let json = serde_json::to_vec(msg)?;
107    write_control(w, &json).await
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn serde_round_trip_client() {
116        let msgs = vec![
117            ClientMessage::CreateSession {
118                shell: Some("bash".into()),
119            },
120            ClientMessage::ListSessions,
121            ClientMessage::AttachSession { id: Uuid::nil() },
122            ClientMessage::DetachSession,
123            ClientMessage::ResizeSession {
124                id: Uuid::nil(),
125                cols: 80,
126                rows: 24,
127            },
128            ClientMessage::KillSession { id: Uuid::nil() },
129        ];
130        for msg in msgs {
131            let json = serde_json::to_string(&msg).unwrap();
132            let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
133            assert_eq!(msg, decoded);
134        }
135    }
136
137    #[test]
138    fn serde_round_trip_server() {
139        let msgs = vec![
140            ServerMessage::SessionCreated { id: Uuid::nil() },
141            ServerMessage::Sessions {
142                sessions: vec![SessionInfo {
143                    id: Uuid::nil(),
144                    cols: 80,
145                    rows: 24,
146                    created_at: Utc::now(),
147                }],
148            },
149            ServerMessage::Attached { id: Uuid::nil() },
150            ServerMessage::Detached,
151            ServerMessage::SessionEnded {
152                id: Uuid::nil(),
153                exit_code: Some(0),
154            },
155            ServerMessage::Error {
156                message: "fail".into(),
157            },
158        ];
159        for msg in msgs {
160            let json = serde_json::to_string(&msg).unwrap();
161            let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
162            assert_eq!(msg, decoded);
163        }
164    }
165
166    #[tokio::test]
167    async fn frame_round_trip_control() {
168        let (mut client, mut server) = tokio::io::duplex(1024);
169        let payload = b"hello control";
170        write_control(&mut client, payload).await.unwrap();
171        drop(client);
172        let frame = read_frame(&mut server).await.unwrap().unwrap();
173        match frame {
174            Frame::Control(data) => assert_eq!(data, payload),
175            Frame::Data(_) => panic!("expected control frame"),
176        }
177    }
178
179    #[tokio::test]
180    async fn frame_round_trip_data() {
181        let (mut client, mut server) = tokio::io::duplex(1024);
182        let payload = b"hello data";
183        write_data(&mut client, payload).await.unwrap();
184        drop(client);
185        let frame = read_frame(&mut server).await.unwrap().unwrap();
186        match frame {
187            Frame::Data(data) => assert_eq!(data, payload),
188            Frame::Control(_) => panic!("expected data frame"),
189        }
190    }
191
192    #[tokio::test]
193    async fn frame_eof_returns_none() {
194        let (client, mut server) = tokio::io::duplex(1024);
195        drop(client);
196        let frame = read_frame(&mut server).await.unwrap();
197        assert!(frame.is_none());
198    }
199
200    #[tokio::test]
201    async fn frame_bad_tag() {
202        let (mut client, mut server) = tokio::io::duplex(1024);
203        // Write a frame with tag 0xFF
204        let len: u32 = 2; // 1 byte tag + 1 byte payload
205        client.write_all(&len.to_be_bytes()).await.unwrap();
206        client.write_u8(0xFF).await.unwrap();
207        client.write_u8(0x00).await.unwrap();
208        drop(client);
209        let result = read_frame(&mut server).await;
210        assert!(result.is_err());
211        assert!(
212            result
213                .unwrap_err()
214                .to_string()
215                .contains("unknown frame tag")
216        );
217    }
218
219    #[tokio::test]
220    async fn send_client_message_round_trip() {
221        let (mut client, mut server) = tokio::io::duplex(4096);
222        let msg = ClientMessage::CreateSession {
223            shell: Some("zsh".into()),
224        };
225        send_client_message(&mut client, &msg).await.unwrap();
226        drop(client);
227        let frame = read_frame(&mut server).await.unwrap().unwrap();
228        match frame {
229            Frame::Control(data) => {
230                let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
231                assert_eq!(decoded, msg);
232            }
233            Frame::Data(_) => panic!("expected control frame"),
234        }
235    }
236}