1use anyhow::{Result, bail};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use uuid::Uuid;
6
7const TAG_CONTROL: u8 = 0x01;
8const TAG_DATA: u8 = 0x02;
9const MAX_FRAME_SIZE: usize = 1_048_576; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(tag = "type")]
13pub enum ClientMessage {
14 CreateSession {
15 shell: Option<String>,
16 },
17 ListSessions,
18 AttachSession {
19 id: Uuid,
20 cols: u16,
21 rows: u16,
22 },
23 DetachSession,
24 ResizeSession {
25 id: Uuid,
26 cols: u16,
27 rows: u16,
28 },
29 KillSession {
30 id: Uuid,
31 },
32 CreateAgent {
33 model: Option<String>,
34 permission_mode: Option<String>,
35 allowed_tools: Vec<String>,
36 max_turns: Option<u32>,
37 cwd: Option<String>,
38 },
39 AgentPrompt {
40 id: Uuid,
41 prompt: String,
42 },
43 AgentStatus {
44 id: Uuid,
45 },
46 ListAgents,
47 KillAgent {
48 id: Uuid,
49 },
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53#[serde(tag = "type")]
54pub enum ServerMessage {
55 SessionCreated {
56 id: Uuid,
57 },
58 Sessions {
59 sessions: Vec<SessionInfo>,
60 },
61 Attached {
62 id: Uuid,
63 },
64 Detached,
65 SessionEnded {
66 id: Uuid,
67 exit_code: Option<i32>,
68 },
69 ClientJoined {
70 session_id: Uuid,
71 client_id: Uuid,
72 },
73 ClientLeft {
74 session_id: Uuid,
75 client_id: Uuid,
76 },
77 Error {
78 message: String,
79 },
80 AgentCreated {
81 id: Uuid,
82 },
83 AgentOutput {
84 id: Uuid,
85 event: AgentEvent,
86 },
87 AgentPromptDone {
88 id: Uuid,
89 turn_count: u32,
90 },
91 AgentStatusResponse {
92 id: Uuid,
93 status: AgentState,
94 claude_session_id: Option<String>,
95 model: Option<String>,
96 turn_count: u32,
97 },
98 Agents {
99 agents: Vec<AgentInfo>,
100 },
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
104pub struct SessionInfo {
105 pub id: Uuid,
106 pub cols: u16,
107 pub rows: u16,
108 pub created_at: DateTime<Utc>,
109 pub client_count: usize,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
113pub enum AgentState {
114 Idle,
115 Processing,
116 Error(String),
117}
118
119impl std::fmt::Display for AgentState {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 match self {
122 AgentState::Idle => write!(f, "idle"),
123 AgentState::Processing => write!(f, "processing"),
124 AgentState::Error(msg) => write!(f, "error: {}", msg),
125 }
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
130pub struct AgentInfo {
131 pub id: Uuid,
132 pub status: AgentState,
133 pub model: Option<String>,
134 pub turn_count: u32,
135 pub created_at: DateTime<Utc>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
139pub struct AgentEvent {
140 pub event_type: String,
141 pub raw_json: String,
142}
143
144#[derive(Debug)]
145pub enum Frame {
146 Control(Vec<u8>),
147 Data(Vec<u8>),
148}
149
150pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
151 let len = (1 + payload.len()) as u32;
152 w.write_all(&len.to_be_bytes()).await?;
153 w.write_u8(TAG_CONTROL).await?;
154 w.write_all(payload).await?;
155 w.flush().await?;
156 Ok(())
157}
158
159pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
160 let len = (1 + payload.len()) as u32;
161 w.write_all(&len.to_be_bytes()).await?;
162 w.write_u8(TAG_DATA).await?;
163 w.write_all(payload).await?;
164 w.flush().await?;
165 Ok(())
166}
167
168pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
169 let mut len_buf = [0u8; 4];
170 match r.read_exact(&mut len_buf).await {
171 Ok(_) => {}
172 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
173 Err(e) => return Err(e.into()),
174 }
175 let len = u32::from_be_bytes(len_buf) as usize;
176 if len == 0 {
177 bail!("invalid frame: zero length");
178 }
179 if len > MAX_FRAME_SIZE {
180 bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
181 }
182 let tag = {
183 let mut tag_buf = [0u8; 1];
184 r.read_exact(&mut tag_buf).await?;
185 tag_buf[0]
186 };
187 let payload_len = len - 1;
188 let mut payload = vec![0u8; payload_len];
189 if payload_len > 0 {
190 r.read_exact(&mut payload).await?;
191 }
192 match tag {
193 TAG_CONTROL => Ok(Some(Frame::Control(payload))),
194 TAG_DATA => Ok(Some(Frame::Data(payload))),
195 other => bail!("unknown frame tag: 0x{:02x}", other),
196 }
197}
198
199pub async fn send_client_message<W: AsyncWrite + Unpin>(
201 w: &mut W,
202 msg: &ClientMessage,
203) -> Result<()> {
204 let json = serde_json::to_vec(msg)?;
205 write_control(w, &json).await
206}
207
208pub async fn send_server_message<W: AsyncWrite + Unpin>(
210 w: &mut W,
211 msg: &ServerMessage,
212) -> Result<()> {
213 let json = serde_json::to_vec(msg)?;
214 write_control(w, &json).await
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn serde_round_trip_client() {
223 let msgs = vec![
224 ClientMessage::CreateSession {
225 shell: Some("bash".into()),
226 },
227 ClientMessage::ListSessions,
228 ClientMessage::AttachSession {
229 id: Uuid::nil(),
230 cols: 120,
231 rows: 40,
232 },
233 ClientMessage::DetachSession,
234 ClientMessage::ResizeSession {
235 id: Uuid::nil(),
236 cols: 80,
237 rows: 24,
238 },
239 ClientMessage::KillSession { id: Uuid::nil() },
240 ClientMessage::CreateAgent {
241 model: Some("sonnet".into()),
242 permission_mode: Some("plan".into()),
243 allowed_tools: vec!["Read".into(), "Write".into()],
244 max_turns: Some(10),
245 cwd: Some("/tmp".into()),
246 },
247 ClientMessage::CreateAgent {
248 model: None,
249 permission_mode: None,
250 allowed_tools: vec![],
251 max_turns: None,
252 cwd: None,
253 },
254 ClientMessage::AgentPrompt {
255 id: Uuid::nil(),
256 prompt: "do something".into(),
257 },
258 ClientMessage::AgentStatus { id: Uuid::nil() },
259 ClientMessage::ListAgents,
260 ClientMessage::KillAgent { id: Uuid::nil() },
261 ];
262 for msg in msgs {
263 let json = serde_json::to_string(&msg).unwrap();
264 let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
265 assert_eq!(msg, decoded);
266 }
267 }
268
269 #[test]
270 fn serde_round_trip_server() {
271 let msgs = vec![
272 ServerMessage::SessionCreated { id: Uuid::nil() },
273 ServerMessage::Sessions {
274 sessions: vec![SessionInfo {
275 id: Uuid::nil(),
276 cols: 80,
277 rows: 24,
278 created_at: Utc::now(),
279 client_count: 2,
280 }],
281 },
282 ServerMessage::Attached { id: Uuid::nil() },
283 ServerMessage::Detached,
284 ServerMessage::SessionEnded {
285 id: Uuid::nil(),
286 exit_code: Some(0),
287 },
288 ServerMessage::ClientJoined {
289 session_id: Uuid::nil(),
290 client_id: Uuid::nil(),
291 },
292 ServerMessage::ClientLeft {
293 session_id: Uuid::nil(),
294 client_id: Uuid::nil(),
295 },
296 ServerMessage::Error {
297 message: "fail".into(),
298 },
299 ServerMessage::AgentCreated { id: Uuid::nil() },
300 ServerMessage::AgentOutput {
301 id: Uuid::nil(),
302 event: AgentEvent {
303 event_type: "content_block_delta".into(),
304 raw_json: r#"{"type":"content_block_delta"}"#.into(),
305 },
306 },
307 ServerMessage::AgentPromptDone {
308 id: Uuid::nil(),
309 turn_count: 3,
310 },
311 ServerMessage::AgentStatusResponse {
312 id: Uuid::nil(),
313 status: AgentState::Idle,
314 claude_session_id: Some("sess-123".into()),
315 model: Some("sonnet".into()),
316 turn_count: 5,
317 },
318 ServerMessage::AgentStatusResponse {
319 id: Uuid::nil(),
320 status: AgentState::Processing,
321 claude_session_id: None,
322 model: None,
323 turn_count: 0,
324 },
325 ServerMessage::AgentStatusResponse {
326 id: Uuid::nil(),
327 status: AgentState::Error("something broke".into()),
328 claude_session_id: None,
329 model: None,
330 turn_count: 0,
331 },
332 ServerMessage::Agents {
333 agents: vec![AgentInfo {
334 id: Uuid::nil(),
335 status: AgentState::Idle,
336 model: Some("sonnet".into()),
337 turn_count: 2,
338 created_at: Utc::now(),
339 }],
340 },
341 ];
342 for msg in msgs {
343 let json = serde_json::to_string(&msg).unwrap();
344 let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
345 assert_eq!(msg, decoded);
346 }
347 }
348
349 #[tokio::test]
350 async fn frame_round_trip_control() {
351 let (mut client, mut server) = tokio::io::duplex(1024);
352 let payload = b"hello control";
353 write_control(&mut client, payload).await.unwrap();
354 drop(client);
355 let frame = read_frame(&mut server).await.unwrap().unwrap();
356 match frame {
357 Frame::Control(data) => assert_eq!(data, payload),
358 Frame::Data(_) => panic!("expected control frame"),
359 }
360 }
361
362 #[tokio::test]
363 async fn frame_round_trip_data() {
364 let (mut client, mut server) = tokio::io::duplex(1024);
365 let payload = b"hello data";
366 write_data(&mut client, payload).await.unwrap();
367 drop(client);
368 let frame = read_frame(&mut server).await.unwrap().unwrap();
369 match frame {
370 Frame::Data(data) => assert_eq!(data, payload),
371 Frame::Control(_) => panic!("expected data frame"),
372 }
373 }
374
375 #[tokio::test]
376 async fn frame_eof_returns_none() {
377 let (client, mut server) = tokio::io::duplex(1024);
378 drop(client);
379 let frame = read_frame(&mut server).await.unwrap();
380 assert!(frame.is_none());
381 }
382
383 #[tokio::test]
384 async fn frame_bad_tag() {
385 let (mut client, mut server) = tokio::io::duplex(1024);
386 let len: u32 = 2; client.write_all(&len.to_be_bytes()).await.unwrap();
389 client.write_u8(0xFF).await.unwrap();
390 client.write_u8(0x00).await.unwrap();
391 drop(client);
392 let result = read_frame(&mut server).await;
393 assert!(result.is_err());
394 assert!(
395 result
396 .unwrap_err()
397 .to_string()
398 .contains("unknown frame tag")
399 );
400 }
401
402 #[tokio::test]
403 async fn frame_too_large() {
404 let (mut client, mut server) = tokio::io::duplex(1024);
405 let len: u32 = 2 * 1024 * 1024;
407 client.write_all(&len.to_be_bytes()).await.unwrap();
408 drop(client);
409 let result = read_frame(&mut server).await;
410 assert!(result.is_err());
411 assert!(result.unwrap_err().to_string().contains("frame too large"));
412 }
413
414 #[tokio::test]
415 async fn send_client_message_round_trip() {
416 let (mut client, mut server) = tokio::io::duplex(4096);
417 let msg = ClientMessage::CreateSession {
418 shell: Some("zsh".into()),
419 };
420 send_client_message(&mut client, &msg).await.unwrap();
421 drop(client);
422 let frame = read_frame(&mut server).await.unwrap().unwrap();
423 match frame {
424 Frame::Control(data) => {
425 let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
426 assert_eq!(decoded, msg);
427 }
428 Frame::Data(_) => panic!("expected control frame"),
429 }
430 }
431}