1use std::sync::Arc;
15use std::time::Duration;
16use std::{future::Future, pin::Pin};
17
18use anyhow::Result;
19use futures_util::future::join_all;
20use tokio::sync::Mutex;
21
22use crate::llm_client::LlmClient;
23use crate::models::{ContentBlock, Message, MessageRequest, MessageResponse, SystemPrompt, Usage};
24use crate::repl::runtime::{BatchResp, RpcDispatcher, RpcRequest, RpcResponse, SingleResp};
25use crate::utils::spawn_supervised;
26
27const CHILD_TIMEOUT_SECS: u64 = 120;
29const DEFAULT_CHILD_MAX_TOKENS: u32 = 4096;
31pub const MAX_BATCH: usize = 16;
33
34pub(crate) trait RlmLlmClient: Send + Sync {
40 fn create_message_boxed(
41 &self,
42 request: MessageRequest,
43 ) -> Pin<Box<dyn Future<Output = Result<MessageResponse>> + Send + '_>>;
44}
45
46impl<T> RlmLlmClient for T
47where
48 T: LlmClient + Send + Sync + ?Sized,
49{
50 fn create_message_boxed(
51 &self,
52 request: MessageRequest,
53 ) -> Pin<Box<dyn Future<Output = Result<MessageResponse>> + Send + '_>> {
54 Box::pin(self.create_message(request))
55 }
56}
57
58pub struct RlmBridge {
60 client: Arc<dyn crate::llm_client::LlmClient>,
61 child_model: String,
62 depth_remaining: u32,
65 usage: Arc<Mutex<Usage>>,
66}
67
68impl RlmBridge {
69 pub(crate) fn new(
70 client: Arc<dyn crate::llm_client::LlmClient>,
71 child_model: String,
72 depth_remaining: u32,
73 ) -> Self {
74 Self {
75 client,
76 child_model,
77 depth_remaining,
78 usage: Arc::new(Mutex::new(Usage::default())),
79 }
80 }
81
82 pub fn usage_handle(&self) -> Arc<Mutex<Usage>> {
83 Arc::clone(&self.usage)
84 }
85
86 async fn dispatch_llm(
87 &self,
88 prompt: String,
89 _model: Option<String>,
90 max_tokens: Option<u32>,
91 system: Option<String>,
92 ) -> SingleResp {
93 let request = MessageRequest {
94 model: self.child_model.clone(),
99 messages: vec![Message {
100 role: "user".to_string(),
101 content: vec![ContentBlock::Text {
102 text: prompt,
103 cache_control: None,
104 }],
105 }],
106 max_tokens: max_tokens.unwrap_or(DEFAULT_CHILD_MAX_TOKENS),
107 system: system.map(SystemPrompt::Text),
108 tools: None,
109 tool_choice: None,
110 metadata: None,
111 thinking: None,
112 reasoning_effort: None,
113 stream: Some(false),
114 temperature: Some(0.4_f32),
115 top_p: Some(0.9_f32),
116 };
117
118 let fut = self.client.create_message(request);
119 let response =
120 match tokio::time::timeout(Duration::from_secs(CHILD_TIMEOUT_SECS), fut).await {
121 Ok(Ok(r)) => r,
122 Ok(Err(e)) => {
123 return SingleResp {
124 text: String::new(),
125 error: Some(format!("llm_query failed: {e}")),
126 };
127 }
128 Err(_) => {
129 return SingleResp {
130 text: String::new(),
131 error: Some(format!("llm_query timed out after {CHILD_TIMEOUT_SECS}s")),
132 };
133 }
134 };
135
136 let text = response
137 .content
138 .iter()
139 .filter_map(|b| match b {
140 ContentBlock::Text { text, .. } => Some(text.as_str()),
141 _ => None,
142 })
143 .collect::<Vec<_>>()
144 .join("\n");
145
146 {
147 let mut u = self.usage.lock().await;
148 u.input_tokens = u.input_tokens.saturating_add(response.usage.input_tokens);
149 u.output_tokens = u.output_tokens.saturating_add(response.usage.output_tokens);
150 }
151
152 SingleResp { text, error: None }
153 }
154
155 async fn dispatch_llm_batch(&self, prompts: Vec<String>, _model: Option<String>) -> BatchResp {
156 if let Some(resp) = batch_guard(prompts.len()) {
157 return resp;
158 }
159
160 let model = Arc::new(self.child_model.clone());
161
162 let futures = prompts.into_iter().map(|prompt| {
163 let model = Arc::clone(&model);
164 async move {
165 self.dispatch_llm((*prompt).to_string(), Some((*model).clone()), None, None)
166 .await
167 }
168 });
169
170 BatchResp {
171 results: join_all(futures).await,
172 }
173 }
174
175 async fn dispatch_rlm(&self, prompt: String, _model: Option<String>) -> SingleResp {
176 if self.depth_remaining == 0 {
177 return self.dispatch_llm(prompt, None, None, None).await;
181 }
182
183 let (tx, mut rx) = tokio::sync::mpsc::channel(64);
187 let drain = spawn_supervised(
188 "rlm-bridge-drain",
189 std::panic::Location::caller(),
190 async move { while rx.recv().await.is_some() {} },
191 );
192
193 let child_model = self.child_model.clone();
194
195 let result = super::turn::run_rlm_turn_inner(
198 self.client.clone(),
199 child_model.clone(),
200 prompt,
201 None,
202 child_model,
203 tx,
204 self.depth_remaining.saturating_sub(1),
205 )
206 .await;
207
208 drain.abort();
209
210 {
211 let mut u = self.usage.lock().await;
212 u.input_tokens = u.input_tokens.saturating_add(result.usage.input_tokens);
213 u.output_tokens = u.output_tokens.saturating_add(result.usage.output_tokens);
214 }
215
216 SingleResp {
217 text: result.answer,
218 error: result.error,
219 }
220 }
221
222 async fn dispatch_rlm_batch(&self, prompts: Vec<String>, _model: Option<String>) -> BatchResp {
223 if let Some(resp) = batch_guard(prompts.len()) {
224 return resp;
225 }
226
227 let futures = prompts
228 .into_iter()
229 .map(|p| async move { self.dispatch_rlm(p, None).await });
230 BatchResp {
231 results: join_all(futures).await,
232 }
233 }
234}
235
236fn batch_guard(prompt_count: usize) -> Option<BatchResp> {
237 if prompt_count == 0 {
238 return Some(BatchResp { results: vec![] });
239 }
240 if prompt_count > MAX_BATCH {
241 return Some(BatchResp {
242 results: (0..prompt_count)
243 .map(|_| SingleResp {
244 text: String::new(),
245 error: Some(format!("batch too large: {prompt_count} > {MAX_BATCH}")),
246 })
247 .collect(),
248 });
249 }
250 None
251}
252
253impl RpcDispatcher for RlmBridge {
254 fn dispatch<'a>(
255 &'a self,
256 req: RpcRequest,
257 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = RpcResponse> + Send + 'a>> {
258 Box::pin(async move {
259 match req {
260 RpcRequest::Llm {
261 prompt,
262 model,
263 max_tokens,
264 system,
265 } => {
266 RpcResponse::Single(self.dispatch_llm(prompt, model, max_tokens, system).await)
267 }
268 RpcRequest::LlmBatch { prompts, model } => {
269 RpcResponse::Batch(self.dispatch_llm_batch(prompts, model).await)
270 }
271 RpcRequest::Rlm { prompt, model } => {
272 RpcResponse::Single(self.dispatch_rlm(prompt, model).await)
273 }
274 RpcRequest::RlmBatch { prompts, model } => {
275 RpcResponse::Batch(self.dispatch_rlm_batch(prompts, model).await)
276 }
277 }
278 })
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::llm_client::mock::MockLlmClient;
286
287 fn mock_response(text: &str, input_tokens: u32, output_tokens: u32) -> MessageResponse {
288 MessageResponse {
289 id: "mock_msg".to_string(),
290 r#type: "message".to_string(),
291 role: "assistant".to_string(),
292 content: vec![ContentBlock::Text {
293 text: text.to_string(),
294 cache_control: None,
295 }],
296 model: "mock-model".to_string(),
297 stop_reason: Some("end_turn".to_string()),
298 stop_sequence: None,
299 container: None,
300 usage: Usage {
301 input_tokens,
302 output_tokens,
303 ..Usage::default()
304 },
305 }
306 }
307
308 fn bridge_for(mock: Arc<MockLlmClient>, depth_remaining: u32) -> RlmBridge {
309 let client: Arc<dyn crate::llm_client::LlmClient> = mock;
310 RlmBridge::new(client, "child-model".to_string(), depth_remaining)
311 }
312
313 #[test]
314 fn batch_guard_allows_non_empty_batches_at_the_cap() {
315 assert!(batch_guard(MAX_BATCH).is_none());
316 }
317
318 #[test]
319 fn batch_guard_returns_empty_response_for_empty_batches() {
320 let response = batch_guard(0).expect("empty batch should be handled");
321 assert!(response.results.is_empty());
322 }
323
324 #[test]
325 fn batch_guard_returns_one_error_per_oversized_prompt() {
326 let response = batch_guard(MAX_BATCH + 2).expect("oversized batch should be handled");
327 assert_eq!(response.results.len(), MAX_BATCH + 2);
328 assert!(response.results.iter().all(|result| {
329 result.text.is_empty()
330 && result
331 .error
332 .as_deref()
333 .is_some_and(|err| err.contains("batch too large"))
334 }));
335 }
336
337 #[tokio::test]
338 async fn llm_dispatch_pins_configured_child_model() {
339 let mock = Arc::new(MockLlmClient::new(Vec::new()));
340 mock.push_message_response(mock_response("child answer", 7, 11));
341 let bridge = bridge_for(Arc::clone(&mock), 1);
342
343 let response = bridge
344 .dispatch(RpcRequest::Llm {
345 prompt: "child prompt".to_string(),
346 model: Some("override-model".to_string()),
347 max_tokens: Some(123),
348 system: Some("child system".to_string()),
349 })
350 .await;
351
352 match response {
353 RpcResponse::Single(single) => {
354 assert_eq!(single.text, "child answer");
355 assert!(single.error.is_none());
356 }
357 other => panic!("expected single response, got {other:?}"),
358 }
359
360 let captured = mock.captured_requests();
361 assert_eq!(captured.len(), 1);
362 assert_eq!(captured[0].model, "child-model");
363 assert_eq!(captured[0].max_tokens, 123);
364 assert_eq!(
365 captured[0].system,
366 Some(SystemPrompt::Text("child system".to_string()))
367 );
368
369 let usage = bridge.usage.lock().await;
370 assert_eq!(usage.input_tokens, 7);
371 assert_eq!(usage.output_tokens, 11);
372 }
373
374 #[tokio::test]
375 async fn llm_batch_dispatch_pins_configured_child_model() {
376 let mock = Arc::new(MockLlmClient::new(Vec::new()));
377 mock.push_message_response(mock_response("one", 1, 2));
378 mock.push_message_response(mock_response("two", 3, 4));
379 mock.push_message_response(mock_response("three", 5, 6));
380 let bridge = bridge_for(Arc::clone(&mock), 1);
381
382 let response = bridge
383 .dispatch(RpcRequest::LlmBatch {
384 prompts: vec!["a".to_string(), "b".to_string(), "c".to_string()],
385 model: Some("batch-model".to_string()),
386 })
387 .await;
388
389 match response {
390 RpcResponse::Batch(batch) => {
391 let texts: Vec<_> = batch
392 .results
393 .iter()
394 .map(|result| result.text.as_str())
395 .collect();
396 assert_eq!(texts, ["one", "two", "three"]);
397 assert!(batch.results.iter().all(|result| result.error.is_none()));
398 }
399 other => panic!("expected batch response, got {other:?}"),
400 }
401
402 let captured = mock.captured_requests();
403 assert_eq!(captured.len(), 3);
404 assert!(
405 captured
406 .iter()
407 .all(|request| request.model == "child-model")
408 );
409
410 let usage = bridge.usage.lock().await;
411 assert_eq!(usage.input_tokens, 9);
412 assert_eq!(usage.output_tokens, 12);
413 }
414
415 #[tokio::test]
416 async fn rlm_dispatch_at_depth_zero_pins_configured_child_model() {
417 let mock = Arc::new(MockLlmClient::new(Vec::new()));
418 mock.push_message_response(mock_response("fallback answer", 3, 5));
419 let bridge = bridge_for(Arc::clone(&mock), 0);
420
421 let response = bridge
422 .dispatch(RpcRequest::Rlm {
423 prompt: "nested prompt".to_string(),
424 model: Some("override-model".to_string()),
425 })
426 .await;
427
428 match response {
429 RpcResponse::Single(single) => {
430 assert_eq!(single.text, "fallback answer");
431 assert!(single.error.is_none());
432 }
433 other => panic!("expected single response, got {other:?}"),
434 }
435
436 let usage = bridge.usage.lock().await;
437 assert_eq!(usage.input_tokens, 3);
438 assert_eq!(usage.output_tokens, 5);
439
440 let captured = mock.captured_requests();
441 assert_eq!(captured.len(), 1);
442 assert_eq!(captured[0].model, "child-model");
443 }
444}