1use std::sync::{
11 Arc, Mutex,
12 atomic::{AtomicUsize, Ordering},
13};
14
15use anyhow::Result;
16use async_trait::async_trait;
17use swiftide_core::chat_completion::ChatMessage;
18use swiftide_core::{AgentContext, Command, CommandError, CommandOutput, ToolExecutor};
19
20use crate::tools::local_executor::LocalExecutor;
21
22#[derive(Clone)]
24pub struct DefaultContext {
25 completion_history: Arc<Mutex<Vec<ChatMessage>>>,
26 completions_ptr: Arc<AtomicUsize>,
28
29 current_completions_ptr: Arc<AtomicUsize>,
32
33 tool_executor: Arc<dyn ToolExecutor>,
35
36 stop_on_assistant: bool,
38}
39
40impl Default for DefaultContext {
41 fn default() -> Self {
42 DefaultContext {
43 completion_history: Arc::new(Mutex::new(Vec::new())),
44 completions_ptr: Arc::new(AtomicUsize::new(0)),
45 current_completions_ptr: Arc::new(AtomicUsize::new(0)),
46 tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
47 stop_on_assistant: true,
48 }
49 }
50}
51
52impl std::fmt::Debug for DefaultContext {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("DefaultContext")
55 .field("completion_history", &self.completion_history)
56 .field("completions_ptr", &self.completions_ptr)
57 .field("current_completions_ptr", &self.current_completions_ptr)
58 .field("tool_executor", &"Arc<dyn ToolExecutor>")
59 .field("stop_on_assistant", &self.stop_on_assistant)
60 .finish()
61 }
62}
63
64impl DefaultContext {
65 pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
67 DefaultContext {
68 tool_executor: executor.into(),
69 ..Default::default()
70 }
71 }
72
73 pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
76 self.stop_on_assistant = stop;
77 self
78 }
79
80 pub fn with_message_history<I: IntoIterator<Item = ChatMessage>>(
86 &mut self,
87 message_history: I,
88 ) -> &mut Self {
89 self.completion_history
90 .lock()
91 .unwrap()
92 .extend(message_history);
93
94 self
95 }
96}
97#[async_trait]
98impl AgentContext for DefaultContext {
99 async fn next_completion(&self) -> Option<Vec<ChatMessage>> {
101 let history = self.completion_history.lock().unwrap();
102
103 let current = self.completions_ptr.load(Ordering::SeqCst);
104
105 if history[current..].is_empty()
106 || (self.stop_on_assistant
107 && matches!(history.last(), Some(ChatMessage::Assistant(_, _))))
108 {
109 None
110 } else {
111 let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
112 self.current_completions_ptr
113 .store(previous, Ordering::SeqCst);
114
115 Some(filter_messages_since_summary(history.clone()))
116 }
117 }
118
119 async fn current_new_messages(&self) -> Vec<ChatMessage> {
121 let current = self.current_completions_ptr.load(Ordering::SeqCst);
122 let end = self.completions_ptr.load(Ordering::SeqCst);
123
124 let history = self.completion_history.lock().unwrap();
125
126 filter_messages_since_summary(history[current..end].to_vec())
127 }
128
129 async fn history(&self) -> Vec<ChatMessage> {
131 self.completion_history.lock().unwrap().clone()
132 }
133
134 async fn add_messages(&self, messages: Vec<ChatMessage>) {
136 for item in messages {
137 self.add_message(item).await;
138 }
139 }
140
141 async fn add_message(&self, item: ChatMessage) {
143 self.completion_history.lock().unwrap().push(item);
144 }
145
146 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
148 self.tool_executor.exec_cmd(cmd).await
149 }
150
151 async fn redrive(&self) {
156 let mut history = self.completion_history.lock().unwrap();
157 let previous = self.current_completions_ptr.load(Ordering::SeqCst);
158 let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst);
159
160 history.truncate(redrive_ptr);
162 }
163}
164
165fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
166 let mut summary_found = false;
167 let mut messages = messages
168 .into_iter()
169 .rev()
170 .filter(|m| {
171 if summary_found {
172 return matches!(m, ChatMessage::System(_));
173 }
174 if let ChatMessage::Summary(_) = m {
175 summary_found = true;
176 }
177 true
178 })
179 .collect::<Vec<_>>();
180
181 messages.reverse();
182
183 messages
184}
185
186#[cfg(test)]
187mod tests {
188 use crate::{assistant, tool_output, user};
189
190 use super::*;
191 use swiftide_core::chat_completion::{ChatMessage, ToolCall};
192
193 #[tokio::test]
194 async fn test_iteration_tracking() {
195 let mut context = DefaultContext::default();
196
197 context
199 .add_messages(vec![
200 ChatMessage::System("You are awesome".into()),
201 ChatMessage::User("Hello".into()),
202 ])
203 .await;
204
205 let messages = context.next_completion().await.unwrap();
206 assert_eq!(messages.len(), 2);
207 assert!(context.next_completion().await.is_none());
208
209 context
210 .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
211 .await;
212
213 let messages = context.next_completion().await.unwrap();
214 assert_eq!(messages.len(), 4);
215 assert!(context.next_completion().await.is_none());
216
217 context.add_messages(vec![assistant!("I am fine")]).await;
219
220 assert!(context.next_completion().await.is_none());
221
222 context.with_stop_on_assistant(false);
223
224 assert!(context.next_completion().await.is_some());
225 }
226
227 #[tokio::test]
228 async fn test_should_complete_after_tool_call() {
229 let context = DefaultContext::default();
230 context
232 .add_messages(vec![
233 ChatMessage::System("You are awesome".into()),
234 ChatMessage::User("Hello".into()),
235 ])
236 .await;
237 let messages = context.next_completion().await.unwrap();
238 assert_eq!(messages.len(), 2);
239 assert_eq!(context.current_new_messages().await.len(), 2);
240 assert!(context.next_completion().await.is_none());
241
242 context
243 .add_messages(vec![
244 assistant!("Hey?", ["test"]),
245 tool_output!("test", "Hoi"),
246 ])
247 .await;
248
249 let messages = context.next_completion().await.unwrap();
250 assert_eq!(context.current_new_messages().await.len(), 2);
251 assert_eq!(messages.len(), 4);
252
253 assert!(context.next_completion().await.is_none());
254 }
255
256 #[tokio::test]
257 async fn test_filters_messages_before_summary() {
258 let messages = vec![
259 ChatMessage::System("System message".into()),
260 ChatMessage::User("Hello".into()),
261 ChatMessage::Assistant(Some("Hello there".into()), None),
262 ChatMessage::Summary("Summary message".into()),
263 ChatMessage::User("This should be ignored".into()),
264 ];
265 let context = DefaultContext::default();
266 context.add_messages(messages).await;
268
269 let new_messages = context.next_completion().await.unwrap();
270
271 assert_eq!(new_messages.len(), 3);
272 assert!(matches!(new_messages[0], ChatMessage::System(_)));
273 assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
274 assert!(matches!(new_messages[2], ChatMessage::User(_)));
275
276 let current_new_messages = context.current_new_messages().await;
277 assert_eq!(current_new_messages.len(), 3);
278 assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
279 assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
280 assert!(matches!(current_new_messages[2], ChatMessage::User(_)));
281
282 assert!(context.next_completion().await.is_none());
283 }
284
285 #[tokio::test]
286 async fn test_filters_messages_before_summary_with_assistant_last() {
287 let messages = vec![
288 ChatMessage::System("System message".into()),
289 ChatMessage::User("Hello".into()),
290 ChatMessage::Assistant(Some("Hello there".into()), None),
291 ];
292 let mut context = DefaultContext::default();
293 context.with_stop_on_assistant(false);
294 context.add_messages(messages).await;
296
297 let new_messages = context.next_completion().await.unwrap();
298
299 assert_eq!(new_messages.len(), 3);
300 assert!(matches!(new_messages[0], ChatMessage::System(_)));
301 assert!(matches!(new_messages[1], ChatMessage::User(_)));
302 assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));
303
304 context
305 .add_message(ChatMessage::Summary("Summary message 1".into()))
306 .await;
307
308 let new_messages = context.next_completion().await.unwrap();
309 dbg!(&new_messages);
310 assert_eq!(new_messages.len(), 2);
311 assert!(matches!(new_messages[0], ChatMessage::System(_)));
312 assert_eq!(
313 new_messages[1],
314 ChatMessage::Summary("Summary message 1".into())
315 );
316
317 assert!(context.next_completion().await.is_none());
318
319 let messages = vec![
320 ChatMessage::User("Hello again".into()),
321 ChatMessage::Assistant(Some("Hello there again".into()), None),
322 ];
323
324 context.add_messages(messages).await;
325
326 let new_messages = context.next_completion().await.unwrap();
327
328 assert!(matches!(new_messages[0], ChatMessage::System(_)));
329 assert_eq!(
330 new_messages[1],
331 ChatMessage::Summary("Summary message 1".into())
332 );
333 assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
334 assert_eq!(
335 new_messages[3],
336 ChatMessage::Assistant(Some("Hello there again".to_string()), None)
337 );
338
339 context
340 .add_message(ChatMessage::Summary("Summary message 2".into()))
341 .await;
342
343 let new_messages = context.next_completion().await.unwrap();
344 assert_eq!(new_messages.len(), 2);
345
346 assert!(matches!(new_messages[0], ChatMessage::System(_)));
347 assert_eq!(
348 new_messages[1],
349 ChatMessage::Summary("Summary message 2".into())
350 );
351 }
352
353 #[tokio::test]
354 async fn test_redrive() {
355 let context = DefaultContext::default();
356
357 context
359 .add_messages(vec![
360 ChatMessage::System("System message".into()),
361 ChatMessage::User("Hello".into()),
362 ])
363 .await;
364
365 let messages = context.next_completion().await.unwrap();
366 assert_eq!(messages.len(), 2);
367 assert!(context.next_completion().await.is_none());
368 context.redrive().await;
369
370 let messages = context.next_completion().await.unwrap();
371 assert_eq!(messages.len(), 2);
372
373 context
374 .add_messages(vec![ChatMessage::User("Hey?".into())])
375 .await;
376
377 let messages = context.next_completion().await.unwrap();
378 assert_eq!(messages.len(), 3);
379 assert!(context.next_completion().await.is_none());
380 context.redrive().await;
381
382 context
384 .add_messages(vec![ChatMessage::User("How are you?".into())])
385 .await;
386
387 let messages = context.next_completion().await.unwrap();
388 assert_eq!(messages.len(), 4);
389 assert!(context.next_completion().await.is_none());
390
391 dbg!(&context);
393 context.redrive().await;
394 dbg!(&context);
395
396 let messages = context.next_completion().await.unwrap();
398 assert_eq!(messages.len(), 4);
399 assert!(context.next_completion().await.is_none());
400
401 context
403 .add_messages(vec![
404 ChatMessage::User("How are you really?".into()),
405 ChatMessage::User("How are you really?".into()),
406 ])
407 .await;
408
409 context.redrive().await;
411
412 let messages = context.next_completion().await.unwrap();
414 assert_eq!(messages.len(), 4);
415 assert!(context.next_completion().await.is_none());
416
417 context.redrive().await;
419 let messages = context.next_completion().await.unwrap();
420 assert_eq!(messages.len(), 4);
421 assert!(context.next_completion().await.is_none());
422 }
423}