1use std::{
11 collections::HashMap,
12 sync::{
13 Arc, Mutex,
14 atomic::{AtomicUsize, Ordering},
15 },
16};
17
18use anyhow::Result;
19use async_trait::async_trait;
20use swiftide_core::{
21 AgentContext, Command, CommandError, CommandOutput, MessageHistory, ToolExecutor,
22};
23use swiftide_core::{
24 ToolFeedback,
25 chat_completion::{ChatMessage, ToolCall},
26};
27
28use crate::tools::local_executor::LocalExecutor;
29
30#[derive(Clone)]
32pub struct DefaultContext {
33 message_history: Arc<dyn MessageHistory>,
37 completions_ptr: Arc<AtomicUsize>,
39
40 current_completions_ptr: Arc<AtomicUsize>,
43
44 tool_executor: Arc<dyn ToolExecutor>,
46
47 stop_on_assistant: bool,
49
50 feedback_received: Arc<Mutex<HashMap<ToolCall, ToolFeedback>>>,
51}
52
53impl Default for DefaultContext {
54 fn default() -> Self {
55 DefaultContext {
56 message_history: Arc::new(Mutex::new(Vec::new())),
57 completions_ptr: Arc::new(AtomicUsize::new(0)),
58 current_completions_ptr: Arc::new(AtomicUsize::new(0)),
59 tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
60 stop_on_assistant: true,
61 feedback_received: Arc::new(Mutex::new(HashMap::new())),
62 }
63 }
64}
65
66impl std::fmt::Debug for DefaultContext {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("DefaultContext")
69 .field("completion_history", &self.message_history)
70 .field("completions_ptr", &self.completions_ptr)
71 .field("current_completions_ptr", &self.current_completions_ptr)
72 .field("tool_executor", &"Arc<dyn ToolExecutor>")
73 .field("stop_on_assistant", &self.stop_on_assistant)
74 .finish()
75 }
76}
77
78impl DefaultContext {
79 pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
81 DefaultContext {
82 tool_executor: executor.into(),
83 ..Default::default()
84 }
85 }
86
87 pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
90 self.stop_on_assistant = stop;
91 self
92 }
93
94 pub fn with_message_history(&mut self, backend: impl MessageHistory + 'static) -> &mut Self {
95 self.message_history = Arc::new(backend) as Arc<dyn MessageHistory>;
96 self
97 }
98
99 pub async fn with_existing_messages<I: IntoIterator<Item = ChatMessage>>(
109 &mut self,
110 message_history: I,
111 ) -> Result<&mut Self> {
112 self.message_history
113 .overwrite(message_history.into_iter().collect())
114 .await?;
115
116 Ok(self)
117 }
118
119 pub fn with_tool_feedback(&mut self, feedback: impl Into<HashMap<ToolCall, ToolFeedback>>) {
125 self.feedback_received
126 .lock()
127 .unwrap()
128 .extend(feedback.into());
129 }
130}
131#[async_trait]
132impl AgentContext for DefaultContext {
133 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
135 let history = self.message_history.history().await?;
136
137 let current = self.completions_ptr.load(Ordering::SeqCst);
138
139 if history[current..].is_empty()
140 || (self.stop_on_assistant
141 && matches!(history.last(), Some(ChatMessage::Assistant(_, _)))
142 && self.feedback_received.lock().unwrap().is_empty())
143 {
144 tracing::debug!(?history, "No new messages for completion");
145 Ok(None)
146 } else {
147 let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
148 self.current_completions_ptr
149 .store(previous, Ordering::SeqCst);
150
151 Ok(Some(filter_messages_since_summary(history)))
152 }
153 }
154
155 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
157 let current = self.current_completions_ptr.load(Ordering::SeqCst);
158 let end = self.completions_ptr.load(Ordering::SeqCst);
159
160 let history = self.message_history.history().await?;
161
162 Ok(filter_messages_since_summary(
163 history[current..end].to_vec(),
164 ))
165 }
166
167 async fn history(&self) -> Result<Vec<ChatMessage>> {
169 self.message_history.history().await
170 }
171
172 async fn add_messages(&self, messages: Vec<ChatMessage>) -> Result<()> {
174 self.message_history.extend_owned(messages).await
175 }
176
177 async fn add_message(&self, item: ChatMessage) -> Result<()> {
179 self.message_history.push_owned(item).await
180 }
181
182 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
184 self.tool_executor.exec_cmd(cmd).await
185 }
186
187 fn executor(&self) -> &Arc<dyn ToolExecutor> {
188 &self.tool_executor
189 }
190
191 async fn redrive(&self) -> Result<()> {
196 let mut history = self.message_history.history().await?;
197 let previous = self.current_completions_ptr.load(Ordering::SeqCst);
198 let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst);
199
200 history.truncate(redrive_ptr);
202
203 self.message_history.overwrite(history).await?;
204
205 Ok(())
206 }
207
208 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
209 let mut lock = self.feedback_received.lock().unwrap();
213 lock.remove(tool_call)
214 }
215
216 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
217 let mut lock = self.feedback_received.lock().unwrap();
218 if lock.is_empty() {
221 let previous = self.current_completions_ptr.load(Ordering::SeqCst);
222 self.completions_ptr.swap(previous, Ordering::SeqCst);
223 }
224 tracing::debug!(?tool_call, context = ?self, "feedback received");
225 lock.insert(tool_call.clone(), feedback.clone());
226
227 Ok(())
228 }
229}
230
231fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
232 let mut summary_found = false;
233 let mut messages = messages
234 .into_iter()
235 .rev()
236 .filter(|m| {
237 if summary_found {
238 return matches!(m, ChatMessage::System(_));
239 }
240 if let ChatMessage::Summary(_) = m {
241 summary_found = true;
242 }
243 true
244 })
245 .collect::<Vec<_>>();
246
247 messages.reverse();
248
249 messages
250}
251
252#[cfg(test)]
253mod tests {
254 use crate::{assistant, tool_output, user};
255
256 use super::*;
257 use swiftide_core::chat_completion::{ChatMessage, ToolCall};
258
259 #[tokio::test]
260 async fn test_iteration_tracking() {
261 let mut context = DefaultContext::default();
262
263 context
265 .add_messages(vec![
266 ChatMessage::System("You are awesome".into()),
267 ChatMessage::User("Hello".into()),
268 ])
269 .await
270 .unwrap();
271
272 let messages = context.next_completion().await.unwrap().unwrap();
273 assert_eq!(messages.len(), 2);
274 assert!(context.next_completion().await.unwrap().is_none());
275
276 context
277 .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
278 .await
279 .unwrap();
280
281 let messages = context.next_completion().await.unwrap().unwrap();
282 assert_eq!(messages.len(), 4);
283 assert!(context.next_completion().await.unwrap().is_none());
284
285 context
287 .add_messages(vec![assistant!("I am fine")])
288 .await
289 .unwrap();
290
291 assert!(context.next_completion().await.unwrap().is_none());
292
293 context.with_stop_on_assistant(false);
294
295 assert!(context.next_completion().await.unwrap().is_some());
296 }
297
298 #[tokio::test]
299 async fn test_should_complete_after_tool_call() {
300 let context = DefaultContext::default();
301 context
303 .add_messages(vec![
304 ChatMessage::System("You are awesome".into()),
305 ChatMessage::User("Hello".into()),
306 ])
307 .await
308 .unwrap();
309 let messages = context.next_completion().await.unwrap().unwrap();
310 assert_eq!(messages.len(), 2);
311 assert_eq!(context.current_new_messages().await.unwrap().len(), 2);
312 assert!(context.next_completion().await.unwrap().is_none());
313
314 context
315 .add_messages(vec![
316 assistant!("Hey?", ["test"]),
317 tool_output!("test", "Hoi"),
318 ])
319 .await
320 .unwrap();
321
322 let messages = context.next_completion().await.unwrap().unwrap();
323 assert_eq!(context.current_new_messages().await.unwrap().len(), 2);
324 assert_eq!(messages.len(), 4);
325
326 assert!(context.next_completion().await.unwrap().is_none());
327 }
328
329 #[tokio::test]
330 async fn test_filters_messages_before_summary() {
331 let messages = vec![
332 ChatMessage::System("System message".into()),
333 ChatMessage::User("Hello".into()),
334 ChatMessage::Assistant(Some("Hello there".into()), None),
335 ChatMessage::Summary("Summary message".into()),
336 ChatMessage::User("This should be ignored".into()),
337 ];
338 let context = DefaultContext::default();
339 context.add_messages(messages).await.unwrap();
341
342 let new_messages = context.next_completion().await.unwrap().unwrap();
343
344 assert_eq!(new_messages.len(), 3);
345 assert!(matches!(new_messages[0], ChatMessage::System(_)));
346 assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
347 assert!(matches!(new_messages[2], ChatMessage::User(_)));
348
349 let current_new_messages = context.current_new_messages().await.unwrap();
350 assert_eq!(current_new_messages.len(), 3);
351 assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
352 assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
353 assert!(matches!(current_new_messages[2], ChatMessage::User(_)));
354
355 assert!(context.next_completion().await.unwrap().is_none());
356 }
357
358 #[tokio::test]
359 async fn test_filters_messages_before_summary_with_assistant_last() {
360 let messages = vec![
361 ChatMessage::System("System message".into()),
362 ChatMessage::User("Hello".into()),
363 ChatMessage::Assistant(Some("Hello there".into()), None),
364 ];
365 let mut context = DefaultContext::default();
366 context.with_stop_on_assistant(false);
367 context.add_messages(messages).await.unwrap();
369
370 let new_messages = context.next_completion().await.unwrap().unwrap();
371
372 assert_eq!(new_messages.len(), 3);
373 assert!(matches!(new_messages[0], ChatMessage::System(_)));
374 assert!(matches!(new_messages[1], ChatMessage::User(_)));
375 assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));
376
377 context
378 .add_message(ChatMessage::Summary("Summary message 1".into()))
379 .await
380 .unwrap();
381
382 let new_messages = context.next_completion().await.unwrap().unwrap();
383 dbg!(&new_messages);
384 assert_eq!(new_messages.len(), 2);
385 assert!(matches!(new_messages[0], ChatMessage::System(_)));
386 assert_eq!(
387 new_messages[1],
388 ChatMessage::Summary("Summary message 1".into())
389 );
390
391 assert!(context.next_completion().await.unwrap().is_none());
392
393 let messages = vec![
394 ChatMessage::User("Hello again".into()),
395 ChatMessage::Assistant(Some("Hello there again".into()), None),
396 ];
397
398 context.add_messages(messages).await.unwrap();
399
400 let new_messages = context.next_completion().await.unwrap().unwrap();
401
402 assert!(matches!(new_messages[0], ChatMessage::System(_)));
403 assert_eq!(
404 new_messages[1],
405 ChatMessage::Summary("Summary message 1".into())
406 );
407 assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
408 assert_eq!(
409 new_messages[3],
410 ChatMessage::Assistant(Some("Hello there again".to_string()), None)
411 );
412
413 context
414 .add_message(ChatMessage::Summary("Summary message 2".into()))
415 .await
416 .unwrap();
417
418 let new_messages = context.next_completion().await.unwrap().unwrap();
419 assert_eq!(new_messages.len(), 2);
420
421 assert!(matches!(new_messages[0], ChatMessage::System(_)));
422 assert_eq!(
423 new_messages[1],
424 ChatMessage::Summary("Summary message 2".into())
425 );
426 }
427
428 #[tokio::test]
429 async fn test_redrive() {
430 let context = DefaultContext::default();
431
432 context
434 .add_messages(vec![
435 ChatMessage::System("System message".into()),
436 ChatMessage::User("Hello".into()),
437 ])
438 .await
439 .unwrap();
440
441 let messages = context.next_completion().await.unwrap().unwrap();
442 assert_eq!(messages.len(), 2);
443 assert!(context.next_completion().await.unwrap().is_none());
444 context.redrive().await.unwrap();
445
446 let messages = context.next_completion().await.unwrap().unwrap();
447 assert_eq!(messages.len(), 2);
448
449 context
450 .add_messages(vec![ChatMessage::User("Hey?".into())])
451 .await
452 .unwrap();
453
454 let messages = context.next_completion().await.unwrap().unwrap();
455 assert_eq!(messages.len(), 3);
456 assert!(context.next_completion().await.unwrap().is_none());
457 context.redrive().await.unwrap();
458
459 context
461 .add_messages(vec![ChatMessage::User("How are you?".into())])
462 .await
463 .unwrap();
464
465 let messages = context.next_completion().await.unwrap().unwrap();
466 assert_eq!(messages.len(), 4);
467 assert!(context.next_completion().await.unwrap().is_none());
468
469 dbg!(&context);
471 context.redrive().await.unwrap();
472 dbg!(&context);
473
474 let messages = context.next_completion().await.unwrap().unwrap();
476 assert_eq!(messages.len(), 4);
477 assert!(context.next_completion().await.unwrap().is_none());
478
479 context
481 .add_messages(vec![
482 ChatMessage::User("How are you really?".into()),
483 ChatMessage::User("How are you really?".into()),
484 ])
485 .await
486 .unwrap();
487
488 context.redrive().await.unwrap();
490
491 let messages = context.next_completion().await.unwrap().unwrap();
493 assert_eq!(messages.len(), 4);
494 assert!(context.next_completion().await.unwrap().is_none());
495
496 context.redrive().await.unwrap();
498 let messages = context.next_completion().await.unwrap().unwrap();
499 assert_eq!(messages.len(), 4);
500 assert!(context.next_completion().await.unwrap().is_none());
501 }
502}