Skip to main content

zagens_runtime/rlm/
turn.rs

1//! RLM turn loop — paper Algorithm 1 driven over a long-lived Python
2//! subprocess + stdin/stdout RPC bridge (no HTTP sidecar).
3
4use std::path::PathBuf;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use tokio::sync::mpsc;
9use uuid::Uuid;
10
11use crate::core::events::Event;
12use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt, Usage};
13use crate::repl::PythonRuntime;
14
15use super::bridge::{RlmBridge, RlmLlmClient};
16use super::prompt::rlm_system_prompt;
17
18// ---------------------------------------------------------------------------
19// Constants
20// ---------------------------------------------------------------------------
21
22/// Maximum number of RLM iterations before the loop gives up.
23const MAX_RLM_ITERATIONS: u32 = 25;
24/// Max consecutive rounds where the model returns no `repl` fence before we
25/// hard-fail. The paper requires `code → REPL → Final`; anything else is
26/// not the RLM contract.
27const MAX_CONSECUTIVE_NO_CODE: u32 = 3;
28/// Max output tokens for the root LLM — it just needs to generate code.
29const ROOT_MAX_TOKENS: u32 = 4096;
30/// Max chars of stdout shown as metadata to the root LLM in next iteration.
31const STDOUT_METADATA_PREVIEW_LEN: usize = 800;
32/// Max chars of `context` shown as a preview in the metadata.
33const PROMPT_PREVIEW_LEN: usize = 500;
34/// Temperature for root LLM calls.
35const ROOT_TEMPERATURE: f32 = 0.3;
36/// Hard wall-clock cap on a whole RLM turn.
37const TURN_TIMEOUT: Duration = Duration::from_secs(180);
38/// Bound on conversation history we keep across iterations.
39const MAX_HISTORY_MESSAGES: usize = 20;
40
41// ---------------------------------------------------------------------------
42// Public API
43// ---------------------------------------------------------------------------
44
45/// How an RLM turn ended.
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum RlmTermination {
48    /// `FINAL(value)` was called inside the REPL or `FINAL(...)` appeared
49    /// at the top of the model's response on its own line.
50    Final,
51    /// The model failed to emit a `repl` block for too many rounds in a
52    /// row. The accumulated last response text is surfaced as the answer
53    /// rather than being thrown away.
54    NoCode,
55    /// Iteration cap reached without `FINAL`.
56    Exhausted,
57    /// Hard error — LLM call failed, REPL crashed, timeout.
58    Error,
59}
60
61/// Per-round trace entry. Surfaced in the tool result so the user can see
62/// exactly what the sub-agent did.
63#[derive(Debug, Clone)]
64pub struct RlmRoundTrace {
65    pub round: u32,
66    pub code_summary: String,
67    pub stdout_preview: String,
68    pub had_error: bool,
69    pub rpc_count: u32,
70    pub elapsed_ms: u64,
71}
72
73/// Result of an RLM turn.
74#[derive(Debug, Clone)]
75pub struct RlmTurnResult {
76    pub answer: String,
77    pub iterations: u32,
78    pub duration: Duration,
79    pub error: Option<String>,
80    pub usage: Usage,
81    pub termination: RlmTermination,
82    /// Per-round trace. Empty when the loop never reached the REPL.
83    pub trace: Vec<RlmRoundTrace>,
84    /// Total sub-LLM RPCs made by the sub-agent (sum of `rpc_count` across
85    /// rounds). Useful for verifying that the model engaged with `context`
86    /// rather than answering directly.
87    pub total_rpcs: u32,
88}
89
90/// Run a full RLM turn. `prompt` is loaded into the REPL as `context`; it
91/// never enters the root LLM's window.
92pub async fn run_rlm_turn(
93    client: Arc<dyn crate::llm_client::LlmClient>,
94    model: String,
95    prompt: String,
96    child_model: String,
97    tx_event: mpsc::Sender<Event>,
98    max_depth: u32,
99) -> RlmTurnResult {
100    run_rlm_turn_inner(
101        client.clone(),
102        model,
103        prompt,
104        None,
105        child_model,
106        tx_event,
107        max_depth,
108    )
109    .await
110}
111
112/// Variant that also passes a small `root_prompt` (the user-facing task)
113/// shown to the root LLM each iteration so it remembers its objective.
114pub async fn run_rlm_turn_with_root(
115    client: Arc<dyn crate::llm_client::LlmClient>,
116    model: String,
117    prompt: String,
118    root_prompt: Option<String>,
119    child_model: String,
120    tx_event: mpsc::Sender<Event>,
121    max_depth: u32,
122) -> RlmTurnResult {
123    run_rlm_turn_inner(
124        client.clone(),
125        model,
126        prompt,
127        root_prompt,
128        child_model,
129        tx_event,
130        max_depth,
131    )
132    .await
133}
134
135/// Inner entry point — also used by the bridge when it recurses. Returns
136/// a boxed future to break the recursive opaque-future-type cycle:
137/// `run_rlm_turn_inner` → `RlmBridge::dispatch` → `run_rlm_turn_inner`.
138pub(crate) fn run_rlm_turn_inner(
139    client: Arc<dyn crate::llm_client::LlmClient>,
140    model: String,
141    prompt: String,
142    root_prompt: Option<String>,
143    child_model: String,
144    tx_event: mpsc::Sender<Event>,
145    max_depth: u32,
146) -> std::pin::Pin<Box<dyn std::future::Future<Output = RlmTurnResult> + Send>> {
147    Box::pin(run_rlm_turn_impl(
148        client,
149        model,
150        prompt,
151        root_prompt,
152        child_model,
153        tx_event,
154        max_depth,
155    ))
156}
157
158// ---------------------------------------------------------------------------
159// Implementation
160// ---------------------------------------------------------------------------
161
162async fn run_rlm_turn_impl(
163    client: Arc<dyn crate::llm_client::LlmClient>,
164    model: String,
165    prompt: String,
166    root_prompt: Option<String>,
167    child_model: String,
168    tx_event: mpsc::Sender<Event>,
169    max_depth: u32,
170) -> RlmTurnResult {
171    let start = Instant::now();
172    let mut total_usage = Usage::default();
173    let mut trace: Vec<RlmRoundTrace> = Vec::new();
174    let mut total_rpcs: u32 = 0;
175
176    // 1. Stage `context` to a temp file. The REPL reads it on bootstrap so
177    //    the big string never enters the process command line and doesn't
178    //    show up in `ps`.
179    let ctx_path = match write_context_file(&prompt) {
180        Ok(p) => p,
181        Err(e) => {
182            return RlmTurnResult {
183                answer: String::new(),
184                iterations: 0,
185                duration: start.elapsed(),
186                error: Some(format!("rlm: failed to stage context: {e}")),
187                usage: total_usage,
188                termination: RlmTermination::Error,
189                trace,
190                total_rpcs,
191            };
192        }
193    };
194
195    // 2. Spawn the long-lived REPL.
196    let mut repl = match PythonRuntime::spawn_with_context(&ctx_path).await {
197        Ok(rt) => rt,
198        Err(e) => {
199            let _ = tokio::fs::remove_file(&ctx_path).await;
200            return RlmTurnResult {
201                answer: String::new(),
202                iterations: 0,
203                duration: start.elapsed(),
204                error: Some(format!("rlm: failed to spawn REPL: {e}")),
205                usage: total_usage,
206                termination: RlmTermination::Error,
207                trace,
208                total_rpcs,
209            };
210        }
211    };
212
213    // 3. Build the bridge that services llm_query / rlm_query RPCs.
214    let bridge = RlmBridge::new(client.clone(), child_model.clone(), max_depth);
215    let usage_handle = bridge.usage_handle();
216
217    let _ = tx_event
218        .send(Event::status(format!(
219            "RLM: spawned Python REPL (root={model}, child={child_model}, max_depth={max_depth}, ctx={} chars)",
220            prompt.chars().count()
221        )))
222        .await;
223
224    // 4. Build initial metadata-only history.
225    let system = rlm_system_prompt();
226    let mut messages: Vec<Message> = vec![build_metadata_message(
227        &prompt,
228        root_prompt.as_deref(),
229        0,
230        None,
231        None,
232    )];
233
234    let mut consecutive_no_code: u32 = 0;
235    let mut last_response_text = String::new();
236
237    let result = 'turn: {
238        for iteration in 0..MAX_RLM_ITERATIONS {
239            if start.elapsed() > TURN_TIMEOUT {
240                break 'turn RlmTurnResult {
241                    answer: String::new(),
242                    iterations: iteration,
243                    duration: start.elapsed(),
244                    error: Some(format!(
245                        "RLM turn timed out after {}s",
246                        TURN_TIMEOUT.as_secs()
247                    )),
248                    usage: total_usage,
249                    termination: RlmTermination::Error,
250                    trace: trace.clone(),
251                    total_rpcs,
252                };
253            }
254
255            let _ = tx_event
256                .send(Event::status(format!(
257                    "RLM iteration {}/{}",
258                    iteration + 1,
259                    MAX_RLM_ITERATIONS
260                )))
261                .await;
262
263            // 4a. Root LLM generates code from metadata-only context.
264            let request = build_root_request(&model, &messages, &system);
265
266            let response = match client.create_message_boxed(request).await {
267                Ok(r) => r,
268                Err(e) => {
269                    break 'turn RlmTurnResult {
270                        answer: String::new(),
271                        iterations: iteration + 1,
272                        duration: start.elapsed(),
273                        error: Some(format!("Root LLM call failed: {e}")),
274                        usage: total_usage,
275                        termination: RlmTermination::Error,
276                        trace: trace.clone(),
277                        total_rpcs,
278                    };
279                }
280            };
281
282            total_usage.input_tokens = total_usage
283                .input_tokens
284                .saturating_add(response.usage.input_tokens);
285            total_usage.output_tokens = total_usage
286                .output_tokens
287                .saturating_add(response.usage.output_tokens);
288
289            let response_text = extract_text_blocks(&response.content);
290            last_response_text = response_text.clone();
291
292            // 4b. Top-level FINAL(...) lets the model close out without
293            //     touching the REPL — but only if it has done some work
294            //     (non-zero rpc_count) on a prior round. Otherwise it's a
295            //     shortcut and we reject it.
296            if let Some(final_val) = parse_text_final(&response_text) {
297                if total_rpcs == 0 {
298                    // Discard the top-level FINAL — the model is bypassing
299                    // the loop. Force it to use the REPL by appending a
300                    // strict reminder.
301                    consecutive_no_code = consecutive_no_code.saturating_add(1);
302                    if consecutive_no_code >= MAX_CONSECUTIVE_NO_CODE {
303                        break 'turn RlmTurnResult {
304                            answer: final_val,
305                            iterations: iteration + 1,
306                            duration: start.elapsed(),
307                            error: None,
308                            usage: total_usage,
309                            termination: RlmTermination::NoCode,
310                            trace: trace.clone(),
311                            total_rpcs,
312                        };
313                    }
314                    messages.push(Message {
315                        role: "assistant".to_string(),
316                        content: vec![ContentBlock::Text {
317                            text: response_text.clone(),
318                            cache_control: None,
319                        }],
320                    });
321                    messages.push(Message {
322                        role: "user".to_string(),
323                        content: vec![ContentBlock::Text {
324                            text: "You called FINAL(...) without ever running a ```repl block. \
325                                   That defeats the recursive language model — you're guessing \
326                                   from the preview alone. Emit a ```repl block now that uses \
327                                   `llm_query`, `llm_query_batched`, or `rlm_query` against \
328                                   `context` to actually compute the answer."
329                                .to_string(),
330                            cache_control: None,
331                        }],
332                    });
333                    continue;
334                }
335                let _ = tx_event
336                    .send(Event::status(
337                        "RLM: FINAL detected in response text".to_string(),
338                    ))
339                    .await;
340                break 'turn RlmTurnResult {
341                    answer: final_val,
342                    iterations: iteration + 1,
343                    duration: start.elapsed(),
344                    error: None,
345                    usage: total_usage,
346                    termination: RlmTermination::Final,
347                    trace: trace.clone(),
348                    total_rpcs,
349                };
350            }
351
352            // 4c. Extract a ```repl block.
353            let code = extract_repl_code(&response_text);
354            let code_to_run = match code {
355                Some(c) => {
356                    consecutive_no_code = 0;
357                    c
358                }
359                None => {
360                    consecutive_no_code = consecutive_no_code.saturating_add(1);
361                    if consecutive_no_code >= MAX_CONSECUTIVE_NO_CODE {
362                        break 'turn RlmTurnResult {
363                            answer: response_text,
364                            iterations: iteration + 1,
365                            duration: start.elapsed(),
366                            error: Some(format!(
367                                "RLM: model failed to emit ```repl after {MAX_CONSECUTIVE_NO_CODE} consecutive rounds"
368                            )),
369                            usage: total_usage,
370                            termination: RlmTermination::NoCode,
371                            trace: trace.clone(),
372                            total_rpcs,
373                        };
374                    }
375                    messages.push(Message {
376                        role: "assistant".to_string(),
377                        content: vec![ContentBlock::Text {
378                            text: response_text.clone(),
379                            cache_control: None,
380                        }],
381                    });
382                    messages.push(Message {
383                        role: "user".to_string(),
384                        content: vec![ContentBlock::Text {
385                            text: "Reminder: emit Python inside a ```repl … ``` fence. \
386                                   Use `llm_query` / `llm_query_batched` / `rlm_query` to \
387                                   process `context` and call `FINAL(value)` when done."
388                                .to_string(),
389                            cache_control: None,
390                        }],
391                    });
392                    continue;
393                }
394            };
395
396            let _ = tx_event
397                .send(Event::MessageDelta {
398                    index: iteration as usize,
399                    content: format!(
400                        "\n[RLM round {} — code]\n```repl\n{code_to_run}\n```\n",
401                        iteration + 1
402                    ),
403                })
404                .await;
405
406            // 4d. Execute the code in the REPL with the bridge servicing
407            //     llm_query / rlm_query callbacks.
408            let round = match repl.run(&code_to_run, Some(&bridge)).await {
409                Ok(r) => r,
410                Err(e) => {
411                    break 'turn RlmTurnResult {
412                        answer: String::new(),
413                        iterations: iteration + 1,
414                        duration: start.elapsed(),
415                        error: Some(format!("REPL execution failed: {e}")),
416                        usage: total_usage,
417                        termination: RlmTermination::Error,
418                        trace: trace.clone(),
419                        total_rpcs,
420                    };
421                }
422            };
423
424            total_rpcs = total_rpcs.saturating_add(round.rpc_count);
425
426            // Trace this round.
427            let stdout_preview = truncate_text(round.stdout.trim(), STDOUT_METADATA_PREVIEW_LEN);
428            trace.push(RlmRoundTrace {
429                round: iteration + 1,
430                code_summary: summarize_code(&code_to_run),
431                stdout_preview: stdout_preview.clone(),
432                had_error: round.has_error,
433                rpc_count: round.rpc_count,
434                elapsed_ms: round.elapsed.as_millis() as u64,
435            });
436
437            let _ = tx_event
438                .send(Event::status(format!(
439                    "RLM round {}: {} bytes stdout, {} sub-LLM call(s){}",
440                    iteration + 1,
441                    round.full_stdout.len(),
442                    round.rpc_count,
443                    if round.has_error { " (error)" } else { "" },
444                )))
445                .await;
446
447            // 4e. FINAL detection.
448            if let Some(final_val) = round.final_value.clone() {
449                let _ = tx_event
450                    .send(Event::status(
451                        "RLM: FINAL detected in REPL, ending loop".to_string(),
452                    ))
453                    .await;
454                break 'turn RlmTurnResult {
455                    answer: final_val,
456                    iterations: iteration + 1,
457                    duration: start.elapsed(),
458                    error: None,
459                    usage: total_usage,
460                    termination: RlmTermination::Final,
461                    trace: trace.clone(),
462                    total_rpcs,
463                };
464            }
465
466            // 4f. Build metadata for next iteration.
467            messages.push(Message {
468                role: "assistant".to_string(),
469                content: vec![ContentBlock::Text {
470                    text: format!("```repl\n{code_to_run}\n```"),
471                    cache_control: None,
472                }],
473            });
474            messages.push(build_metadata_message(
475                &prompt,
476                root_prompt.as_deref(),
477                iteration + 1,
478                Some(&code_to_run),
479                Some(&stdout_preview),
480            ));
481
482            if messages.len() > MAX_HISTORY_MESSAGES {
483                let drop_from = messages.len() - MAX_HISTORY_MESSAGES + 1;
484                let mut kept = vec![messages[0].clone()];
485                kept.extend(messages.drain(drop_from..));
486                messages = kept;
487            }
488        }
489
490        let _ = last_response_text;
491        RlmTurnResult {
492            answer: String::new(),
493            iterations: MAX_RLM_ITERATIONS,
494            duration: start.elapsed(),
495            error: Some(format!(
496                "RLM loop exhausted after {MAX_RLM_ITERATIONS} iterations without FINAL"
497            )),
498            usage: total_usage,
499            termination: RlmTermination::Exhausted,
500            trace: trace.clone(),
501            total_rpcs,
502        }
503    };
504
505    // Fold bridge usage (children + nested sub_rlm) into totals.
506    let bridge_usage = usage_handle.lock().await;
507    let mut final_usage = result.usage.clone();
508    final_usage.input_tokens = final_usage
509        .input_tokens
510        .saturating_add(bridge_usage.input_tokens);
511    final_usage.output_tokens = final_usage
512        .output_tokens
513        .saturating_add(bridge_usage.output_tokens);
514    drop(bridge_usage);
515
516    repl.shutdown().await;
517
518    RlmTurnResult {
519        usage: final_usage,
520        ..result
521    }
522}
523
524// ---------------------------------------------------------------------------
525// Helpers
526// ---------------------------------------------------------------------------
527
528fn write_context_file(prompt: &str) -> std::io::Result<PathBuf> {
529    let dir = std::env::temp_dir().join("deepseek_rlm_ctx");
530    std::fs::create_dir_all(&dir)?;
531    let path = dir.join(format!(
532        "ctx_{}_{}.txt",
533        std::process::id(),
534        Uuid::new_v4().simple()
535    ));
536    std::fs::write(&path, prompt)?;
537    Ok(path)
538}
539
540fn build_root_request(model: &str, messages: &[Message], system: &SystemPrompt) -> MessageRequest {
541    MessageRequest {
542        model: model.to_string(),
543        messages: messages.to_vec(),
544        max_tokens: ROOT_MAX_TOKENS,
545        system: Some(system.clone()),
546        tools: None,
547        tool_choice: None,
548        metadata: None,
549        thinking: None,
550        reasoning_effort: None,
551        stream: Some(false),
552        temperature: Some(ROOT_TEMPERATURE),
553        top_p: Some(0.9_f32),
554    }
555}
556
557/// Build `Metadata(state)` from the paper. Surfaces:
558/// - the small `root_prompt` (if any) — repeated each iteration
559/// - `context` length + preview
560/// - the REPL helpers
561/// - the previous round's code summary + stdout preview
562fn build_metadata_message(
563    prompt: &str,
564    root_prompt: Option<&str>,
565    iteration: u32,
566    previous_code: Option<&str>,
567    previous_stdout: Option<&str>,
568) -> Message {
569    let prompt_len = prompt.chars().count();
570    let prompt_preview = truncate_text(prompt, PROMPT_PREVIEW_LEN);
571
572    let mut parts = Vec::new();
573    parts.push(format!("## REPL state (round {iteration})"));
574    parts.push(String::new());
575    if let Some(rp) = root_prompt
576        && !rp.trim().is_empty()
577    {
578        parts.push("**Original task** (re-shown every round)".to_string());
579        parts.push(format!("> {}", truncate_text(rp.trim(), 600)));
580        parts.push(String::new());
581    }
582    parts.push("**`context`** — the long input lives in the REPL only".to_string());
583    parts.push(format!("- Length: {prompt_len} chars"));
584    parts.push(format!("- Preview: \"{prompt_preview}\""));
585    parts.push(String::new());
586
587    parts.push("**REPL helpers** (use inside ```repl blocks)".to_string());
588    parts.push("- `context` / `ctx`                       — the full input string".to_string());
589    parts.push("- `len(context)` / `context[a:b]` / `context.splitlines()` — slice it".to_string());
590    parts.push(
591        "- `llm_query(prompt, model=None)`        — one-shot child LLM; `model` is ignored and child calls stay pinned to Flash"
592            .to_string(),
593    );
594    parts.push(
595        "- `llm_query_batched([p1, p2, ...])`     — concurrent fan-out; `model` is ignored"
596            .to_string(),
597    );
598    parts.push(
599        "- `rlm_query(prompt, model=None)`        — recursive sub-RLM; `model` is ignored"
600            .to_string(),
601    );
602    parts.push(
603        "- `rlm_query_batched([p1, p2, ...])`     — concurrent recursive sub-RLMs; `model` is ignored"
604            .to_string(),
605    );
606    parts.push("- `SHOW_VARS()`                          — list user variables".to_string());
607    parts.push("- `repl_set(name, value)` / `repl_get(name)` — explicit store".to_string());
608    parts.push(
609        "- `FINAL(value)`                         — end the loop with this answer".to_string(),
610    );
611    parts.push(
612        "- `FINAL_VAR(name)`                      — end the loop with a variable's value"
613            .to_string(),
614    );
615    parts.push(String::new());
616
617    if iteration > 0 {
618        parts.push("**Previous round**".to_string());
619        if let Some(code) = previous_code {
620            parts.push(format!("- Code: {}", summarize_code(code)));
621        }
622        if let Some(stdout) = previous_stdout {
623            let stdout_clean = stdout.trim();
624            if !stdout_clean.is_empty() {
625                parts.push(format!("- Stdout preview: \"{stdout_clean}\""));
626            } else {
627                parts.push("- Stdout: (empty)".to_string());
628            }
629        }
630    }
631
632    let text = parts.join("\n");
633
634    Message {
635        role: "user".to_string(),
636        content: vec![ContentBlock::Text {
637            text,
638            cache_control: None,
639        }],
640    }
641}
642
643fn summarize_code(code: &str) -> String {
644    let lines: Vec<&str> = code.lines().collect();
645    if lines.len() <= 8 {
646        return code.to_string();
647    }
648    let head = lines[..4].join("\n");
649    let tail = lines[lines.len() - 4..].join("\n");
650    format!("{} lines:\n{head}\n…\n{tail}", lines.len())
651}
652
653fn extract_text_blocks(blocks: &[ContentBlock]) -> String {
654    blocks
655        .iter()
656        .filter_map(|b| match b {
657            ContentBlock::Text { text, .. } => Some(text.as_str()),
658            _ => None,
659        })
660        .collect::<Vec<_>>()
661        .join("\n")
662}
663
664/// Extract the first ` ```repl ` block from `text`. Falls back to
665/// ` ```python `/`` ```py `` for compatibility with prompts that learned
666/// the older fence style.
667fn extract_repl_code(text: &str) -> Option<String> {
668    let start_markers = [
669        "```repl\n",
670        "```repl\r\n",
671        "```python\n",
672        "```py\n",
673        "```python\r\n",
674        "```py\r\n",
675    ];
676    let mut best_start: Option<(usize, &str)> = None;
677
678    for marker in &start_markers {
679        if let Some(idx) = text.find(marker) {
680            let end_pos = idx + marker.len();
681            match best_start {
682                Some((best_idx, _)) if idx < best_idx => {
683                    best_start = Some((idx, &text[end_pos..]));
684                }
685                None => {
686                    best_start = Some((idx, &text[end_pos..]));
687                }
688                _ => {}
689            }
690        }
691    }
692
693    let after_fence = best_start.map(|(_, rest)| rest)?;
694
695    let end_idx = after_fence
696        .find("\n```")
697        .or_else(|| after_fence.find("```"))?;
698
699    let code = after_fence[..end_idx].trim().to_string();
700    if code.is_empty() {
701        return None;
702    }
703    Some(code)
704}
705
706/// Parse a top-level `FINAL(...)` directive from the model's raw text.
707/// Mirrors the reference RLM's `find_final_answer`: directive must appear
708/// at the start of a line, *outside* any code fence.
709fn parse_text_final(text: &str) -> Option<String> {
710    let outside_fence = strip_code_fences(text);
711
712    for line in outside_fence.lines() {
713        let trimmed = line.trim_start();
714        if trimmed.starts_with("FINAL_VAR(") {
715            // FINAL_VAR can't be resolved from text alone — defer to REPL.
716            continue;
717        }
718        if let Some(rest) = trimmed.strip_prefix("FINAL(") {
719            let inner = rest.trim_end();
720            if let Some(end) = inner.rfind(')') {
721                let value = inner[..end].trim();
722                if !value.is_empty() {
723                    return Some(strip_quotes(value));
724                }
725            }
726        }
727    }
728    None
729}
730
731fn strip_code_fences(text: &str) -> String {
732    let mut out = String::with_capacity(text.len());
733    let mut in_fence = false;
734    for line in text.lines() {
735        if line.trim_start().starts_with("```") {
736            in_fence = !in_fence;
737            continue;
738        }
739        if !in_fence {
740            out.push_str(line);
741            out.push('\n');
742        }
743    }
744    out
745}
746
747fn strip_quotes(s: &str) -> String {
748    let bytes = s.as_bytes();
749    if bytes.len() >= 2
750        && ((bytes[0] == b'"' && bytes[bytes.len() - 1] == b'"')
751            || (bytes[0] == b'\'' && bytes[bytes.len() - 1] == b'\''))
752    {
753        return s[1..s.len() - 1].to_string();
754    }
755    s.to_string()
756}
757
758fn truncate_text(text: &str, max_chars: usize) -> String {
759    let count = text.chars().count();
760    if count <= max_chars {
761        return text.to_string();
762    }
763    let take = max_chars.saturating_sub(3);
764    let mut result: String = text.chars().take(take).collect();
765    result.push_str("...");
766    result
767}
768
769// ---------------------------------------------------------------------------
770// Tests
771// ---------------------------------------------------------------------------
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776
777    #[test]
778    fn extract_repl_code_finds_simple_block() {
779        let text = "Here:\n```repl\nprint('hi')\n```\nEnd.";
780        let code = extract_repl_code(text).unwrap();
781        assert_eq!(code, "print('hi')");
782    }
783
784    #[test]
785    fn extract_repl_code_falls_back_to_python_marker() {
786        let text = "Code:\n```python\nx = 1 + 2\n```";
787        let code = extract_repl_code(text).unwrap();
788        assert_eq!(code, "x = 1 + 2");
789    }
790
791    #[test]
792    fn extract_repl_code_returns_none_when_missing() {
793        assert!(extract_repl_code("Just text.").is_none());
794    }
795
796    #[test]
797    fn extract_repl_code_returns_none_on_empty_block() {
798        assert!(extract_repl_code("```repl\n\n```").is_none());
799    }
800
801    #[test]
802    fn extract_repl_code_handles_multiple_blocks() {
803        let text = "```repl\na=1\n```\n```repl\nb=2\n```";
804        let code = extract_repl_code(text).unwrap();
805        assert_eq!(code, "a=1");
806    }
807
808    #[test]
809    fn extract_repl_code_ignores_other_fences() {
810        let text = "```\nfoo\n```\n```repl\nreal_code()\n```";
811        let code = extract_repl_code(text).unwrap();
812        assert_eq!(code, "real_code()");
813    }
814
815    #[test]
816    fn parse_text_final_extracts_simple_value() {
817        let text = "OK.\nFINAL(42)\nThanks.";
818        assert_eq!(parse_text_final(text).as_deref(), Some("42"));
819    }
820
821    #[test]
822    fn parse_text_final_strips_quotes() {
823        let text = "FINAL(\"the answer is yes\")";
824        assert_eq!(parse_text_final(text).as_deref(), Some("the answer is yes"));
825    }
826
827    #[test]
828    fn parse_text_final_ignores_inside_code_fence() {
829        let text =
830            "Some prose.\n```repl\n# Note: when ready, call FINAL(value)\nx = 1\n```\nMore prose.";
831        assert!(parse_text_final(text).is_none());
832    }
833
834    #[test]
835    fn parse_text_final_returns_none_when_absent() {
836        assert!(parse_text_final("just talking, no final.").is_none());
837    }
838
839    #[test]
840    fn build_metadata_contains_key_information() {
841        let msg = build_metadata_message("Hello, world!", None, 0, None, None);
842        let text = extract_text_blocks(&msg.content);
843        assert!(text.contains("context"));
844        assert!(text.contains("Hello, world!"));
845        assert!(text.contains("round 0"));
846        assert!(text.contains("llm_query"));
847        assert!(text.contains("rlm_query"));
848        assert!(text.contains("FINAL"));
849    }
850
851    #[test]
852    fn build_metadata_truncates_long_context_without_leaking_tail() {
853        let secret_tail = "DO_NOT_LEAK_CONTEXT_TAIL";
854        let prompt = format!("{}{}", "a".repeat(PROMPT_PREVIEW_LEN + 100), secret_tail);
855        let msg = build_metadata_message(&prompt, None, 0, None, None);
856        let text = extract_text_blocks(&msg.content);
857
858        assert!(text.contains(&format!("- Length: {} chars", prompt.chars().count())));
859        assert!(text.contains("- Preview: \""));
860        assert!(text.contains("..."));
861        assert!(
862            !text.contains(secret_tail),
863            "metadata leaked the non-preview tail of context"
864        );
865    }
866
867    #[test]
868    fn build_root_request_keeps_context_tail_out_of_root_payload() {
869        let secret_tail = "DO_NOT_LEAK_ROOT_REQUEST";
870        let prompt = format!("{}{}", "a".repeat(PROMPT_PREVIEW_LEN + 100), secret_tail);
871        let messages = vec![build_metadata_message(
872            &prompt,
873            Some("answer from the long context"),
874            0,
875            None,
876            None,
877        )];
878
879        let request = build_root_request("root-model", &messages, &rlm_system_prompt());
880        let payload = serde_json::to_string(&request).expect("request should serialize");
881
882        assert!(payload.contains(&format!("- Length: {} chars", prompt.chars().count())));
883        assert!(
884            !payload.contains(secret_tail),
885            "root LLM request leaked the non-preview tail of context"
886        );
887    }
888
889    #[test]
890    fn build_metadata_with_iteration_shows_previous_code() {
891        let msg = build_metadata_message("Test prompt", None, 3, Some("print('hi')"), Some("hi"));
892        let text = extract_text_blocks(&msg.content);
893        assert!(text.contains("round 3"));
894        assert!(text.contains("print('hi')"));
895        assert!(text.contains("hi"));
896    }
897
898    #[test]
899    fn build_metadata_includes_root_prompt() {
900        let msg = build_metadata_message(
901            "long context",
902            Some("Summarize the security model"),
903            1,
904            Some("# noop"),
905            Some("ok"),
906        );
907        let text = extract_text_blocks(&msg.content);
908        assert!(text.contains("Original task"));
909        assert!(text.contains("Summarize the security model"));
910    }
911
912    #[test]
913    fn truncate_text_leaves_short_alone() {
914        assert_eq!(truncate_text("hello", 100), "hello");
915    }
916
917    #[test]
918    fn truncate_text_shortens_long_text() {
919        let long = "a".repeat(1000);
920        let truncated = truncate_text(&long, 10);
921        assert_eq!(truncated.chars().count(), 10);
922        assert!(truncated.ends_with("..."));
923    }
924
925    #[test]
926    fn truncate_text_is_unicode_safe() {
927        let s = "日本語テスト";
928        let out = truncate_text(s, 4);
929        assert_eq!(out.chars().count(), 4);
930        assert!(out.ends_with("..."));
931        assert!(std::str::from_utf8(out.as_bytes()).is_ok());
932    }
933
934    #[test]
935    fn extract_text_blocks_joins_text() {
936        let blocks = vec![
937            ContentBlock::Text {
938                text: "first".to_string(),
939                cache_control: None,
940            },
941            ContentBlock::Thinking {
942                thinking: "skip".to_string(),
943            },
944            ContentBlock::Text {
945                text: "second".to_string(),
946                cache_control: None,
947            },
948        ];
949        assert_eq!(extract_text_blocks(&blocks), "first\nsecond");
950    }
951
952    #[test]
953    fn metadata_msg_role_is_user() {
954        let msg = build_metadata_message("test", None, 0, None, None);
955        assert_eq!(msg.role, "user");
956    }
957
958    #[test]
959    fn summarize_code_keeps_short() {
960        assert_eq!(summarize_code("a\nb\nc"), "a\nb\nc");
961    }
962
963    #[test]
964    fn summarize_code_compresses_long() {
965        let lines: Vec<String> = (0..20).map(|i| format!("line{i}")).collect();
966        let code = lines.join("\n");
967        let s = summarize_code(&code);
968        assert!(s.starts_with("20 lines:"));
969        assert!(s.contains("line0"));
970        assert!(s.contains("line19"));
971        assert!(s.contains("…"));
972    }
973}