1use crate::core::rpc_protocol::{
10 AssistantEvent, RpcAttachment, RpcCommand, RpcEvent, TurnUsage,
11};
12use crate::{AgentEvent, LlmEvent, SessionEvent, StreamEvent};
13
14pub const MAX_FRAME_BYTES: usize = 1024 * 1024;
18
19pub fn parse_frame(line: &str, max_bytes: usize) -> Result<RpcCommand, RpcEvent> {
26 if line.len() > max_bytes {
27 return Err(RpcEvent::Error {
28 id: None,
29 message: "frame exceeds 1 MiB limit".to_string(),
30 });
31 }
32 serde_json::from_str::<RpcCommand>(line).map_err(|e| RpcEvent::Error {
33 id: None,
34 message: e.to_string(),
35 })
36}
37
38pub fn map_stream_event(ev: &StreamEvent) -> Option<RpcEvent> {
52 match ev {
53 StreamEvent::Llm(LlmEvent::Thinking(s)) => Some(RpcEvent::MessageUpdate {
54 event: AssistantEvent::ThinkingDelta { delta: s.clone() },
55 }),
56 StreamEvent::Llm(LlmEvent::Text(s)) => Some(RpcEvent::MessageUpdate {
57 event: AssistantEvent::TextDelta { delta: s.clone() },
58 }),
59 StreamEvent::Llm(LlmEvent::ToolUseStart { tool_name, tool_id }) => {
60 Some(RpcEvent::MessageUpdate {
61 event: AssistantEvent::ToolcallStart {
62 tool_id: tool_id.clone(),
63 tool_name: tool_name.clone(),
64 },
65 })
66 }
67 StreamEvent::Llm(LlmEvent::ToolUseDelta { tool_id, delta }) => {
68 Some(RpcEvent::MessageUpdate {
69 event: AssistantEvent::ToolcallInputDelta {
70 tool_id: tool_id.clone(),
71 delta: delta.clone(),
72 },
73 })
74 }
75 StreamEvent::Llm(LlmEvent::ToolUse { tool_id, input, .. }) => {
77 Some(RpcEvent::MessageUpdate {
78 event: AssistantEvent::ToolcallInput {
79 tool_id: tool_id.clone(),
80 input: input.clone(),
81 },
82 })
83 }
84 StreamEvent::Llm(LlmEvent::ToolResult { tool_id, result }) => {
85 Some(RpcEvent::MessageUpdate {
86 event: AssistantEvent::ToolcallResult {
87 tool_id: tool_id.clone(),
88 result: result.clone(),
89 },
90 })
91 }
92 StreamEvent::Llm(LlmEvent::ToolResultDelta { .. }) => None,
94
95 StreamEvent::Agent(AgentEvent::SubagentStart {
96 subagent_id,
97 agent_name,
98 task_preview,
99 }) => Some(RpcEvent::SubagentStart {
100 subagent_id: *subagent_id,
101 agent_name: agent_name.clone(),
102 task_preview: task_preview.clone(),
103 }),
104 StreamEvent::Agent(AgentEvent::SubagentUpdate {
105 subagent_id,
106 agent_name,
107 status,
108 }) => Some(RpcEvent::SubagentUpdate {
109 subagent_id: *subagent_id,
110 agent_name: agent_name.clone(),
111 status: status.clone(),
112 }),
113 StreamEvent::Agent(AgentEvent::SubagentDone {
114 subagent_id,
115 agent_name,
116 result_preview,
117 duration_secs,
118 }) => Some(RpcEvent::SubagentDone {
119 subagent_id: *subagent_id,
120 agent_name: agent_name.clone(),
121 result_preview: result_preview.clone(),
122 duration_secs: *duration_secs,
123 }),
124 StreamEvent::Agent(AgentEvent::SteeringDelivered { .. }) => None,
126
127 StreamEvent::Session(_) => None,
130 }
131}
132
133pub fn accumulate_usage(acc: &mut TurnUsage, event: &SessionEvent) {
141 if let SessionEvent::Usage {
142 input_tokens,
143 output_tokens,
144 cache_read_input_tokens,
145 cache_creation_input_tokens,
146 model,
147 } = event
148 {
149 acc.input_tokens += input_tokens;
150 acc.output_tokens += output_tokens;
151 acc.cache_read_input_tokens += cache_read_input_tokens;
152 acc.cache_creation_input_tokens += cache_creation_input_tokens;
153 if acc.model.is_none() {
154 acc.model = model.clone();
155 }
156 }
157}
158
159fn quote_path(p: &str) -> String {
166 let escaped = p.replace('\\', "\\\\").replace('"', "\\\"");
167 format!("\"{escaped}\"")
168}
169
170pub fn build_user_content(message: &str, attachments: &[RpcAttachment]) -> String {
171 if attachments.is_empty() {
172 return message.to_string();
173 }
174 let parts: Vec<String> = attachments.iter().map(|a| quote_path(&a.path)).collect();
175 format!("[user attached files: {}]\n{}", parts.join(", "), message)
176}
177
178#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::core::rpc_protocol::{AssistantEvent, RpcCommand, RpcEvent, RpcAttachment, TurnUsage};
184 use crate::{AgentEvent, LlmEvent, SessionEvent, StreamEvent};
185 use serde_json::json;
186
187 #[test]
190 fn parse_frame_valid_prompt() {
191 let line = r#"{"type":"prompt","id":"abc","message":"hello"}"#;
192 let result = parse_frame(line, MAX_FRAME_BYTES);
193 assert!(result.is_ok(), "should parse valid prompt frame");
194 match result.unwrap() {
195 RpcCommand::Prompt { id, message, attachments } => {
196 assert_eq!(id, "abc");
197 assert_eq!(message, "hello");
198 assert!(attachments.is_empty());
199 }
200 other => panic!("unexpected variant: {:?}", other),
201 }
202 }
203
204 #[test]
205 fn parse_frame_valid_shutdown() {
206 let line = r#"{"type":"shutdown"}"#;
207 let result = parse_frame(line, MAX_FRAME_BYTES);
208 assert!(result.is_ok());
209 assert!(matches!(result.unwrap(), RpcCommand::Shutdown));
210 }
211
212 #[test]
213 fn parse_frame_valid_follow_up() {
214 let line = r#"{"type":"follow_up","id":"f1","message":"and then?"}"#;
215 let result = parse_frame(line, MAX_FRAME_BYTES);
216 match result.unwrap() {
217 RpcCommand::FollowUp { id, message } => {
218 assert_eq!(id, "f1");
219 assert_eq!(message, "and then?");
220 }
221 other => panic!("unexpected: {:?}", other),
222 }
223 }
224
225 #[test]
226 fn parse_frame_valid_abort() {
227 let line = r#"{"type":"abort","id":"x"}"#;
228 assert!(matches!(parse_frame(line, MAX_FRAME_BYTES).unwrap(), RpcCommand::Abort { .. }));
229 }
230
231 #[test]
232 fn parse_frame_malformed_json() {
233 let line = "not json at all";
234 let result = parse_frame(line, MAX_FRAME_BYTES);
235 assert!(result.is_err());
236 match result.unwrap_err() {
237 RpcEvent::Error { id, message } => {
238 assert!(id.is_none(), "malformed-JSON error must have id=None");
239 assert!(!message.is_empty(), "error message must be non-empty");
240 }
241 other => panic!("unexpected event: {:?}", other),
242 }
243 }
244
245 #[test]
246 fn parse_frame_valid_json_unknown_type() {
247 let line = r#"{"type":"does_not_exist","id":"1"}"#;
249 let result = parse_frame(line, MAX_FRAME_BYTES);
250 assert!(result.is_err(), "unknown type should fail to deserialise");
251 }
252
253 #[test]
254 fn parse_frame_oversize() {
255 let oversize = "x".repeat(MAX_FRAME_BYTES + 1);
256 let result = parse_frame(&oversize, MAX_FRAME_BYTES);
257 assert!(result.is_err());
258 match result.unwrap_err() {
259 RpcEvent::Error { id, message } => {
260 assert!(id.is_none());
261 assert!(
262 message.contains("1 MiB"),
263 "expected '1 MiB' in message, got: {message}"
264 );
265 }
266 other => panic!("unexpected event: {:?}", other),
267 }
268 }
269
270 #[test]
271 fn parse_frame_exactly_at_limit_valid_json() {
272 let line = r#"{"type":"get_state","id":"x"}"#;
274 assert!(line.len() <= MAX_FRAME_BYTES);
275 let result = parse_frame(line, MAX_FRAME_BYTES);
276 assert!(result.is_ok());
277 }
278
279 #[test]
280 fn parse_frame_custom_small_limit() {
281 let line = r#"{"type":"shutdown"}"#; let result = parse_frame(line, 5); assert!(result.is_err());
285 match result.unwrap_err() {
286 RpcEvent::Error { id, .. } => assert!(id.is_none()),
287 other => panic!("unexpected: {:?}", other),
288 }
289 }
290
291 #[test]
294 fn map_llm_thinking() {
295 let ev = StreamEvent::Llm(LlmEvent::Thinking("hmm".to_string()));
296 let rpc = map_stream_event(&ev).expect("Thinking must produce an event");
297 match rpc {
298 RpcEvent::MessageUpdate {
299 event: AssistantEvent::ThinkingDelta { delta },
300 } => assert_eq!(delta, "hmm"),
301 other => panic!("unexpected: {:?}", other),
302 }
303 }
304
305 #[test]
306 fn map_llm_text() {
307 let ev = StreamEvent::Llm(LlmEvent::Text("hi".to_string()));
308 let rpc = map_stream_event(&ev).expect("Text must produce an event");
309 match rpc {
310 RpcEvent::MessageUpdate {
311 event: AssistantEvent::TextDelta { delta },
312 } => assert_eq!(delta, "hi"),
313 other => panic!("unexpected: {:?}", other),
314 }
315 }
316
317 #[test]
318 fn map_llm_tool_use_start() {
319 let ev = StreamEvent::Llm(LlmEvent::ToolUseStart {
320 tool_name: "bash".to_string(),
321 tool_id: "tid1".to_string(),
322 });
323 let rpc = map_stream_event(&ev).expect("ToolUseStart must produce an event");
324 match rpc {
325 RpcEvent::MessageUpdate {
326 event: AssistantEvent::ToolcallStart { tool_id, tool_name },
327 } => {
328 assert_eq!(tool_id, "tid1");
329 assert_eq!(tool_name, "bash");
330 }
331 other => panic!("unexpected: {:?}", other),
332 }
333 }
334
335 #[test]
336 fn map_llm_tool_use_delta() {
337 let ev = StreamEvent::Llm(LlmEvent::ToolUseDelta {
338 tool_id: "tid1".to_string(),
339 delta: r#"{"cmd":"#.to_string(),
340 });
341 let rpc = map_stream_event(&ev).expect("ToolUseDelta must produce an event");
342 match rpc {
343 RpcEvent::MessageUpdate {
344 event: AssistantEvent::ToolcallInputDelta { tool_id, delta },
345 } => {
346 assert_eq!(tool_id, "tid1");
347 assert_eq!(delta, r#"{"cmd":"#);
348 }
349 other => panic!("unexpected: {:?}", other),
350 }
351 }
352
353 #[test]
354 fn map_llm_tool_use_final_drops_tool_name() {
355 let ev = StreamEvent::Llm(LlmEvent::ToolUse {
356 tool_name: "bash".to_string(), tool_id: "tid1".to_string(),
358 input: json!({"cmd": "ls"}),
359 });
360 let rpc = map_stream_event(&ev).expect("ToolUse must produce an event");
361 match rpc {
362 RpcEvent::MessageUpdate {
363 event: AssistantEvent::ToolcallInput { tool_id, input },
364 } => {
365 assert_eq!(tool_id, "tid1");
366 assert_eq!(input, json!({"cmd": "ls"}));
367 }
369 other => panic!("unexpected: {:?}", other),
370 }
371 }
372
373 #[test]
374 fn map_llm_tool_result() {
375 let ev = StreamEvent::Llm(LlmEvent::ToolResult {
376 tool_id: "tid1".to_string(),
377 result: "output here".to_string(),
378 });
379 let rpc = map_stream_event(&ev).expect("ToolResult must produce an event");
380 match rpc {
381 RpcEvent::MessageUpdate {
382 event: AssistantEvent::ToolcallResult { tool_id, result },
383 } => {
384 assert_eq!(tool_id, "tid1");
385 assert_eq!(result, "output here");
386 }
387 other => panic!("unexpected: {:?}", other),
388 }
389 }
390
391 #[test]
392 fn map_llm_tool_result_delta_is_dropped() {
393 let ev = StreamEvent::Llm(LlmEvent::ToolResultDelta {
394 tool_id: "tid1".to_string(),
395 delta: "partial".to_string(),
396 });
397 assert!(
398 map_stream_event(&ev).is_none(),
399 "ToolResultDelta must be dropped — wire format has no streaming-result variant"
400 );
401 }
402
403 #[test]
404 fn map_agent_subagent_start() {
405 let ev = StreamEvent::Agent(AgentEvent::SubagentStart {
406 subagent_id: 7,
407 agent_name: "worker".to_string(),
408 task_preview: "do thing".to_string(),
409 });
410 let rpc = map_stream_event(&ev).expect("SubagentStart must produce an event");
411 match rpc {
412 RpcEvent::SubagentStart { subagent_id, agent_name, task_preview } => {
413 assert_eq!(subagent_id, 7);
414 assert_eq!(agent_name, "worker");
415 assert_eq!(task_preview, "do thing");
416 }
417 other => panic!("unexpected: {:?}", other),
418 }
419 }
420
421 #[test]
422 fn map_agent_subagent_update() {
423 let ev = StreamEvent::Agent(AgentEvent::SubagentUpdate {
424 subagent_id: 7,
425 agent_name: "worker".to_string(),
426 status: "running".to_string(),
427 });
428 let rpc = map_stream_event(&ev).expect("SubagentUpdate must produce an event");
429 match rpc {
430 RpcEvent::SubagentUpdate { subagent_id, agent_name, status } => {
431 assert_eq!(subagent_id, 7);
432 assert_eq!(agent_name, "worker");
433 assert_eq!(status, "running");
434 }
435 other => panic!("unexpected: {:?}", other),
436 }
437 }
438
439 #[test]
440 fn map_agent_subagent_done() {
441 let ev = StreamEvent::Agent(AgentEvent::SubagentDone {
442 subagent_id: 7,
443 agent_name: "worker".to_string(),
444 result_preview: "done!".to_string(),
445 duration_secs: 1.5,
446 });
447 let rpc = map_stream_event(&ev).expect("SubagentDone must produce an event");
448 match rpc {
449 RpcEvent::SubagentDone {
450 subagent_id,
451 agent_name,
452 result_preview,
453 duration_secs,
454 } => {
455 assert_eq!(subagent_id, 7);
456 assert_eq!(agent_name, "worker");
457 assert_eq!(result_preview, "done!");
458 assert!((duration_secs - 1.5).abs() < f64::EPSILON);
459 }
460 other => panic!("unexpected: {:?}", other),
461 }
462 }
463
464 #[test]
465 fn map_agent_steering_delivered_is_dropped() {
466 let ev = StreamEvent::Agent(AgentEvent::SteeringDelivered {
467 message: "steer".to_string(),
468 });
469 assert!(
470 map_stream_event(&ev).is_none(),
471 "SteeringDelivered must be dropped — internal hook signal"
472 );
473 }
474
475 #[test]
476 fn map_session_events_all_return_none() {
477 let events: &[StreamEvent] = &[
480 StreamEvent::Session(SessionEvent::Done),
481 StreamEvent::Session(SessionEvent::Error("oops".to_string())),
482 StreamEvent::Session(SessionEvent::MessageHistory(vec![])),
483 StreamEvent::Session(SessionEvent::Usage {
484 input_tokens: 1,
485 output_tokens: 2,
486 cache_read_input_tokens: 0,
487 cache_creation_input_tokens: 0,
488 model: None,
489 }),
490 ];
491 for ev in events {
492 assert!(
493 map_stream_event(ev).is_none(),
494 "Session event {:?} should return None",
495 ev
496 );
497 }
498 }
499
500 fn zero_usage() -> TurnUsage {
503 TurnUsage {
504 input_tokens: 0,
505 output_tokens: 0,
506 cache_read_input_tokens: 0,
507 cache_creation_input_tokens: 0,
508 model: None,
509 }
510 }
511
512 #[test]
513 fn accumulate_usage_basic() {
514 let mut acc = zero_usage();
515 let ev = SessionEvent::Usage {
516 input_tokens: 100,
517 output_tokens: 50,
518 cache_read_input_tokens: 10,
519 cache_creation_input_tokens: 5,
520 model: Some("claude-3-5".to_string()),
521 };
522 accumulate_usage(&mut acc, &ev);
523 assert_eq!(acc.input_tokens, 100);
524 assert_eq!(acc.output_tokens, 50);
525 assert_eq!(acc.cache_read_input_tokens, 10);
526 assert_eq!(acc.cache_creation_input_tokens, 5);
527 assert_eq!(acc.model.as_deref(), Some("claude-3-5"));
528 }
529
530 #[test]
531 fn accumulate_usage_additive_across_calls() {
532 let mut acc = TurnUsage {
533 input_tokens: 10,
534 output_tokens: 5,
535 cache_read_input_tokens: 0,
536 cache_creation_input_tokens: 0,
537 model: Some("first-model".to_string()),
538 };
539 let ev = SessionEvent::Usage {
540 input_tokens: 20,
541 output_tokens: 8,
542 cache_read_input_tokens: 2,
543 cache_creation_input_tokens: 1,
544 model: Some("second-model".to_string()),
545 };
546 accumulate_usage(&mut acc, &ev);
547 assert_eq!(acc.input_tokens, 30);
548 assert_eq!(acc.output_tokens, 13);
549 assert_eq!(acc.cache_read_input_tokens, 2);
550 assert_eq!(acc.cache_creation_input_tokens, 1);
551 assert_eq!(acc.model.as_deref(), Some("first-model"));
553 }
554
555 #[test]
556 fn accumulate_usage_sets_model_when_none() {
557 let mut acc = zero_usage();
558 let ev = SessionEvent::Usage {
559 input_tokens: 1,
560 output_tokens: 1,
561 cache_read_input_tokens: 0,
562 cache_creation_input_tokens: 0,
563 model: Some("my-model".to_string()),
564 };
565 accumulate_usage(&mut acc, &ev);
566 assert_eq!(acc.model.as_deref(), Some("my-model"));
567 }
568
569 #[test]
570 fn accumulate_usage_ignores_done() {
571 let mut acc = zero_usage();
572 acc.input_tokens = 5;
573 accumulate_usage(&mut acc, &SessionEvent::Done);
574 assert_eq!(acc.input_tokens, 5, "Done must not mutate the accumulator");
575 }
576
577 #[test]
578 fn accumulate_usage_ignores_error() {
579 let mut acc = zero_usage();
580 acc.output_tokens = 3;
581 accumulate_usage(&mut acc, &SessionEvent::Error("boom".to_string()));
582 assert_eq!(acc.output_tokens, 3, "Error must not mutate the accumulator");
583 }
584
585 #[test]
586 fn accumulate_usage_ignores_message_history() {
587 let mut acc = zero_usage();
588 acc.input_tokens = 7;
589 accumulate_usage(&mut acc, &SessionEvent::MessageHistory(vec![]));
590 assert_eq!(acc.input_tokens, 7, "MessageHistory must not mutate the accumulator");
591 }
592
593 #[test]
596 fn build_user_content_no_attachments() {
597 assert_eq!(build_user_content("hello", &[]), "hello");
598 }
599
600 #[test]
601 fn build_user_content_single_attachment() {
602 let attachments = vec![RpcAttachment {
603 path: "/tmp/a.txt".to_string(),
604 name: None,
605 mime: None,
606 }];
607 let msg = build_user_content("check this", &attachments);
608 assert!(msg.starts_with("[user attached files: \"/tmp/a.txt\"]"));
609 assert!(msg.contains("check this"));
610 }
611
612 #[test]
613 fn build_user_content_multiple_attachments() {
614 let attachments = vec![
615 RpcAttachment { path: "/tmp/a.txt".to_string(), name: None, mime: None },
616 RpcAttachment { path: "/tmp/b.pdf".to_string(), name: None, mime: None },
617 ];
618 let msg = build_user_content("check these", &attachments);
619 assert!(
620 msg.contains("[user attached files: \"/tmp/a.txt\", \"/tmp/b.pdf\"]"),
621 "paths must be quoted and comma-separated: {msg}"
622 );
623 assert!(msg.contains("check these"));
624 }
625
626 #[test]
627 fn build_user_content_preserves_original_message() {
628 let attachments = vec![RpcAttachment {
629 path: "/tmp/x".to_string(),
630 name: Some("x".to_string()),
631 mime: Some("text/plain".to_string()),
632 }];
633 let original = "multi\nline\nmessage";
634 let msg = build_user_content(original, &attachments);
635 assert!(msg.ends_with(original), "original message must appear verbatim at the end");
636 }
637
638 #[test]
641 fn build_user_content_path_with_comma_is_quoted() {
642 let attachments = vec![RpcAttachment {
643 path: "/tmp/a,b.pdf".to_string(),
644 name: None,
645 mime: None,
646 }];
647 let msg = build_user_content("look", &attachments);
648 assert!(
649 msg.contains("\"/tmp/a,b.pdf\""),
650 "comma path must be wrapped in quotes: {msg}"
651 );
652 assert!(
654 !msg.contains("[user attached files: /tmp/a,b.pdf]"),
655 "bare unquoted comma path must not appear: {msg}"
656 );
657 }
658
659 #[test]
660 fn build_user_content_multiple_paths_each_quoted() {
661 let attachments = vec![
662 RpcAttachment { path: "/p1".to_string(), name: None, mime: None },
663 RpcAttachment { path: "/p2".to_string(), name: None, mime: None },
664 ];
665 let msg = build_user_content("x", &attachments);
666 assert!(
667 msg.contains("\"/p1\", \"/p2\""),
668 "each path must be individually quoted: {msg}"
669 );
670 }
671
672 #[test]
673 fn build_user_content_path_with_embedded_quote_is_escaped() {
674 let attachments = vec![RpcAttachment {
675 path: "/tmp/he\"llo".to_string(),
676 name: None,
677 mime: None,
678 }];
679 let msg = build_user_content("x", &attachments);
680 assert!(
681 msg.contains("\"/tmp/he\\\"llo\""),
682 "embedded double-quote must be backslash-escaped: {msg}"
683 );
684 }
685
686 #[test]
687 fn build_user_content_path_with_backslash_is_escaped() {
688 let attachments = vec![RpcAttachment {
689 path: "/tmp/a\\b".to_string(),
690 name: None,
691 mime: None,
692 }];
693 let msg = build_user_content("x", &attachments);
694 assert!(
695 msg.contains("\"/tmp/a\\\\b\""),
696 "backslash in path must be doubled: {msg}"
697 );
698 }
699
700 #[tokio::test]
712 async fn handle_compact_releases_lock_before_slow_await() {
713 use std::sync::Arc;
714 use tokio::sync::Mutex;
715
716 let shared: Arc<Mutex<u32>> = Arc::new(Mutex::new(0));
717
718 let shared2 = shared.clone();
723 let task = tokio::spawn(async move {
724 let snapshot = {
726 let mut g = shared2.lock().await;
727 *g += 1; *g };
730 tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
734
735 let mut g = shared2.lock().await;
737 *g = snapshot + 100;
738 });
739
740 tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
743 let acquired = tokio::time::timeout(
744 tokio::time::Duration::from_millis(5),
745 shared.lock(),
746 )
747 .await;
748 assert!(
749 acquired.is_ok(),
750 "second task must acquire the lock during the slow phase — \
751 handle_compact must NOT hold the lock across compact_conversation"
752 );
753 drop(acquired);
754
755 task.await.unwrap();
756 assert_eq!(*shared.lock().await, 101);
757 }
758}