swink_agent/
async_context_transformer.rs1use std::future::Future;
9use std::pin::Pin;
10
11use crate::context::CompactionReport;
12use crate::types::AgentMessage;
13
14pub type AsyncTransformFuture<'a> =
16 Pin<Box<dyn Future<Output = Option<CompactionReport>> + Send + 'a>>;
17
18pub trait AsyncContextTransformer: Send + Sync {
31 fn transform<'a>(
38 &'a self,
39 messages: &'a mut Vec<AgentMessage>,
40 overflow: bool,
41 ) -> AsyncTransformFuture<'a>;
42}
43
44#[cfg(test)]
45mod tests {
46 use super::*;
47 use crate::types::{ContentBlock, LlmMessage, UserMessage};
48
49 fn text_message(text: &str) -> AgentMessage {
50 AgentMessage::Llm(LlmMessage::User(UserMessage {
51 content: vec![ContentBlock::Text {
52 text: text.to_owned(),
53 }],
54 timestamp: 0,
55 cache_hint: None,
56 }))
57 }
58
59 #[tokio::test]
60 async fn async_transformer_struct_impl() {
61 struct OverflowTruncator;
62
63 impl AsyncContextTransformer for OverflowTruncator {
64 fn transform<'a>(
65 &'a self,
66 messages: &'a mut Vec<AgentMessage>,
67 overflow: bool,
68 ) -> AsyncTransformFuture<'a> {
69 Box::pin(async move {
70 if overflow && messages.len() > 2 {
71 let before = messages.len();
72 messages.truncate(2);
73 Some(CompactionReport {
74 dropped_count: before - 2,
75 tokens_before: 0,
76 tokens_after: 0,
77 overflow: true,
78 dropped_messages: Vec::new(),
79 })
80 } else {
81 None
82 }
83 })
84 }
85 }
86
87 let transformer = OverflowTruncator;
88
89 let mut messages = vec![text_message("a"), text_message("b"), text_message("c")];
91 let report = transformer.transform(&mut messages, false).await;
92 assert!(report.is_none());
93 assert_eq!(messages.len(), 3);
94
95 let report = transformer.transform(&mut messages, true).await;
97 assert!(report.is_some());
98 let report = report.unwrap();
99 assert_eq!(report.dropped_count, 1);
100 assert!(report.overflow);
101 assert_eq!(messages.len(), 2);
102 }
103
104 #[tokio::test]
105 async fn async_transformer_trait_object() {
106 struct SummaryInjector;
107
108 impl AsyncContextTransformer for SummaryInjector {
109 fn transform<'a>(
110 &'a self,
111 messages: &'a mut Vec<AgentMessage>,
112 _overflow: bool,
113 ) -> AsyncTransformFuture<'a> {
114 Box::pin(async move {
115 messages.insert(0, text_message("[summary of prior context]"));
117 None })
119 }
120 }
121
122 let transformer: Box<dyn AsyncContextTransformer> = Box::new(SummaryInjector);
123 let mut messages = vec![text_message("hello")];
124 transformer.transform(&mut messages, false).await;
125 assert_eq!(messages.len(), 2);
126 if let AgentMessage::Llm(LlmMessage::User(u)) = &messages[0] {
127 assert_eq!(
128 ContentBlock::extract_text(&u.content),
129 "[summary of prior context]"
130 );
131 } else {
132 panic!("expected user message");
133 }
134 }
135}