1use anyhow::{Result, bail};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use uuid::Uuid;
6
7const TAG_CONTROL: u8 = 0x01;
8const TAG_DATA: u8 = 0x02;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11#[serde(tag = "type")]
12pub enum ClientMessage {
13 Authenticate { token: String },
14 CreateSession { shell: Option<String> },
15 ListSessions,
16 AttachSession { id: Uuid },
17 DetachSession,
18 ResizeSession { id: Uuid, cols: u16, rows: u16 },
19 KillSession { id: Uuid },
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23#[serde(tag = "type")]
24pub enum ServerMessage {
25 Authenticated,
26 SessionCreated { id: Uuid },
27 Sessions { sessions: Vec<SessionInfo> },
28 Attached { id: Uuid },
29 Detached,
30 SessionEnded { id: Uuid, exit_code: Option<i32> },
31 ClientJoined { session_id: Uuid, client_id: Uuid },
32 ClientLeft { session_id: Uuid, client_id: Uuid },
33 Error { message: String },
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
37pub struct SessionInfo {
38 pub id: Uuid,
39 pub cols: u16,
40 pub rows: u16,
41 pub created_at: DateTime<Utc>,
42 pub client_count: usize,
43}
44
45#[derive(Debug)]
46pub enum Frame {
47 Control(Vec<u8>),
48 Data(Vec<u8>),
49}
50
51pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
52 let len = (1 + payload.len()) as u32;
53 w.write_all(&len.to_be_bytes()).await?;
54 w.write_u8(TAG_CONTROL).await?;
55 w.write_all(payload).await?;
56 w.flush().await?;
57 Ok(())
58}
59
60pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
61 let len = (1 + payload.len()) as u32;
62 w.write_all(&len.to_be_bytes()).await?;
63 w.write_u8(TAG_DATA).await?;
64 w.write_all(payload).await?;
65 w.flush().await?;
66 Ok(())
67}
68
69pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
70 let mut len_buf = [0u8; 4];
71 match r.read_exact(&mut len_buf).await {
72 Ok(_) => {}
73 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
74 Err(e) => return Err(e.into()),
75 }
76 let len = u32::from_be_bytes(len_buf) as usize;
77 if len == 0 {
78 bail!("invalid frame: zero length");
79 }
80 let tag = {
81 let mut tag_buf = [0u8; 1];
82 r.read_exact(&mut tag_buf).await?;
83 tag_buf[0]
84 };
85 let payload_len = len - 1;
86 let mut payload = vec![0u8; payload_len];
87 if payload_len > 0 {
88 r.read_exact(&mut payload).await?;
89 }
90 match tag {
91 TAG_CONTROL => Ok(Some(Frame::Control(payload))),
92 TAG_DATA => Ok(Some(Frame::Data(payload))),
93 other => bail!("unknown frame tag: 0x{:02x}", other),
94 }
95}
96
97pub async fn send_client_message<W: AsyncWrite + Unpin>(
99 w: &mut W,
100 msg: &ClientMessage,
101) -> Result<()> {
102 let json = serde_json::to_vec(msg)?;
103 write_control(w, &json).await
104}
105
106pub async fn send_server_message<W: AsyncWrite + Unpin>(
108 w: &mut W,
109 msg: &ServerMessage,
110) -> Result<()> {
111 let json = serde_json::to_vec(msg)?;
112 write_control(w, &json).await
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn serde_round_trip_client() {
121 let msgs = vec![
122 ClientMessage::Authenticate {
123 token: "test-token".into(),
124 },
125 ClientMessage::CreateSession {
126 shell: Some("bash".into()),
127 },
128 ClientMessage::ListSessions,
129 ClientMessage::AttachSession { id: Uuid::nil() },
130 ClientMessage::DetachSession,
131 ClientMessage::ResizeSession {
132 id: Uuid::nil(),
133 cols: 80,
134 rows: 24,
135 },
136 ClientMessage::KillSession { id: Uuid::nil() },
137 ];
138 for msg in msgs {
139 let json = serde_json::to_string(&msg).unwrap();
140 let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
141 assert_eq!(msg, decoded);
142 }
143 }
144
145 #[test]
146 fn serde_round_trip_server() {
147 let msgs = vec![
148 ServerMessage::Authenticated,
149 ServerMessage::SessionCreated { id: Uuid::nil() },
150 ServerMessage::Sessions {
151 sessions: vec![SessionInfo {
152 id: Uuid::nil(),
153 cols: 80,
154 rows: 24,
155 created_at: Utc::now(),
156 client_count: 2,
157 }],
158 },
159 ServerMessage::Attached { id: Uuid::nil() },
160 ServerMessage::Detached,
161 ServerMessage::SessionEnded {
162 id: Uuid::nil(),
163 exit_code: Some(0),
164 },
165 ServerMessage::ClientJoined {
166 session_id: Uuid::nil(),
167 client_id: Uuid::nil(),
168 },
169 ServerMessage::ClientLeft {
170 session_id: Uuid::nil(),
171 client_id: Uuid::nil(),
172 },
173 ServerMessage::Error {
174 message: "fail".into(),
175 },
176 ];
177 for msg in msgs {
178 let json = serde_json::to_string(&msg).unwrap();
179 let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
180 assert_eq!(msg, decoded);
181 }
182 }
183
184 #[tokio::test]
185 async fn frame_round_trip_control() {
186 let (mut client, mut server) = tokio::io::duplex(1024);
187 let payload = b"hello control";
188 write_control(&mut client, payload).await.unwrap();
189 drop(client);
190 let frame = read_frame(&mut server).await.unwrap().unwrap();
191 match frame {
192 Frame::Control(data) => assert_eq!(data, payload),
193 Frame::Data(_) => panic!("expected control frame"),
194 }
195 }
196
197 #[tokio::test]
198 async fn frame_round_trip_data() {
199 let (mut client, mut server) = tokio::io::duplex(1024);
200 let payload = b"hello data";
201 write_data(&mut client, payload).await.unwrap();
202 drop(client);
203 let frame = read_frame(&mut server).await.unwrap().unwrap();
204 match frame {
205 Frame::Data(data) => assert_eq!(data, payload),
206 Frame::Control(_) => panic!("expected data frame"),
207 }
208 }
209
210 #[tokio::test]
211 async fn frame_eof_returns_none() {
212 let (client, mut server) = tokio::io::duplex(1024);
213 drop(client);
214 let frame = read_frame(&mut server).await.unwrap();
215 assert!(frame.is_none());
216 }
217
218 #[tokio::test]
219 async fn frame_bad_tag() {
220 let (mut client, mut server) = tokio::io::duplex(1024);
221 let len: u32 = 2; client.write_all(&len.to_be_bytes()).await.unwrap();
224 client.write_u8(0xFF).await.unwrap();
225 client.write_u8(0x00).await.unwrap();
226 drop(client);
227 let result = read_frame(&mut server).await;
228 assert!(result.is_err());
229 assert!(
230 result
231 .unwrap_err()
232 .to_string()
233 .contains("unknown frame tag")
234 );
235 }
236
237 #[tokio::test]
238 async fn send_client_message_round_trip() {
239 let (mut client, mut server) = tokio::io::duplex(4096);
240 let msg = ClientMessage::CreateSession {
241 shell: Some("zsh".into()),
242 };
243 send_client_message(&mut client, &msg).await.unwrap();
244 drop(client);
245 let frame = read_frame(&mut server).await.unwrap().unwrap();
246 match frame {
247 Frame::Control(data) => {
248 let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
249 assert_eq!(decoded, msg);
250 }
251 Frame::Data(_) => panic!("expected control frame"),
252 }
253 }
254}