Skip to main content

traitclaw_core/
transformers.rs

1//! Built-in [`OutputTransformer`] implementations for common use cases.
2//!
3//! These transformers can be used directly or composed for more complex processing.
4//!
5//! # Available Transformers
6//!
7//! - [`BudgetAwareTruncator`] — truncate by char count based on context utilization
8//! - [`JsonExtractor`] — extract JSON from verbose output
9//! - [`TransformerChain`] — chain multiple transformers
10//! - [`ProgressiveTransformer`] — summarize large outputs; full output on demand
11
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14
15use async_trait::async_trait;
16
17use crate::traits::output_transformer::OutputTransformer;
18use crate::traits::provider::Provider;
19use crate::types::agent_state::AgentState;
20use crate::types::completion::{CompletionRequest, ResponseContent};
21use crate::types::message::Message;
22
23// ===========================================================================
24// BudgetAwareTruncator
25// ===========================================================================
26
27/// Truncates output to a maximum character count, respecting context utilization.
28///
29/// When context utilization exceeds the `aggressive_threshold`, the limit is
30/// halved to preserve context budget.
31///
32/// # Example
33///
34/// ```rust
35/// use traitclaw_core::transformers::BudgetAwareTruncator;
36///
37/// let t = BudgetAwareTruncator::new(1000, 0.8);
38/// ```
39pub struct BudgetAwareTruncator {
40    max_chars: usize,
41    aggressive_threshold: f32,
42}
43
44impl BudgetAwareTruncator {
45    /// Create a new truncator.
46    ///
47    /// - `max_chars`: Maximum output length in characters.
48    /// - `aggressive_threshold`: Context utilization (0.0–1.0) above which
49    ///   truncation becomes more aggressive (halved limit).
50    #[must_use]
51    pub fn new(max_chars: usize, aggressive_threshold: f32) -> Self {
52        Self {
53            max_chars,
54            aggressive_threshold: aggressive_threshold.clamp(0.0, 1.0),
55        }
56    }
57}
58
59impl Default for BudgetAwareTruncator {
60    fn default() -> Self {
61        Self::new(10_000, 0.8)
62    }
63}
64
65#[async_trait]
66impl OutputTransformer for BudgetAwareTruncator {
67    async fn transform(&self, output: String, _tool_name: &str, state: &AgentState) -> String {
68        let limit = if state.context_utilization() > self.aggressive_threshold {
69            self.max_chars / 2
70        } else {
71            self.max_chars
72        };
73
74        if output.len() <= limit {
75            return output;
76        }
77
78        // Truncate at char boundary
79        let truncated: String = output.chars().take(limit).collect();
80        format!(
81            "{truncated}\n\n[output truncated from {} to {limit} chars]",
82            output.len()
83        )
84    }
85}
86
87// ===========================================================================
88// JsonExtractor
89// ===========================================================================
90
91/// Extracts JSON from tool output, discarding surrounding text.
92///
93/// Useful for tools that embed JSON in verbose output.
94pub struct JsonExtractor;
95
96#[async_trait]
97impl OutputTransformer for JsonExtractor {
98    async fn transform(&self, output: String, _tool_name: &str, _state: &AgentState) -> String {
99        // Try to find JSON object or array in the output
100        if let Some(start) = output.find('{') {
101            if let Some(end) = output.rfind('}') {
102                if end >= start {
103                    return output[start..=end].to_string();
104                }
105            }
106        }
107        if let Some(start) = output.find('[') {
108            if let Some(end) = output.rfind(']') {
109                if end >= start {
110                    return output[start..=end].to_string();
111                }
112            }
113        }
114        // No JSON found, return as-is
115        output
116    }
117}
118
119// ===========================================================================
120// TransformerChain
121// ===========================================================================
122
123/// Pipes output through multiple transformers in order.
124pub struct TransformerChain {
125    transformers: Vec<Box<dyn OutputTransformer>>,
126}
127
128impl TransformerChain {
129    /// Create a chain from a list of transformers.
130    #[must_use]
131    pub fn new(transformers: Vec<Box<dyn OutputTransformer>>) -> Self {
132        Self { transformers }
133    }
134}
135
136#[async_trait]
137impl OutputTransformer for TransformerChain {
138    async fn transform(&self, mut output: String, tool_name: &str, state: &AgentState) -> String {
139        for t in &self.transformers {
140            output = t.transform(output, tool_name, state).await;
141        }
142        output
143    }
144}
145
146// ===========================================================================
147// ProgressiveTransformer
148// ===========================================================================
149
150/// Default summarization prompt template.
151const DEFAULT_SUMMARY_PROMPT: &str =
152    "Summarize the following tool output concisely, preserving all key data points and values. \
153     Be brief but complete:\n\n{output}";
154
155/// A two-phase output transformer that returns an **LLM-generated summary** first,
156/// with the **full output** cached and available on demand via the
157/// `__get_full_output` virtual tool.
158///
159/// # Workflow
160///
161/// 1. Output arrives from a tool call.
162/// 2. If `output.len() <= max_summary_length` → returned unchanged (no LLM call).
163/// 3. If larger → LLM is called to summarize → summary returned + note appended.
164/// 4. Full output cached internally keyed by `tool_name`.
165/// 5. Agent can call `__get_full_output` → [`FullOutputRetriever`] serves it.
166///
167/// # Example
168///
169/// ```rust,no_run
170/// use traitclaw_core::transformers::ProgressiveTransformer;
171/// use std::sync::Arc;
172///
173/// // let transformer = ProgressiveTransformer::new(provider.clone(), 500)
174/// //     .with_summary_prompt("Give a one-sentence summary: {output}");
175/// ```
176pub struct ProgressiveTransformer {
177    provider: Arc<dyn Provider>,
178    max_summary_length: usize,
179    summary_prompt: String,
180    /// Cache: tool_name → full output
181    cache: Arc<RwLock<HashMap<String, String>>>,
182}
183
184impl ProgressiveTransformer {
185    /// Create a new progressive transformer.
186    ///
187    /// - `provider`: LLM used to generate summaries.
188    /// - `max_summary_length`: Outputs shorter than this are passed through unchanged.
189    #[must_use]
190    pub fn new(provider: Arc<dyn Provider>, max_summary_length: usize) -> Self {
191        Self {
192            provider,
193            max_summary_length,
194            summary_prompt: DEFAULT_SUMMARY_PROMPT.to_string(),
195            cache: Arc::new(RwLock::new(HashMap::new())),
196        }
197    }
198
199    /// Override the summarization prompt template.
200    ///
201    /// Use `{output}` as the placeholder for the tool output.
202    #[must_use]
203    pub fn with_summary_prompt(mut self, prompt: impl Into<String>) -> Self {
204        self.summary_prompt = prompt.into();
205        self
206    }
207
208    /// Build a [`FullOutputRetriever`] that reads from this transformer's cache.
209    ///
210    /// Register this tool with the agent so the LLM can call `__get_full_output`.
211    #[must_use]
212    pub fn retriever_tool(&self) -> FullOutputRetriever {
213        FullOutputRetriever {
214            cache: Arc::clone(&self.cache),
215        }
216    }
217
218    /// Store full output in cache keyed by tool name.
219    fn cache_output(&self, tool_name: &str, output: &str) {
220        let mut cache = self
221            .cache
222            .write()
223            .expect("ProgressiveTransformer cache lock poisoned");
224        cache.insert(tool_name.to_string(), output.to_string());
225    }
226
227    /// Build the prompt with the output injected.
228    fn build_prompt(&self, output: &str) -> String {
229        self.summary_prompt.replace("{output}", output)
230    }
231}
232
233#[async_trait]
234impl OutputTransformer for ProgressiveTransformer {
235    async fn transform(&self, output: String, tool_name: &str, _state: &AgentState) -> String {
236        // AC #7: short output → pass through unchanged
237        if output.len() <= self.max_summary_length {
238            return output;
239        }
240
241        // Cache the full output for later retrieval
242        self.cache_output(tool_name, &output);
243
244        // AC #2: call LLM to summarize
245        let prompt = self.build_prompt(&output);
246        let request = CompletionRequest {
247            model: self.provider.model_info().name.clone(),
248            messages: vec![Message::user(prompt)],
249            tools: vec![],
250            max_tokens: Some(500),
251            temperature: Some(0.3),
252            response_format: None,
253            stream: false,
254        };
255
256        match self.provider.complete(request).await {
257            Ok(response) => {
258                // AC #2: return summary + footer note
259                let summary = match response.content {
260                    ResponseContent::Text(t) => t,
261                    ResponseContent::ToolCalls(_) => {
262                        // Unexpected tool calls from summarizer — fallback
263                        let truncated: String =
264                            output.chars().take(self.max_summary_length).collect();
265                        return format!(
266                            "{truncated}\n\n\
267                             [output truncated from {} chars — summarizer returned tool calls]",
268                            output.len()
269                        );
270                    }
271                };
272                format!(
273                    "{summary}\n\n\
274                     [Full output ({} chars) cached. \
275                     Call __get_full_output with {{\"tool_name\": \"{tool_name}\"}} to retrieve it.]",
276                    output.len()
277                )
278            }
279            Err(e) => {
280                // AC #6: fallback to truncation on LLM failure
281                tracing::warn!(
282                    "ProgressiveTransformer: LLM summarization failed for '{tool_name}': {e}. \
283                     Falling back to truncation."
284                );
285                let truncated: String = output.chars().take(self.max_summary_length).collect();
286                format!(
287                    "{truncated}\n\n\
288                     [output truncated from {} chars — LLM summarization failed]",
289                    output.len()
290                )
291            }
292        }
293    }
294}
295
296// ===========================================================================
297// FullOutputRetriever — virtual tool
298// ===========================================================================
299
300/// Virtual tool that retrieves the full output cached by [`ProgressiveTransformer`].
301///
302/// The LLM calls this tool as `__get_full_output` with `{"tool_name": "..."}`.
303///
304/// Obtain via [`ProgressiveTransformer::retriever_tool()`].
305pub struct FullOutputRetriever {
306    cache: Arc<RwLock<HashMap<String, String>>>,
307}
308
309impl FullOutputRetriever {
310    /// Retrieve cached full output for a tool name.
311    ///
312    /// Returns the cached output or an error message if not found.
313    #[must_use]
314    pub fn retrieve(&self, tool_name: &str) -> String {
315        let cache = self
316            .cache
317            .read()
318            .expect("FullOutputRetriever cache lock poisoned");
319        match cache.get(tool_name) {
320            Some(output) => output.clone(),
321            None => format!(
322                "[No cached output found for tool '{tool_name}'. \
323                 The output may have expired or the tool name is incorrect.]"
324            ),
325        }
326    }
327
328    /// Check if a full output exists in cache for the given tool name.
329    #[must_use]
330    pub fn has_cached(&self, tool_name: &str) -> bool {
331        let cache = self
332            .cache
333            .read()
334            .expect("FullOutputRetriever cache lock poisoned");
335        cache.contains_key(tool_name)
336    }
337}