1use std::path::PathBuf;
2
3use anyhow::{Result, bail};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use uuid::Uuid;
8
9const TAG_CONTROL: u8 = 0x01;
10const TAG_DATA: u8 = 0x02;
11const MAX_FRAME_SIZE: usize = 1_048_576; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(tag = "type")]
15pub enum ClientMessage {
16 CreateSession {
17 shell: Option<String>,
18 repo: Option<String>,
19 },
20 ListSessions,
21 AttachSession {
22 id: Uuid,
23 cols: u16,
24 rows: u16,
25 },
26 DetachSession,
27 ResizeSession {
28 id: Uuid,
29 cols: u16,
30 rows: u16,
31 },
32 KillSession {
33 id: Uuid,
34 },
35 AgentList,
36 AgentNotifications,
37 AgentWatch {
38 session_id: Uuid,
39 },
40 AgentPrompt {
41 session_id: Uuid,
42 text: String,
43 },
44 AgentSpawn {
45 repo: String,
46 },
47 RepoAdd {
48 name: String,
49 path: PathBuf,
50 },
51 RepoRemove {
52 name: String,
53 },
54 RepoList,
55 RepoIntrospectPath {
56 path: PathBuf,
57 },
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
61#[serde(tag = "type")]
62pub enum ServerMessage {
63 SessionCreated {
64 id: Uuid,
65 },
66 Sessions {
67 sessions: Vec<SessionInfo>,
68 },
69 Attached {
70 id: Uuid,
71 },
72 Detached,
73 SessionEnded {
74 id: Uuid,
75 exit_code: Option<i32>,
76 },
77 ClientJoined {
78 session_id: Uuid,
79 client_id: Uuid,
80 },
81 ClientLeft {
82 session_id: Uuid,
83 client_id: Uuid,
84 },
85 Error {
86 message: String,
87 },
88 AgentListResponse {
89 agents: Vec<AgentEntry>,
90 },
91 AgentPromptSent {
92 session_id: Uuid,
93 },
94 AgentConversationLine {
95 session_id: Uuid,
96 line: String,
97 },
98 AgentWatchEnd {
99 session_id: Uuid,
100 },
101 RepoAdded {
102 name: String,
103 path: PathBuf,
104 },
105 RepoRemoved {
106 name: String,
107 },
108 Repos {
109 repos: Vec<RepoEntry>,
110 },
111 RepoIntrospected {
112 suggested_name: String,
113 path: PathBuf,
114 git_remote: Option<String>,
115 git_branch: Option<String>,
116 },
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
120pub struct SessionInfo {
121 pub id: Uuid,
122 pub cols: u16,
123 pub rows: u16,
124 pub created_at: DateTime<Utc>,
125 pub client_count: usize,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
129pub struct AgentEntry {
130 pub vex_session_id: Uuid,
131 pub claude_session_id: String,
132 pub claude_pid: u32,
133 pub cwd: PathBuf,
134 pub detected_at: DateTime<Utc>,
135 pub needs_intervention: bool,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
139pub struct RepoEntry {
140 pub name: String,
141 pub path: PathBuf,
142}
143
144#[derive(Debug)]
145pub enum Frame {
146 Control(Vec<u8>),
147 Data(Vec<u8>),
148}
149
150pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
151 let len = (1 + payload.len()) as u32;
152 w.write_all(&len.to_be_bytes()).await?;
153 w.write_u8(TAG_CONTROL).await?;
154 w.write_all(payload).await?;
155 w.flush().await?;
156 Ok(())
157}
158
159pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
160 let len = (1 + payload.len()) as u32;
161 w.write_all(&len.to_be_bytes()).await?;
162 w.write_u8(TAG_DATA).await?;
163 w.write_all(payload).await?;
164 w.flush().await?;
165 Ok(())
166}
167
168pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
169 let mut len_buf = [0u8; 4];
170 match r.read_exact(&mut len_buf).await {
171 Ok(_) => {}
172 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
173 Err(e) => return Err(e.into()),
174 }
175 let len = u32::from_be_bytes(len_buf) as usize;
176 if len == 0 {
177 bail!("invalid frame: zero length");
178 }
179 if len > MAX_FRAME_SIZE {
180 bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
181 }
182 let tag = {
183 let mut tag_buf = [0u8; 1];
184 r.read_exact(&mut tag_buf).await?;
185 tag_buf[0]
186 };
187 let payload_len = len - 1;
188 let mut payload = vec![0u8; payload_len];
189 if payload_len > 0 {
190 r.read_exact(&mut payload).await?;
191 }
192 match tag {
193 TAG_CONTROL => Ok(Some(Frame::Control(payload))),
194 TAG_DATA => Ok(Some(Frame::Data(payload))),
195 other => bail!("unknown frame tag: 0x{:02x}", other),
196 }
197}
198
199pub async fn send_client_message<W: AsyncWrite + Unpin>(
201 w: &mut W,
202 msg: &ClientMessage,
203) -> Result<()> {
204 let json = serde_json::to_vec(msg)?;
205 write_control(w, &json).await
206}
207
208pub async fn send_server_message<W: AsyncWrite + Unpin>(
210 w: &mut W,
211 msg: &ServerMessage,
212) -> Result<()> {
213 let json = serde_json::to_vec(msg)?;
214 write_control(w, &json).await
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn serde_round_trip_client() {
223 let msgs = vec![
224 ClientMessage::CreateSession {
225 shell: Some("bash".into()),
226 repo: None,
227 },
228 ClientMessage::ListSessions,
229 ClientMessage::AttachSession {
230 id: Uuid::nil(),
231 cols: 120,
232 rows: 40,
233 },
234 ClientMessage::DetachSession,
235 ClientMessage::ResizeSession {
236 id: Uuid::nil(),
237 cols: 80,
238 rows: 24,
239 },
240 ClientMessage::KillSession { id: Uuid::nil() },
241 ClientMessage::AgentList,
242 ClientMessage::AgentNotifications,
243 ClientMessage::AgentWatch {
244 session_id: Uuid::nil(),
245 },
246 ClientMessage::AgentPrompt {
247 session_id: Uuid::nil(),
248 text: "hello".into(),
249 },
250 ClientMessage::AgentSpawn { repo: "vex".into() },
251 ClientMessage::RepoAdd {
252 name: "vex".into(),
253 path: PathBuf::from("/tmp/vex"),
254 },
255 ClientMessage::RepoRemove { name: "vex".into() },
256 ClientMessage::RepoList,
257 ClientMessage::RepoIntrospectPath {
258 path: PathBuf::from("/tmp"),
259 },
260 ];
261 for msg in msgs {
262 let json = serde_json::to_string(&msg).unwrap();
263 let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
264 assert_eq!(msg, decoded);
265 }
266 }
267
268 #[test]
269 fn serde_round_trip_server() {
270 let msgs = vec![
271 ServerMessage::SessionCreated { id: Uuid::nil() },
272 ServerMessage::Sessions {
273 sessions: vec![SessionInfo {
274 id: Uuid::nil(),
275 cols: 80,
276 rows: 24,
277 created_at: Utc::now(),
278 client_count: 2,
279 }],
280 },
281 ServerMessage::Attached { id: Uuid::nil() },
282 ServerMessage::Detached,
283 ServerMessage::SessionEnded {
284 id: Uuid::nil(),
285 exit_code: Some(0),
286 },
287 ServerMessage::ClientJoined {
288 session_id: Uuid::nil(),
289 client_id: Uuid::nil(),
290 },
291 ServerMessage::ClientLeft {
292 session_id: Uuid::nil(),
293 client_id: Uuid::nil(),
294 },
295 ServerMessage::Error {
296 message: "fail".into(),
297 },
298 ServerMessage::AgentListResponse {
299 agents: vec![AgentEntry {
300 vex_session_id: Uuid::nil(),
301 claude_session_id: "abc123".into(),
302 claude_pid: 1234,
303 cwd: PathBuf::from("/tmp"),
304 detected_at: Utc::now(),
305 needs_intervention: true,
306 }],
307 },
308 ServerMessage::AgentPromptSent {
309 session_id: Uuid::nil(),
310 },
311 ServerMessage::AgentConversationLine {
312 session_id: Uuid::nil(),
313 line: "test line".into(),
314 },
315 ServerMessage::AgentWatchEnd {
316 session_id: Uuid::nil(),
317 },
318 ServerMessage::RepoAdded {
319 name: "vex".into(),
320 path: PathBuf::from("/tmp/vex"),
321 },
322 ServerMessage::RepoRemoved { name: "vex".into() },
323 ServerMessage::Repos {
324 repos: vec![RepoEntry {
325 name: "vex".into(),
326 path: PathBuf::from("/tmp/vex"),
327 }],
328 },
329 ServerMessage::RepoIntrospected {
330 suggested_name: "vex".into(),
331 path: PathBuf::from("/tmp/vex"),
332 git_remote: Some("git@github.com:user/vex.git".into()),
333 git_branch: Some("main".into()),
334 },
335 ];
336 for msg in msgs {
337 let json = serde_json::to_string(&msg).unwrap();
338 let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
339 assert_eq!(msg, decoded);
340 }
341 }
342
343 #[tokio::test]
344 async fn frame_round_trip_control() {
345 let (mut client, mut server) = tokio::io::duplex(1024);
346 let payload = b"hello control";
347 write_control(&mut client, payload).await.unwrap();
348 drop(client);
349 let frame = read_frame(&mut server).await.unwrap().unwrap();
350 match frame {
351 Frame::Control(data) => assert_eq!(data, payload),
352 Frame::Data(_) => panic!("expected control frame"),
353 }
354 }
355
356 #[tokio::test]
357 async fn frame_round_trip_data() {
358 let (mut client, mut server) = tokio::io::duplex(1024);
359 let payload = b"hello data";
360 write_data(&mut client, payload).await.unwrap();
361 drop(client);
362 let frame = read_frame(&mut server).await.unwrap().unwrap();
363 match frame {
364 Frame::Data(data) => assert_eq!(data, payload),
365 Frame::Control(_) => panic!("expected data frame"),
366 }
367 }
368
369 #[tokio::test]
370 async fn frame_eof_returns_none() {
371 let (client, mut server) = tokio::io::duplex(1024);
372 drop(client);
373 let frame = read_frame(&mut server).await.unwrap();
374 assert!(frame.is_none());
375 }
376
377 #[tokio::test]
378 async fn frame_bad_tag() {
379 let (mut client, mut server) = tokio::io::duplex(1024);
380 let len: u32 = 2; client.write_all(&len.to_be_bytes()).await.unwrap();
383 client.write_u8(0xFF).await.unwrap();
384 client.write_u8(0x00).await.unwrap();
385 drop(client);
386 let result = read_frame(&mut server).await;
387 assert!(result.is_err());
388 assert!(
389 result
390 .unwrap_err()
391 .to_string()
392 .contains("unknown frame tag")
393 );
394 }
395
396 #[tokio::test]
397 async fn frame_too_large() {
398 let (mut client, mut server) = tokio::io::duplex(1024);
399 let len: u32 = 2 * 1024 * 1024;
401 client.write_all(&len.to_be_bytes()).await.unwrap();
402 drop(client);
403 let result = read_frame(&mut server).await;
404 assert!(result.is_err());
405 assert!(result.unwrap_err().to_string().contains("frame too large"));
406 }
407
408 #[tokio::test]
409 async fn send_client_message_round_trip() {
410 let (mut client, mut server) = tokio::io::duplex(4096);
411 let msg = ClientMessage::CreateSession {
412 shell: Some("zsh".into()),
413 repo: None,
414 };
415 send_client_message(&mut client, &msg).await.unwrap();
416 drop(client);
417 let frame = read_frame(&mut server).await.unwrap().unwrap();
418 match frame {
419 Frame::Control(data) => {
420 let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
421 assert_eq!(decoded, msg);
422 }
423 Frame::Data(_) => panic!("expected control frame"),
424 }
425 }
426}