1use std::sync::{
11 atomic::{AtomicUsize, Ordering},
12 Arc, Mutex,
13};
14
15use anyhow::Result;
16use async_trait::async_trait;
17use swiftide_core::chat_completion::ChatMessage;
18use swiftide_core::{AgentContext, Command, CommandError, CommandOutput, ToolExecutor};
19
20use crate::tools::local_executor::LocalExecutor;
21
22#[derive(Clone)]
24pub struct DefaultContext {
25 completion_history: Arc<Mutex<Vec<ChatMessage>>>,
26 completions_ptr: Arc<AtomicUsize>,
28
29 current_completions_ptr: Arc<AtomicUsize>,
32
33 tool_executor: Arc<dyn ToolExecutor>,
35
36 stop_on_assistant: bool,
38}
39
40impl Default for DefaultContext {
41 fn default() -> Self {
42 DefaultContext {
43 completion_history: Arc::new(Mutex::new(Vec::new())),
44 completions_ptr: Arc::new(AtomicUsize::new(0)),
45 current_completions_ptr: Arc::new(AtomicUsize::new(0)),
46 tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
47 stop_on_assistant: true,
48 }
49 }
50}
51
52impl std::fmt::Debug for DefaultContext {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("DefaultContext")
55 .field("completion_history", &self.completion_history)
56 .field("completions_ptr", &self.completions_ptr)
57 .field("current_completions_ptr", &self.current_completions_ptr)
58 .field("tool_executor", &"Arc<dyn ToolExecutor>")
59 .field("stop_on_assistant", &self.stop_on_assistant)
60 .finish()
61 }
62}
63
64impl DefaultContext {
65 pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
67 DefaultContext {
68 tool_executor: executor.into(),
69 ..Default::default()
70 }
71 }
72
73 pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
76 self.stop_on_assistant = stop;
77 self
78 }
79}
80#[async_trait]
81impl AgentContext for DefaultContext {
82 async fn next_completion(&self) -> Option<Vec<ChatMessage>> {
84 let history = self.completion_history.lock().unwrap();
85
86 let current = self.completions_ptr.load(Ordering::SeqCst);
87
88 if history[current..].is_empty()
89 || (self.stop_on_assistant
90 && matches!(history.last(), Some(ChatMessage::Assistant(_, _))))
91 {
92 None
93 } else {
94 let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
95 self.current_completions_ptr
96 .store(previous, Ordering::SeqCst);
97
98 Some(filter_messages_since_summary(history.clone()))
99 }
100 }
101
102 async fn current_new_messages(&self) -> Vec<ChatMessage> {
104 let current = self.current_completions_ptr.load(Ordering::SeqCst);
105 let end = self.completions_ptr.load(Ordering::SeqCst);
106
107 let history = self.completion_history.lock().unwrap();
108
109 filter_messages_since_summary(history[current..end].to_vec())
110 }
111
112 async fn history(&self) -> Vec<ChatMessage> {
114 self.completion_history.lock().unwrap().clone()
115 }
116
117 async fn add_messages(&self, messages: Vec<ChatMessage>) {
119 for item in messages {
120 self.add_message(item).await;
121 }
122 }
123
124 async fn add_message(&self, item: ChatMessage) {
126 self.completion_history.lock().unwrap().push(item);
127 }
128
129 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
131 self.tool_executor.exec_cmd(cmd).await
132 }
133
134 async fn redrive(&self) {
139 let mut history = self.completion_history.lock().unwrap();
140 let previous = self.current_completions_ptr.load(Ordering::SeqCst);
141 let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst);
142
143 history.truncate(redrive_ptr);
145 }
146}
147
148fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
149 let mut summary_found = false;
150 let mut messages = messages
151 .into_iter()
152 .rev()
153 .filter(|m| {
154 if summary_found {
155 return matches!(m, ChatMessage::System(_));
156 }
157 if let ChatMessage::Summary(_) = m {
158 summary_found = true;
159 }
160 true
161 })
162 .collect::<Vec<_>>();
163
164 messages.reverse();
165
166 messages
167}
168
169#[cfg(test)]
170mod tests {
171 use crate::{assistant, tool_output, user};
172
173 use super::*;
174 use swiftide_core::chat_completion::{ChatMessage, ToolCall, ToolOutput};
175
176 #[tokio::test]
177 async fn test_iteration_tracking() {
178 let mut context = DefaultContext::default();
179
180 context
182 .add_messages(vec![
183 ChatMessage::System("You are awesome".into()),
184 ChatMessage::User("Hello".into()),
185 ])
186 .await;
187
188 let messages = context.next_completion().await.unwrap();
189 assert_eq!(messages.len(), 2);
190 assert!(context.next_completion().await.is_none());
191
192 context
193 .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
194 .await;
195
196 let messages = context.next_completion().await.unwrap();
197 assert_eq!(messages.len(), 4);
198 assert!(context.next_completion().await.is_none());
199
200 context.add_messages(vec![assistant!("I am fine")]).await;
202
203 assert!(context.next_completion().await.is_none());
204
205 context.with_stop_on_assistant(false);
206
207 assert!(context.next_completion().await.is_some());
208 }
209
210 #[tokio::test]
211 async fn test_should_complete_after_tool_call() {
212 let context = DefaultContext::default();
213 context
215 .add_messages(vec![
216 ChatMessage::System("You are awesome".into()),
217 ChatMessage::User("Hello".into()),
218 ])
219 .await;
220 let messages = context.next_completion().await.unwrap();
221 assert_eq!(messages.len(), 2);
222 assert_eq!(context.current_new_messages().await.len(), 2);
223 assert!(context.next_completion().await.is_none());
224
225 context
226 .add_messages(vec![
227 assistant!("Hey?", ["test"]),
228 tool_output!("test", "Hoi"),
229 ])
230 .await;
231
232 let messages = context.next_completion().await.unwrap();
233 assert_eq!(context.current_new_messages().await.len(), 2);
234 assert_eq!(messages.len(), 4);
235
236 assert!(context.next_completion().await.is_none());
237 }
238
239 #[tokio::test]
240 async fn test_filters_messages_before_summary() {
241 let messages = vec![
242 ChatMessage::System("System message".into()),
243 ChatMessage::User("Hello".into()),
244 ChatMessage::Assistant(Some("Hello there".into()), None),
245 ChatMessage::Summary("Summary message".into()),
246 ChatMessage::User("This should be ignored".into()),
247 ];
248 let context = DefaultContext::default();
249 context.add_messages(messages).await;
251
252 let new_messages = context.next_completion().await.unwrap();
253
254 assert_eq!(new_messages.len(), 3);
255 assert!(matches!(new_messages[0], ChatMessage::System(_)));
256 assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
257 assert!(matches!(new_messages[2], ChatMessage::User(_)));
258
259 let current_new_messages = context.current_new_messages().await;
260 assert_eq!(current_new_messages.len(), 3);
261 assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
262 assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
263 assert!(matches!(current_new_messages[2], ChatMessage::User(_)));
264
265 assert!(context.next_completion().await.is_none());
266 }
267
268 #[tokio::test]
269 async fn test_filters_messages_before_summary_with_assistant_last() {
270 let messages = vec![
271 ChatMessage::System("System message".into()),
272 ChatMessage::User("Hello".into()),
273 ChatMessage::Assistant(Some("Hello there".into()), None),
274 ];
275 let mut context = DefaultContext::default();
276 context.with_stop_on_assistant(false);
277 context.add_messages(messages).await;
279
280 let new_messages = context.next_completion().await.unwrap();
281
282 assert_eq!(new_messages.len(), 3);
283 assert!(matches!(new_messages[0], ChatMessage::System(_)));
284 assert!(matches!(new_messages[1], ChatMessage::User(_)));
285 assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));
286
287 context
288 .add_message(ChatMessage::Summary("Summary message 1".into()))
289 .await;
290
291 let new_messages = context.next_completion().await.unwrap();
292 dbg!(&new_messages);
293 assert_eq!(new_messages.len(), 2);
294 assert!(matches!(new_messages[0], ChatMessage::System(_)));
295 assert_eq!(
296 new_messages[1],
297 ChatMessage::Summary("Summary message 1".into())
298 );
299
300 assert!(context.next_completion().await.is_none());
301
302 let messages = vec![
303 ChatMessage::User("Hello again".into()),
304 ChatMessage::Assistant(Some("Hello there again".into()), None),
305 ];
306
307 context.add_messages(messages).await;
308
309 let new_messages = context.next_completion().await.unwrap();
310
311 assert!(matches!(new_messages[0], ChatMessage::System(_)));
312 assert_eq!(
313 new_messages[1],
314 ChatMessage::Summary("Summary message 1".into())
315 );
316 assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
317 assert_eq!(
318 new_messages[3],
319 ChatMessage::Assistant(Some("Hello there again".to_string()), None)
320 );
321
322 context
323 .add_message(ChatMessage::Summary("Summary message 2".into()))
324 .await;
325
326 let new_messages = context.next_completion().await.unwrap();
327 assert_eq!(new_messages.len(), 2);
328
329 assert!(matches!(new_messages[0], ChatMessage::System(_)));
330 assert_eq!(
331 new_messages[1],
332 ChatMessage::Summary("Summary message 2".into())
333 );
334 }
335
336 #[tokio::test]
337 async fn test_redrive() {
338 let context = DefaultContext::default();
339
340 context
342 .add_messages(vec![
343 ChatMessage::System("System message".into()),
344 ChatMessage::User("Hello".into()),
345 ])
346 .await;
347
348 let messages = context.next_completion().await.unwrap();
349 assert_eq!(messages.len(), 2);
350 assert!(context.next_completion().await.is_none());
351 context.redrive().await;
352
353 let messages = context.next_completion().await.unwrap();
354 assert_eq!(messages.len(), 2);
355
356 context
357 .add_messages(vec![ChatMessage::User("Hey?".into())])
358 .await;
359
360 let messages = context.next_completion().await.unwrap();
361 assert_eq!(messages.len(), 3);
362 assert!(context.next_completion().await.is_none());
363 context.redrive().await;
364
365 context
367 .add_messages(vec![ChatMessage::User("How are you?".into())])
368 .await;
369
370 let messages = context.next_completion().await.unwrap();
371 assert_eq!(messages.len(), 4);
372 assert!(context.next_completion().await.is_none());
373
374 dbg!(&context);
376 context.redrive().await;
377 dbg!(&context);
378
379 let messages = context.next_completion().await.unwrap();
381 assert_eq!(messages.len(), 4);
382 assert!(context.next_completion().await.is_none());
383
384 context
386 .add_messages(vec![
387 ChatMessage::User("How are you really?".into()),
388 ChatMessage::User("How are you really?".into()),
389 ])
390 .await;
391
392 context.redrive().await;
394
395 let messages = context.next_completion().await.unwrap();
397 assert_eq!(messages.len(), 4);
398 assert!(context.next_completion().await.is_none());
399
400 context.redrive().await;
402 let messages = context.next_completion().await.unwrap();
403 assert_eq!(messages.len(), 4);
404 assert!(context.next_completion().await.is_none());
405 }
406}