1use std::{
11 collections::HashMap,
12 sync::{
13 Arc, Mutex,
14 atomic::{AtomicUsize, Ordering},
15 },
16};
17
18use anyhow::Result;
19use async_trait::async_trait;
20use swiftide_core::{
21 AgentContext, Command, CommandError, CommandOutput, MessageHistory, ToolExecutor,
22};
23use swiftide_core::{
24 ToolFeedback,
25 chat_completion::{ChatMessage, ToolCall},
26};
27
28use crate::tools::local_executor::LocalExecutor;
29
30#[derive(Clone)]
32pub struct DefaultContext {
33 message_history: Arc<dyn MessageHistory>,
37 completions_ptr: Arc<AtomicUsize>,
39
40 current_completions_ptr: Arc<AtomicUsize>,
43
44 tool_executor: Arc<dyn ToolExecutor>,
46
47 stop_on_assistant: bool,
49
50 feedback_received: Arc<Mutex<HashMap<ToolCall, ToolFeedback>>>,
51}
52
53impl Default for DefaultContext {
54 fn default() -> Self {
55 DefaultContext {
56 message_history: Arc::new(Mutex::new(Vec::new())),
57 completions_ptr: Arc::new(AtomicUsize::new(0)),
58 current_completions_ptr: Arc::new(AtomicUsize::new(0)),
59 tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
60 stop_on_assistant: true,
61 feedback_received: Arc::new(Mutex::new(HashMap::new())),
62 }
63 }
64}
65
66impl std::fmt::Debug for DefaultContext {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("DefaultContext")
69 .field("completion_history", &self.message_history)
70 .field("completions_ptr", &self.completions_ptr)
71 .field("current_completions_ptr", &self.current_completions_ptr)
72 .field("tool_executor", &"Arc<dyn ToolExecutor>")
73 .field("stop_on_assistant", &self.stop_on_assistant)
74 .finish()
75 }
76}
77
78impl DefaultContext {
79 pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
81 DefaultContext {
82 tool_executor: executor.into(),
83 ..Default::default()
84 }
85 }
86
87 pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
90 self.stop_on_assistant = stop;
91 self
92 }
93
94 pub fn with_message_history(&mut self, backend: impl MessageHistory + 'static) -> &mut Self {
95 self.message_history = Arc::new(backend) as Arc<dyn MessageHistory>;
96 self
97 }
98
99 pub async fn with_existing_messages<I: IntoIterator<Item = ChatMessage>>(
109 &mut self,
110 message_history: I,
111 ) -> Result<&mut Self> {
112 self.message_history
113 .overwrite(message_history.into_iter().collect())
114 .await?;
115
116 Ok(self)
117 }
118
119 pub fn with_tool_feedback(&mut self, feedback: impl Into<HashMap<ToolCall, ToolFeedback>>) {
125 self.feedback_received
126 .lock()
127 .unwrap()
128 .extend(feedback.into());
129 }
130}
131#[async_trait]
132impl AgentContext for DefaultContext {
133 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
135 let history = self.message_history.history().await?;
136
137 let mut current = self.completions_ptr.load(Ordering::SeqCst);
138
139 if history.is_empty() {
142 tracing::debug!("No messages in history for completion");
143 return Ok(None);
144 }
145
146 if current > history.len() {
147 tracing::warn!(
148 current,
149 len = history.len(),
150 "Completions index was higher than history length, resetting to 0; this might be a bug"
151 );
152 self.completions_ptr.store(0, Ordering::SeqCst);
153 self.current_completions_ptr.store(0, Ordering::SeqCst);
154
155 current = 0;
156 }
157
158 if history[current..].is_empty()
159 || (self.stop_on_assistant
160 && matches!(history.last(), Some(ChatMessage::Assistant(_, _)))
161 && self.feedback_received.lock().unwrap().is_empty())
162 {
163 tracing::debug!(?history, "No new messages for completion");
164 Ok(None)
165 } else {
166 let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
167 self.current_completions_ptr
168 .store(previous, Ordering::SeqCst);
169
170 Ok(Some(filter_messages_since_summary(history)))
171 }
172 }
173
174 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
176 let current = self.current_completions_ptr.load(Ordering::SeqCst);
177 let end = self.completions_ptr.load(Ordering::SeqCst);
178
179 let history = self.message_history.history().await?;
180
181 Ok(filter_messages_since_summary(
182 history[current..end].to_vec(),
183 ))
184 }
185
186 async fn history(&self) -> Result<Vec<ChatMessage>> {
188 self.message_history.history().await
189 }
190
191 async fn add_messages(&self, messages: Vec<ChatMessage>) -> Result<()> {
193 self.message_history.extend_owned(messages).await
194 }
195
196 async fn add_message(&self, item: ChatMessage) -> Result<()> {
198 self.message_history.push_owned(item).await
199 }
200
201 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
203 self.tool_executor.exec_cmd(cmd).await
204 }
205
206 fn executor(&self) -> &Arc<dyn ToolExecutor> {
207 &self.tool_executor
208 }
209
210 async fn redrive(&self) -> Result<()> {
215 let mut history = self.message_history.history().await?;
216 let previous = self.current_completions_ptr.load(Ordering::SeqCst);
217 let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst);
218
219 history.truncate(redrive_ptr);
221
222 self.message_history.overwrite(history).await?;
223
224 Ok(())
225 }
226
227 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
228 let mut lock = self.feedback_received.lock().unwrap();
232 lock.remove(tool_call)
233 }
234
235 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
236 let mut lock = self.feedback_received.lock().unwrap();
237 if lock.is_empty() {
240 let previous = self.current_completions_ptr.load(Ordering::SeqCst);
241 self.completions_ptr.swap(previous, Ordering::SeqCst);
242 }
243 tracing::debug!(?tool_call, context = ?self, "feedback received");
244 lock.insert(tool_call.clone(), feedback.clone());
245
246 Ok(())
247 }
248}
249
250fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
251 let mut summary_found = false;
252 let mut messages = messages
253 .into_iter()
254 .rev()
255 .filter(|m| {
256 if summary_found {
257 return matches!(m, ChatMessage::System(_));
258 }
259 if let ChatMessage::Summary(_) = m {
260 summary_found = true;
261 }
262 true
263 })
264 .collect::<Vec<_>>();
265
266 messages.reverse();
267
268 messages
269}
270
271#[cfg(test)]
272mod tests {
273 use crate::{assistant, tool_output, user};
274
275 use super::*;
276 use swiftide_core::chat_completion::{ChatMessage, ToolCall};
277
278 #[tokio::test]
279 async fn test_iteration_tracking() {
280 let mut context = DefaultContext::default();
281
282 context
284 .add_messages(vec![
285 ChatMessage::System("You are awesome".into()),
286 ChatMessage::User("Hello".into()),
287 ])
288 .await
289 .unwrap();
290
291 let messages = context.next_completion().await.unwrap().unwrap();
292 assert_eq!(messages.len(), 2);
293 assert!(context.next_completion().await.unwrap().is_none());
294
295 context
296 .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
297 .await
298 .unwrap();
299
300 let messages = context.next_completion().await.unwrap().unwrap();
301 assert_eq!(messages.len(), 4);
302 assert!(context.next_completion().await.unwrap().is_none());
303
304 context
306 .add_messages(vec![assistant!("I am fine")])
307 .await
308 .unwrap();
309
310 assert!(context.next_completion().await.unwrap().is_none());
311
312 context.with_stop_on_assistant(false);
313
314 assert!(context.next_completion().await.unwrap().is_some());
315 }
316
317 #[tokio::test]
318 async fn test_should_complete_after_tool_call() {
319 let context = DefaultContext::default();
320 context
322 .add_messages(vec![
323 ChatMessage::System("You are awesome".into()),
324 ChatMessage::User("Hello".into()),
325 ])
326 .await
327 .unwrap();
328 let messages = context.next_completion().await.unwrap().unwrap();
329 assert_eq!(messages.len(), 2);
330 assert_eq!(context.current_new_messages().await.unwrap().len(), 2);
331 assert!(context.next_completion().await.unwrap().is_none());
332
333 context
334 .add_messages(vec![
335 assistant!("Hey?", ["test"]),
336 tool_output!("test", "Hoi"),
337 ])
338 .await
339 .unwrap();
340
341 let messages = context.next_completion().await.unwrap().unwrap();
342 assert_eq!(context.current_new_messages().await.unwrap().len(), 2);
343 assert_eq!(messages.len(), 4);
344
345 assert!(context.next_completion().await.unwrap().is_none());
346 }
347
348 #[tokio::test]
349 async fn test_filters_messages_before_summary() {
350 let messages = vec![
351 ChatMessage::System("System message".into()),
352 ChatMessage::User("Hello".into()),
353 ChatMessage::Assistant(Some("Hello there".into()), None),
354 ChatMessage::Summary("Summary message".into()),
355 ChatMessage::User("This should be ignored".into()),
356 ];
357 let context = DefaultContext::default();
358 context.add_messages(messages).await.unwrap();
360
361 let new_messages = context.next_completion().await.unwrap().unwrap();
362
363 assert_eq!(new_messages.len(), 3);
364 assert!(matches!(new_messages[0], ChatMessage::System(_)));
365 assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
366 assert!(matches!(new_messages[2], ChatMessage::User(_)));
367
368 let current_new_messages = context.current_new_messages().await.unwrap();
369 assert_eq!(current_new_messages.len(), 3);
370 assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
371 assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
372 assert!(matches!(current_new_messages[2], ChatMessage::User(_)));
373
374 assert!(context.next_completion().await.unwrap().is_none());
375 }
376
377 #[tokio::test]
378 async fn test_filters_messages_before_summary_with_assistant_last() {
379 let messages = vec![
380 ChatMessage::System("System message".into()),
381 ChatMessage::User("Hello".into()),
382 ChatMessage::Assistant(Some("Hello there".into()), None),
383 ];
384 let mut context = DefaultContext::default();
385 context.with_stop_on_assistant(false);
386 context.add_messages(messages).await.unwrap();
388
389 let new_messages = context.next_completion().await.unwrap().unwrap();
390
391 assert_eq!(new_messages.len(), 3);
392 assert!(matches!(new_messages[0], ChatMessage::System(_)));
393 assert!(matches!(new_messages[1], ChatMessage::User(_)));
394 assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));
395
396 context
397 .add_message(ChatMessage::Summary("Summary message 1".into()))
398 .await
399 .unwrap();
400
401 let new_messages = context.next_completion().await.unwrap().unwrap();
402 dbg!(&new_messages);
403 assert_eq!(new_messages.len(), 2);
404 assert!(matches!(new_messages[0], ChatMessage::System(_)));
405 assert_eq!(
406 new_messages[1],
407 ChatMessage::Summary("Summary message 1".into())
408 );
409
410 assert!(context.next_completion().await.unwrap().is_none());
411
412 let messages = vec![
413 ChatMessage::User("Hello again".into()),
414 ChatMessage::Assistant(Some("Hello there again".into()), None),
415 ];
416
417 context.add_messages(messages).await.unwrap();
418
419 let new_messages = context.next_completion().await.unwrap().unwrap();
420
421 assert!(matches!(new_messages[0], ChatMessage::System(_)));
422 assert_eq!(
423 new_messages[1],
424 ChatMessage::Summary("Summary message 1".into())
425 );
426 assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
427 assert_eq!(
428 new_messages[3],
429 ChatMessage::Assistant(Some("Hello there again".to_string()), None)
430 );
431
432 context
433 .add_message(ChatMessage::Summary("Summary message 2".into()))
434 .await
435 .unwrap();
436
437 let new_messages = context.next_completion().await.unwrap().unwrap();
438 assert_eq!(new_messages.len(), 2);
439
440 assert!(matches!(new_messages[0], ChatMessage::System(_)));
441 assert_eq!(
442 new_messages[1],
443 ChatMessage::Summary("Summary message 2".into())
444 );
445 }
446
447 #[tokio::test]
448 async fn test_redrive() {
449 let context = DefaultContext::default();
450
451 context
453 .add_messages(vec![
454 ChatMessage::System("System message".into()),
455 ChatMessage::User("Hello".into()),
456 ])
457 .await
458 .unwrap();
459
460 let messages = context.next_completion().await.unwrap().unwrap();
461 assert_eq!(messages.len(), 2);
462 assert!(context.next_completion().await.unwrap().is_none());
463 context.redrive().await.unwrap();
464
465 let messages = context.next_completion().await.unwrap().unwrap();
466 assert_eq!(messages.len(), 2);
467
468 context
469 .add_messages(vec![ChatMessage::User("Hey?".into())])
470 .await
471 .unwrap();
472
473 let messages = context.next_completion().await.unwrap().unwrap();
474 assert_eq!(messages.len(), 3);
475 assert!(context.next_completion().await.unwrap().is_none());
476 context.redrive().await.unwrap();
477
478 context
480 .add_messages(vec![ChatMessage::User("How are you?".into())])
481 .await
482 .unwrap();
483
484 let messages = context.next_completion().await.unwrap().unwrap();
485 assert_eq!(messages.len(), 4);
486 assert!(context.next_completion().await.unwrap().is_none());
487
488 dbg!(&context);
490 context.redrive().await.unwrap();
491 dbg!(&context);
492
493 let messages = context.next_completion().await.unwrap().unwrap();
495 assert_eq!(messages.len(), 4);
496 assert!(context.next_completion().await.unwrap().is_none());
497
498 context
500 .add_messages(vec![
501 ChatMessage::User("How are you really?".into()),
502 ChatMessage::User("How are you really?".into()),
503 ])
504 .await
505 .unwrap();
506
507 context.redrive().await.unwrap();
509
510 let messages = context.next_completion().await.unwrap().unwrap();
512 assert_eq!(messages.len(), 4);
513 assert!(context.next_completion().await.unwrap().is_none());
514
515 context.redrive().await.unwrap();
517 let messages = context.next_completion().await.unwrap().unwrap();
518 assert_eq!(messages.len(), 4);
519 assert!(context.next_completion().await.unwrap().is_none());
520 }
521
522 #[tokio::test]
523 async fn test_next_completion_empty_history() {
524 let context = DefaultContext::default();
525 let next = context.next_completion().await;
526 assert!(next.unwrap().is_none());
527 }
528
529 #[tokio::test]
530 async fn test_next_completion_out_of_bounds_ptr() {
531 let context = DefaultContext::default();
532 context
533 .add_messages(vec![
534 ChatMessage::System("System".into()),
535 ChatMessage::User("Hi".into()),
536 ])
537 .await
538 .unwrap();
539
540 context
542 .completions_ptr
543 .store(10, std::sync::atomic::Ordering::SeqCst);
544
545 let messages = context.next_completion().await.unwrap().unwrap();
547 assert_eq!(messages.len(), 2);
548
549 assert!(context.next_completion().await.unwrap().is_none());
551 }
552}