1use std::path::PathBuf;
2
3use anyhow::{Result, bail};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use uuid::Uuid;
8
9const TAG_CONTROL: u8 = 0x01;
10const TAG_DATA: u8 = 0x02;
11const MAX_FRAME_SIZE: usize = 1_048_576; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(tag = "type")]
15pub enum ClientMessage {
16 CreateSession {
17 shell: Option<String>,
18 repo: Option<String>,
19 },
20 ListSessions,
21 AttachSession {
22 id: Uuid,
23 cols: u16,
24 rows: u16,
25 },
26 DetachSession,
27 ResizeSession {
28 id: Uuid,
29 cols: u16,
30 rows: u16,
31 },
32 KillSession {
33 id: Uuid,
34 },
35 AgentList,
36 AgentNotifications,
37 AgentWatch {
38 session_id: Uuid,
39 },
40 AgentPrompt {
41 session_id: Uuid,
42 text: String,
43 },
44 AgentSpawn {
45 repo: String,
46 workstream: Option<String>,
47 },
48 WorkstreamCreate {
49 repo: String,
50 name: String,
51 },
52 WorkstreamList {
53 repo: Option<String>,
54 },
55 WorkstreamRemove {
56 repo: String,
57 name: String,
58 },
59 RepoAdd {
60 name: String,
61 path: PathBuf,
62 },
63 RepoRemove {
64 name: String,
65 },
66 RepoList,
67 RepoIntrospectPath {
68 path: PathBuf,
69 },
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
73#[serde(tag = "type")]
74pub enum ServerMessage {
75 SessionCreated {
76 id: Uuid,
77 },
78 Sessions {
79 sessions: Vec<SessionInfo>,
80 },
81 Attached {
82 id: Uuid,
83 },
84 Detached,
85 SessionEnded {
86 id: Uuid,
87 exit_code: Option<i32>,
88 },
89 ClientJoined {
90 session_id: Uuid,
91 client_id: Uuid,
92 },
93 ClientLeft {
94 session_id: Uuid,
95 client_id: Uuid,
96 },
97 Error {
98 message: String,
99 },
100 AgentListResponse {
101 agents: Vec<AgentEntry>,
102 },
103 AgentPromptSent {
104 session_id: Uuid,
105 },
106 AgentConversationLine {
107 session_id: Uuid,
108 line: String,
109 },
110 AgentWatchEnd {
111 session_id: Uuid,
112 },
113 RepoAdded {
114 name: String,
115 path: PathBuf,
116 },
117 RepoRemoved {
118 name: String,
119 },
120 Repos {
121 repos: Vec<RepoEntry>,
122 },
123 RepoIntrospected {
124 suggested_name: String,
125 path: PathBuf,
126 git_remote: Option<String>,
127 git_branch: Option<String>,
128 },
129 WorkstreamCreated {
130 repo: String,
131 name: String,
132 worktree_path: PathBuf,
133 },
134 WorkstreamRemoved {
135 repo: String,
136 name: String,
137 },
138 Workstreams {
139 workstreams: Vec<WorkstreamInfo>,
140 },
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
144pub struct SessionInfo {
145 pub id: Uuid,
146 pub cols: u16,
147 pub rows: u16,
148 pub created_at: DateTime<Utc>,
149 pub client_count: usize,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
153pub struct AgentEntry {
154 pub vex_session_id: Uuid,
155 pub claude_session_id: String,
156 pub claude_pid: u32,
157 pub cwd: PathBuf,
158 pub detected_at: DateTime<Utc>,
159 pub needs_intervention: bool,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
163pub struct RepoEntry {
164 pub name: String,
165 pub path: PathBuf,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
169pub struct WorkstreamInfo {
170 pub repo: String,
171 pub name: String,
172 pub worktree_path: PathBuf,
173 pub branch: String,
174 pub created_at: DateTime<Utc>,
175}
176
177#[derive(Debug)]
178pub enum Frame {
179 Control(Vec<u8>),
180 Data(Vec<u8>),
181}
182
183pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
184 let len = (1 + payload.len()) as u32;
185 w.write_all(&len.to_be_bytes()).await?;
186 w.write_u8(TAG_CONTROL).await?;
187 w.write_all(payload).await?;
188 w.flush().await?;
189 Ok(())
190}
191
192pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
193 let len = (1 + payload.len()) as u32;
194 w.write_all(&len.to_be_bytes()).await?;
195 w.write_u8(TAG_DATA).await?;
196 w.write_all(payload).await?;
197 w.flush().await?;
198 Ok(())
199}
200
201pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
202 let mut len_buf = [0u8; 4];
203 match r.read_exact(&mut len_buf).await {
204 Ok(_) => {}
205 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
206 Err(e) => return Err(e.into()),
207 }
208 let len = u32::from_be_bytes(len_buf) as usize;
209 if len == 0 {
210 bail!("invalid frame: zero length");
211 }
212 if len > MAX_FRAME_SIZE {
213 bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
214 }
215 let tag = {
216 let mut tag_buf = [0u8; 1];
217 r.read_exact(&mut tag_buf).await?;
218 tag_buf[0]
219 };
220 let payload_len = len - 1;
221 let mut payload = vec![0u8; payload_len];
222 if payload_len > 0 {
223 r.read_exact(&mut payload).await?;
224 }
225 match tag {
226 TAG_CONTROL => Ok(Some(Frame::Control(payload))),
227 TAG_DATA => Ok(Some(Frame::Data(payload))),
228 other => bail!("unknown frame tag: 0x{:02x}", other),
229 }
230}
231
232pub async fn send_client_message<W: AsyncWrite + Unpin>(
234 w: &mut W,
235 msg: &ClientMessage,
236) -> Result<()> {
237 let json = serde_json::to_vec(msg)?;
238 write_control(w, &json).await
239}
240
241pub async fn send_server_message<W: AsyncWrite + Unpin>(
243 w: &mut W,
244 msg: &ServerMessage,
245) -> Result<()> {
246 let json = serde_json::to_vec(msg)?;
247 write_control(w, &json).await
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn serde_round_trip_client() {
256 let msgs = vec![
257 ClientMessage::CreateSession {
258 shell: Some("bash".into()),
259 repo: None,
260 },
261 ClientMessage::ListSessions,
262 ClientMessage::AttachSession {
263 id: Uuid::nil(),
264 cols: 120,
265 rows: 40,
266 },
267 ClientMessage::DetachSession,
268 ClientMessage::ResizeSession {
269 id: Uuid::nil(),
270 cols: 80,
271 rows: 24,
272 },
273 ClientMessage::KillSession { id: Uuid::nil() },
274 ClientMessage::AgentList,
275 ClientMessage::AgentNotifications,
276 ClientMessage::AgentWatch {
277 session_id: Uuid::nil(),
278 },
279 ClientMessage::AgentPrompt {
280 session_id: Uuid::nil(),
281 text: "hello".into(),
282 },
283 ClientMessage::AgentSpawn {
284 repo: "vex".into(),
285 workstream: None,
286 },
287 ClientMessage::AgentSpawn {
288 repo: "vex".into(),
289 workstream: Some("feature-x".into()),
290 },
291 ClientMessage::WorkstreamCreate {
292 repo: "vex".into(),
293 name: "feature-x".into(),
294 },
295 ClientMessage::WorkstreamList { repo: None },
296 ClientMessage::WorkstreamList {
297 repo: Some("vex".into()),
298 },
299 ClientMessage::WorkstreamRemove {
300 repo: "vex".into(),
301 name: "feature-x".into(),
302 },
303 ClientMessage::RepoAdd {
304 name: "vex".into(),
305 path: PathBuf::from("/tmp/vex"),
306 },
307 ClientMessage::RepoRemove { name: "vex".into() },
308 ClientMessage::RepoList,
309 ClientMessage::RepoIntrospectPath {
310 path: PathBuf::from("/tmp"),
311 },
312 ];
313 for msg in msgs {
314 let json = serde_json::to_string(&msg).unwrap();
315 let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
316 assert_eq!(msg, decoded);
317 }
318 }
319
320 #[test]
321 fn serde_round_trip_server() {
322 let msgs = vec![
323 ServerMessage::SessionCreated { id: Uuid::nil() },
324 ServerMessage::Sessions {
325 sessions: vec![SessionInfo {
326 id: Uuid::nil(),
327 cols: 80,
328 rows: 24,
329 created_at: Utc::now(),
330 client_count: 2,
331 }],
332 },
333 ServerMessage::Attached { id: Uuid::nil() },
334 ServerMessage::Detached,
335 ServerMessage::SessionEnded {
336 id: Uuid::nil(),
337 exit_code: Some(0),
338 },
339 ServerMessage::ClientJoined {
340 session_id: Uuid::nil(),
341 client_id: Uuid::nil(),
342 },
343 ServerMessage::ClientLeft {
344 session_id: Uuid::nil(),
345 client_id: Uuid::nil(),
346 },
347 ServerMessage::Error {
348 message: "fail".into(),
349 },
350 ServerMessage::AgentListResponse {
351 agents: vec![AgentEntry {
352 vex_session_id: Uuid::nil(),
353 claude_session_id: "abc123".into(),
354 claude_pid: 1234,
355 cwd: PathBuf::from("/tmp"),
356 detected_at: Utc::now(),
357 needs_intervention: true,
358 }],
359 },
360 ServerMessage::AgentPromptSent {
361 session_id: Uuid::nil(),
362 },
363 ServerMessage::AgentConversationLine {
364 session_id: Uuid::nil(),
365 line: "test line".into(),
366 },
367 ServerMessage::AgentWatchEnd {
368 session_id: Uuid::nil(),
369 },
370 ServerMessage::RepoAdded {
371 name: "vex".into(),
372 path: PathBuf::from("/tmp/vex"),
373 },
374 ServerMessage::RepoRemoved { name: "vex".into() },
375 ServerMessage::Repos {
376 repos: vec![RepoEntry {
377 name: "vex".into(),
378 path: PathBuf::from("/tmp/vex"),
379 }],
380 },
381 ServerMessage::RepoIntrospected {
382 suggested_name: "vex".into(),
383 path: PathBuf::from("/tmp/vex"),
384 git_remote: Some("git@github.com:user/vex.git".into()),
385 git_branch: Some("main".into()),
386 },
387 ServerMessage::WorkstreamCreated {
388 repo: "vex".into(),
389 name: "feature-x".into(),
390 worktree_path: PathBuf::from("/tmp/workstreams/vex/feature-x"),
391 },
392 ServerMessage::WorkstreamRemoved {
393 repo: "vex".into(),
394 name: "feature-x".into(),
395 },
396 ServerMessage::Workstreams {
397 workstreams: vec![WorkstreamInfo {
398 repo: "vex".into(),
399 name: "feature-x".into(),
400 worktree_path: PathBuf::from("/tmp/workstreams/vex/feature-x"),
401 branch: "feature-x".into(),
402 created_at: Utc::now(),
403 }],
404 },
405 ];
406 for msg in msgs {
407 let json = serde_json::to_string(&msg).unwrap();
408 let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
409 assert_eq!(msg, decoded);
410 }
411 }
412
413 #[tokio::test]
414 async fn frame_round_trip_control() {
415 let (mut client, mut server) = tokio::io::duplex(1024);
416 let payload = b"hello control";
417 write_control(&mut client, payload).await.unwrap();
418 drop(client);
419 let frame = read_frame(&mut server).await.unwrap().unwrap();
420 match frame {
421 Frame::Control(data) => assert_eq!(data, payload),
422 Frame::Data(_) => panic!("expected control frame"),
423 }
424 }
425
426 #[tokio::test]
427 async fn frame_round_trip_data() {
428 let (mut client, mut server) = tokio::io::duplex(1024);
429 let payload = b"hello data";
430 write_data(&mut client, payload).await.unwrap();
431 drop(client);
432 let frame = read_frame(&mut server).await.unwrap().unwrap();
433 match frame {
434 Frame::Data(data) => assert_eq!(data, payload),
435 Frame::Control(_) => panic!("expected data frame"),
436 }
437 }
438
439 #[tokio::test]
440 async fn frame_eof_returns_none() {
441 let (client, mut server) = tokio::io::duplex(1024);
442 drop(client);
443 let frame = read_frame(&mut server).await.unwrap();
444 assert!(frame.is_none());
445 }
446
447 #[tokio::test]
448 async fn frame_bad_tag() {
449 let (mut client, mut server) = tokio::io::duplex(1024);
450 let len: u32 = 2; client.write_all(&len.to_be_bytes()).await.unwrap();
453 client.write_u8(0xFF).await.unwrap();
454 client.write_u8(0x00).await.unwrap();
455 drop(client);
456 let result = read_frame(&mut server).await;
457 assert!(result.is_err());
458 assert!(
459 result
460 .unwrap_err()
461 .to_string()
462 .contains("unknown frame tag")
463 );
464 }
465
466 #[tokio::test]
467 async fn frame_too_large() {
468 let (mut client, mut server) = tokio::io::duplex(1024);
469 let len: u32 = 2 * 1024 * 1024;
471 client.write_all(&len.to_be_bytes()).await.unwrap();
472 drop(client);
473 let result = read_frame(&mut server).await;
474 assert!(result.is_err());
475 assert!(result.unwrap_err().to_string().contains("frame too large"));
476 }
477
478 #[tokio::test]
479 async fn send_client_message_round_trip() {
480 let (mut client, mut server) = tokio::io::duplex(4096);
481 let msg = ClientMessage::CreateSession {
482 shell: Some("zsh".into()),
483 repo: None,
484 };
485 send_client_message(&mut client, &msg).await.unwrap();
486 drop(client);
487 let frame = read_frame(&mut server).await.unwrap().unwrap();
488 match frame {
489 Frame::Control(data) => {
490 let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
491 assert_eq!(decoded, msg);
492 }
493 Frame::Data(_) => panic!("expected control frame"),
494 }
495 }
496}