1use std::path::PathBuf;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use tokio::sync::mpsc;
9use uuid::Uuid;
10
11use crate::core::events::Event;
12use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt, Usage};
13use crate::repl::PythonRuntime;
14
15use super::bridge::{RlmBridge, RlmLlmClient};
16use super::prompt::rlm_system_prompt;
17
18const MAX_RLM_ITERATIONS: u32 = 25;
24const MAX_CONSECUTIVE_NO_CODE: u32 = 3;
28const ROOT_MAX_TOKENS: u32 = 4096;
30const STDOUT_METADATA_PREVIEW_LEN: usize = 800;
32const PROMPT_PREVIEW_LEN: usize = 500;
34const ROOT_TEMPERATURE: f32 = 0.3;
36const TURN_TIMEOUT: Duration = Duration::from_secs(180);
38const MAX_HISTORY_MESSAGES: usize = 20;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum RlmTermination {
48 Final,
51 NoCode,
55 Exhausted,
57 Error,
59}
60
61#[derive(Debug, Clone)]
64pub struct RlmRoundTrace {
65 pub round: u32,
66 pub code_summary: String,
67 pub stdout_preview: String,
68 pub had_error: bool,
69 pub rpc_count: u32,
70 pub elapsed_ms: u64,
71}
72
73#[derive(Debug, Clone)]
75pub struct RlmTurnResult {
76 pub answer: String,
77 pub iterations: u32,
78 pub duration: Duration,
79 pub error: Option<String>,
80 pub usage: Usage,
81 pub termination: RlmTermination,
82 pub trace: Vec<RlmRoundTrace>,
84 pub total_rpcs: u32,
88}
89
90pub async fn run_rlm_turn(
93 client: Arc<dyn crate::llm_client::LlmClient>,
94 model: String,
95 prompt: String,
96 child_model: String,
97 tx_event: mpsc::Sender<Event>,
98 max_depth: u32,
99) -> RlmTurnResult {
100 run_rlm_turn_inner(
101 client.clone(),
102 model,
103 prompt,
104 None,
105 child_model,
106 tx_event,
107 max_depth,
108 )
109 .await
110}
111
112pub async fn run_rlm_turn_with_root(
115 client: Arc<dyn crate::llm_client::LlmClient>,
116 model: String,
117 prompt: String,
118 root_prompt: Option<String>,
119 child_model: String,
120 tx_event: mpsc::Sender<Event>,
121 max_depth: u32,
122) -> RlmTurnResult {
123 run_rlm_turn_inner(
124 client.clone(),
125 model,
126 prompt,
127 root_prompt,
128 child_model,
129 tx_event,
130 max_depth,
131 )
132 .await
133}
134
135pub(crate) fn run_rlm_turn_inner(
139 client: Arc<dyn crate::llm_client::LlmClient>,
140 model: String,
141 prompt: String,
142 root_prompt: Option<String>,
143 child_model: String,
144 tx_event: mpsc::Sender<Event>,
145 max_depth: u32,
146) -> std::pin::Pin<Box<dyn std::future::Future<Output = RlmTurnResult> + Send>> {
147 Box::pin(run_rlm_turn_impl(
148 client,
149 model,
150 prompt,
151 root_prompt,
152 child_model,
153 tx_event,
154 max_depth,
155 ))
156}
157
158async fn run_rlm_turn_impl(
163 client: Arc<dyn crate::llm_client::LlmClient>,
164 model: String,
165 prompt: String,
166 root_prompt: Option<String>,
167 child_model: String,
168 tx_event: mpsc::Sender<Event>,
169 max_depth: u32,
170) -> RlmTurnResult {
171 let start = Instant::now();
172 let mut total_usage = Usage::default();
173 let mut trace: Vec<RlmRoundTrace> = Vec::new();
174 let mut total_rpcs: u32 = 0;
175
176 let ctx_path = match write_context_file(&prompt) {
180 Ok(p) => p,
181 Err(e) => {
182 return RlmTurnResult {
183 answer: String::new(),
184 iterations: 0,
185 duration: start.elapsed(),
186 error: Some(format!("rlm: failed to stage context: {e}")),
187 usage: total_usage,
188 termination: RlmTermination::Error,
189 trace,
190 total_rpcs,
191 };
192 }
193 };
194
195 let mut repl = match PythonRuntime::spawn_with_context(&ctx_path).await {
197 Ok(rt) => rt,
198 Err(e) => {
199 let _ = tokio::fs::remove_file(&ctx_path).await;
200 return RlmTurnResult {
201 answer: String::new(),
202 iterations: 0,
203 duration: start.elapsed(),
204 error: Some(format!("rlm: failed to spawn REPL: {e}")),
205 usage: total_usage,
206 termination: RlmTermination::Error,
207 trace,
208 total_rpcs,
209 };
210 }
211 };
212
213 let bridge = RlmBridge::new(client.clone(), child_model.clone(), max_depth);
215 let usage_handle = bridge.usage_handle();
216
217 let _ = tx_event
218 .send(Event::status(format!(
219 "RLM: spawned Python REPL (root={model}, child={child_model}, max_depth={max_depth}, ctx={} chars)",
220 prompt.chars().count()
221 )))
222 .await;
223
224 let system = rlm_system_prompt();
226 let mut messages: Vec<Message> = vec![build_metadata_message(
227 &prompt,
228 root_prompt.as_deref(),
229 0,
230 None,
231 None,
232 )];
233
234 let mut consecutive_no_code: u32 = 0;
235 let mut last_response_text = String::new();
236
237 let result = 'turn: {
238 for iteration in 0..MAX_RLM_ITERATIONS {
239 if start.elapsed() > TURN_TIMEOUT {
240 break 'turn RlmTurnResult {
241 answer: String::new(),
242 iterations: iteration,
243 duration: start.elapsed(),
244 error: Some(format!(
245 "RLM turn timed out after {}s",
246 TURN_TIMEOUT.as_secs()
247 )),
248 usage: total_usage,
249 termination: RlmTermination::Error,
250 trace: trace.clone(),
251 total_rpcs,
252 };
253 }
254
255 let _ = tx_event
256 .send(Event::status(format!(
257 "RLM iteration {}/{}",
258 iteration + 1,
259 MAX_RLM_ITERATIONS
260 )))
261 .await;
262
263 let request = build_root_request(&model, &messages, &system);
265
266 let response = match client.create_message_boxed(request).await {
267 Ok(r) => r,
268 Err(e) => {
269 break 'turn RlmTurnResult {
270 answer: String::new(),
271 iterations: iteration + 1,
272 duration: start.elapsed(),
273 error: Some(format!("Root LLM call failed: {e}")),
274 usage: total_usage,
275 termination: RlmTermination::Error,
276 trace: trace.clone(),
277 total_rpcs,
278 };
279 }
280 };
281
282 total_usage.input_tokens = total_usage
283 .input_tokens
284 .saturating_add(response.usage.input_tokens);
285 total_usage.output_tokens = total_usage
286 .output_tokens
287 .saturating_add(response.usage.output_tokens);
288
289 let response_text = extract_text_blocks(&response.content);
290 last_response_text = response_text.clone();
291
292 if let Some(final_val) = parse_text_final(&response_text) {
297 if total_rpcs == 0 {
298 consecutive_no_code = consecutive_no_code.saturating_add(1);
302 if consecutive_no_code >= MAX_CONSECUTIVE_NO_CODE {
303 break 'turn RlmTurnResult {
304 answer: final_val,
305 iterations: iteration + 1,
306 duration: start.elapsed(),
307 error: None,
308 usage: total_usage,
309 termination: RlmTermination::NoCode,
310 trace: trace.clone(),
311 total_rpcs,
312 };
313 }
314 messages.push(Message {
315 role: "assistant".to_string(),
316 content: vec![ContentBlock::Text {
317 text: response_text.clone(),
318 cache_control: None,
319 }],
320 });
321 messages.push(Message {
322 role: "user".to_string(),
323 content: vec![ContentBlock::Text {
324 text: "You called FINAL(...) without ever running a ```repl block. \
325 That defeats the recursive language model — you're guessing \
326 from the preview alone. Emit a ```repl block now that uses \
327 `llm_query`, `llm_query_batched`, or `rlm_query` against \
328 `context` to actually compute the answer."
329 .to_string(),
330 cache_control: None,
331 }],
332 });
333 continue;
334 }
335 let _ = tx_event
336 .send(Event::status(
337 "RLM: FINAL detected in response text".to_string(),
338 ))
339 .await;
340 break 'turn RlmTurnResult {
341 answer: final_val,
342 iterations: iteration + 1,
343 duration: start.elapsed(),
344 error: None,
345 usage: total_usage,
346 termination: RlmTermination::Final,
347 trace: trace.clone(),
348 total_rpcs,
349 };
350 }
351
352 let code = extract_repl_code(&response_text);
354 let code_to_run = match code {
355 Some(c) => {
356 consecutive_no_code = 0;
357 c
358 }
359 None => {
360 consecutive_no_code = consecutive_no_code.saturating_add(1);
361 if consecutive_no_code >= MAX_CONSECUTIVE_NO_CODE {
362 break 'turn RlmTurnResult {
363 answer: response_text,
364 iterations: iteration + 1,
365 duration: start.elapsed(),
366 error: Some(format!(
367 "RLM: model failed to emit ```repl after {MAX_CONSECUTIVE_NO_CODE} consecutive rounds"
368 )),
369 usage: total_usage,
370 termination: RlmTermination::NoCode,
371 trace: trace.clone(),
372 total_rpcs,
373 };
374 }
375 messages.push(Message {
376 role: "assistant".to_string(),
377 content: vec![ContentBlock::Text {
378 text: response_text.clone(),
379 cache_control: None,
380 }],
381 });
382 messages.push(Message {
383 role: "user".to_string(),
384 content: vec![ContentBlock::Text {
385 text: "Reminder: emit Python inside a ```repl … ``` fence. \
386 Use `llm_query` / `llm_query_batched` / `rlm_query` to \
387 process `context` and call `FINAL(value)` when done."
388 .to_string(),
389 cache_control: None,
390 }],
391 });
392 continue;
393 }
394 };
395
396 let _ = tx_event
397 .send(Event::MessageDelta {
398 index: iteration as usize,
399 content: format!(
400 "\n[RLM round {} — code]\n```repl\n{code_to_run}\n```\n",
401 iteration + 1
402 ),
403 })
404 .await;
405
406 let round = match repl.run(&code_to_run, Some(&bridge)).await {
409 Ok(r) => r,
410 Err(e) => {
411 break 'turn RlmTurnResult {
412 answer: String::new(),
413 iterations: iteration + 1,
414 duration: start.elapsed(),
415 error: Some(format!("REPL execution failed: {e}")),
416 usage: total_usage,
417 termination: RlmTermination::Error,
418 trace: trace.clone(),
419 total_rpcs,
420 };
421 }
422 };
423
424 total_rpcs = total_rpcs.saturating_add(round.rpc_count);
425
426 let stdout_preview = truncate_text(round.stdout.trim(), STDOUT_METADATA_PREVIEW_LEN);
428 trace.push(RlmRoundTrace {
429 round: iteration + 1,
430 code_summary: summarize_code(&code_to_run),
431 stdout_preview: stdout_preview.clone(),
432 had_error: round.has_error,
433 rpc_count: round.rpc_count,
434 elapsed_ms: round.elapsed.as_millis() as u64,
435 });
436
437 let _ = tx_event
438 .send(Event::status(format!(
439 "RLM round {}: {} bytes stdout, {} sub-LLM call(s){}",
440 iteration + 1,
441 round.full_stdout.len(),
442 round.rpc_count,
443 if round.has_error { " (error)" } else { "" },
444 )))
445 .await;
446
447 if let Some(final_val) = round.final_value.clone() {
449 let _ = tx_event
450 .send(Event::status(
451 "RLM: FINAL detected in REPL, ending loop".to_string(),
452 ))
453 .await;
454 break 'turn RlmTurnResult {
455 answer: final_val,
456 iterations: iteration + 1,
457 duration: start.elapsed(),
458 error: None,
459 usage: total_usage,
460 termination: RlmTermination::Final,
461 trace: trace.clone(),
462 total_rpcs,
463 };
464 }
465
466 messages.push(Message {
468 role: "assistant".to_string(),
469 content: vec![ContentBlock::Text {
470 text: format!("```repl\n{code_to_run}\n```"),
471 cache_control: None,
472 }],
473 });
474 messages.push(build_metadata_message(
475 &prompt,
476 root_prompt.as_deref(),
477 iteration + 1,
478 Some(&code_to_run),
479 Some(&stdout_preview),
480 ));
481
482 if messages.len() > MAX_HISTORY_MESSAGES {
483 let drop_from = messages.len() - MAX_HISTORY_MESSAGES + 1;
484 let mut kept = vec![messages[0].clone()];
485 kept.extend(messages.drain(drop_from..));
486 messages = kept;
487 }
488 }
489
490 let _ = last_response_text;
491 RlmTurnResult {
492 answer: String::new(),
493 iterations: MAX_RLM_ITERATIONS,
494 duration: start.elapsed(),
495 error: Some(format!(
496 "RLM loop exhausted after {MAX_RLM_ITERATIONS} iterations without FINAL"
497 )),
498 usage: total_usage,
499 termination: RlmTermination::Exhausted,
500 trace: trace.clone(),
501 total_rpcs,
502 }
503 };
504
505 let bridge_usage = usage_handle.lock().await;
507 let mut final_usage = result.usage.clone();
508 final_usage.input_tokens = final_usage
509 .input_tokens
510 .saturating_add(bridge_usage.input_tokens);
511 final_usage.output_tokens = final_usage
512 .output_tokens
513 .saturating_add(bridge_usage.output_tokens);
514 drop(bridge_usage);
515
516 repl.shutdown().await;
517
518 RlmTurnResult {
519 usage: final_usage,
520 ..result
521 }
522}
523
524fn write_context_file(prompt: &str) -> std::io::Result<PathBuf> {
529 let dir = std::env::temp_dir().join("deepseek_rlm_ctx");
530 std::fs::create_dir_all(&dir)?;
531 let path = dir.join(format!(
532 "ctx_{}_{}.txt",
533 std::process::id(),
534 Uuid::new_v4().simple()
535 ));
536 std::fs::write(&path, prompt)?;
537 Ok(path)
538}
539
540fn build_root_request(model: &str, messages: &[Message], system: &SystemPrompt) -> MessageRequest {
541 MessageRequest {
542 model: model.to_string(),
543 messages: messages.to_vec(),
544 max_tokens: ROOT_MAX_TOKENS,
545 system: Some(system.clone()),
546 tools: None,
547 tool_choice: None,
548 metadata: None,
549 thinking: None,
550 reasoning_effort: None,
551 stream: Some(false),
552 temperature: Some(ROOT_TEMPERATURE),
553 top_p: Some(0.9_f32),
554 }
555}
556
557fn build_metadata_message(
563 prompt: &str,
564 root_prompt: Option<&str>,
565 iteration: u32,
566 previous_code: Option<&str>,
567 previous_stdout: Option<&str>,
568) -> Message {
569 let prompt_len = prompt.chars().count();
570 let prompt_preview = truncate_text(prompt, PROMPT_PREVIEW_LEN);
571
572 let mut parts = Vec::new();
573 parts.push(format!("## REPL state (round {iteration})"));
574 parts.push(String::new());
575 if let Some(rp) = root_prompt
576 && !rp.trim().is_empty()
577 {
578 parts.push("**Original task** (re-shown every round)".to_string());
579 parts.push(format!("> {}", truncate_text(rp.trim(), 600)));
580 parts.push(String::new());
581 }
582 parts.push("**`context`** — the long input lives in the REPL only".to_string());
583 parts.push(format!("- Length: {prompt_len} chars"));
584 parts.push(format!("- Preview: \"{prompt_preview}\""));
585 parts.push(String::new());
586
587 parts.push("**REPL helpers** (use inside ```repl blocks)".to_string());
588 parts.push("- `context` / `ctx` — the full input string".to_string());
589 parts.push("- `len(context)` / `context[a:b]` / `context.splitlines()` — slice it".to_string());
590 parts.push(
591 "- `llm_query(prompt, model=None)` — one-shot child LLM; `model` is ignored and child calls stay pinned to Flash"
592 .to_string(),
593 );
594 parts.push(
595 "- `llm_query_batched([p1, p2, ...])` — concurrent fan-out; `model` is ignored"
596 .to_string(),
597 );
598 parts.push(
599 "- `rlm_query(prompt, model=None)` — recursive sub-RLM; `model` is ignored"
600 .to_string(),
601 );
602 parts.push(
603 "- `rlm_query_batched([p1, p2, ...])` — concurrent recursive sub-RLMs; `model` is ignored"
604 .to_string(),
605 );
606 parts.push("- `SHOW_VARS()` — list user variables".to_string());
607 parts.push("- `repl_set(name, value)` / `repl_get(name)` — explicit store".to_string());
608 parts.push(
609 "- `FINAL(value)` — end the loop with this answer".to_string(),
610 );
611 parts.push(
612 "- `FINAL_VAR(name)` — end the loop with a variable's value"
613 .to_string(),
614 );
615 parts.push(String::new());
616
617 if iteration > 0 {
618 parts.push("**Previous round**".to_string());
619 if let Some(code) = previous_code {
620 parts.push(format!("- Code: {}", summarize_code(code)));
621 }
622 if let Some(stdout) = previous_stdout {
623 let stdout_clean = stdout.trim();
624 if !stdout_clean.is_empty() {
625 parts.push(format!("- Stdout preview: \"{stdout_clean}\""));
626 } else {
627 parts.push("- Stdout: (empty)".to_string());
628 }
629 }
630 }
631
632 let text = parts.join("\n");
633
634 Message {
635 role: "user".to_string(),
636 content: vec![ContentBlock::Text {
637 text,
638 cache_control: None,
639 }],
640 }
641}
642
643fn summarize_code(code: &str) -> String {
644 let lines: Vec<&str> = code.lines().collect();
645 if lines.len() <= 8 {
646 return code.to_string();
647 }
648 let head = lines[..4].join("\n");
649 let tail = lines[lines.len() - 4..].join("\n");
650 format!("{} lines:\n{head}\n…\n{tail}", lines.len())
651}
652
653fn extract_text_blocks(blocks: &[ContentBlock]) -> String {
654 blocks
655 .iter()
656 .filter_map(|b| match b {
657 ContentBlock::Text { text, .. } => Some(text.as_str()),
658 _ => None,
659 })
660 .collect::<Vec<_>>()
661 .join("\n")
662}
663
664fn extract_repl_code(text: &str) -> Option<String> {
668 let start_markers = [
669 "```repl\n",
670 "```repl\r\n",
671 "```python\n",
672 "```py\n",
673 "```python\r\n",
674 "```py\r\n",
675 ];
676 let mut best_start: Option<(usize, &str)> = None;
677
678 for marker in &start_markers {
679 if let Some(idx) = text.find(marker) {
680 let end_pos = idx + marker.len();
681 match best_start {
682 Some((best_idx, _)) if idx < best_idx => {
683 best_start = Some((idx, &text[end_pos..]));
684 }
685 None => {
686 best_start = Some((idx, &text[end_pos..]));
687 }
688 _ => {}
689 }
690 }
691 }
692
693 let after_fence = best_start.map(|(_, rest)| rest)?;
694
695 let end_idx = after_fence
696 .find("\n```")
697 .or_else(|| after_fence.find("```"))?;
698
699 let code = after_fence[..end_idx].trim().to_string();
700 if code.is_empty() {
701 return None;
702 }
703 Some(code)
704}
705
706fn parse_text_final(text: &str) -> Option<String> {
710 let outside_fence = strip_code_fences(text);
711
712 for line in outside_fence.lines() {
713 let trimmed = line.trim_start();
714 if trimmed.starts_with("FINAL_VAR(") {
715 continue;
717 }
718 if let Some(rest) = trimmed.strip_prefix("FINAL(") {
719 let inner = rest.trim_end();
720 if let Some(end) = inner.rfind(')') {
721 let value = inner[..end].trim();
722 if !value.is_empty() {
723 return Some(strip_quotes(value));
724 }
725 }
726 }
727 }
728 None
729}
730
731fn strip_code_fences(text: &str) -> String {
732 let mut out = String::with_capacity(text.len());
733 let mut in_fence = false;
734 for line in text.lines() {
735 if line.trim_start().starts_with("```") {
736 in_fence = !in_fence;
737 continue;
738 }
739 if !in_fence {
740 out.push_str(line);
741 out.push('\n');
742 }
743 }
744 out
745}
746
747fn strip_quotes(s: &str) -> String {
748 let bytes = s.as_bytes();
749 if bytes.len() >= 2
750 && ((bytes[0] == b'"' && bytes[bytes.len() - 1] == b'"')
751 || (bytes[0] == b'\'' && bytes[bytes.len() - 1] == b'\''))
752 {
753 return s[1..s.len() - 1].to_string();
754 }
755 s.to_string()
756}
757
758fn truncate_text(text: &str, max_chars: usize) -> String {
759 let count = text.chars().count();
760 if count <= max_chars {
761 return text.to_string();
762 }
763 let take = max_chars.saturating_sub(3);
764 let mut result: String = text.chars().take(take).collect();
765 result.push_str("...");
766 result
767}
768
769#[cfg(test)]
774mod tests {
775 use super::*;
776
777 #[test]
778 fn extract_repl_code_finds_simple_block() {
779 let text = "Here:\n```repl\nprint('hi')\n```\nEnd.";
780 let code = extract_repl_code(text).unwrap();
781 assert_eq!(code, "print('hi')");
782 }
783
784 #[test]
785 fn extract_repl_code_falls_back_to_python_marker() {
786 let text = "Code:\n```python\nx = 1 + 2\n```";
787 let code = extract_repl_code(text).unwrap();
788 assert_eq!(code, "x = 1 + 2");
789 }
790
791 #[test]
792 fn extract_repl_code_returns_none_when_missing() {
793 assert!(extract_repl_code("Just text.").is_none());
794 }
795
796 #[test]
797 fn extract_repl_code_returns_none_on_empty_block() {
798 assert!(extract_repl_code("```repl\n\n```").is_none());
799 }
800
801 #[test]
802 fn extract_repl_code_handles_multiple_blocks() {
803 let text = "```repl\na=1\n```\n```repl\nb=2\n```";
804 let code = extract_repl_code(text).unwrap();
805 assert_eq!(code, "a=1");
806 }
807
808 #[test]
809 fn extract_repl_code_ignores_other_fences() {
810 let text = "```\nfoo\n```\n```repl\nreal_code()\n```";
811 let code = extract_repl_code(text).unwrap();
812 assert_eq!(code, "real_code()");
813 }
814
815 #[test]
816 fn parse_text_final_extracts_simple_value() {
817 let text = "OK.\nFINAL(42)\nThanks.";
818 assert_eq!(parse_text_final(text).as_deref(), Some("42"));
819 }
820
821 #[test]
822 fn parse_text_final_strips_quotes() {
823 let text = "FINAL(\"the answer is yes\")";
824 assert_eq!(parse_text_final(text).as_deref(), Some("the answer is yes"));
825 }
826
827 #[test]
828 fn parse_text_final_ignores_inside_code_fence() {
829 let text =
830 "Some prose.\n```repl\n# Note: when ready, call FINAL(value)\nx = 1\n```\nMore prose.";
831 assert!(parse_text_final(text).is_none());
832 }
833
834 #[test]
835 fn parse_text_final_returns_none_when_absent() {
836 assert!(parse_text_final("just talking, no final.").is_none());
837 }
838
839 #[test]
840 fn build_metadata_contains_key_information() {
841 let msg = build_metadata_message("Hello, world!", None, 0, None, None);
842 let text = extract_text_blocks(&msg.content);
843 assert!(text.contains("context"));
844 assert!(text.contains("Hello, world!"));
845 assert!(text.contains("round 0"));
846 assert!(text.contains("llm_query"));
847 assert!(text.contains("rlm_query"));
848 assert!(text.contains("FINAL"));
849 }
850
851 #[test]
852 fn build_metadata_truncates_long_context_without_leaking_tail() {
853 let secret_tail = "DO_NOT_LEAK_CONTEXT_TAIL";
854 let prompt = format!("{}{}", "a".repeat(PROMPT_PREVIEW_LEN + 100), secret_tail);
855 let msg = build_metadata_message(&prompt, None, 0, None, None);
856 let text = extract_text_blocks(&msg.content);
857
858 assert!(text.contains(&format!("- Length: {} chars", prompt.chars().count())));
859 assert!(text.contains("- Preview: \""));
860 assert!(text.contains("..."));
861 assert!(
862 !text.contains(secret_tail),
863 "metadata leaked the non-preview tail of context"
864 );
865 }
866
867 #[test]
868 fn build_root_request_keeps_context_tail_out_of_root_payload() {
869 let secret_tail = "DO_NOT_LEAK_ROOT_REQUEST";
870 let prompt = format!("{}{}", "a".repeat(PROMPT_PREVIEW_LEN + 100), secret_tail);
871 let messages = vec![build_metadata_message(
872 &prompt,
873 Some("answer from the long context"),
874 0,
875 None,
876 None,
877 )];
878
879 let request = build_root_request("root-model", &messages, &rlm_system_prompt());
880 let payload = serde_json::to_string(&request).expect("request should serialize");
881
882 assert!(payload.contains(&format!("- Length: {} chars", prompt.chars().count())));
883 assert!(
884 !payload.contains(secret_tail),
885 "root LLM request leaked the non-preview tail of context"
886 );
887 }
888
889 #[test]
890 fn build_metadata_with_iteration_shows_previous_code() {
891 let msg = build_metadata_message("Test prompt", None, 3, Some("print('hi')"), Some("hi"));
892 let text = extract_text_blocks(&msg.content);
893 assert!(text.contains("round 3"));
894 assert!(text.contains("print('hi')"));
895 assert!(text.contains("hi"));
896 }
897
898 #[test]
899 fn build_metadata_includes_root_prompt() {
900 let msg = build_metadata_message(
901 "long context",
902 Some("Summarize the security model"),
903 1,
904 Some("# noop"),
905 Some("ok"),
906 );
907 let text = extract_text_blocks(&msg.content);
908 assert!(text.contains("Original task"));
909 assert!(text.contains("Summarize the security model"));
910 }
911
912 #[test]
913 fn truncate_text_leaves_short_alone() {
914 assert_eq!(truncate_text("hello", 100), "hello");
915 }
916
917 #[test]
918 fn truncate_text_shortens_long_text() {
919 let long = "a".repeat(1000);
920 let truncated = truncate_text(&long, 10);
921 assert_eq!(truncated.chars().count(), 10);
922 assert!(truncated.ends_with("..."));
923 }
924
925 #[test]
926 fn truncate_text_is_unicode_safe() {
927 let s = "日本語テスト";
928 let out = truncate_text(s, 4);
929 assert_eq!(out.chars().count(), 4);
930 assert!(out.ends_with("..."));
931 assert!(std::str::from_utf8(out.as_bytes()).is_ok());
932 }
933
934 #[test]
935 fn extract_text_blocks_joins_text() {
936 let blocks = vec![
937 ContentBlock::Text {
938 text: "first".to_string(),
939 cache_control: None,
940 },
941 ContentBlock::Thinking {
942 thinking: "skip".to_string(),
943 },
944 ContentBlock::Text {
945 text: "second".to_string(),
946 cache_control: None,
947 },
948 ];
949 assert_eq!(extract_text_blocks(&blocks), "first\nsecond");
950 }
951
952 #[test]
953 fn metadata_msg_role_is_user() {
954 let msg = build_metadata_message("test", None, 0, None, None);
955 assert_eq!(msg.role, "user");
956 }
957
958 #[test]
959 fn summarize_code_keeps_short() {
960 assert_eq!(summarize_code("a\nb\nc"), "a\nb\nc");
961 }
962
963 #[test]
964 fn summarize_code_compresses_long() {
965 let lines: Vec<String> = (0..20).map(|i| format!("line{i}")).collect();
966 let code = lines.join("\n");
967 let s = summarize_code(&code);
968 assert!(s.starts_with("20 lines:"));
969 assert!(s.contains("line0"));
970 assert!(s.contains("line19"));
971 assert!(s.contains("…"));
972 }
973}