1use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14
15use async_trait::async_trait;
16
17use crate::traits::output_transformer::OutputTransformer;
18use crate::traits::provider::Provider;
19use crate::types::agent_state::AgentState;
20use crate::types::completion::{CompletionRequest, ResponseContent};
21use crate::types::message::Message;
22
23pub struct BudgetAwareTruncator {
40 max_chars: usize,
41 aggressive_threshold: f32,
42}
43
44impl BudgetAwareTruncator {
45 #[must_use]
51 pub fn new(max_chars: usize, aggressive_threshold: f32) -> Self {
52 Self {
53 max_chars,
54 aggressive_threshold: aggressive_threshold.clamp(0.0, 1.0),
55 }
56 }
57}
58
59impl Default for BudgetAwareTruncator {
60 fn default() -> Self {
61 Self::new(10_000, 0.8)
62 }
63}
64
65#[async_trait]
66impl OutputTransformer for BudgetAwareTruncator {
67 async fn transform(&self, output: String, _tool_name: &str, state: &AgentState) -> String {
68 let limit = if state.context_utilization() > self.aggressive_threshold {
69 self.max_chars / 2
70 } else {
71 self.max_chars
72 };
73
74 if output.len() <= limit {
75 return output;
76 }
77
78 let truncated: String = output.chars().take(limit).collect();
80 format!(
81 "{truncated}\n\n[output truncated from {} to {limit} chars]",
82 output.len()
83 )
84 }
85}
86
87pub struct JsonExtractor;
95
96#[async_trait]
97impl OutputTransformer for JsonExtractor {
98 async fn transform(&self, output: String, _tool_name: &str, _state: &AgentState) -> String {
99 if let Some(start) = output.find('{') {
101 if let Some(end) = output.rfind('}') {
102 if end >= start {
103 return output[start..=end].to_string();
104 }
105 }
106 }
107 if let Some(start) = output.find('[') {
108 if let Some(end) = output.rfind(']') {
109 if end >= start {
110 return output[start..=end].to_string();
111 }
112 }
113 }
114 output
116 }
117}
118
119pub struct TransformerChain {
125 transformers: Vec<Box<dyn OutputTransformer>>,
126}
127
128impl TransformerChain {
129 #[must_use]
131 pub fn new(transformers: Vec<Box<dyn OutputTransformer>>) -> Self {
132 Self { transformers }
133 }
134}
135
136#[async_trait]
137impl OutputTransformer for TransformerChain {
138 async fn transform(&self, mut output: String, tool_name: &str, state: &AgentState) -> String {
139 for t in &self.transformers {
140 output = t.transform(output, tool_name, state).await;
141 }
142 output
143 }
144}
145
146const DEFAULT_SUMMARY_PROMPT: &str =
152 "Summarize the following tool output concisely, preserving all key data points and values. \
153 Be brief but complete:\n\n{output}";
154
155pub struct ProgressiveTransformer {
177 provider: Arc<dyn Provider>,
178 max_summary_length: usize,
179 summary_prompt: String,
180 cache: Arc<RwLock<HashMap<String, String>>>,
182}
183
184impl ProgressiveTransformer {
185 #[must_use]
190 pub fn new(provider: Arc<dyn Provider>, max_summary_length: usize) -> Self {
191 Self {
192 provider,
193 max_summary_length,
194 summary_prompt: DEFAULT_SUMMARY_PROMPT.to_string(),
195 cache: Arc::new(RwLock::new(HashMap::new())),
196 }
197 }
198
199 #[must_use]
203 pub fn with_summary_prompt(mut self, prompt: impl Into<String>) -> Self {
204 self.summary_prompt = prompt.into();
205 self
206 }
207
208 #[must_use]
212 pub fn retriever_tool(&self) -> FullOutputRetriever {
213 FullOutputRetriever {
214 cache: Arc::clone(&self.cache),
215 }
216 }
217
218 fn cache_output(&self, tool_name: &str, output: &str) {
220 let mut cache = self
221 .cache
222 .write()
223 .expect("ProgressiveTransformer cache lock poisoned");
224 cache.insert(tool_name.to_string(), output.to_string());
225 }
226
227 fn build_prompt(&self, output: &str) -> String {
229 self.summary_prompt.replace("{output}", output)
230 }
231}
232
233#[async_trait]
234impl OutputTransformer for ProgressiveTransformer {
235 async fn transform(&self, output: String, tool_name: &str, _state: &AgentState) -> String {
236 if output.len() <= self.max_summary_length {
238 return output;
239 }
240
241 self.cache_output(tool_name, &output);
243
244 let prompt = self.build_prompt(&output);
246 let request = CompletionRequest {
247 model: self.provider.model_info().name.clone(),
248 messages: vec![Message::user(prompt)],
249 tools: vec![],
250 max_tokens: Some(500),
251 temperature: Some(0.3),
252 response_format: None,
253 stream: false,
254 };
255
256 match self.provider.complete(request).await {
257 Ok(response) => {
258 let summary = match response.content {
260 ResponseContent::Text(t) => t,
261 ResponseContent::ToolCalls(_) => {
262 let truncated: String =
264 output.chars().take(self.max_summary_length).collect();
265 return format!(
266 "{truncated}\n\n\
267 [output truncated from {} chars — summarizer returned tool calls]",
268 output.len()
269 );
270 }
271 };
272 format!(
273 "{summary}\n\n\
274 [Full output ({} chars) cached. \
275 Call __get_full_output with {{\"tool_name\": \"{tool_name}\"}} to retrieve it.]",
276 output.len()
277 )
278 }
279 Err(e) => {
280 tracing::warn!(
282 "ProgressiveTransformer: LLM summarization failed for '{tool_name}': {e}. \
283 Falling back to truncation."
284 );
285 let truncated: String = output.chars().take(self.max_summary_length).collect();
286 format!(
287 "{truncated}\n\n\
288 [output truncated from {} chars — LLM summarization failed]",
289 output.len()
290 )
291 }
292 }
293 }
294}
295
296pub struct FullOutputRetriever {
306 cache: Arc<RwLock<HashMap<String, String>>>,
307}
308
309impl FullOutputRetriever {
310 #[must_use]
314 pub fn retrieve(&self, tool_name: &str) -> String {
315 let cache = self
316 .cache
317 .read()
318 .expect("FullOutputRetriever cache lock poisoned");
319 match cache.get(tool_name) {
320 Some(output) => output.clone(),
321 None => format!(
322 "[No cached output found for tool '{tool_name}'. \
323 The output may have expired or the tool name is incorrect.]"
324 ),
325 }
326 }
327
328 #[must_use]
330 pub fn has_cached(&self, tool_name: &str) -> bool {
331 let cache = self
332 .cache
333 .read()
334 .expect("FullOutputRetriever cache lock poisoned");
335 cache.contains_key(tool_name)
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use crate::types::model_info::ModelTier;
343
344 fn state_with_utilization(util: f64) -> AgentState {
345 let window = 1000;
346 let mut state = AgentState::new(ModelTier::Medium, window);
347 state.total_context_tokens = (util * window as f64) as usize;
348 state
349 }
350
351 #[tokio::test]
354 async fn test_budget_truncator_under_limit() {
355 let t = BudgetAwareTruncator::new(100, 0.8);
356 let state = state_with_utilization(0.5);
357 let result = t.transform("short".to_string(), "test", &state).await;
358 assert_eq!(result, "short");
359 }
360
361 #[tokio::test]
362 async fn test_budget_truncator_over_limit() {
363 let t = BudgetAwareTruncator::new(10, 0.8);
364 let state = state_with_utilization(0.5);
365 let result = t.transform("a".repeat(100), "test", &state).await;
366 assert!(result.contains("[output truncated"));
367 assert!(result.starts_with("aaaaaaaaaa"));
368 }
369
370 #[tokio::test]
371 async fn test_budget_truncator_aggressive() {
372 let t = BudgetAwareTruncator::new(20, 0.8);
373 let state = state_with_utilization(0.9); let result = t.transform("a".repeat(50), "test", &state).await;
376 assert!(result.contains("[output truncated"));
377 let first_line: &str = result.lines().next().unwrap();
379 assert_eq!(first_line.len(), 10);
380 }
381
382 #[tokio::test]
385 async fn test_json_extractor_object() {
386 let t = JsonExtractor;
387 let state = state_with_utilization(0.0);
388 let result = t
389 .transform(
390 "Here is the result: {\"key\": \"value\"} done.".to_string(),
391 "test",
392 &state,
393 )
394 .await;
395 assert_eq!(result, "{\"key\": \"value\"}");
396 }
397
398 #[tokio::test]
399 async fn test_json_extractor_array() {
400 let t = JsonExtractor;
401 let state = state_with_utilization(0.0);
402 let result = t
403 .transform("Output: [1, 2, 3] end".to_string(), "test", &state)
404 .await;
405 assert_eq!(result, "[1, 2, 3]");
406 }
407
408 #[tokio::test]
409 async fn test_json_extractor_no_json() {
410 let t = JsonExtractor;
411 let state = state_with_utilization(0.0);
412 let result = t.transform("plain text".to_string(), "test", &state).await;
413 assert_eq!(result, "plain text");
414 }
415
416 #[tokio::test]
419 async fn test_transformer_chain() {
420 let chain = TransformerChain::new(vec![
421 Box::new(JsonExtractor),
422 Box::new(BudgetAwareTruncator::new(5, 0.8)),
423 ]);
424 let state = state_with_utilization(0.5);
425 let result = chain
426 .transform(
427 "Result: {\"key\": \"long_value_here\"}".to_string(),
428 "test",
429 &state,
430 )
431 .await;
432 assert!(result.contains("[output truncated"));
434 }
435
436 use crate::types::completion::{CompletionResponse, ResponseContent, Usage};
439 use crate::types::model_info::ModelInfo;
440 use crate::types::stream::{CompletionStream, StreamEvent};
441
442 struct MockProvider {
443 info: ModelInfo,
444 response: String,
445 should_fail: bool,
446 }
447
448 impl MockProvider {
449 fn ok(response: impl Into<String>) -> Self {
450 Self {
451 info: ModelInfo {
452 name: "mock-model".to_string(),
453 tier: ModelTier::Medium,
454 context_window: 8_192,
455 supports_tools: true,
456 supports_vision: false,
457 supports_structured: false,
458 },
459 response: response.into(),
460 should_fail: false,
461 }
462 }
463 fn failing() -> Self {
464 Self {
465 info: ModelInfo {
466 name: "mock-model".to_string(),
467 tier: ModelTier::Medium,
468 context_window: 8_192,
469 supports_tools: true,
470 supports_vision: false,
471 supports_structured: false,
472 },
473 response: String::new(),
474 should_fail: true,
475 }
476 }
477 }
478
479 #[async_trait]
480 impl Provider for MockProvider {
481 async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
482 if self.should_fail {
483 return Err(crate::error::Error::Provider {
484 message: "mock failure".into(),
485 status_code: None,
486 });
487 }
488 Ok(CompletionResponse {
489 content: ResponseContent::Text(self.response.clone()),
490 usage: Usage {
491 prompt_tokens: 10,
492 completion_tokens: 5,
493 total_tokens: 15,
494 },
495 })
496 }
497
498 async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
499 use tokio_stream;
500 Ok(Box::pin(tokio_stream::once(Ok(StreamEvent::Done))))
501 }
502
503 fn model_info(&self) -> &crate::types::model_info::ModelInfo {
504 &self.info
505 }
506 }
507
508 #[tokio::test]
509 async fn test_progressive_short_output_passthrough() {
510 let provider = Arc::new(MockProvider::failing()); let transformer = ProgressiveTransformer::new(provider, 500);
513 let state = state_with_utilization(0.0);
514
515 let short = "short output".to_string();
516 let result = transformer
517 .transform(short.clone(), "my_tool", &state)
518 .await;
519 assert_eq!(result, short); }
521
522 #[tokio::test]
523 async fn test_progressive_large_output_summarized() {
524 let provider = Arc::new(MockProvider::ok("This is the summary."));
526 let transformer = ProgressiveTransformer::new(provider, 50);
527 let state = state_with_utilization(0.0);
528
529 let large_output = "x".repeat(500);
530 let result = transformer
531 .transform(large_output.clone(), "search_tool", &state)
532 .await;
533
534 assert!(result.contains("This is the summary."));
535 assert!(result.contains("__get_full_output"));
536 assert!(result.contains("search_tool"));
537
538 let retriever = transformer.retriever_tool();
540 assert!(retriever.has_cached("search_tool"));
541 assert_eq!(retriever.retrieve("search_tool"), large_output);
542 }
543
544 #[tokio::test]
545 async fn test_progressive_llm_failure_fallback() {
546 let provider = Arc::new(MockProvider::failing());
548 let transformer = ProgressiveTransformer::new(provider, 20);
549 let state = state_with_utilization(0.0);
550
551 let large_output = "a".repeat(200);
552 let result = transformer.transform(large_output, "tool_x", &state).await;
553
554 assert!(result.starts_with("aaaaaaaaaaaaaaaaaaaa"));
556 assert!(result.contains("LLM summarization failed"));
557 }
558
559 #[tokio::test]
560 async fn test_full_output_retriever_not_found() {
561 let transformer = ProgressiveTransformer::new(Arc::new(MockProvider::ok("x")), 50);
563 let retriever = transformer.retriever_tool();
564 let result = retriever.retrieve("nonexistent_tool");
565 assert!(result.contains("No cached output found"));
566 }
567
568 #[tokio::test]
569 async fn test_progressive_custom_prompt() {
570 let provider = Arc::new(MockProvider::ok("custom summary"));
571 let transformer =
572 ProgressiveTransformer::new(provider, 10).with_summary_prompt("Brief: {output}");
573 let state = state_with_utilization(0.0);
574
575 let result = transformer.transform("x".repeat(100), "t", &state).await;
576 assert!(result.contains("custom summary"));
577 }
578}