1use anyhow::{Result, bail};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use uuid::Uuid;
6
7const TAG_CONTROL: u8 = 0x01;
8const TAG_DATA: u8 = 0x02;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11#[serde(tag = "type")]
12pub enum ClientMessage {
13 CreateSession { shell: Option<String> },
14 ListSessions,
15 AttachSession { id: Uuid },
16 DetachSession,
17 ResizeSession { id: Uuid, cols: u16, rows: u16 },
18 KillSession { id: Uuid },
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
22#[serde(tag = "type")]
23pub enum ServerMessage {
24 SessionCreated { id: Uuid },
25 Sessions { sessions: Vec<SessionInfo> },
26 Attached { id: Uuid },
27 Detached,
28 SessionEnded { id: Uuid, exit_code: Option<i32> },
29 Error { message: String },
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
33pub struct SessionInfo {
34 pub id: Uuid,
35 pub cols: u16,
36 pub rows: u16,
37 pub created_at: DateTime<Utc>,
38}
39
40#[derive(Debug)]
41pub enum Frame {
42 Control(Vec<u8>),
43 Data(Vec<u8>),
44}
45
46pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
47 let len = (1 + payload.len()) as u32;
48 w.write_all(&len.to_be_bytes()).await?;
49 w.write_u8(TAG_CONTROL).await?;
50 w.write_all(payload).await?;
51 w.flush().await?;
52 Ok(())
53}
54
55pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
56 let len = (1 + payload.len()) as u32;
57 w.write_all(&len.to_be_bytes()).await?;
58 w.write_u8(TAG_DATA).await?;
59 w.write_all(payload).await?;
60 w.flush().await?;
61 Ok(())
62}
63
64pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
65 let mut len_buf = [0u8; 4];
66 match r.read_exact(&mut len_buf).await {
67 Ok(_) => {}
68 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
69 Err(e) => return Err(e.into()),
70 }
71 let len = u32::from_be_bytes(len_buf) as usize;
72 if len == 0 {
73 bail!("invalid frame: zero length");
74 }
75 let tag = {
76 let mut tag_buf = [0u8; 1];
77 r.read_exact(&mut tag_buf).await?;
78 tag_buf[0]
79 };
80 let payload_len = len - 1;
81 let mut payload = vec![0u8; payload_len];
82 if payload_len > 0 {
83 r.read_exact(&mut payload).await?;
84 }
85 match tag {
86 TAG_CONTROL => Ok(Some(Frame::Control(payload))),
87 TAG_DATA => Ok(Some(Frame::Data(payload))),
88 other => bail!("unknown frame tag: 0x{:02x}", other),
89 }
90}
91
92pub async fn send_client_message<W: AsyncWrite + Unpin>(
94 w: &mut W,
95 msg: &ClientMessage,
96) -> Result<()> {
97 let json = serde_json::to_vec(msg)?;
98 write_control(w, &json).await
99}
100
101pub async fn send_server_message<W: AsyncWrite + Unpin>(
103 w: &mut W,
104 msg: &ServerMessage,
105) -> Result<()> {
106 let json = serde_json::to_vec(msg)?;
107 write_control(w, &json).await
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn serde_round_trip_client() {
116 let msgs = vec![
117 ClientMessage::CreateSession {
118 shell: Some("bash".into()),
119 },
120 ClientMessage::ListSessions,
121 ClientMessage::AttachSession { id: Uuid::nil() },
122 ClientMessage::DetachSession,
123 ClientMessage::ResizeSession {
124 id: Uuid::nil(),
125 cols: 80,
126 rows: 24,
127 },
128 ClientMessage::KillSession { id: Uuid::nil() },
129 ];
130 for msg in msgs {
131 let json = serde_json::to_string(&msg).unwrap();
132 let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
133 assert_eq!(msg, decoded);
134 }
135 }
136
137 #[test]
138 fn serde_round_trip_server() {
139 let msgs = vec![
140 ServerMessage::SessionCreated { id: Uuid::nil() },
141 ServerMessage::Sessions {
142 sessions: vec![SessionInfo {
143 id: Uuid::nil(),
144 cols: 80,
145 rows: 24,
146 created_at: Utc::now(),
147 }],
148 },
149 ServerMessage::Attached { id: Uuid::nil() },
150 ServerMessage::Detached,
151 ServerMessage::SessionEnded {
152 id: Uuid::nil(),
153 exit_code: Some(0),
154 },
155 ServerMessage::Error {
156 message: "fail".into(),
157 },
158 ];
159 for msg in msgs {
160 let json = serde_json::to_string(&msg).unwrap();
161 let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
162 assert_eq!(msg, decoded);
163 }
164 }
165
166 #[tokio::test]
167 async fn frame_round_trip_control() {
168 let (mut client, mut server) = tokio::io::duplex(1024);
169 let payload = b"hello control";
170 write_control(&mut client, payload).await.unwrap();
171 drop(client);
172 let frame = read_frame(&mut server).await.unwrap().unwrap();
173 match frame {
174 Frame::Control(data) => assert_eq!(data, payload),
175 Frame::Data(_) => panic!("expected control frame"),
176 }
177 }
178
179 #[tokio::test]
180 async fn frame_round_trip_data() {
181 let (mut client, mut server) = tokio::io::duplex(1024);
182 let payload = b"hello data";
183 write_data(&mut client, payload).await.unwrap();
184 drop(client);
185 let frame = read_frame(&mut server).await.unwrap().unwrap();
186 match frame {
187 Frame::Data(data) => assert_eq!(data, payload),
188 Frame::Control(_) => panic!("expected data frame"),
189 }
190 }
191
192 #[tokio::test]
193 async fn frame_eof_returns_none() {
194 let (client, mut server) = tokio::io::duplex(1024);
195 drop(client);
196 let frame = read_frame(&mut server).await.unwrap();
197 assert!(frame.is_none());
198 }
199
200 #[tokio::test]
201 async fn frame_bad_tag() {
202 let (mut client, mut server) = tokio::io::duplex(1024);
203 let len: u32 = 2; client.write_all(&len.to_be_bytes()).await.unwrap();
206 client.write_u8(0xFF).await.unwrap();
207 client.write_u8(0x00).await.unwrap();
208 drop(client);
209 let result = read_frame(&mut server).await;
210 assert!(result.is_err());
211 assert!(
212 result
213 .unwrap_err()
214 .to_string()
215 .contains("unknown frame tag")
216 );
217 }
218
219 #[tokio::test]
220 async fn send_client_message_round_trip() {
221 let (mut client, mut server) = tokio::io::duplex(4096);
222 let msg = ClientMessage::CreateSession {
223 shell: Some("zsh".into()),
224 };
225 send_client_message(&mut client, &msg).await.unwrap();
226 drop(client);
227 let frame = read_frame(&mut server).await.unwrap().unwrap();
228 match frame {
229 Frame::Control(data) => {
230 let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
231 assert_eq!(decoded, msg);
232 }
233 Frame::Data(_) => panic!("expected control frame"),
234 }
235 }
236}