1use std::path::PathBuf;
2
3use anyhow::{Result, bail};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use uuid::Uuid;
8
9const TAG_CONTROL: u8 = 0x01;
10const TAG_DATA: u8 = 0x02;
11const MAX_FRAME_SIZE: usize = 1_048_576; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(tag = "type")]
15pub enum ClientMessage {
16 CreateSession {
17 shell: Option<String>,
18 repo: Option<String>,
19 },
20 ListSessions,
21 AttachSession {
22 id: Uuid,
23 cols: u16,
24 rows: u16,
25 },
26 DetachSession,
27 ResizeSession {
28 id: Uuid,
29 cols: u16,
30 rows: u16,
31 },
32 KillSession {
33 id: Uuid,
34 },
35 AgentList,
36 AgentNotifications,
37 AgentWatch {
38 session_id: Uuid,
39 },
40 AgentPrompt {
41 session_id: Uuid,
42 text: String,
43 },
44 RepoAdd {
45 name: String,
46 path: PathBuf,
47 },
48 RepoRemove {
49 name: String,
50 },
51 RepoList,
52 RepoIntrospectPath {
53 path: PathBuf,
54 },
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
58#[serde(tag = "type")]
59pub enum ServerMessage {
60 SessionCreated {
61 id: Uuid,
62 },
63 Sessions {
64 sessions: Vec<SessionInfo>,
65 },
66 Attached {
67 id: Uuid,
68 },
69 Detached,
70 SessionEnded {
71 id: Uuid,
72 exit_code: Option<i32>,
73 },
74 ClientJoined {
75 session_id: Uuid,
76 client_id: Uuid,
77 },
78 ClientLeft {
79 session_id: Uuid,
80 client_id: Uuid,
81 },
82 Error {
83 message: String,
84 },
85 AgentListResponse {
86 agents: Vec<AgentEntry>,
87 },
88 AgentPromptSent {
89 session_id: Uuid,
90 },
91 AgentConversationLine {
92 session_id: Uuid,
93 line: String,
94 },
95 AgentWatchEnd {
96 session_id: Uuid,
97 },
98 RepoAdded {
99 name: String,
100 path: PathBuf,
101 },
102 RepoRemoved {
103 name: String,
104 },
105 Repos {
106 repos: Vec<RepoEntry>,
107 },
108 RepoIntrospected {
109 suggested_name: String,
110 path: PathBuf,
111 git_remote: Option<String>,
112 git_branch: Option<String>,
113 },
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
117pub struct SessionInfo {
118 pub id: Uuid,
119 pub cols: u16,
120 pub rows: u16,
121 pub created_at: DateTime<Utc>,
122 pub client_count: usize,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
126pub struct AgentEntry {
127 pub vex_session_id: Uuid,
128 pub claude_session_id: String,
129 pub claude_pid: u32,
130 pub cwd: PathBuf,
131 pub detected_at: DateTime<Utc>,
132 pub needs_intervention: bool,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
136pub struct RepoEntry {
137 pub name: String,
138 pub path: PathBuf,
139}
140
141#[derive(Debug)]
142pub enum Frame {
143 Control(Vec<u8>),
144 Data(Vec<u8>),
145}
146
147pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
148 let len = (1 + payload.len()) as u32;
149 w.write_all(&len.to_be_bytes()).await?;
150 w.write_u8(TAG_CONTROL).await?;
151 w.write_all(payload).await?;
152 w.flush().await?;
153 Ok(())
154}
155
156pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
157 let len = (1 + payload.len()) as u32;
158 w.write_all(&len.to_be_bytes()).await?;
159 w.write_u8(TAG_DATA).await?;
160 w.write_all(payload).await?;
161 w.flush().await?;
162 Ok(())
163}
164
165pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
166 let mut len_buf = [0u8; 4];
167 match r.read_exact(&mut len_buf).await {
168 Ok(_) => {}
169 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
170 Err(e) => return Err(e.into()),
171 }
172 let len = u32::from_be_bytes(len_buf) as usize;
173 if len == 0 {
174 bail!("invalid frame: zero length");
175 }
176 if len > MAX_FRAME_SIZE {
177 bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
178 }
179 let tag = {
180 let mut tag_buf = [0u8; 1];
181 r.read_exact(&mut tag_buf).await?;
182 tag_buf[0]
183 };
184 let payload_len = len - 1;
185 let mut payload = vec![0u8; payload_len];
186 if payload_len > 0 {
187 r.read_exact(&mut payload).await?;
188 }
189 match tag {
190 TAG_CONTROL => Ok(Some(Frame::Control(payload))),
191 TAG_DATA => Ok(Some(Frame::Data(payload))),
192 other => bail!("unknown frame tag: 0x{:02x}", other),
193 }
194}
195
196pub async fn send_client_message<W: AsyncWrite + Unpin>(
198 w: &mut W,
199 msg: &ClientMessage,
200) -> Result<()> {
201 let json = serde_json::to_vec(msg)?;
202 write_control(w, &json).await
203}
204
205pub async fn send_server_message<W: AsyncWrite + Unpin>(
207 w: &mut W,
208 msg: &ServerMessage,
209) -> Result<()> {
210 let json = serde_json::to_vec(msg)?;
211 write_control(w, &json).await
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn serde_round_trip_client() {
220 let msgs = vec![
221 ClientMessage::CreateSession {
222 shell: Some("bash".into()),
223 repo: None,
224 },
225 ClientMessage::ListSessions,
226 ClientMessage::AttachSession {
227 id: Uuid::nil(),
228 cols: 120,
229 rows: 40,
230 },
231 ClientMessage::DetachSession,
232 ClientMessage::ResizeSession {
233 id: Uuid::nil(),
234 cols: 80,
235 rows: 24,
236 },
237 ClientMessage::KillSession { id: Uuid::nil() },
238 ClientMessage::AgentList,
239 ClientMessage::AgentNotifications,
240 ClientMessage::AgentWatch {
241 session_id: Uuid::nil(),
242 },
243 ClientMessage::AgentPrompt {
244 session_id: Uuid::nil(),
245 text: "hello".into(),
246 },
247 ClientMessage::RepoAdd {
248 name: "vex".into(),
249 path: PathBuf::from("/tmp/vex"),
250 },
251 ClientMessage::RepoRemove { name: "vex".into() },
252 ClientMessage::RepoList,
253 ClientMessage::RepoIntrospectPath {
254 path: PathBuf::from("/tmp"),
255 },
256 ];
257 for msg in msgs {
258 let json = serde_json::to_string(&msg).unwrap();
259 let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
260 assert_eq!(msg, decoded);
261 }
262 }
263
264 #[test]
265 fn serde_round_trip_server() {
266 let msgs = vec![
267 ServerMessage::SessionCreated { id: Uuid::nil() },
268 ServerMessage::Sessions {
269 sessions: vec![SessionInfo {
270 id: Uuid::nil(),
271 cols: 80,
272 rows: 24,
273 created_at: Utc::now(),
274 client_count: 2,
275 }],
276 },
277 ServerMessage::Attached { id: Uuid::nil() },
278 ServerMessage::Detached,
279 ServerMessage::SessionEnded {
280 id: Uuid::nil(),
281 exit_code: Some(0),
282 },
283 ServerMessage::ClientJoined {
284 session_id: Uuid::nil(),
285 client_id: Uuid::nil(),
286 },
287 ServerMessage::ClientLeft {
288 session_id: Uuid::nil(),
289 client_id: Uuid::nil(),
290 },
291 ServerMessage::Error {
292 message: "fail".into(),
293 },
294 ServerMessage::AgentListResponse {
295 agents: vec![AgentEntry {
296 vex_session_id: Uuid::nil(),
297 claude_session_id: "abc123".into(),
298 claude_pid: 1234,
299 cwd: PathBuf::from("/tmp"),
300 detected_at: Utc::now(),
301 needs_intervention: true,
302 }],
303 },
304 ServerMessage::AgentPromptSent {
305 session_id: Uuid::nil(),
306 },
307 ServerMessage::AgentConversationLine {
308 session_id: Uuid::nil(),
309 line: "test line".into(),
310 },
311 ServerMessage::AgentWatchEnd {
312 session_id: Uuid::nil(),
313 },
314 ServerMessage::RepoAdded {
315 name: "vex".into(),
316 path: PathBuf::from("/tmp/vex"),
317 },
318 ServerMessage::RepoRemoved { name: "vex".into() },
319 ServerMessage::Repos {
320 repos: vec![RepoEntry {
321 name: "vex".into(),
322 path: PathBuf::from("/tmp/vex"),
323 }],
324 },
325 ServerMessage::RepoIntrospected {
326 suggested_name: "vex".into(),
327 path: PathBuf::from("/tmp/vex"),
328 git_remote: Some("git@github.com:user/vex.git".into()),
329 git_branch: Some("main".into()),
330 },
331 ];
332 for msg in msgs {
333 let json = serde_json::to_string(&msg).unwrap();
334 let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
335 assert_eq!(msg, decoded);
336 }
337 }
338
339 #[tokio::test]
340 async fn frame_round_trip_control() {
341 let (mut client, mut server) = tokio::io::duplex(1024);
342 let payload = b"hello control";
343 write_control(&mut client, payload).await.unwrap();
344 drop(client);
345 let frame = read_frame(&mut server).await.unwrap().unwrap();
346 match frame {
347 Frame::Control(data) => assert_eq!(data, payload),
348 Frame::Data(_) => panic!("expected control frame"),
349 }
350 }
351
352 #[tokio::test]
353 async fn frame_round_trip_data() {
354 let (mut client, mut server) = tokio::io::duplex(1024);
355 let payload = b"hello data";
356 write_data(&mut client, payload).await.unwrap();
357 drop(client);
358 let frame = read_frame(&mut server).await.unwrap().unwrap();
359 match frame {
360 Frame::Data(data) => assert_eq!(data, payload),
361 Frame::Control(_) => panic!("expected data frame"),
362 }
363 }
364
365 #[tokio::test]
366 async fn frame_eof_returns_none() {
367 let (client, mut server) = tokio::io::duplex(1024);
368 drop(client);
369 let frame = read_frame(&mut server).await.unwrap();
370 assert!(frame.is_none());
371 }
372
373 #[tokio::test]
374 async fn frame_bad_tag() {
375 let (mut client, mut server) = tokio::io::duplex(1024);
376 let len: u32 = 2; client.write_all(&len.to_be_bytes()).await.unwrap();
379 client.write_u8(0xFF).await.unwrap();
380 client.write_u8(0x00).await.unwrap();
381 drop(client);
382 let result = read_frame(&mut server).await;
383 assert!(result.is_err());
384 assert!(
385 result
386 .unwrap_err()
387 .to_string()
388 .contains("unknown frame tag")
389 );
390 }
391
392 #[tokio::test]
393 async fn frame_too_large() {
394 let (mut client, mut server) = tokio::io::duplex(1024);
395 let len: u32 = 2 * 1024 * 1024;
397 client.write_all(&len.to_be_bytes()).await.unwrap();
398 drop(client);
399 let result = read_frame(&mut server).await;
400 assert!(result.is_err());
401 assert!(result.unwrap_err().to_string().contains("frame too large"));
402 }
403
404 #[tokio::test]
405 async fn send_client_message_round_trip() {
406 let (mut client, mut server) = tokio::io::duplex(4096);
407 let msg = ClientMessage::CreateSession {
408 shell: Some("zsh".into()),
409 repo: None,
410 };
411 send_client_message(&mut client, &msg).await.unwrap();
412 drop(client);
413 let frame = read_frame(&mut server).await.unwrap().unwrap();
414 match frame {
415 Frame::Control(data) => {
416 let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
417 assert_eq!(decoded, msg);
418 }
419 Frame::Data(_) => panic!("expected control frame"),
420 }
421 }
422}