swiftide_agents/
default_context.rs

1//! Manages agent history and provides an interface for the external world
2//!
3//! This is the default for agents. It is fully async and shareable between agents.
4//!
5//! By default uses the `LocalExecutor` for tool execution.
6//!
7//! If chat messages include a `ChatMessage::Summary`, all previous messages are ignored except the
8//! system prompt. This is useful for maintaining focus in long conversations or managing token
9//! limits.
10use std::{
11    collections::HashMap,
12    sync::{
13        Arc, Mutex,
14        atomic::{AtomicUsize, Ordering},
15    },
16};
17
18use anyhow::Result;
19use async_trait::async_trait;
20use swiftide_core::{
21    AgentContext, Command, CommandError, CommandOutput, MessageHistory, ToolExecutor,
22};
23use swiftide_core::{
24    ToolFeedback,
25    chat_completion::{ChatMessage, ToolCall},
26};
27
28use crate::tools::local_executor::LocalExecutor;
29
30// TODO: Remove unit as executor and implement a local executor instead
31#[derive(Clone)]
32pub struct DefaultContext {
33    /// Responsible for managing the conversation history
34    ///
35    /// By default, this is a `Arc<Mutex<Vec<ChatMessage>>>`.
36    message_history: Arc<dyn MessageHistory>,
37    /// Index in the conversation history where the next completion will start
38    completions_ptr: Arc<AtomicUsize>,
39
40    /// Index in the conversation history where the current completion started
41    /// Allows for retrieving only new messages since the last completion
42    current_completions_ptr: Arc<AtomicUsize>,
43
44    /// The executor used to run tools. I.e. local, remote, docker
45    tool_executor: Arc<dyn ToolExecutor>,
46
47    /// Stop if last message is from the assistant
48    stop_on_assistant: bool,
49
50    feedback_received: Arc<Mutex<HashMap<ToolCall, ToolFeedback>>>,
51}
52
53impl Default for DefaultContext {
54    fn default() -> Self {
55        DefaultContext {
56            message_history: Arc::new(Mutex::new(Vec::new())),
57            completions_ptr: Arc::new(AtomicUsize::new(0)),
58            current_completions_ptr: Arc::new(AtomicUsize::new(0)),
59            tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
60            stop_on_assistant: true,
61            feedback_received: Arc::new(Mutex::new(HashMap::new())),
62        }
63    }
64}
65
66impl std::fmt::Debug for DefaultContext {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("DefaultContext")
69            .field("completion_history", &self.message_history)
70            .field("completions_ptr", &self.completions_ptr)
71            .field("current_completions_ptr", &self.current_completions_ptr)
72            .field("tool_executor", &"Arc<dyn ToolExecutor>")
73            .field("stop_on_assistant", &self.stop_on_assistant)
74            .finish()
75    }
76}
77
78impl DefaultContext {
79    /// Create a new context with a custom executor
80    pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
81        DefaultContext {
82            tool_executor: executor.into(),
83            ..Default::default()
84        }
85    }
86
87    /// If set to true, the agent will stop if the last message is from the assistant (i.e. no new
88    /// tool calls, summaries or user messages)
89    pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
90        self.stop_on_assistant = stop;
91        self
92    }
93
94    pub fn with_message_history(&mut self, backend: impl MessageHistory + 'static) -> &mut Self {
95        self.message_history = Arc::new(backend) as Arc<dyn MessageHistory>;
96        self
97    }
98
99    /// Build a context from an existing message history
100    ///
101    /// # Errors
102    ///
103    /// Errors if the message history cannot be extended
104    ///
105    /// # Panics
106    ///
107    /// Panics if the inner mutex is poisoned
108    pub async fn with_existing_messages<I: IntoIterator<Item = ChatMessage>>(
109        &mut self,
110        message_history: I,
111    ) -> Result<&mut Self> {
112        self.message_history
113            .overwrite(message_history.into_iter().collect())
114            .await?;
115
116        Ok(self)
117    }
118
119    /// Add existing tool feedback to the context
120    ///
121    /// # Panics
122    ///
123    /// Panics if the inner mutex is poisoned
124    pub fn with_tool_feedback(&mut self, feedback: impl Into<HashMap<ToolCall, ToolFeedback>>) {
125        self.feedback_received
126            .lock()
127            .unwrap()
128            .extend(feedback.into());
129    }
130}
131#[async_trait]
132impl AgentContext for DefaultContext {
133    /// Retrieve messages for the next completion
134    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
135        let history = self.message_history.history().await?;
136
137        let mut current = self.completions_ptr.load(Ordering::SeqCst);
138
139        // handle out of bounds; if current > length, reset current to 0
140        // if length is 0, return None
141        if history.is_empty() {
142            tracing::debug!("No messages in history for completion");
143            return Ok(None);
144        }
145
146        if current > history.len() {
147            tracing::warn!(
148                current,
149                len = history.len(),
150                "Completions index was higher than history length, resetting to 0; this might be a bug"
151            );
152            self.completions_ptr.store(0, Ordering::SeqCst);
153            self.current_completions_ptr.store(0, Ordering::SeqCst);
154
155            current = 0;
156        }
157
158        if history[current..].is_empty()
159            || (self.stop_on_assistant
160                && matches!(history.last(), Some(ChatMessage::Assistant(_, _)))
161                && self.feedback_received.lock().unwrap().is_empty())
162        {
163            tracing::debug!(?history, "No new messages for completion");
164            Ok(None)
165        } else {
166            let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
167            self.current_completions_ptr
168                .store(previous, Ordering::SeqCst);
169
170            Ok(Some(filter_messages_since_summary(history)))
171        }
172    }
173
174    /// Returns the messages the agent is currently completing on
175    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
176        let current = self.current_completions_ptr.load(Ordering::SeqCst);
177        let end = self.completions_ptr.load(Ordering::SeqCst);
178
179        let history = self.message_history.history().await?;
180
181        Ok(filter_messages_since_summary(
182            history[current..end].to_vec(),
183        ))
184    }
185
186    /// Retrieve all messages in the conversation history
187    async fn history(&self) -> Result<Vec<ChatMessage>> {
188        self.message_history.history().await
189    }
190
191    /// Add multiple messages to the conversation history
192    async fn add_messages(&self, messages: Vec<ChatMessage>) -> Result<()> {
193        self.message_history.extend_owned(messages).await
194    }
195
196    /// Add a single message to the conversation history
197    async fn add_message(&self, item: ChatMessage) -> Result<()> {
198        self.message_history.push_owned(item).await
199    }
200
201    /// Execute a command in the tool executor
202    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
203        self.tool_executor.exec_cmd(cmd).await
204    }
205
206    fn executor(&self) -> &Arc<dyn ToolExecutor> {
207        &self.tool_executor
208    }
209
210    /// Pops the last messages up until the previous completion
211    ///
212    /// LLMs failing completion for various reasons is unfortunately a common occurrence
213    /// This gives a way to redrive the last completion in a generic way
214    async fn redrive(&self) -> Result<()> {
215        let mut history = self.message_history.history().await?;
216        let previous = self.current_completions_ptr.load(Ordering::SeqCst);
217        let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst);
218
219        // delete everything after the last completion
220        history.truncate(redrive_ptr);
221
222        self.message_history.overwrite(history).await?;
223
224        Ok(())
225    }
226
227    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
228        // If feedback is present, return true with the optional payload,
229        // and remove it
230        // otherwise return false
231        let mut lock = self.feedback_received.lock().unwrap();
232        lock.remove(tool_call)
233    }
234
235    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
236        let mut lock = self.feedback_received.lock().unwrap();
237        // Set the message counter one back so that on a next try, the agent can resume by
238        // trying the tool calls first. Only does this if there are no other approvals
239        if lock.is_empty() {
240            let previous = self.current_completions_ptr.load(Ordering::SeqCst);
241            self.completions_ptr.swap(previous, Ordering::SeqCst);
242        }
243        tracing::debug!(?tool_call, context = ?self, "feedback received");
244        lock.insert(tool_call.clone(), feedback.clone());
245
246        Ok(())
247    }
248}
249
250fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
251    let mut summary_found = false;
252    let mut messages = messages
253        .into_iter()
254        .rev()
255        .filter(|m| {
256            if summary_found {
257                return matches!(m, ChatMessage::System(_));
258            }
259            if let ChatMessage::Summary(_) = m {
260                summary_found = true;
261            }
262            true
263        })
264        .collect::<Vec<_>>();
265
266    messages.reverse();
267
268    messages
269}
270
271#[cfg(test)]
272mod tests {
273    use crate::{assistant, tool_output, user};
274
275    use super::*;
276    use swiftide_core::chat_completion::{ChatMessage, ToolCall};
277
278    #[tokio::test]
279    async fn test_iteration_tracking() {
280        let mut context = DefaultContext::default();
281
282        // Record initial chat messages
283        context
284            .add_messages(vec![
285                ChatMessage::System("You are awesome".into()),
286                ChatMessage::User("Hello".into()),
287            ])
288            .await
289            .unwrap();
290
291        let messages = context.next_completion().await.unwrap().unwrap();
292        assert_eq!(messages.len(), 2);
293        assert!(context.next_completion().await.unwrap().is_none());
294
295        context
296            .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
297            .await
298            .unwrap();
299
300        let messages = context.next_completion().await.unwrap().unwrap();
301        assert_eq!(messages.len(), 4);
302        assert!(context.next_completion().await.unwrap().is_none());
303
304        // If the last message is from the assistant, we should not get any more completions
305        context
306            .add_messages(vec![assistant!("I am fine")])
307            .await
308            .unwrap();
309
310        assert!(context.next_completion().await.unwrap().is_none());
311
312        context.with_stop_on_assistant(false);
313
314        assert!(context.next_completion().await.unwrap().is_some());
315    }
316
317    #[tokio::test]
318    async fn test_should_complete_after_tool_call() {
319        let context = DefaultContext::default();
320        // Record initial chat messages
321        context
322            .add_messages(vec![
323                ChatMessage::System("You are awesome".into()),
324                ChatMessage::User("Hello".into()),
325            ])
326            .await
327            .unwrap();
328        let messages = context.next_completion().await.unwrap().unwrap();
329        assert_eq!(messages.len(), 2);
330        assert_eq!(context.current_new_messages().await.unwrap().len(), 2);
331        assert!(context.next_completion().await.unwrap().is_none());
332
333        context
334            .add_messages(vec![
335                assistant!("Hey?", ["test"]),
336                tool_output!("test", "Hoi"),
337            ])
338            .await
339            .unwrap();
340
341        let messages = context.next_completion().await.unwrap().unwrap();
342        assert_eq!(context.current_new_messages().await.unwrap().len(), 2);
343        assert_eq!(messages.len(), 4);
344
345        assert!(context.next_completion().await.unwrap().is_none());
346    }
347
348    #[tokio::test]
349    async fn test_filters_messages_before_summary() {
350        let messages = vec![
351            ChatMessage::System("System message".into()),
352            ChatMessage::User("Hello".into()),
353            ChatMessage::Assistant(Some("Hello there".into()), None),
354            ChatMessage::Summary("Summary message".into()),
355            ChatMessage::User("This should be ignored".into()),
356        ];
357        let context = DefaultContext::default();
358        // Record initial chat messages
359        context.add_messages(messages).await.unwrap();
360
361        let new_messages = context.next_completion().await.unwrap().unwrap();
362
363        assert_eq!(new_messages.len(), 3);
364        assert!(matches!(new_messages[0], ChatMessage::System(_)));
365        assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
366        assert!(matches!(new_messages[2], ChatMessage::User(_)));
367
368        let current_new_messages = context.current_new_messages().await.unwrap();
369        assert_eq!(current_new_messages.len(), 3);
370        assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
371        assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
372        assert!(matches!(current_new_messages[2], ChatMessage::User(_)));
373
374        assert!(context.next_completion().await.unwrap().is_none());
375    }
376
377    #[tokio::test]
378    async fn test_filters_messages_before_summary_with_assistant_last() {
379        let messages = vec![
380            ChatMessage::System("System message".into()),
381            ChatMessage::User("Hello".into()),
382            ChatMessage::Assistant(Some("Hello there".into()), None),
383        ];
384        let mut context = DefaultContext::default();
385        context.with_stop_on_assistant(false);
386        // Record initial chat messages
387        context.add_messages(messages).await.unwrap();
388
389        let new_messages = context.next_completion().await.unwrap().unwrap();
390
391        assert_eq!(new_messages.len(), 3);
392        assert!(matches!(new_messages[0], ChatMessage::System(_)));
393        assert!(matches!(new_messages[1], ChatMessage::User(_)));
394        assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));
395
396        context
397            .add_message(ChatMessage::Summary("Summary message 1".into()))
398            .await
399            .unwrap();
400
401        let new_messages = context.next_completion().await.unwrap().unwrap();
402        dbg!(&new_messages);
403        assert_eq!(new_messages.len(), 2);
404        assert!(matches!(new_messages[0], ChatMessage::System(_)));
405        assert_eq!(
406            new_messages[1],
407            ChatMessage::Summary("Summary message 1".into())
408        );
409
410        assert!(context.next_completion().await.unwrap().is_none());
411
412        let messages = vec![
413            ChatMessage::User("Hello again".into()),
414            ChatMessage::Assistant(Some("Hello there again".into()), None),
415        ];
416
417        context.add_messages(messages).await.unwrap();
418
419        let new_messages = context.next_completion().await.unwrap().unwrap();
420
421        assert!(matches!(new_messages[0], ChatMessage::System(_)));
422        assert_eq!(
423            new_messages[1],
424            ChatMessage::Summary("Summary message 1".into())
425        );
426        assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
427        assert_eq!(
428            new_messages[3],
429            ChatMessage::Assistant(Some("Hello there again".to_string()), None)
430        );
431
432        context
433            .add_message(ChatMessage::Summary("Summary message 2".into()))
434            .await
435            .unwrap();
436
437        let new_messages = context.next_completion().await.unwrap().unwrap();
438        assert_eq!(new_messages.len(), 2);
439
440        assert!(matches!(new_messages[0], ChatMessage::System(_)));
441        assert_eq!(
442            new_messages[1],
443            ChatMessage::Summary("Summary message 2".into())
444        );
445    }
446
447    #[tokio::test]
448    async fn test_redrive() {
449        let context = DefaultContext::default();
450
451        // Record initial chat messages
452        context
453            .add_messages(vec![
454                ChatMessage::System("System message".into()),
455                ChatMessage::User("Hello".into()),
456            ])
457            .await
458            .unwrap();
459
460        let messages = context.next_completion().await.unwrap().unwrap();
461        assert_eq!(messages.len(), 2);
462        assert!(context.next_completion().await.unwrap().is_none());
463        context.redrive().await.unwrap();
464
465        let messages = context.next_completion().await.unwrap().unwrap();
466        assert_eq!(messages.len(), 2);
467
468        context
469            .add_messages(vec![ChatMessage::User("Hey?".into())])
470            .await
471            .unwrap();
472
473        let messages = context.next_completion().await.unwrap().unwrap();
474        assert_eq!(messages.len(), 3);
475        assert!(context.next_completion().await.unwrap().is_none());
476        context.redrive().await.unwrap();
477
478        // Add more messages
479        context
480            .add_messages(vec![ChatMessage::User("How are you?".into())])
481            .await
482            .unwrap();
483
484        let messages = context.next_completion().await.unwrap().unwrap();
485        assert_eq!(messages.len(), 4);
486        assert!(context.next_completion().await.unwrap().is_none());
487
488        // Redrive should remove the last set of messages
489        dbg!(&context);
490        context.redrive().await.unwrap();
491        dbg!(&context);
492
493        // We just redrove with the same messages
494        let messages = context.next_completion().await.unwrap().unwrap();
495        assert_eq!(messages.len(), 4);
496        assert!(context.next_completion().await.unwrap().is_none());
497
498        // Add more messages
499        context
500            .add_messages(vec![
501                ChatMessage::User("How are you really?".into()),
502                ChatMessage::User("How are you really?".into()),
503            ])
504            .await
505            .unwrap();
506
507        // This should remove any additional messages
508        context.redrive().await.unwrap();
509
510        // We just redrove with the same messages
511        let messages = context.next_completion().await.unwrap().unwrap();
512        assert_eq!(messages.len(), 4);
513        assert!(context.next_completion().await.unwrap().is_none());
514
515        // Redrive again
516        context.redrive().await.unwrap();
517        let messages = context.next_completion().await.unwrap().unwrap();
518        assert_eq!(messages.len(), 4);
519        assert!(context.next_completion().await.unwrap().is_none());
520    }
521
522    #[tokio::test]
523    async fn test_next_completion_empty_history() {
524        let context = DefaultContext::default();
525        let next = context.next_completion().await;
526        assert!(next.unwrap().is_none());
527    }
528
529    #[tokio::test]
530    async fn test_next_completion_out_of_bounds_ptr() {
531        let context = DefaultContext::default();
532        context
533            .add_messages(vec![
534                ChatMessage::System("System".into()),
535                ChatMessage::User("Hi".into()),
536            ])
537            .await
538            .unwrap();
539
540        // Set completions_ptr beyond the length of messages
541        context
542            .completions_ptr
543            .store(10, std::sync::atomic::Ordering::SeqCst);
544
545        // Should reset the pointer and return the full messages
546        let messages = context.next_completion().await.unwrap().unwrap();
547        assert_eq!(messages.len(), 2);
548
549        // Second call should be empty again
550        assert!(context.next_completion().await.unwrap().is_none());
551    }
552}