swiftide_agents/
default_context.rs

1//! Manages agent history and provides an interface for the external world
2//!
3//! This is the default for agents. It is fully async and shareable between agents.
4//!
5//! By default uses the `LocalExecutor` for tool execution.
6//!
7//! If chat messages include a `ChatMessage::Summary`, all previous messages are ignored except the
8//! system prompt. This is useful for maintaining focus in long conversations or managing token
9//! limits.
10use std::sync::{
11    atomic::{AtomicUsize, Ordering},
12    Arc, Mutex,
13};
14
15use anyhow::Result;
16use async_trait::async_trait;
17use swiftide_core::chat_completion::ChatMessage;
18use swiftide_core::{AgentContext, Command, CommandError, CommandOutput, ToolExecutor};
19
20use crate::tools::local_executor::LocalExecutor;
21
22// TODO: Remove unit as executor and implement a local executor instead
23#[derive(Clone)]
24pub struct DefaultContext {
25    completion_history: Arc<Mutex<Vec<ChatMessage>>>,
26    /// Index in the conversation history where the next completion will start
27    completions_ptr: Arc<AtomicUsize>,
28
29    /// Index in the conversation history where the current completion started
30    /// Allows for retrieving only new messages since the last completion
31    current_completions_ptr: Arc<AtomicUsize>,
32
33    /// The executor used to run tools. I.e. local, remote, docker
34    tool_executor: Arc<dyn ToolExecutor>,
35
36    /// Stop if last message is from the assistant
37    stop_on_assistant: bool,
38}
39
40impl Default for DefaultContext {
41    fn default() -> Self {
42        DefaultContext {
43            completion_history: Arc::new(Mutex::new(Vec::new())),
44            completions_ptr: Arc::new(AtomicUsize::new(0)),
45            current_completions_ptr: Arc::new(AtomicUsize::new(0)),
46            tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
47            stop_on_assistant: true,
48        }
49    }
50}
51
52impl std::fmt::Debug for DefaultContext {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("DefaultContext")
55            .field("completion_history", &self.completion_history)
56            .field("completions_ptr", &self.completions_ptr)
57            .field("current_completions_ptr", &self.current_completions_ptr)
58            .field("tool_executor", &"Arc<dyn ToolExecutor>")
59            .field("stop_on_assistant", &self.stop_on_assistant)
60            .finish()
61    }
62}
63
64impl DefaultContext {
65    /// Create a new context with a custom executor
66    pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
67        DefaultContext {
68            tool_executor: executor.into(),
69            ..Default::default()
70        }
71    }
72
73    /// If set to true, the agent will stop if the last message is from the assistant (i.e. no new
74    /// tool calls, summaries or user messages)
75    pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
76        self.stop_on_assistant = stop;
77        self
78    }
79}
80#[async_trait]
81impl AgentContext for DefaultContext {
82    /// Retrieve messages for the next completion
83    async fn next_completion(&self) -> Option<Vec<ChatMessage>> {
84        let history = self.completion_history.lock().unwrap();
85
86        let current = self.completions_ptr.load(Ordering::SeqCst);
87
88        if history[current..].is_empty()
89            || (self.stop_on_assistant
90                && matches!(history.last(), Some(ChatMessage::Assistant(_, _))))
91        {
92            None
93        } else {
94            let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
95            self.current_completions_ptr
96                .store(previous, Ordering::SeqCst);
97
98            Some(filter_messages_since_summary(history.clone()))
99        }
100    }
101
102    /// Returns the messages the agent is currently completing on
103    async fn current_new_messages(&self) -> Vec<ChatMessage> {
104        let current = self.current_completions_ptr.load(Ordering::SeqCst);
105        let end = self.completions_ptr.load(Ordering::SeqCst);
106
107        let history = self.completion_history.lock().unwrap();
108
109        filter_messages_since_summary(history[current..end].to_vec())
110    }
111
112    /// Retrieve all messages in the conversation history
113    async fn history(&self) -> Vec<ChatMessage> {
114        self.completion_history.lock().unwrap().clone()
115    }
116
117    /// Add multiple messages to the conversation history
118    async fn add_messages(&self, messages: Vec<ChatMessage>) {
119        for item in messages {
120            self.add_message(item).await;
121        }
122    }
123
124    /// Add a single message to the conversation history
125    async fn add_message(&self, item: ChatMessage) {
126        self.completion_history.lock().unwrap().push(item);
127    }
128
129    /// Execute a command in the tool executor
130    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
131        self.tool_executor.exec_cmd(cmd).await
132    }
133
134    /// Pops the last messages up until the previous completion
135    ///
136    /// LLMs failing completion for various reasons is unfortunately a common occurrence
137    /// This gives a way to redrive the last completion in a generic way
138    async fn redrive(&self) {
139        let mut history = self.completion_history.lock().unwrap();
140        let previous = self.current_completions_ptr.load(Ordering::SeqCst);
141        let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst);
142
143        // delete everything after the last completion
144        history.truncate(redrive_ptr);
145    }
146}
147
148fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
149    let mut summary_found = false;
150    let mut messages = messages
151        .into_iter()
152        .rev()
153        .filter(|m| {
154            if summary_found {
155                return matches!(m, ChatMessage::System(_));
156            }
157            if let ChatMessage::Summary(_) = m {
158                summary_found = true;
159            }
160            true
161        })
162        .collect::<Vec<_>>();
163
164    messages.reverse();
165
166    messages
167}
168
169#[cfg(test)]
170mod tests {
171    use crate::{assistant, tool_output, user};
172
173    use super::*;
174    use swiftide_core::chat_completion::{ChatMessage, ToolCall, ToolOutput};
175
176    #[tokio::test]
177    async fn test_iteration_tracking() {
178        let mut context = DefaultContext::default();
179
180        // Record initial chat messages
181        context
182            .add_messages(vec![
183                ChatMessage::System("You are awesome".into()),
184                ChatMessage::User("Hello".into()),
185            ])
186            .await;
187
188        let messages = context.next_completion().await.unwrap();
189        assert_eq!(messages.len(), 2);
190        assert!(context.next_completion().await.is_none());
191
192        context
193            .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
194            .await;
195
196        let messages = context.next_completion().await.unwrap();
197        assert_eq!(messages.len(), 4);
198        assert!(context.next_completion().await.is_none());
199
200        // If the last message is from the assistant, we should not get any more completions
201        context.add_messages(vec![assistant!("I am fine")]).await;
202
203        assert!(context.next_completion().await.is_none());
204
205        context.with_stop_on_assistant(false);
206
207        assert!(context.next_completion().await.is_some());
208    }
209
210    #[tokio::test]
211    async fn test_should_complete_after_tool_call() {
212        let context = DefaultContext::default();
213        // Record initial chat messages
214        context
215            .add_messages(vec![
216                ChatMessage::System("You are awesome".into()),
217                ChatMessage::User("Hello".into()),
218            ])
219            .await;
220        let messages = context.next_completion().await.unwrap();
221        assert_eq!(messages.len(), 2);
222        assert_eq!(context.current_new_messages().await.len(), 2);
223        assert!(context.next_completion().await.is_none());
224
225        context
226            .add_messages(vec![
227                assistant!("Hey?", ["test"]),
228                tool_output!("test", "Hoi"),
229            ])
230            .await;
231
232        let messages = context.next_completion().await.unwrap();
233        assert_eq!(context.current_new_messages().await.len(), 2);
234        assert_eq!(messages.len(), 4);
235
236        assert!(context.next_completion().await.is_none());
237    }
238
239    #[tokio::test]
240    async fn test_filters_messages_before_summary() {
241        let messages = vec![
242            ChatMessage::System("System message".into()),
243            ChatMessage::User("Hello".into()),
244            ChatMessage::Assistant(Some("Hello there".into()), None),
245            ChatMessage::Summary("Summary message".into()),
246            ChatMessage::User("This should be ignored".into()),
247        ];
248        let context = DefaultContext::default();
249        // Record initial chat messages
250        context.add_messages(messages).await;
251
252        let new_messages = context.next_completion().await.unwrap();
253
254        assert_eq!(new_messages.len(), 3);
255        assert!(matches!(new_messages[0], ChatMessage::System(_)));
256        assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
257        assert!(matches!(new_messages[2], ChatMessage::User(_)));
258
259        let current_new_messages = context.current_new_messages().await;
260        assert_eq!(current_new_messages.len(), 3);
261        assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
262        assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
263        assert!(matches!(current_new_messages[2], ChatMessage::User(_)));
264
265        assert!(context.next_completion().await.is_none());
266    }
267
268    #[tokio::test]
269    async fn test_filters_messages_before_summary_with_assistant_last() {
270        let messages = vec![
271            ChatMessage::System("System message".into()),
272            ChatMessage::User("Hello".into()),
273            ChatMessage::Assistant(Some("Hello there".into()), None),
274        ];
275        let mut context = DefaultContext::default();
276        context.with_stop_on_assistant(false);
277        // Record initial chat messages
278        context.add_messages(messages).await;
279
280        let new_messages = context.next_completion().await.unwrap();
281
282        assert_eq!(new_messages.len(), 3);
283        assert!(matches!(new_messages[0], ChatMessage::System(_)));
284        assert!(matches!(new_messages[1], ChatMessage::User(_)));
285        assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));
286
287        context
288            .add_message(ChatMessage::Summary("Summary message 1".into()))
289            .await;
290
291        let new_messages = context.next_completion().await.unwrap();
292        dbg!(&new_messages);
293        assert_eq!(new_messages.len(), 2);
294        assert!(matches!(new_messages[0], ChatMessage::System(_)));
295        assert_eq!(
296            new_messages[1],
297            ChatMessage::Summary("Summary message 1".into())
298        );
299
300        assert!(context.next_completion().await.is_none());
301
302        let messages = vec![
303            ChatMessage::User("Hello again".into()),
304            ChatMessage::Assistant(Some("Hello there again".into()), None),
305        ];
306
307        context.add_messages(messages).await;
308
309        let new_messages = context.next_completion().await.unwrap();
310
311        assert!(matches!(new_messages[0], ChatMessage::System(_)));
312        assert_eq!(
313            new_messages[1],
314            ChatMessage::Summary("Summary message 1".into())
315        );
316        assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
317        assert_eq!(
318            new_messages[3],
319            ChatMessage::Assistant(Some("Hello there again".to_string()), None)
320        );
321
322        context
323            .add_message(ChatMessage::Summary("Summary message 2".into()))
324            .await;
325
326        let new_messages = context.next_completion().await.unwrap();
327        assert_eq!(new_messages.len(), 2);
328
329        assert!(matches!(new_messages[0], ChatMessage::System(_)));
330        assert_eq!(
331            new_messages[1],
332            ChatMessage::Summary("Summary message 2".into())
333        );
334    }
335
336    #[tokio::test]
337    async fn test_redrive() {
338        let context = DefaultContext::default();
339
340        // Record initial chat messages
341        context
342            .add_messages(vec![
343                ChatMessage::System("System message".into()),
344                ChatMessage::User("Hello".into()),
345            ])
346            .await;
347
348        let messages = context.next_completion().await.unwrap();
349        assert_eq!(messages.len(), 2);
350        assert!(context.next_completion().await.is_none());
351        context.redrive().await;
352
353        let messages = context.next_completion().await.unwrap();
354        assert_eq!(messages.len(), 2);
355
356        context
357            .add_messages(vec![ChatMessage::User("Hey?".into())])
358            .await;
359
360        let messages = context.next_completion().await.unwrap();
361        assert_eq!(messages.len(), 3);
362        assert!(context.next_completion().await.is_none());
363        context.redrive().await;
364
365        // Add more messages
366        context
367            .add_messages(vec![ChatMessage::User("How are you?".into())])
368            .await;
369
370        let messages = context.next_completion().await.unwrap();
371        assert_eq!(messages.len(), 4);
372        assert!(context.next_completion().await.is_none());
373
374        // Redrive should remove the last set of messages
375        dbg!(&context);
376        context.redrive().await;
377        dbg!(&context);
378
379        // We just redrove with the same messages
380        let messages = context.next_completion().await.unwrap();
381        assert_eq!(messages.len(), 4);
382        assert!(context.next_completion().await.is_none());
383
384        // Add more messages
385        context
386            .add_messages(vec![
387                ChatMessage::User("How are you really?".into()),
388                ChatMessage::User("How are you really?".into()),
389            ])
390            .await;
391
392        // This should remove any additional messages
393        context.redrive().await;
394
395        // We just redrove with the same messages
396        let messages = context.next_completion().await.unwrap();
397        assert_eq!(messages.len(), 4);
398        assert!(context.next_completion().await.is_none());
399
400        // Redrive again
401        context.redrive().await;
402        let messages = context.next_completion().await.unwrap();
403        assert_eq!(messages.len(), 4);
404        assert!(context.next_completion().await.is_none());
405    }
406}