1use std::path::PathBuf;
2
3use anyhow::{Result, bail};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use uuid::Uuid;
8
9const TAG_CONTROL: u8 = 0x01;
10const TAG_DATA: u8 = 0x02;
11const MAX_FRAME_SIZE: usize = 1_048_576; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(tag = "type")]
15pub enum ClientMessage {
16 CreateSession { shell: Option<String> },
17 ListSessions,
18 AttachSession { id: Uuid, cols: u16, rows: u16 },
19 DetachSession,
20 ResizeSession { id: Uuid, cols: u16, rows: u16 },
21 KillSession { id: Uuid },
22 AgentList,
23 AgentNotifications,
24 AgentWatch { session_id: Uuid },
25 AgentPrompt { session_id: Uuid, text: String },
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
29#[serde(tag = "type")]
30pub enum ServerMessage {
31 SessionCreated { id: Uuid },
32 Sessions { sessions: Vec<SessionInfo> },
33 Attached { id: Uuid },
34 Detached,
35 SessionEnded { id: Uuid, exit_code: Option<i32> },
36 ClientJoined { session_id: Uuid, client_id: Uuid },
37 ClientLeft { session_id: Uuid, client_id: Uuid },
38 Error { message: String },
39 AgentListResponse { agents: Vec<AgentEntry> },
40 AgentConversationLine { session_id: Uuid, line: String },
41 AgentWatchEnd { session_id: Uuid },
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45pub struct SessionInfo {
46 pub id: Uuid,
47 pub cols: u16,
48 pub rows: u16,
49 pub created_at: DateTime<Utc>,
50 pub client_count: usize,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
54pub struct AgentEntry {
55 pub vex_session_id: Uuid,
56 pub claude_session_id: String,
57 pub claude_pid: u32,
58 pub cwd: PathBuf,
59 pub detected_at: DateTime<Utc>,
60 pub needs_intervention: bool,
61}
62
63#[derive(Debug)]
64pub enum Frame {
65 Control(Vec<u8>),
66 Data(Vec<u8>),
67}
68
69pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
70 let len = (1 + payload.len()) as u32;
71 w.write_all(&len.to_be_bytes()).await?;
72 w.write_u8(TAG_CONTROL).await?;
73 w.write_all(payload).await?;
74 w.flush().await?;
75 Ok(())
76}
77
78pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
79 let len = (1 + payload.len()) as u32;
80 w.write_all(&len.to_be_bytes()).await?;
81 w.write_u8(TAG_DATA).await?;
82 w.write_all(payload).await?;
83 w.flush().await?;
84 Ok(())
85}
86
87pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
88 let mut len_buf = [0u8; 4];
89 match r.read_exact(&mut len_buf).await {
90 Ok(_) => {}
91 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
92 Err(e) => return Err(e.into()),
93 }
94 let len = u32::from_be_bytes(len_buf) as usize;
95 if len == 0 {
96 bail!("invalid frame: zero length");
97 }
98 if len > MAX_FRAME_SIZE {
99 bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
100 }
101 let tag = {
102 let mut tag_buf = [0u8; 1];
103 r.read_exact(&mut tag_buf).await?;
104 tag_buf[0]
105 };
106 let payload_len = len - 1;
107 let mut payload = vec![0u8; payload_len];
108 if payload_len > 0 {
109 r.read_exact(&mut payload).await?;
110 }
111 match tag {
112 TAG_CONTROL => Ok(Some(Frame::Control(payload))),
113 TAG_DATA => Ok(Some(Frame::Data(payload))),
114 other => bail!("unknown frame tag: 0x{:02x}", other),
115 }
116}
117
118pub async fn send_client_message<W: AsyncWrite + Unpin>(
120 w: &mut W,
121 msg: &ClientMessage,
122) -> Result<()> {
123 let json = serde_json::to_vec(msg)?;
124 write_control(w, &json).await
125}
126
127pub async fn send_server_message<W: AsyncWrite + Unpin>(
129 w: &mut W,
130 msg: &ServerMessage,
131) -> Result<()> {
132 let json = serde_json::to_vec(msg)?;
133 write_control(w, &json).await
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn serde_round_trip_client() {
142 let msgs = vec![
143 ClientMessage::CreateSession {
144 shell: Some("bash".into()),
145 },
146 ClientMessage::ListSessions,
147 ClientMessage::AttachSession {
148 id: Uuid::nil(),
149 cols: 120,
150 rows: 40,
151 },
152 ClientMessage::DetachSession,
153 ClientMessage::ResizeSession {
154 id: Uuid::nil(),
155 cols: 80,
156 rows: 24,
157 },
158 ClientMessage::KillSession { id: Uuid::nil() },
159 ClientMessage::AgentList,
160 ClientMessage::AgentNotifications,
161 ClientMessage::AgentWatch {
162 session_id: Uuid::nil(),
163 },
164 ClientMessage::AgentPrompt {
165 session_id: Uuid::nil(),
166 text: "hello".into(),
167 },
168 ];
169 for msg in msgs {
170 let json = serde_json::to_string(&msg).unwrap();
171 let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
172 assert_eq!(msg, decoded);
173 }
174 }
175
176 #[test]
177 fn serde_round_trip_server() {
178 let msgs = vec![
179 ServerMessage::SessionCreated { id: Uuid::nil() },
180 ServerMessage::Sessions {
181 sessions: vec![SessionInfo {
182 id: Uuid::nil(),
183 cols: 80,
184 rows: 24,
185 created_at: Utc::now(),
186 client_count: 2,
187 }],
188 },
189 ServerMessage::Attached { id: Uuid::nil() },
190 ServerMessage::Detached,
191 ServerMessage::SessionEnded {
192 id: Uuid::nil(),
193 exit_code: Some(0),
194 },
195 ServerMessage::ClientJoined {
196 session_id: Uuid::nil(),
197 client_id: Uuid::nil(),
198 },
199 ServerMessage::ClientLeft {
200 session_id: Uuid::nil(),
201 client_id: Uuid::nil(),
202 },
203 ServerMessage::Error {
204 message: "fail".into(),
205 },
206 ServerMessage::AgentListResponse {
207 agents: vec![AgentEntry {
208 vex_session_id: Uuid::nil(),
209 claude_session_id: "abc123".into(),
210 claude_pid: 1234,
211 cwd: PathBuf::from("/tmp"),
212 detected_at: Utc::now(),
213 needs_intervention: true,
214 }],
215 },
216 ServerMessage::AgentConversationLine {
217 session_id: Uuid::nil(),
218 line: "test line".into(),
219 },
220 ServerMessage::AgentWatchEnd {
221 session_id: Uuid::nil(),
222 },
223 ];
224 for msg in msgs {
225 let json = serde_json::to_string(&msg).unwrap();
226 let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
227 assert_eq!(msg, decoded);
228 }
229 }
230
231 #[tokio::test]
232 async fn frame_round_trip_control() {
233 let (mut client, mut server) = tokio::io::duplex(1024);
234 let payload = b"hello control";
235 write_control(&mut client, payload).await.unwrap();
236 drop(client);
237 let frame = read_frame(&mut server).await.unwrap().unwrap();
238 match frame {
239 Frame::Control(data) => assert_eq!(data, payload),
240 Frame::Data(_) => panic!("expected control frame"),
241 }
242 }
243
244 #[tokio::test]
245 async fn frame_round_trip_data() {
246 let (mut client, mut server) = tokio::io::duplex(1024);
247 let payload = b"hello data";
248 write_data(&mut client, payload).await.unwrap();
249 drop(client);
250 let frame = read_frame(&mut server).await.unwrap().unwrap();
251 match frame {
252 Frame::Data(data) => assert_eq!(data, payload),
253 Frame::Control(_) => panic!("expected data frame"),
254 }
255 }
256
257 #[tokio::test]
258 async fn frame_eof_returns_none() {
259 let (client, mut server) = tokio::io::duplex(1024);
260 drop(client);
261 let frame = read_frame(&mut server).await.unwrap();
262 assert!(frame.is_none());
263 }
264
265 #[tokio::test]
266 async fn frame_bad_tag() {
267 let (mut client, mut server) = tokio::io::duplex(1024);
268 let len: u32 = 2; client.write_all(&len.to_be_bytes()).await.unwrap();
271 client.write_u8(0xFF).await.unwrap();
272 client.write_u8(0x00).await.unwrap();
273 drop(client);
274 let result = read_frame(&mut server).await;
275 assert!(result.is_err());
276 assert!(
277 result
278 .unwrap_err()
279 .to_string()
280 .contains("unknown frame tag")
281 );
282 }
283
284 #[tokio::test]
285 async fn frame_too_large() {
286 let (mut client, mut server) = tokio::io::duplex(1024);
287 let len: u32 = 2 * 1024 * 1024;
289 client.write_all(&len.to_be_bytes()).await.unwrap();
290 drop(client);
291 let result = read_frame(&mut server).await;
292 assert!(result.is_err());
293 assert!(result.unwrap_err().to_string().contains("frame too large"));
294 }
295
296 #[tokio::test]
297 async fn send_client_message_round_trip() {
298 let (mut client, mut server) = tokio::io::duplex(4096);
299 let msg = ClientMessage::CreateSession {
300 shell: Some("zsh".into()),
301 };
302 send_client_message(&mut client, &msg).await.unwrap();
303 drop(client);
304 let frame = read_frame(&mut server).await.unwrap().unwrap();
305 match frame {
306 Frame::Control(data) => {
307 let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
308 assert_eq!(decoded, msg);
309 }
310 Frame::Data(_) => panic!("expected control frame"),
311 }
312 }
313}