Skip to main content

traitclaw_core/
transformers.rs

1//! Built-in [`OutputTransformer`] implementations for common use cases.
2//!
3//! These transformers can be used directly or composed for more complex processing.
4//!
5//! # Available Transformers
6//!
7//! - [`BudgetAwareTruncator`] — truncate by char count based on context utilization
8//! - [`JsonExtractor`] — extract JSON from verbose output
9//! - [`TransformerChain`] — chain multiple transformers
10//! - [`ProgressiveTransformer`] — summarize large outputs; full output on demand
11
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14
15use async_trait::async_trait;
16
17use crate::traits::output_transformer::OutputTransformer;
18use crate::traits::provider::Provider;
19use crate::types::agent_state::AgentState;
20use crate::types::completion::{CompletionRequest, ResponseContent};
21use crate::types::message::Message;
22
23// ===========================================================================
24// BudgetAwareTruncator
25// ===========================================================================
26
27/// Truncates output to a maximum character count, respecting context utilization.
28///
29/// When context utilization exceeds the `aggressive_threshold`, the limit is
30/// halved to preserve context budget.
31///
32/// # Example
33///
34/// ```rust
35/// use traitclaw_core::transformers::BudgetAwareTruncator;
36///
37/// let t = BudgetAwareTruncator::new(1000, 0.8);
38/// ```
39pub struct BudgetAwareTruncator {
40    max_chars: usize,
41    aggressive_threshold: f32,
42}
43
44impl BudgetAwareTruncator {
45    /// Create a new truncator.
46    ///
47    /// - `max_chars`: Maximum output length in characters.
48    /// - `aggressive_threshold`: Context utilization (0.0–1.0) above which
49    ///   truncation becomes more aggressive (halved limit).
50    #[must_use]
51    pub fn new(max_chars: usize, aggressive_threshold: f32) -> Self {
52        Self {
53            max_chars,
54            aggressive_threshold: aggressive_threshold.clamp(0.0, 1.0),
55        }
56    }
57}
58
59impl Default for BudgetAwareTruncator {
60    fn default() -> Self {
61        Self::new(10_000, 0.8)
62    }
63}
64
65#[async_trait]
66impl OutputTransformer for BudgetAwareTruncator {
67    async fn transform(&self, output: String, _tool_name: &str, state: &AgentState) -> String {
68        let limit = if state.context_utilization() > self.aggressive_threshold {
69            self.max_chars / 2
70        } else {
71            self.max_chars
72        };
73
74        if output.len() <= limit {
75            return output;
76        }
77
78        // Truncate at char boundary
79        let truncated: String = output.chars().take(limit).collect();
80        format!(
81            "{truncated}\n\n[output truncated from {} to {limit} chars]",
82            output.len()
83        )
84    }
85}
86
87// ===========================================================================
88// JsonExtractor
89// ===========================================================================
90
91/// Extracts JSON from tool output, discarding surrounding text.
92///
93/// Useful for tools that embed JSON in verbose output.
94pub struct JsonExtractor;
95
96#[async_trait]
97impl OutputTransformer for JsonExtractor {
98    async fn transform(&self, output: String, _tool_name: &str, _state: &AgentState) -> String {
99        // Try to find JSON object or array in the output
100        if let Some(start) = output.find('{') {
101            if let Some(end) = output.rfind('}') {
102                if end >= start {
103                    return output[start..=end].to_string();
104                }
105            }
106        }
107        if let Some(start) = output.find('[') {
108            if let Some(end) = output.rfind(']') {
109                if end >= start {
110                    return output[start..=end].to_string();
111                }
112            }
113        }
114        // No JSON found, return as-is
115        output
116    }
117}
118
119// ===========================================================================
120// TransformerChain
121// ===========================================================================
122
123/// Pipes output through multiple transformers in order.
124pub struct TransformerChain {
125    transformers: Vec<Box<dyn OutputTransformer>>,
126}
127
128impl TransformerChain {
129    /// Create a chain from a list of transformers.
130    #[must_use]
131    pub fn new(transformers: Vec<Box<dyn OutputTransformer>>) -> Self {
132        Self { transformers }
133    }
134}
135
136#[async_trait]
137impl OutputTransformer for TransformerChain {
138    async fn transform(&self, mut output: String, tool_name: &str, state: &AgentState) -> String {
139        for t in &self.transformers {
140            output = t.transform(output, tool_name, state).await;
141        }
142        output
143    }
144}
145
146// ===========================================================================
147// ProgressiveTransformer
148// ===========================================================================
149
150/// Default summarization prompt template.
151const DEFAULT_SUMMARY_PROMPT: &str =
152    "Summarize the following tool output concisely, preserving all key data points and values. \
153     Be brief but complete:\n\n{output}";
154
155/// A two-phase output transformer that returns an **LLM-generated summary** first,
156/// with the **full output** cached and available on demand via the
157/// `__get_full_output` virtual tool.
158///
159/// # Workflow
160///
161/// 1. Output arrives from a tool call.
162/// 2. If `output.len() <= max_summary_length` → returned unchanged (no LLM call).
163/// 3. If larger → LLM is called to summarize → summary returned + note appended.
164/// 4. Full output cached internally keyed by `tool_name`.
165/// 5. Agent can call `__get_full_output` → [`FullOutputRetriever`] serves it.
166///
167/// # Example
168///
169/// ```rust,no_run
170/// use traitclaw_core::transformers::ProgressiveTransformer;
171/// use std::sync::Arc;
172///
173/// // let transformer = ProgressiveTransformer::new(provider.clone(), 500)
174/// //     .with_summary_prompt("Give a one-sentence summary: {output}");
175/// ```
176pub struct ProgressiveTransformer {
177    provider: Arc<dyn Provider>,
178    max_summary_length: usize,
179    summary_prompt: String,
180    /// Cache: tool_name → full output
181    cache: Arc<RwLock<HashMap<String, String>>>,
182}
183
184impl ProgressiveTransformer {
185    /// Create a new progressive transformer.
186    ///
187    /// - `provider`: LLM used to generate summaries.
188    /// - `max_summary_length`: Outputs shorter than this are passed through unchanged.
189    #[must_use]
190    pub fn new(provider: Arc<dyn Provider>, max_summary_length: usize) -> Self {
191        Self {
192            provider,
193            max_summary_length,
194            summary_prompt: DEFAULT_SUMMARY_PROMPT.to_string(),
195            cache: Arc::new(RwLock::new(HashMap::new())),
196        }
197    }
198
199    /// Override the summarization prompt template.
200    ///
201    /// Use `{output}` as the placeholder for the tool output.
202    #[must_use]
203    pub fn with_summary_prompt(mut self, prompt: impl Into<String>) -> Self {
204        self.summary_prompt = prompt.into();
205        self
206    }
207
208    /// Build a [`FullOutputRetriever`] that reads from this transformer's cache.
209    ///
210    /// Register this tool with the agent so the LLM can call `__get_full_output`.
211    #[must_use]
212    pub fn retriever_tool(&self) -> FullOutputRetriever {
213        FullOutputRetriever {
214            cache: Arc::clone(&self.cache),
215        }
216    }
217
218    /// Store full output in cache keyed by tool name.
219    fn cache_output(&self, tool_name: &str, output: &str) {
220        let mut cache = self
221            .cache
222            .write()
223            .expect("ProgressiveTransformer cache lock poisoned");
224        cache.insert(tool_name.to_string(), output.to_string());
225    }
226
227    /// Build the prompt with the output injected.
228    fn build_prompt(&self, output: &str) -> String {
229        self.summary_prompt.replace("{output}", output)
230    }
231}
232
233#[async_trait]
234impl OutputTransformer for ProgressiveTransformer {
235    async fn transform(&self, output: String, tool_name: &str, _state: &AgentState) -> String {
236        // AC #7: short output → pass through unchanged
237        if output.len() <= self.max_summary_length {
238            return output;
239        }
240
241        // Cache the full output for later retrieval
242        self.cache_output(tool_name, &output);
243
244        // AC #2: call LLM to summarize
245        let prompt = self.build_prompt(&output);
246        let request = CompletionRequest {
247            model: self.provider.model_info().name.clone(),
248            messages: vec![Message::user(prompt)],
249            tools: vec![],
250            max_tokens: Some(500),
251            temperature: Some(0.3),
252            response_format: None,
253            stream: false,
254        };
255
256        match self.provider.complete(request).await {
257            Ok(response) => {
258                // AC #2: return summary + footer note
259                let summary = match response.content {
260                    ResponseContent::Text(t) => t,
261                    ResponseContent::ToolCalls(_) => {
262                        // Unexpected tool calls from summarizer — fallback
263                        let truncated: String =
264                            output.chars().take(self.max_summary_length).collect();
265                        return format!(
266                            "{truncated}\n\n\
267                             [output truncated from {} chars — summarizer returned tool calls]",
268                            output.len()
269                        );
270                    }
271                };
272                format!(
273                    "{summary}\n\n\
274                     [Full output ({} chars) cached. \
275                     Call __get_full_output with {{\"tool_name\": \"{tool_name}\"}} to retrieve it.]",
276                    output.len()
277                )
278            }
279            Err(e) => {
280                // AC #6: fallback to truncation on LLM failure
281                tracing::warn!(
282                    "ProgressiveTransformer: LLM summarization failed for '{tool_name}': {e}. \
283                     Falling back to truncation."
284                );
285                let truncated: String = output.chars().take(self.max_summary_length).collect();
286                format!(
287                    "{truncated}\n\n\
288                     [output truncated from {} chars — LLM summarization failed]",
289                    output.len()
290                )
291            }
292        }
293    }
294}
295
296// ===========================================================================
297// FullOutputRetriever — virtual tool
298// ===========================================================================
299
300/// Virtual tool that retrieves the full output cached by [`ProgressiveTransformer`].
301///
302/// The LLM calls this tool as `__get_full_output` with `{"tool_name": "..."}`.
303///
304/// Obtain via [`ProgressiveTransformer::retriever_tool()`].
305pub struct FullOutputRetriever {
306    cache: Arc<RwLock<HashMap<String, String>>>,
307}
308
309impl FullOutputRetriever {
310    /// Retrieve cached full output for a tool name.
311    ///
312    /// Returns the cached output or an error message if not found.
313    #[must_use]
314    pub fn retrieve(&self, tool_name: &str) -> String {
315        let cache = self
316            .cache
317            .read()
318            .expect("FullOutputRetriever cache lock poisoned");
319        match cache.get(tool_name) {
320            Some(output) => output.clone(),
321            None => format!(
322                "[No cached output found for tool '{tool_name}'. \
323                 The output may have expired or the tool name is incorrect.]"
324            ),
325        }
326    }
327
328    /// Check if a full output exists in cache for the given tool name.
329    #[must_use]
330    pub fn has_cached(&self, tool_name: &str) -> bool {
331        let cache = self
332            .cache
333            .read()
334            .expect("FullOutputRetriever cache lock poisoned");
335        cache.contains_key(tool_name)
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use crate::types::model_info::ModelTier;
343
344    fn state_with_utilization(util: f64) -> AgentState {
345        let window = 1000;
346        let mut state = AgentState::new(ModelTier::Medium, window);
347        state.total_context_tokens = (util * window as f64) as usize;
348        state
349    }
350
351    // ── BudgetAwareTruncator ─────────────────────────────────────────────
352
353    #[tokio::test]
354    async fn test_budget_truncator_under_limit() {
355        let t = BudgetAwareTruncator::new(100, 0.8);
356        let state = state_with_utilization(0.5);
357        let result = t.transform("short".to_string(), "test", &state).await;
358        assert_eq!(result, "short");
359    }
360
361    #[tokio::test]
362    async fn test_budget_truncator_over_limit() {
363        let t = BudgetAwareTruncator::new(10, 0.8);
364        let state = state_with_utilization(0.5);
365        let result = t.transform("a".repeat(100), "test", &state).await;
366        assert!(result.contains("[output truncated"));
367        assert!(result.starts_with("aaaaaaaaaa"));
368    }
369
370    #[tokio::test]
371    async fn test_budget_truncator_aggressive() {
372        let t = BudgetAwareTruncator::new(20, 0.8);
373        let state = state_with_utilization(0.9); // above threshold
374                                                 // Limit becomes 20/2 = 10
375        let result = t.transform("a".repeat(50), "test", &state).await;
376        assert!(result.contains("[output truncated"));
377        // Should truncate to 10 chars
378        let first_line: &str = result.lines().next().unwrap();
379        assert_eq!(first_line.len(), 10);
380    }
381
382    // ── JsonExtractor ────────────────────────────────────────────────────
383
384    #[tokio::test]
385    async fn test_json_extractor_object() {
386        let t = JsonExtractor;
387        let state = state_with_utilization(0.0);
388        let result = t
389            .transform(
390                "Here is the result: {\"key\": \"value\"} done.".to_string(),
391                "test",
392                &state,
393            )
394            .await;
395        assert_eq!(result, "{\"key\": \"value\"}");
396    }
397
398    #[tokio::test]
399    async fn test_json_extractor_array() {
400        let t = JsonExtractor;
401        let state = state_with_utilization(0.0);
402        let result = t
403            .transform("Output: [1, 2, 3] end".to_string(), "test", &state)
404            .await;
405        assert_eq!(result, "[1, 2, 3]");
406    }
407
408    #[tokio::test]
409    async fn test_json_extractor_no_json() {
410        let t = JsonExtractor;
411        let state = state_with_utilization(0.0);
412        let result = t.transform("plain text".to_string(), "test", &state).await;
413        assert_eq!(result, "plain text");
414    }
415
416    // ── TransformerChain ─────────────────────────────────────────────────
417
418    #[tokio::test]
419    async fn test_transformer_chain() {
420        let chain = TransformerChain::new(vec![
421            Box::new(JsonExtractor),
422            Box::new(BudgetAwareTruncator::new(5, 0.8)),
423        ]);
424        let state = state_with_utilization(0.5);
425        let result = chain
426            .transform(
427                "Result: {\"key\": \"long_value_here\"}".to_string(),
428                "test",
429                &state,
430            )
431            .await;
432        // First extracts JSON, then truncates to 5 chars
433        assert!(result.contains("[output truncated"));
434    }
435
436    // ── ProgressiveTransformer ───────────────────────────────────────────
437
438    use crate::types::completion::{CompletionResponse, ResponseContent, Usage};
439    use crate::types::model_info::ModelInfo;
440    use crate::types::stream::{CompletionStream, StreamEvent};
441
442    struct MockProvider {
443        info: ModelInfo,
444        response: String,
445        should_fail: bool,
446    }
447
448    impl MockProvider {
449        fn ok(response: impl Into<String>) -> Self {
450            Self {
451                info: ModelInfo {
452                    name: "mock-model".to_string(),
453                    tier: ModelTier::Medium,
454                    context_window: 8_192,
455                    supports_tools: true,
456                    supports_vision: false,
457                    supports_structured: false,
458                },
459                response: response.into(),
460                should_fail: false,
461            }
462        }
463        fn failing() -> Self {
464            Self {
465                info: ModelInfo {
466                    name: "mock-model".to_string(),
467                    tier: ModelTier::Medium,
468                    context_window: 8_192,
469                    supports_tools: true,
470                    supports_vision: false,
471                    supports_structured: false,
472                },
473                response: String::new(),
474                should_fail: true,
475            }
476        }
477    }
478
479    #[async_trait]
480    impl Provider for MockProvider {
481        async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
482            if self.should_fail {
483                return Err(crate::error::Error::Provider {
484                    message: "mock failure".into(),
485                    status_code: None,
486                });
487            }
488            Ok(CompletionResponse {
489                content: ResponseContent::Text(self.response.clone()),
490                usage: Usage {
491                    prompt_tokens: 10,
492                    completion_tokens: 5,
493                    total_tokens: 15,
494                },
495            })
496        }
497
498        async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
499            use tokio_stream;
500            Ok(Box::pin(tokio_stream::once(Ok(StreamEvent::Done))))
501        }
502
503        fn model_info(&self) -> &crate::types::model_info::ModelInfo {
504            &self.info
505        }
506    }
507
508    #[tokio::test]
509    async fn test_progressive_short_output_passthrough() {
510        // AC #9: short output → passed through without LLM call
511        let provider = Arc::new(MockProvider::failing()); // would fail if called
512        let transformer = ProgressiveTransformer::new(provider, 500);
513        let state = state_with_utilization(0.0);
514
515        let short = "short output".to_string();
516        let result = transformer
517            .transform(short.clone(), "my_tool", &state)
518            .await;
519        assert_eq!(result, short); // unchanged
520    }
521
522    #[tokio::test]
523    async fn test_progressive_large_output_summarized() {
524        // AC #8: large output → summary returned + cache populated
525        let provider = Arc::new(MockProvider::ok("This is the summary."));
526        let transformer = ProgressiveTransformer::new(provider, 50);
527        let state = state_with_utilization(0.0);
528
529        let large_output = "x".repeat(500);
530        let result = transformer
531            .transform(large_output.clone(), "search_tool", &state)
532            .await;
533
534        assert!(result.contains("This is the summary."));
535        assert!(result.contains("__get_full_output"));
536        assert!(result.contains("search_tool"));
537
538        // Cache should contain full output
539        let retriever = transformer.retriever_tool();
540        assert!(retriever.has_cached("search_tool"));
541        assert_eq!(retriever.retrieve("search_tool"), large_output);
542    }
543
544    #[tokio::test]
545    async fn test_progressive_llm_failure_fallback() {
546        // AC #10: LLM failure → graceful truncation fallback
547        let provider = Arc::new(MockProvider::failing());
548        let transformer = ProgressiveTransformer::new(provider, 20);
549        let state = state_with_utilization(0.0);
550
551        let large_output = "a".repeat(200);
552        let result = transformer.transform(large_output, "tool_x", &state).await;
553
554        // Starts with first 20 chars
555        assert!(result.starts_with("aaaaaaaaaaaaaaaaaaaa"));
556        assert!(result.contains("LLM summarization failed"));
557    }
558
559    #[tokio::test]
560    async fn test_full_output_retriever_not_found() {
561        // AC #8: retriever returns error message when cache is empty
562        let transformer = ProgressiveTransformer::new(Arc::new(MockProvider::ok("x")), 50);
563        let retriever = transformer.retriever_tool();
564        let result = retriever.retrieve("nonexistent_tool");
565        assert!(result.contains("No cached output found"));
566    }
567
568    #[tokio::test]
569    async fn test_progressive_custom_prompt() {
570        let provider = Arc::new(MockProvider::ok("custom summary"));
571        let transformer =
572            ProgressiveTransformer::new(provider, 10).with_summary_prompt("Brief: {output}");
573        let state = state_with_utilization(0.0);
574
575        let result = transformer.transform("x".repeat(100), "t", &state).await;
576        assert!(result.contains("custom summary"));
577    }
578}