swiftide_agents/
default_context.rs

1//! Manages agent history and provides an interface for the external world
2//!
3//! This is the default for agents. It is fully async and shareable between agents.
4//!
5//! By default uses the `LocalExecutor` for tool execution.
6//!
7//! If chat messages include a `ChatMessage::Summary`, all previous messages are ignored except the
8//! system prompt. This is useful for maintaining focus in long conversations or managing token
9//! limits.
10use std::{
11    collections::HashMap,
12    sync::{
13        Arc, Mutex,
14        atomic::{AtomicUsize, Ordering},
15    },
16};
17
18use anyhow::Result;
19use async_trait::async_trait;
20use swiftide_core::{
21    AgentContext, Command, CommandError, CommandOutput, MessageHistory, ToolExecutor,
22};
23use swiftide_core::{
24    ToolFeedback,
25    chat_completion::{ChatMessage, ToolCall},
26};
27
28use crate::tools::local_executor::LocalExecutor;
29
30// TODO: Remove unit as executor and implement a local executor instead
31#[derive(Clone)]
32pub struct DefaultContext {
33    /// Responsible for managing the conversation history
34    ///
35    /// By default, this is a `Arc<Mutex<Vec<ChatMessage>>>`.
36    message_history: Arc<dyn MessageHistory>,
37    /// Index in the conversation history where the next completion will start
38    completions_ptr: Arc<AtomicUsize>,
39
40    /// Index in the conversation history where the current completion started
41    /// Allows for retrieving only new messages since the last completion
42    current_completions_ptr: Arc<AtomicUsize>,
43
44    /// The executor used to run tools. I.e. local, remote, docker
45    tool_executor: Arc<dyn ToolExecutor>,
46
47    /// Stop if last message is from the assistant
48    stop_on_assistant: bool,
49
50    feedback_received: Arc<Mutex<HashMap<ToolCall, ToolFeedback>>>,
51}
52
53impl Default for DefaultContext {
54    fn default() -> Self {
55        DefaultContext {
56            message_history: Arc::new(Mutex::new(Vec::new())),
57            completions_ptr: Arc::new(AtomicUsize::new(0)),
58            current_completions_ptr: Arc::new(AtomicUsize::new(0)),
59            tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
60            stop_on_assistant: true,
61            feedback_received: Arc::new(Mutex::new(HashMap::new())),
62        }
63    }
64}
65
66impl std::fmt::Debug for DefaultContext {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("DefaultContext")
69            .field("completion_history", &self.message_history)
70            .field("completions_ptr", &self.completions_ptr)
71            .field("current_completions_ptr", &self.current_completions_ptr)
72            .field("tool_executor", &"Arc<dyn ToolExecutor>")
73            .field("stop_on_assistant", &self.stop_on_assistant)
74            .finish()
75    }
76}
77
78impl DefaultContext {
79    /// Create a new context with a custom executor
80    pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
81        DefaultContext {
82            tool_executor: executor.into(),
83            ..Default::default()
84        }
85    }
86
87    /// If set to true, the agent will stop if the last message is from the assistant (i.e. no new
88    /// tool calls, summaries or user messages)
89    pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
90        self.stop_on_assistant = stop;
91        self
92    }
93
94    pub fn with_message_history(&mut self, backend: impl MessageHistory + 'static) -> &mut Self {
95        self.message_history = Arc::new(backend) as Arc<dyn MessageHistory>;
96        self
97    }
98
99    /// Build a context from an existing message history
100    ///
101    /// # Errors
102    ///
103    /// Errors if the message history cannot be extended
104    ///
105    /// # Panics
106    ///
107    /// Panics if the inner mutex is poisoned
108    pub async fn with_existing_messages<I: IntoIterator<Item = ChatMessage>>(
109        &mut self,
110        message_history: I,
111    ) -> Result<&mut Self> {
112        self.message_history
113            .overwrite(message_history.into_iter().collect())
114            .await?;
115
116        Ok(self)
117    }
118
119    /// Add existing tool feedback to the context
120    ///
121    /// # Panics
122    ///
123    /// Panics if the inner mutex is poisoned
124    pub fn with_tool_feedback(&mut self, feedback: impl Into<HashMap<ToolCall, ToolFeedback>>) {
125        self.feedback_received
126            .lock()
127            .unwrap()
128            .extend(feedback.into());
129    }
130}
131#[async_trait]
132impl AgentContext for DefaultContext {
133    /// Retrieve messages for the next completion
134    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
135        let history = self.message_history.history().await?;
136
137        let current = self.completions_ptr.load(Ordering::SeqCst);
138
139        if history[current..].is_empty()
140            || (self.stop_on_assistant
141                && matches!(history.last(), Some(ChatMessage::Assistant(_, _)))
142                && self.feedback_received.lock().unwrap().is_empty())
143        {
144            tracing::debug!(?history, "No new messages for completion");
145            Ok(None)
146        } else {
147            let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
148            self.current_completions_ptr
149                .store(previous, Ordering::SeqCst);
150
151            Ok(Some(filter_messages_since_summary(history)))
152        }
153    }
154
155    /// Returns the messages the agent is currently completing on
156    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
157        let current = self.current_completions_ptr.load(Ordering::SeqCst);
158        let end = self.completions_ptr.load(Ordering::SeqCst);
159
160        let history = self.message_history.history().await?;
161
162        Ok(filter_messages_since_summary(
163            history[current..end].to_vec(),
164        ))
165    }
166
167    /// Retrieve all messages in the conversation history
168    async fn history(&self) -> Result<Vec<ChatMessage>> {
169        self.message_history.history().await
170    }
171
172    /// Add multiple messages to the conversation history
173    async fn add_messages(&self, messages: Vec<ChatMessage>) -> Result<()> {
174        self.message_history.extend_owned(messages).await
175    }
176
177    /// Add a single message to the conversation history
178    async fn add_message(&self, item: ChatMessage) -> Result<()> {
179        self.message_history.push_owned(item).await
180    }
181
182    /// Execute a command in the tool executor
183    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
184        self.tool_executor.exec_cmd(cmd).await
185    }
186
187    fn executor(&self) -> &Arc<dyn ToolExecutor> {
188        &self.tool_executor
189    }
190
191    /// Pops the last messages up until the previous completion
192    ///
193    /// LLMs failing completion for various reasons is unfortunately a common occurrence
194    /// This gives a way to redrive the last completion in a generic way
195    async fn redrive(&self) -> Result<()> {
196        let mut history = self.message_history.history().await?;
197        let previous = self.current_completions_ptr.load(Ordering::SeqCst);
198        let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst);
199
200        // delete everything after the last completion
201        history.truncate(redrive_ptr);
202
203        self.message_history.overwrite(history).await?;
204
205        Ok(())
206    }
207
208    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
209        // If feedback is present, return true with the optional payload,
210        // and remove it
211        // otherwise return false
212        let mut lock = self.feedback_received.lock().unwrap();
213        lock.remove(tool_call)
214    }
215
216    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
217        let mut lock = self.feedback_received.lock().unwrap();
218        // Set the message counter one back so that on a next try, the agent can resume by
219        // trying the tool calls first. Only does this if there are no other approvals
220        if lock.is_empty() {
221            let previous = self.current_completions_ptr.load(Ordering::SeqCst);
222            self.completions_ptr.swap(previous, Ordering::SeqCst);
223        }
224        tracing::debug!(?tool_call, context = ?self, "feedback received");
225        lock.insert(tool_call.clone(), feedback.clone());
226
227        Ok(())
228    }
229}
230
231fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
232    let mut summary_found = false;
233    let mut messages = messages
234        .into_iter()
235        .rev()
236        .filter(|m| {
237            if summary_found {
238                return matches!(m, ChatMessage::System(_));
239            }
240            if let ChatMessage::Summary(_) = m {
241                summary_found = true;
242            }
243            true
244        })
245        .collect::<Vec<_>>();
246
247    messages.reverse();
248
249    messages
250}
251
252#[cfg(test)]
253mod tests {
254    use crate::{assistant, tool_output, user};
255
256    use super::*;
257    use swiftide_core::chat_completion::{ChatMessage, ToolCall};
258
259    #[tokio::test]
260    async fn test_iteration_tracking() {
261        let mut context = DefaultContext::default();
262
263        // Record initial chat messages
264        context
265            .add_messages(vec![
266                ChatMessage::System("You are awesome".into()),
267                ChatMessage::User("Hello".into()),
268            ])
269            .await
270            .unwrap();
271
272        let messages = context.next_completion().await.unwrap().unwrap();
273        assert_eq!(messages.len(), 2);
274        assert!(context.next_completion().await.unwrap().is_none());
275
276        context
277            .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
278            .await
279            .unwrap();
280
281        let messages = context.next_completion().await.unwrap().unwrap();
282        assert_eq!(messages.len(), 4);
283        assert!(context.next_completion().await.unwrap().is_none());
284
285        // If the last message is from the assistant, we should not get any more completions
286        context
287            .add_messages(vec![assistant!("I am fine")])
288            .await
289            .unwrap();
290
291        assert!(context.next_completion().await.unwrap().is_none());
292
293        context.with_stop_on_assistant(false);
294
295        assert!(context.next_completion().await.unwrap().is_some());
296    }
297
298    #[tokio::test]
299    async fn test_should_complete_after_tool_call() {
300        let context = DefaultContext::default();
301        // Record initial chat messages
302        context
303            .add_messages(vec![
304                ChatMessage::System("You are awesome".into()),
305                ChatMessage::User("Hello".into()),
306            ])
307            .await
308            .unwrap();
309        let messages = context.next_completion().await.unwrap().unwrap();
310        assert_eq!(messages.len(), 2);
311        assert_eq!(context.current_new_messages().await.unwrap().len(), 2);
312        assert!(context.next_completion().await.unwrap().is_none());
313
314        context
315            .add_messages(vec![
316                assistant!("Hey?", ["test"]),
317                tool_output!("test", "Hoi"),
318            ])
319            .await
320            .unwrap();
321
322        let messages = context.next_completion().await.unwrap().unwrap();
323        assert_eq!(context.current_new_messages().await.unwrap().len(), 2);
324        assert_eq!(messages.len(), 4);
325
326        assert!(context.next_completion().await.unwrap().is_none());
327    }
328
329    #[tokio::test]
330    async fn test_filters_messages_before_summary() {
331        let messages = vec![
332            ChatMessage::System("System message".into()),
333            ChatMessage::User("Hello".into()),
334            ChatMessage::Assistant(Some("Hello there".into()), None),
335            ChatMessage::Summary("Summary message".into()),
336            ChatMessage::User("This should be ignored".into()),
337        ];
338        let context = DefaultContext::default();
339        // Record initial chat messages
340        context.add_messages(messages).await.unwrap();
341
342        let new_messages = context.next_completion().await.unwrap().unwrap();
343
344        assert_eq!(new_messages.len(), 3);
345        assert!(matches!(new_messages[0], ChatMessage::System(_)));
346        assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
347        assert!(matches!(new_messages[2], ChatMessage::User(_)));
348
349        let current_new_messages = context.current_new_messages().await.unwrap();
350        assert_eq!(current_new_messages.len(), 3);
351        assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
352        assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
353        assert!(matches!(current_new_messages[2], ChatMessage::User(_)));
354
355        assert!(context.next_completion().await.unwrap().is_none());
356    }
357
358    #[tokio::test]
359    async fn test_filters_messages_before_summary_with_assistant_last() {
360        let messages = vec![
361            ChatMessage::System("System message".into()),
362            ChatMessage::User("Hello".into()),
363            ChatMessage::Assistant(Some("Hello there".into()), None),
364        ];
365        let mut context = DefaultContext::default();
366        context.with_stop_on_assistant(false);
367        // Record initial chat messages
368        context.add_messages(messages).await.unwrap();
369
370        let new_messages = context.next_completion().await.unwrap().unwrap();
371
372        assert_eq!(new_messages.len(), 3);
373        assert!(matches!(new_messages[0], ChatMessage::System(_)));
374        assert!(matches!(new_messages[1], ChatMessage::User(_)));
375        assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));
376
377        context
378            .add_message(ChatMessage::Summary("Summary message 1".into()))
379            .await
380            .unwrap();
381
382        let new_messages = context.next_completion().await.unwrap().unwrap();
383        dbg!(&new_messages);
384        assert_eq!(new_messages.len(), 2);
385        assert!(matches!(new_messages[0], ChatMessage::System(_)));
386        assert_eq!(
387            new_messages[1],
388            ChatMessage::Summary("Summary message 1".into())
389        );
390
391        assert!(context.next_completion().await.unwrap().is_none());
392
393        let messages = vec![
394            ChatMessage::User("Hello again".into()),
395            ChatMessage::Assistant(Some("Hello there again".into()), None),
396        ];
397
398        context.add_messages(messages).await.unwrap();
399
400        let new_messages = context.next_completion().await.unwrap().unwrap();
401
402        assert!(matches!(new_messages[0], ChatMessage::System(_)));
403        assert_eq!(
404            new_messages[1],
405            ChatMessage::Summary("Summary message 1".into())
406        );
407        assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
408        assert_eq!(
409            new_messages[3],
410            ChatMessage::Assistant(Some("Hello there again".to_string()), None)
411        );
412
413        context
414            .add_message(ChatMessage::Summary("Summary message 2".into()))
415            .await
416            .unwrap();
417
418        let new_messages = context.next_completion().await.unwrap().unwrap();
419        assert_eq!(new_messages.len(), 2);
420
421        assert!(matches!(new_messages[0], ChatMessage::System(_)));
422        assert_eq!(
423            new_messages[1],
424            ChatMessage::Summary("Summary message 2".into())
425        );
426    }
427
428    #[tokio::test]
429    async fn test_redrive() {
430        let context = DefaultContext::default();
431
432        // Record initial chat messages
433        context
434            .add_messages(vec![
435                ChatMessage::System("System message".into()),
436                ChatMessage::User("Hello".into()),
437            ])
438            .await
439            .unwrap();
440
441        let messages = context.next_completion().await.unwrap().unwrap();
442        assert_eq!(messages.len(), 2);
443        assert!(context.next_completion().await.unwrap().is_none());
444        context.redrive().await.unwrap();
445
446        let messages = context.next_completion().await.unwrap().unwrap();
447        assert_eq!(messages.len(), 2);
448
449        context
450            .add_messages(vec![ChatMessage::User("Hey?".into())])
451            .await
452            .unwrap();
453
454        let messages = context.next_completion().await.unwrap().unwrap();
455        assert_eq!(messages.len(), 3);
456        assert!(context.next_completion().await.unwrap().is_none());
457        context.redrive().await.unwrap();
458
459        // Add more messages
460        context
461            .add_messages(vec![ChatMessage::User("How are you?".into())])
462            .await
463            .unwrap();
464
465        let messages = context.next_completion().await.unwrap().unwrap();
466        assert_eq!(messages.len(), 4);
467        assert!(context.next_completion().await.unwrap().is_none());
468
469        // Redrive should remove the last set of messages
470        dbg!(&context);
471        context.redrive().await.unwrap();
472        dbg!(&context);
473
474        // We just redrove with the same messages
475        let messages = context.next_completion().await.unwrap().unwrap();
476        assert_eq!(messages.len(), 4);
477        assert!(context.next_completion().await.unwrap().is_none());
478
479        // Add more messages
480        context
481            .add_messages(vec![
482                ChatMessage::User("How are you really?".into()),
483                ChatMessage::User("How are you really?".into()),
484            ])
485            .await
486            .unwrap();
487
488        // This should remove any additional messages
489        context.redrive().await.unwrap();
490
491        // We just redrove with the same messages
492        let messages = context.next_completion().await.unwrap().unwrap();
493        assert_eq!(messages.len(), 4);
494        assert!(context.next_completion().await.unwrap().is_none());
495
496        // Redrive again
497        context.redrive().await.unwrap();
498        let messages = context.next_completion().await.unwrap().unwrap();
499        assert_eq!(messages.len(), 4);
500        assert!(context.next_completion().await.unwrap().is_none());
501    }
502}