traitclaw_core/transformers.rs
1//! Built-in [`OutputTransformer`] implementations for common use cases.
2//!
3//! These transformers can be used directly or composed for more complex processing.
4//!
5//! # Available Transformers
6//!
7//! - [`BudgetAwareTruncator`] — truncate by char count based on context utilization
8//! - [`JsonExtractor`] — extract JSON from verbose output
9//! - [`TransformerChain`] — chain multiple transformers
10//! - [`ProgressiveTransformer`] — summarize large outputs; full output on demand
11
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14
15use async_trait::async_trait;
16
17use crate::traits::output_transformer::OutputTransformer;
18use crate::traits::provider::Provider;
19use crate::types::agent_state::AgentState;
20use crate::types::completion::{CompletionRequest, ResponseContent};
21use crate::types::message::Message;
22
23// ===========================================================================
24// BudgetAwareTruncator
25// ===========================================================================
26
27/// Truncates output to a maximum character count, respecting context utilization.
28///
29/// When context utilization exceeds the `aggressive_threshold`, the limit is
30/// halved to preserve context budget.
31///
32/// # Example
33///
34/// ```rust
35/// use traitclaw_core::transformers::BudgetAwareTruncator;
36///
37/// let t = BudgetAwareTruncator::new(1000, 0.8);
38/// ```
39pub struct BudgetAwareTruncator {
40 max_chars: usize,
41 aggressive_threshold: f32,
42}
43
44impl BudgetAwareTruncator {
45 /// Create a new truncator.
46 ///
47 /// - `max_chars`: Maximum output length in characters.
48 /// - `aggressive_threshold`: Context utilization (0.0–1.0) above which
49 /// truncation becomes more aggressive (halved limit).
50 #[must_use]
51 pub fn new(max_chars: usize, aggressive_threshold: f32) -> Self {
52 Self {
53 max_chars,
54 aggressive_threshold: aggressive_threshold.clamp(0.0, 1.0),
55 }
56 }
57}
58
59impl Default for BudgetAwareTruncator {
60 fn default() -> Self {
61 Self::new(10_000, 0.8)
62 }
63}
64
65#[async_trait]
66impl OutputTransformer for BudgetAwareTruncator {
67 async fn transform(&self, output: String, _tool_name: &str, state: &AgentState) -> String {
68 let limit = if state.context_utilization() > self.aggressive_threshold {
69 self.max_chars / 2
70 } else {
71 self.max_chars
72 };
73
74 if output.len() <= limit {
75 return output;
76 }
77
78 // Truncate at char boundary
79 let truncated: String = output.chars().take(limit).collect();
80 format!(
81 "{truncated}\n\n[output truncated from {} to {limit} chars]",
82 output.len()
83 )
84 }
85}
86
87// ===========================================================================
88// JsonExtractor
89// ===========================================================================
90
91/// Extracts JSON from tool output, discarding surrounding text.
92///
93/// Useful for tools that embed JSON in verbose output.
94pub struct JsonExtractor;
95
96#[async_trait]
97impl OutputTransformer for JsonExtractor {
98 async fn transform(&self, output: String, _tool_name: &str, _state: &AgentState) -> String {
99 // Try to find JSON object or array in the output
100 if let Some(start) = output.find('{') {
101 if let Some(end) = output.rfind('}') {
102 if end >= start {
103 return output[start..=end].to_string();
104 }
105 }
106 }
107 if let Some(start) = output.find('[') {
108 if let Some(end) = output.rfind(']') {
109 if end >= start {
110 return output[start..=end].to_string();
111 }
112 }
113 }
114 // No JSON found, return as-is
115 output
116 }
117}
118
119// ===========================================================================
120// TransformerChain
121// ===========================================================================
122
123/// Pipes output through multiple transformers in order.
124pub struct TransformerChain {
125 transformers: Vec<Box<dyn OutputTransformer>>,
126}
127
128impl TransformerChain {
129 /// Create a chain from a list of transformers.
130 #[must_use]
131 pub fn new(transformers: Vec<Box<dyn OutputTransformer>>) -> Self {
132 Self { transformers }
133 }
134}
135
136#[async_trait]
137impl OutputTransformer for TransformerChain {
138 async fn transform(&self, mut output: String, tool_name: &str, state: &AgentState) -> String {
139 for t in &self.transformers {
140 output = t.transform(output, tool_name, state).await;
141 }
142 output
143 }
144}
145
146// ===========================================================================
147// ProgressiveTransformer
148// ===========================================================================
149
150/// Default summarization prompt template.
151const DEFAULT_SUMMARY_PROMPT: &str =
152 "Summarize the following tool output concisely, preserving all key data points and values. \
153 Be brief but complete:\n\n{output}";
154
155/// A two-phase output transformer that returns an **LLM-generated summary** first,
156/// with the **full output** cached and available on demand via the
157/// `__get_full_output` virtual tool.
158///
159/// # Workflow
160///
161/// 1. Output arrives from a tool call.
162/// 2. If `output.len() <= max_summary_length` → returned unchanged (no LLM call).
163/// 3. If larger → LLM is called to summarize → summary returned + note appended.
164/// 4. Full output cached internally keyed by `tool_name`.
165/// 5. Agent can call `__get_full_output` → [`FullOutputRetriever`] serves it.
166///
167/// # Example
168///
169/// ```rust,no_run
170/// use traitclaw_core::transformers::ProgressiveTransformer;
171/// use std::sync::Arc;
172///
173/// // let transformer = ProgressiveTransformer::new(provider.clone(), 500)
174/// // .with_summary_prompt("Give a one-sentence summary: {output}");
175/// ```
176pub struct ProgressiveTransformer {
177 provider: Arc<dyn Provider>,
178 max_summary_length: usize,
179 summary_prompt: String,
180 /// Cache: tool_name → full output
181 cache: Arc<RwLock<HashMap<String, String>>>,
182}
183
184impl ProgressiveTransformer {
185 /// Create a new progressive transformer.
186 ///
187 /// - `provider`: LLM used to generate summaries.
188 /// - `max_summary_length`: Outputs shorter than this are passed through unchanged.
189 #[must_use]
190 pub fn new(provider: Arc<dyn Provider>, max_summary_length: usize) -> Self {
191 Self {
192 provider,
193 max_summary_length,
194 summary_prompt: DEFAULT_SUMMARY_PROMPT.to_string(),
195 cache: Arc::new(RwLock::new(HashMap::new())),
196 }
197 }
198
199 /// Override the summarization prompt template.
200 ///
201 /// Use `{output}` as the placeholder for the tool output.
202 #[must_use]
203 pub fn with_summary_prompt(mut self, prompt: impl Into<String>) -> Self {
204 self.summary_prompt = prompt.into();
205 self
206 }
207
208 /// Build a [`FullOutputRetriever`] that reads from this transformer's cache.
209 ///
210 /// Register this tool with the agent so the LLM can call `__get_full_output`.
211 #[must_use]
212 pub fn retriever_tool(&self) -> FullOutputRetriever {
213 FullOutputRetriever {
214 cache: Arc::clone(&self.cache),
215 }
216 }
217
218 /// Store full output in cache keyed by tool name.
219 fn cache_output(&self, tool_name: &str, output: &str) {
220 let mut cache = self
221 .cache
222 .write()
223 .expect("ProgressiveTransformer cache lock poisoned");
224 cache.insert(tool_name.to_string(), output.to_string());
225 }
226
227 /// Build the prompt with the output injected.
228 fn build_prompt(&self, output: &str) -> String {
229 self.summary_prompt.replace("{output}", output)
230 }
231}
232
233#[async_trait]
234impl OutputTransformer for ProgressiveTransformer {
235 async fn transform(&self, output: String, tool_name: &str, _state: &AgentState) -> String {
236 // AC #7: short output → pass through unchanged
237 if output.len() <= self.max_summary_length {
238 return output;
239 }
240
241 // Cache the full output for later retrieval
242 self.cache_output(tool_name, &output);
243
244 // AC #2: call LLM to summarize
245 let prompt = self.build_prompt(&output);
246 let request = CompletionRequest {
247 model: self.provider.model_info().name.clone(),
248 messages: vec![Message::user(prompt)],
249 tools: vec![],
250 max_tokens: Some(500),
251 temperature: Some(0.3),
252 response_format: None,
253 stream: false,
254 };
255
256 match self.provider.complete(request).await {
257 Ok(response) => {
258 // AC #2: return summary + footer note
259 let summary = match response.content {
260 ResponseContent::Text(t) => t,
261 ResponseContent::ToolCalls(_) => {
262 // Unexpected tool calls from summarizer — fallback
263 let truncated: String =
264 output.chars().take(self.max_summary_length).collect();
265 return format!(
266 "{truncated}\n\n\
267 [output truncated from {} chars — summarizer returned tool calls]",
268 output.len()
269 );
270 }
271 };
272 format!(
273 "{summary}\n\n\
274 [Full output ({} chars) cached. \
275 Call __get_full_output with {{\"tool_name\": \"{tool_name}\"}} to retrieve it.]",
276 output.len()
277 )
278 }
279 Err(e) => {
280 // AC #6: fallback to truncation on LLM failure
281 tracing::warn!(
282 "ProgressiveTransformer: LLM summarization failed for '{tool_name}': {e}. \
283 Falling back to truncation."
284 );
285 let truncated: String = output.chars().take(self.max_summary_length).collect();
286 format!(
287 "{truncated}\n\n\
288 [output truncated from {} chars — LLM summarization failed]",
289 output.len()
290 )
291 }
292 }
293 }
294}
295
296// ===========================================================================
297// FullOutputRetriever — virtual tool
298// ===========================================================================
299
300/// Virtual tool that retrieves the full output cached by [`ProgressiveTransformer`].
301///
302/// The LLM calls this tool as `__get_full_output` with `{"tool_name": "..."}`.
303///
304/// Obtain via [`ProgressiveTransformer::retriever_tool()`].
305pub struct FullOutputRetriever {
306 cache: Arc<RwLock<HashMap<String, String>>>,
307}
308
309impl FullOutputRetriever {
310 /// Retrieve cached full output for a tool name.
311 ///
312 /// Returns the cached output or an error message if not found.
313 #[must_use]
314 pub fn retrieve(&self, tool_name: &str) -> String {
315 let cache = self
316 .cache
317 .read()
318 .expect("FullOutputRetriever cache lock poisoned");
319 match cache.get(tool_name) {
320 Some(output) => output.clone(),
321 None => format!(
322 "[No cached output found for tool '{tool_name}'. \
323 The output may have expired or the tool name is incorrect.]"
324 ),
325 }
326 }
327
328 /// Check if a full output exists in cache for the given tool name.
329 #[must_use]
330 pub fn has_cached(&self, tool_name: &str) -> bool {
331 let cache = self
332 .cache
333 .read()
334 .expect("FullOutputRetriever cache lock poisoned");
335 cache.contains_key(tool_name)
336 }
337}