1use std::path::PathBuf;
2
3use anyhow::{Result, bail};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use uuid::Uuid;
8
9const TAG_CONTROL: u8 = 0x01;
10const TAG_DATA: u8 = 0x02;
11const MAX_FRAME_SIZE: usize = 1_048_576; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(tag = "type")]
15pub enum ClientMessage {
16 CreateSession {
17 shell: Option<String>,
18 repo: Option<String>,
19 },
20 ListSessions,
21 AttachSession {
22 id: Uuid,
23 cols: u16,
24 rows: u16,
25 },
26 DetachSession,
27 ResizeSession {
28 id: Uuid,
29 cols: u16,
30 rows: u16,
31 },
32 KillSession {
33 id: Uuid,
34 },
35 AgentList,
36 AgentNotifications,
37 AgentWatch {
38 session_id: Uuid,
39 },
40 AgentPrompt {
41 session_id: Uuid,
42 text: String,
43 },
44 AgentSpawn {
45 repo: String,
46 workstream: Option<String>,
47 },
48 WorkstreamCreate {
49 repo: String,
50 name: String,
51 },
52 WorkstreamList {
53 repo: Option<String>,
54 },
55 WorkstreamRemove {
56 repo: String,
57 name: String,
58 },
59 RepoAdd {
60 name: String,
61 path: PathBuf,
62 },
63 RepoRemove {
64 name: String,
65 },
66 RepoList,
67 RepoIntrospectPath {
68 path: PathBuf,
69 },
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
73#[serde(tag = "type")]
74pub enum ServerMessage {
75 SessionCreated {
76 id: Uuid,
77 },
78 Sessions {
79 sessions: Vec<SessionInfo>,
80 },
81 Attached {
82 id: Uuid,
83 },
84 Detached,
85 SessionEnded {
86 id: Uuid,
87 exit_code: Option<i32>,
88 },
89 ClientJoined {
90 session_id: Uuid,
91 client_id: Uuid,
92 },
93 ClientLeft {
94 session_id: Uuid,
95 client_id: Uuid,
96 },
97 Error {
98 message: String,
99 },
100 AgentListResponse {
101 agents: Vec<AgentEntry>,
102 },
103 AgentPromptSent {
104 session_id: Uuid,
105 },
106 AgentConversationLine {
107 session_id: Uuid,
108 line: String,
109 },
110 AgentWatchEnd {
111 session_id: Uuid,
112 },
113 RepoAdded {
114 name: String,
115 path: PathBuf,
116 },
117 RepoRemoved {
118 name: String,
119 },
120 Repos {
121 repos: Vec<RepoEntry>,
122 },
123 RepoIntrospected {
124 suggested_name: String,
125 path: PathBuf,
126 git_remote: Option<String>,
127 git_branch: Option<String>,
128 children: Vec<String>,
129 },
130 WorkstreamCreated {
131 repo: String,
132 name: String,
133 worktree_path: PathBuf,
134 },
135 WorkstreamRemoved {
136 repo: String,
137 name: String,
138 },
139 Workstreams {
140 workstreams: Vec<WorkstreamInfo>,
141 },
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
145pub struct SessionInfo {
146 pub id: Uuid,
147 pub cols: u16,
148 pub rows: u16,
149 pub created_at: DateTime<Utc>,
150 pub client_count: usize,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
154pub struct AgentEntry {
155 pub vex_session_id: Uuid,
156 pub claude_session_id: String,
157 pub claude_pid: u32,
158 pub cwd: PathBuf,
159 pub detected_at: DateTime<Utc>,
160 pub needs_intervention: bool,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
164pub struct RepoEntry {
165 pub name: String,
166 pub path: PathBuf,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
170pub struct WorkstreamInfo {
171 pub repo: String,
172 pub name: String,
173 pub worktree_path: PathBuf,
174 pub branch: String,
175 pub created_at: DateTime<Utc>,
176}
177
178#[derive(Debug)]
179pub enum Frame {
180 Control(Vec<u8>),
181 Data(Vec<u8>),
182}
183
184pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
185 let len = (1 + payload.len()) as u32;
186 w.write_all(&len.to_be_bytes()).await?;
187 w.write_u8(TAG_CONTROL).await?;
188 w.write_all(payload).await?;
189 w.flush().await?;
190 Ok(())
191}
192
193pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
194 let len = (1 + payload.len()) as u32;
195 w.write_all(&len.to_be_bytes()).await?;
196 w.write_u8(TAG_DATA).await?;
197 w.write_all(payload).await?;
198 w.flush().await?;
199 Ok(())
200}
201
202pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
203 let mut len_buf = [0u8; 4];
204 match r.read_exact(&mut len_buf).await {
205 Ok(_) => {}
206 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
207 Err(e) => return Err(e.into()),
208 }
209 let len = u32::from_be_bytes(len_buf) as usize;
210 if len == 0 {
211 bail!("invalid frame: zero length");
212 }
213 if len > MAX_FRAME_SIZE {
214 bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
215 }
216 let tag = {
217 let mut tag_buf = [0u8; 1];
218 r.read_exact(&mut tag_buf).await?;
219 tag_buf[0]
220 };
221 let payload_len = len - 1;
222 let mut payload = vec![0u8; payload_len];
223 if payload_len > 0 {
224 r.read_exact(&mut payload).await?;
225 }
226 match tag {
227 TAG_CONTROL => Ok(Some(Frame::Control(payload))),
228 TAG_DATA => Ok(Some(Frame::Data(payload))),
229 other => bail!("unknown frame tag: 0x{:02x}", other),
230 }
231}
232
233pub async fn send_client_message<W: AsyncWrite + Unpin>(
235 w: &mut W,
236 msg: &ClientMessage,
237) -> Result<()> {
238 let json = serde_json::to_vec(msg)?;
239 write_control(w, &json).await
240}
241
242pub async fn send_server_message<W: AsyncWrite + Unpin>(
244 w: &mut W,
245 msg: &ServerMessage,
246) -> Result<()> {
247 let json = serde_json::to_vec(msg)?;
248 write_control(w, &json).await
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn serde_round_trip_client() {
257 let msgs = vec![
258 ClientMessage::CreateSession {
259 shell: Some("bash".into()),
260 repo: None,
261 },
262 ClientMessage::ListSessions,
263 ClientMessage::AttachSession {
264 id: Uuid::nil(),
265 cols: 120,
266 rows: 40,
267 },
268 ClientMessage::DetachSession,
269 ClientMessage::ResizeSession {
270 id: Uuid::nil(),
271 cols: 80,
272 rows: 24,
273 },
274 ClientMessage::KillSession { id: Uuid::nil() },
275 ClientMessage::AgentList,
276 ClientMessage::AgentNotifications,
277 ClientMessage::AgentWatch {
278 session_id: Uuid::nil(),
279 },
280 ClientMessage::AgentPrompt {
281 session_id: Uuid::nil(),
282 text: "hello".into(),
283 },
284 ClientMessage::AgentSpawn {
285 repo: "vex".into(),
286 workstream: None,
287 },
288 ClientMessage::AgentSpawn {
289 repo: "vex".into(),
290 workstream: Some("feature-x".into()),
291 },
292 ClientMessage::WorkstreamCreate {
293 repo: "vex".into(),
294 name: "feature-x".into(),
295 },
296 ClientMessage::WorkstreamList { repo: None },
297 ClientMessage::WorkstreamList {
298 repo: Some("vex".into()),
299 },
300 ClientMessage::WorkstreamRemove {
301 repo: "vex".into(),
302 name: "feature-x".into(),
303 },
304 ClientMessage::RepoAdd {
305 name: "vex".into(),
306 path: PathBuf::from("/tmp/vex"),
307 },
308 ClientMessage::RepoRemove { name: "vex".into() },
309 ClientMessage::RepoList,
310 ClientMessage::RepoIntrospectPath {
311 path: PathBuf::from("/tmp"),
312 },
313 ];
314 for msg in msgs {
315 let json = serde_json::to_string(&msg).unwrap();
316 let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
317 assert_eq!(msg, decoded);
318 }
319 }
320
321 #[test]
322 fn serde_round_trip_server() {
323 let msgs = vec![
324 ServerMessage::SessionCreated { id: Uuid::nil() },
325 ServerMessage::Sessions {
326 sessions: vec![SessionInfo {
327 id: Uuid::nil(),
328 cols: 80,
329 rows: 24,
330 created_at: Utc::now(),
331 client_count: 2,
332 }],
333 },
334 ServerMessage::Attached { id: Uuid::nil() },
335 ServerMessage::Detached,
336 ServerMessage::SessionEnded {
337 id: Uuid::nil(),
338 exit_code: Some(0),
339 },
340 ServerMessage::ClientJoined {
341 session_id: Uuid::nil(),
342 client_id: Uuid::nil(),
343 },
344 ServerMessage::ClientLeft {
345 session_id: Uuid::nil(),
346 client_id: Uuid::nil(),
347 },
348 ServerMessage::Error {
349 message: "fail".into(),
350 },
351 ServerMessage::AgentListResponse {
352 agents: vec![AgentEntry {
353 vex_session_id: Uuid::nil(),
354 claude_session_id: "abc123".into(),
355 claude_pid: 1234,
356 cwd: PathBuf::from("/tmp"),
357 detected_at: Utc::now(),
358 needs_intervention: true,
359 }],
360 },
361 ServerMessage::AgentPromptSent {
362 session_id: Uuid::nil(),
363 },
364 ServerMessage::AgentConversationLine {
365 session_id: Uuid::nil(),
366 line: "test line".into(),
367 },
368 ServerMessage::AgentWatchEnd {
369 session_id: Uuid::nil(),
370 },
371 ServerMessage::RepoAdded {
372 name: "vex".into(),
373 path: PathBuf::from("/tmp/vex"),
374 },
375 ServerMessage::RepoRemoved { name: "vex".into() },
376 ServerMessage::Repos {
377 repos: vec![RepoEntry {
378 name: "vex".into(),
379 path: PathBuf::from("/tmp/vex"),
380 }],
381 },
382 ServerMessage::RepoIntrospected {
383 suggested_name: "vex".into(),
384 path: PathBuf::from("/tmp/vex"),
385 git_remote: Some("git@github.com:user/vex.git".into()),
386 git_branch: Some("main".into()),
387 children: vec!["src".into(), "tests".into()],
388 },
389 ServerMessage::WorkstreamCreated {
390 repo: "vex".into(),
391 name: "feature-x".into(),
392 worktree_path: PathBuf::from("/tmp/workstreams/vex/feature-x"),
393 },
394 ServerMessage::WorkstreamRemoved {
395 repo: "vex".into(),
396 name: "feature-x".into(),
397 },
398 ServerMessage::Workstreams {
399 workstreams: vec![WorkstreamInfo {
400 repo: "vex".into(),
401 name: "feature-x".into(),
402 worktree_path: PathBuf::from("/tmp/workstreams/vex/feature-x"),
403 branch: "feature-x".into(),
404 created_at: Utc::now(),
405 }],
406 },
407 ];
408 for msg in msgs {
409 let json = serde_json::to_string(&msg).unwrap();
410 let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
411 assert_eq!(msg, decoded);
412 }
413 }
414
415 #[tokio::test]
416 async fn frame_round_trip_control() {
417 let (mut client, mut server) = tokio::io::duplex(1024);
418 let payload = b"hello control";
419 write_control(&mut client, payload).await.unwrap();
420 drop(client);
421 let frame = read_frame(&mut server).await.unwrap().unwrap();
422 match frame {
423 Frame::Control(data) => assert_eq!(data, payload),
424 Frame::Data(_) => panic!("expected control frame"),
425 }
426 }
427
428 #[tokio::test]
429 async fn frame_round_trip_data() {
430 let (mut client, mut server) = tokio::io::duplex(1024);
431 let payload = b"hello data";
432 write_data(&mut client, payload).await.unwrap();
433 drop(client);
434 let frame = read_frame(&mut server).await.unwrap().unwrap();
435 match frame {
436 Frame::Data(data) => assert_eq!(data, payload),
437 Frame::Control(_) => panic!("expected data frame"),
438 }
439 }
440
441 #[tokio::test]
442 async fn frame_eof_returns_none() {
443 let (client, mut server) = tokio::io::duplex(1024);
444 drop(client);
445 let frame = read_frame(&mut server).await.unwrap();
446 assert!(frame.is_none());
447 }
448
449 #[tokio::test]
450 async fn frame_bad_tag() {
451 let (mut client, mut server) = tokio::io::duplex(1024);
452 let len: u32 = 2; client.write_all(&len.to_be_bytes()).await.unwrap();
455 client.write_u8(0xFF).await.unwrap();
456 client.write_u8(0x00).await.unwrap();
457 drop(client);
458 let result = read_frame(&mut server).await;
459 assert!(result.is_err());
460 assert!(
461 result
462 .unwrap_err()
463 .to_string()
464 .contains("unknown frame tag")
465 );
466 }
467
468 #[tokio::test]
469 async fn frame_too_large() {
470 let (mut client, mut server) = tokio::io::duplex(1024);
471 let len: u32 = 2 * 1024 * 1024;
473 client.write_all(&len.to_be_bytes()).await.unwrap();
474 drop(client);
475 let result = read_frame(&mut server).await;
476 assert!(result.is_err());
477 assert!(result.unwrap_err().to_string().contains("frame too large"));
478 }
479
480 #[tokio::test]
481 async fn send_client_message_round_trip() {
482 let (mut client, mut server) = tokio::io::duplex(4096);
483 let msg = ClientMessage::CreateSession {
484 shell: Some("zsh".into()),
485 repo: None,
486 };
487 send_client_message(&mut client, &msg).await.unwrap();
488 drop(client);
489 let frame = read_frame(&mut server).await.unwrap().unwrap();
490 match frame {
491 Frame::Control(data) => {
492 let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
493 assert_eq!(decoded, msg);
494 }
495 Frame::Data(_) => panic!("expected control frame"),
496 }
497 }
498}