swiftide_agents/
default_context.rs

1//! Manages agent history and provides an interface for the external world
2//!
3//! This is the default for agents. It is fully async and shareable between agents.
4//!
5//! By default uses the `LocalExecutor` for tool execution.
6//!
7//! If chat messages include a `ChatMessage::Summary`, all previous messages are ignored except the
8//! system prompt. This is useful for maintaining focus in long conversations or managing token
9//! limits.
10use std::sync::{
11    Arc, Mutex,
12    atomic::{AtomicUsize, Ordering},
13};
14
15use anyhow::Result;
16use async_trait::async_trait;
17use swiftide_core::chat_completion::ChatMessage;
18use swiftide_core::{AgentContext, Command, CommandError, CommandOutput, ToolExecutor};
19
20use crate::tools::local_executor::LocalExecutor;
21
22// TODO: Remove unit as executor and implement a local executor instead
23#[derive(Clone)]
24pub struct DefaultContext {
25    completion_history: Arc<Mutex<Vec<ChatMessage>>>,
26    /// Index in the conversation history where the next completion will start
27    completions_ptr: Arc<AtomicUsize>,
28
29    /// Index in the conversation history where the current completion started
30    /// Allows for retrieving only new messages since the last completion
31    current_completions_ptr: Arc<AtomicUsize>,
32
33    /// The executor used to run tools. I.e. local, remote, docker
34    tool_executor: Arc<dyn ToolExecutor>,
35
36    /// Stop if last message is from the assistant
37    stop_on_assistant: bool,
38}
39
40impl Default for DefaultContext {
41    fn default() -> Self {
42        DefaultContext {
43            completion_history: Arc::new(Mutex::new(Vec::new())),
44            completions_ptr: Arc::new(AtomicUsize::new(0)),
45            current_completions_ptr: Arc::new(AtomicUsize::new(0)),
46            tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
47            stop_on_assistant: true,
48        }
49    }
50}
51
52impl std::fmt::Debug for DefaultContext {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("DefaultContext")
55            .field("completion_history", &self.completion_history)
56            .field("completions_ptr", &self.completions_ptr)
57            .field("current_completions_ptr", &self.current_completions_ptr)
58            .field("tool_executor", &"Arc<dyn ToolExecutor>")
59            .field("stop_on_assistant", &self.stop_on_assistant)
60            .finish()
61    }
62}
63
64impl DefaultContext {
65    /// Create a new context with a custom executor
66    pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
67        DefaultContext {
68            tool_executor: executor.into(),
69            ..Default::default()
70        }
71    }
72
73    /// If set to true, the agent will stop if the last message is from the assistant (i.e. no new
74    /// tool calls, summaries or user messages)
75    pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
76        self.stop_on_assistant = stop;
77        self
78    }
79
80    /// Build a context from an existing message history
81    ///
82    /// # Panics
83    ///
84    /// Panics if the inner mutex is poisoned
85    pub fn with_message_history<I: IntoIterator<Item = ChatMessage>>(
86        &mut self,
87        message_history: I,
88    ) -> &mut Self {
89        self.completion_history
90            .lock()
91            .unwrap()
92            .extend(message_history);
93
94        self
95    }
96}
97#[async_trait]
98impl AgentContext for DefaultContext {
99    /// Retrieve messages for the next completion
100    async fn next_completion(&self) -> Option<Vec<ChatMessage>> {
101        let history = self.completion_history.lock().unwrap();
102
103        let current = self.completions_ptr.load(Ordering::SeqCst);
104
105        if history[current..].is_empty()
106            || (self.stop_on_assistant
107                && matches!(history.last(), Some(ChatMessage::Assistant(_, _))))
108        {
109            None
110        } else {
111            let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
112            self.current_completions_ptr
113                .store(previous, Ordering::SeqCst);
114
115            Some(filter_messages_since_summary(history.clone()))
116        }
117    }
118
119    /// Returns the messages the agent is currently completing on
120    async fn current_new_messages(&self) -> Vec<ChatMessage> {
121        let current = self.current_completions_ptr.load(Ordering::SeqCst);
122        let end = self.completions_ptr.load(Ordering::SeqCst);
123
124        let history = self.completion_history.lock().unwrap();
125
126        filter_messages_since_summary(history[current..end].to_vec())
127    }
128
129    /// Retrieve all messages in the conversation history
130    async fn history(&self) -> Vec<ChatMessage> {
131        self.completion_history.lock().unwrap().clone()
132    }
133
134    /// Add multiple messages to the conversation history
135    async fn add_messages(&self, messages: Vec<ChatMessage>) {
136        for item in messages {
137            self.add_message(item).await;
138        }
139    }
140
141    /// Add a single message to the conversation history
142    async fn add_message(&self, item: ChatMessage) {
143        self.completion_history.lock().unwrap().push(item);
144    }
145
146    /// Execute a command in the tool executor
147    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
148        self.tool_executor.exec_cmd(cmd).await
149    }
150
151    /// Pops the last messages up until the previous completion
152    ///
153    /// LLMs failing completion for various reasons is unfortunately a common occurrence
154    /// This gives a way to redrive the last completion in a generic way
155    async fn redrive(&self) {
156        let mut history = self.completion_history.lock().unwrap();
157        let previous = self.current_completions_ptr.load(Ordering::SeqCst);
158        let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst);
159
160        // delete everything after the last completion
161        history.truncate(redrive_ptr);
162    }
163}
164
165fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
166    let mut summary_found = false;
167    let mut messages = messages
168        .into_iter()
169        .rev()
170        .filter(|m| {
171            if summary_found {
172                return matches!(m, ChatMessage::System(_));
173            }
174            if let ChatMessage::Summary(_) = m {
175                summary_found = true;
176            }
177            true
178        })
179        .collect::<Vec<_>>();
180
181    messages.reverse();
182
183    messages
184}
185
186#[cfg(test)]
187mod tests {
188    use crate::{assistant, tool_output, user};
189
190    use super::*;
191    use swiftide_core::chat_completion::{ChatMessage, ToolCall};
192
193    #[tokio::test]
194    async fn test_iteration_tracking() {
195        let mut context = DefaultContext::default();
196
197        // Record initial chat messages
198        context
199            .add_messages(vec![
200                ChatMessage::System("You are awesome".into()),
201                ChatMessage::User("Hello".into()),
202            ])
203            .await;
204
205        let messages = context.next_completion().await.unwrap();
206        assert_eq!(messages.len(), 2);
207        assert!(context.next_completion().await.is_none());
208
209        context
210            .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
211            .await;
212
213        let messages = context.next_completion().await.unwrap();
214        assert_eq!(messages.len(), 4);
215        assert!(context.next_completion().await.is_none());
216
217        // If the last message is from the assistant, we should not get any more completions
218        context.add_messages(vec![assistant!("I am fine")]).await;
219
220        assert!(context.next_completion().await.is_none());
221
222        context.with_stop_on_assistant(false);
223
224        assert!(context.next_completion().await.is_some());
225    }
226
227    #[tokio::test]
228    async fn test_should_complete_after_tool_call() {
229        let context = DefaultContext::default();
230        // Record initial chat messages
231        context
232            .add_messages(vec![
233                ChatMessage::System("You are awesome".into()),
234                ChatMessage::User("Hello".into()),
235            ])
236            .await;
237        let messages = context.next_completion().await.unwrap();
238        assert_eq!(messages.len(), 2);
239        assert_eq!(context.current_new_messages().await.len(), 2);
240        assert!(context.next_completion().await.is_none());
241
242        context
243            .add_messages(vec![
244                assistant!("Hey?", ["test"]),
245                tool_output!("test", "Hoi"),
246            ])
247            .await;
248
249        let messages = context.next_completion().await.unwrap();
250        assert_eq!(context.current_new_messages().await.len(), 2);
251        assert_eq!(messages.len(), 4);
252
253        assert!(context.next_completion().await.is_none());
254    }
255
256    #[tokio::test]
257    async fn test_filters_messages_before_summary() {
258        let messages = vec![
259            ChatMessage::System("System message".into()),
260            ChatMessage::User("Hello".into()),
261            ChatMessage::Assistant(Some("Hello there".into()), None),
262            ChatMessage::Summary("Summary message".into()),
263            ChatMessage::User("This should be ignored".into()),
264        ];
265        let context = DefaultContext::default();
266        // Record initial chat messages
267        context.add_messages(messages).await;
268
269        let new_messages = context.next_completion().await.unwrap();
270
271        assert_eq!(new_messages.len(), 3);
272        assert!(matches!(new_messages[0], ChatMessage::System(_)));
273        assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
274        assert!(matches!(new_messages[2], ChatMessage::User(_)));
275
276        let current_new_messages = context.current_new_messages().await;
277        assert_eq!(current_new_messages.len(), 3);
278        assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
279        assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
280        assert!(matches!(current_new_messages[2], ChatMessage::User(_)));
281
282        assert!(context.next_completion().await.is_none());
283    }
284
285    #[tokio::test]
286    async fn test_filters_messages_before_summary_with_assistant_last() {
287        let messages = vec![
288            ChatMessage::System("System message".into()),
289            ChatMessage::User("Hello".into()),
290            ChatMessage::Assistant(Some("Hello there".into()), None),
291        ];
292        let mut context = DefaultContext::default();
293        context.with_stop_on_assistant(false);
294        // Record initial chat messages
295        context.add_messages(messages).await;
296
297        let new_messages = context.next_completion().await.unwrap();
298
299        assert_eq!(new_messages.len(), 3);
300        assert!(matches!(new_messages[0], ChatMessage::System(_)));
301        assert!(matches!(new_messages[1], ChatMessage::User(_)));
302        assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));
303
304        context
305            .add_message(ChatMessage::Summary("Summary message 1".into()))
306            .await;
307
308        let new_messages = context.next_completion().await.unwrap();
309        dbg!(&new_messages);
310        assert_eq!(new_messages.len(), 2);
311        assert!(matches!(new_messages[0], ChatMessage::System(_)));
312        assert_eq!(
313            new_messages[1],
314            ChatMessage::Summary("Summary message 1".into())
315        );
316
317        assert!(context.next_completion().await.is_none());
318
319        let messages = vec![
320            ChatMessage::User("Hello again".into()),
321            ChatMessage::Assistant(Some("Hello there again".into()), None),
322        ];
323
324        context.add_messages(messages).await;
325
326        let new_messages = context.next_completion().await.unwrap();
327
328        assert!(matches!(new_messages[0], ChatMessage::System(_)));
329        assert_eq!(
330            new_messages[1],
331            ChatMessage::Summary("Summary message 1".into())
332        );
333        assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
334        assert_eq!(
335            new_messages[3],
336            ChatMessage::Assistant(Some("Hello there again".to_string()), None)
337        );
338
339        context
340            .add_message(ChatMessage::Summary("Summary message 2".into()))
341            .await;
342
343        let new_messages = context.next_completion().await.unwrap();
344        assert_eq!(new_messages.len(), 2);
345
346        assert!(matches!(new_messages[0], ChatMessage::System(_)));
347        assert_eq!(
348            new_messages[1],
349            ChatMessage::Summary("Summary message 2".into())
350        );
351    }
352
353    #[tokio::test]
354    async fn test_redrive() {
355        let context = DefaultContext::default();
356
357        // Record initial chat messages
358        context
359            .add_messages(vec![
360                ChatMessage::System("System message".into()),
361                ChatMessage::User("Hello".into()),
362            ])
363            .await;
364
365        let messages = context.next_completion().await.unwrap();
366        assert_eq!(messages.len(), 2);
367        assert!(context.next_completion().await.is_none());
368        context.redrive().await;
369
370        let messages = context.next_completion().await.unwrap();
371        assert_eq!(messages.len(), 2);
372
373        context
374            .add_messages(vec![ChatMessage::User("Hey?".into())])
375            .await;
376
377        let messages = context.next_completion().await.unwrap();
378        assert_eq!(messages.len(), 3);
379        assert!(context.next_completion().await.is_none());
380        context.redrive().await;
381
382        // Add more messages
383        context
384            .add_messages(vec![ChatMessage::User("How are you?".into())])
385            .await;
386
387        let messages = context.next_completion().await.unwrap();
388        assert_eq!(messages.len(), 4);
389        assert!(context.next_completion().await.is_none());
390
391        // Redrive should remove the last set of messages
392        dbg!(&context);
393        context.redrive().await;
394        dbg!(&context);
395
396        // We just redrove with the same messages
397        let messages = context.next_completion().await.unwrap();
398        assert_eq!(messages.len(), 4);
399        assert!(context.next_completion().await.is_none());
400
401        // Add more messages
402        context
403            .add_messages(vec![
404                ChatMessage::User("How are you really?".into()),
405                ChatMessage::User("How are you really?".into()),
406            ])
407            .await;
408
409        // This should remove any additional messages
410        context.redrive().await;
411
412        // We just redrove with the same messages
413        let messages = context.next_completion().await.unwrap();
414        assert_eq!(messages.len(), 4);
415        assert!(context.next_completion().await.is_none());
416
417        // Redrive again
418        context.redrive().await;
419        let messages = context.next_completion().await.unwrap();
420        assert_eq!(messages.len(), 4);
421        assert!(context.next_completion().await.is_none());
422    }
423}