swiftide_agents/
default_context.rs

1//! Manages agent history and provides an interface for the external world
2//!
3//! This is the default for agents. It is fully async and shareable between agents.
4//!
5//! By default uses the `LocalExecutor` for tool execution.
6//!
7//! If chat messages include a `ChatMessage::Summary`, all previous messages are ignored except the
8//! system prompt. This is useful for maintaining focus in long conversations or managing token
9//! limits.
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12
13use anyhow::Result;
14use async_trait::async_trait;
15use swiftide_core::chat_completion::ChatMessage;
16use swiftide_core::{AgentContext, Command, CommandError, CommandOutput, ToolExecutor};
17use tokio::sync::Mutex;
18
19use crate::tools::local_executor::LocalExecutor;
20
21// TODO: Remove unit as executor and implement a local executor instead
22#[derive(Clone)]
23pub struct DefaultContext {
24    completion_history: Arc<Mutex<Vec<ChatMessage>>>,
25    /// Index in the conversation history where the next completion will start
26    completions_ptr: Arc<AtomicUsize>,
27
28    /// Index in the conversation history where the current completion started
29    /// Allows for retrieving only new messages since the last completion
30    current_completions_ptr: Arc<AtomicUsize>,
31
32    /// The executor used to run tools. I.e. local, remote, docker
33    tool_executor: Arc<dyn ToolExecutor>,
34
35    /// Stop if last message is from the assistant
36    stop_on_assistant: bool,
37}
38
39impl Default for DefaultContext {
40    fn default() -> Self {
41        DefaultContext {
42            completion_history: Arc::new(Mutex::new(Vec::new())),
43            completions_ptr: Arc::new(AtomicUsize::new(0)),
44            current_completions_ptr: Arc::new(AtomicUsize::new(0)),
45            tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
46            stop_on_assistant: true,
47        }
48    }
49}
50
51impl std::fmt::Debug for DefaultContext {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("DefaultContext")
54            .field("completion_history", &self.completion_history)
55            .field("completions_ptr", &self.completions_ptr)
56            .field("current_completions_ptr", &self.current_completions_ptr)
57            .field("tool_executor", &"Arc<dyn ToolExecutor>")
58            .field("stop_on_assistant", &self.stop_on_assistant)
59            .finish()
60    }
61}
62
63impl DefaultContext {
64    /// Create a new context with a custom executor
65    pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
66        DefaultContext {
67            tool_executor: executor.into(),
68            ..Default::default()
69        }
70    }
71
72    /// If set to true, the agent will stop if the last message is from the assistant (i.e. no new
73    /// tool calls, summaries or user messages)
74    pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
75        self.stop_on_assistant = stop;
76        self
77    }
78}
79#[async_trait]
80impl AgentContext for DefaultContext {
81    /// Retrieve messages for the next completion
82    async fn next_completion(&self) -> Option<Vec<ChatMessage>> {
83        let history = self.completion_history.lock().await;
84
85        let current = self.completions_ptr.load(Ordering::SeqCst);
86
87        if history[current..].is_empty()
88            || (self.stop_on_assistant
89                && matches!(history.last(), Some(ChatMessage::Assistant(_, _))))
90        {
91            None
92        } else {
93            let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
94            self.current_completions_ptr
95                .store(previous, Ordering::SeqCst);
96
97            Some(filter_messages_since_summary(history.clone()))
98        }
99    }
100
101    /// Returns the messages the agent is currently completing on
102    async fn current_new_messages(&self) -> Vec<ChatMessage> {
103        let current = self.current_completions_ptr.load(Ordering::SeqCst);
104        let end = self.completions_ptr.load(Ordering::SeqCst);
105
106        let history = self.completion_history.lock().await;
107
108        filter_messages_since_summary(history[current..end].to_vec())
109    }
110
111    /// Retrieve all messages in the conversation history
112    async fn history(&self) -> Vec<ChatMessage> {
113        self.completion_history.lock().await.clone()
114    }
115
116    /// Add multiple messages to the conversation history
117    async fn add_messages(&self, messages: Vec<ChatMessage>) {
118        for item in messages {
119            self.add_message(item).await;
120        }
121    }
122
123    /// Add a single message to the conversation history
124    async fn add_message(&self, item: ChatMessage) {
125        self.completion_history.lock().await.push(item);
126    }
127
128    /// Execute a command in the tool executor
129    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
130        self.tool_executor.exec_cmd(cmd).await
131    }
132
133    /// Pops the last messages up until the previous completion
134    ///
135    /// LLMs failing completion for various reasons is unfortunately a common occurrence
136    /// This gives a way to redrive the last completion in a generic way
137    async fn redrive(&self) {
138        let mut history = self.completion_history.lock().await;
139        let previous = self.current_completions_ptr.load(Ordering::SeqCst);
140        let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst);
141
142        // delete everything after the last completion
143        history.truncate(redrive_ptr);
144    }
145}
146
147fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
148    let mut summary_found = false;
149    let mut messages = messages
150        .into_iter()
151        .rev()
152        .filter(|m| {
153            if summary_found {
154                return matches!(m, ChatMessage::System(_));
155            }
156            if let ChatMessage::Summary(_) = m {
157                summary_found = true;
158            }
159            true
160        })
161        .collect::<Vec<_>>();
162
163    messages.reverse();
164
165    messages
166}
167
168#[cfg(test)]
169mod tests {
170    use crate::{assistant, tool_output, user};
171
172    use super::*;
173    use swiftide_core::chat_completion::{ChatMessage, ToolCall, ToolOutput};
174
175    #[tokio::test]
176    async fn test_iteration_tracking() {
177        let mut context = DefaultContext::default();
178
179        // Record initial chat messages
180        context
181            .add_messages(vec![
182                ChatMessage::System("You are awesome".into()),
183                ChatMessage::User("Hello".into()),
184            ])
185            .await;
186
187        let messages = context.next_completion().await.unwrap();
188        assert_eq!(messages.len(), 2);
189        assert!(context.next_completion().await.is_none());
190
191        context
192            .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
193            .await;
194
195        let messages = context.next_completion().await.unwrap();
196        assert_eq!(messages.len(), 4);
197        assert!(context.next_completion().await.is_none());
198
199        // If the last message is from the assistant, we should not get any more completions
200        context.add_messages(vec![assistant!("I am fine")]).await;
201
202        assert!(context.next_completion().await.is_none());
203
204        context.with_stop_on_assistant(false);
205
206        assert!(context.next_completion().await.is_some());
207    }
208
209    #[tokio::test]
210    async fn test_should_complete_after_tool_call() {
211        let context = DefaultContext::default();
212        // Record initial chat messages
213        context
214            .add_messages(vec![
215                ChatMessage::System("You are awesome".into()),
216                ChatMessage::User("Hello".into()),
217            ])
218            .await;
219        let messages = context.next_completion().await.unwrap();
220        assert_eq!(messages.len(), 2);
221        assert_eq!(context.current_new_messages().await.len(), 2);
222        assert!(context.next_completion().await.is_none());
223
224        context
225            .add_messages(vec![
226                assistant!("Hey?", ["test"]),
227                tool_output!("test", "Hoi"),
228            ])
229            .await;
230
231        let messages = context.next_completion().await.unwrap();
232        assert_eq!(context.current_new_messages().await.len(), 2);
233        assert_eq!(messages.len(), 4);
234
235        assert!(context.next_completion().await.is_none());
236    }
237
238    #[tokio::test]
239    async fn test_filters_messages_before_summary() {
240        let messages = vec![
241            ChatMessage::System("System message".into()),
242            ChatMessage::User("Hello".into()),
243            ChatMessage::Assistant(Some("Hello there".into()), None),
244            ChatMessage::Summary("Summary message".into()),
245            ChatMessage::User("This should be ignored".into()),
246        ];
247        let context = DefaultContext::default();
248        // Record initial chat messages
249        context.add_messages(messages).await;
250
251        let new_messages = context.next_completion().await.unwrap();
252
253        assert_eq!(new_messages.len(), 3);
254        assert!(matches!(new_messages[0], ChatMessage::System(_)));
255        assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
256        assert!(matches!(new_messages[2], ChatMessage::User(_)));
257
258        let current_new_messages = context.current_new_messages().await;
259        assert_eq!(current_new_messages.len(), 3);
260        assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
261        assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
262        assert!(matches!(current_new_messages[2], ChatMessage::User(_)));
263
264        assert!(context.next_completion().await.is_none());
265    }
266
267    #[tokio::test]
268    async fn test_filters_messages_before_summary_with_assistant_last() {
269        let messages = vec![
270            ChatMessage::System("System message".into()),
271            ChatMessage::User("Hello".into()),
272            ChatMessage::Assistant(Some("Hello there".into()), None),
273        ];
274        let mut context = DefaultContext::default();
275        context.with_stop_on_assistant(false);
276        // Record initial chat messages
277        context.add_messages(messages).await;
278
279        let new_messages = context.next_completion().await.unwrap();
280
281        assert_eq!(new_messages.len(), 3);
282        assert!(matches!(new_messages[0], ChatMessage::System(_)));
283        assert!(matches!(new_messages[1], ChatMessage::User(_)));
284        assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));
285
286        context
287            .add_message(ChatMessage::Summary("Summary message 1".into()))
288            .await;
289
290        let new_messages = context.next_completion().await.unwrap();
291        dbg!(&new_messages);
292        assert_eq!(new_messages.len(), 2);
293        assert!(matches!(new_messages[0], ChatMessage::System(_)));
294        assert_eq!(
295            new_messages[1],
296            ChatMessage::Summary("Summary message 1".into())
297        );
298
299        assert!(context.next_completion().await.is_none());
300
301        let messages = vec![
302            ChatMessage::User("Hello again".into()),
303            ChatMessage::Assistant(Some("Hello there again".into()), None),
304        ];
305
306        context.add_messages(messages).await;
307
308        let new_messages = context.next_completion().await.unwrap();
309
310        assert!(matches!(new_messages[0], ChatMessage::System(_)));
311        assert_eq!(
312            new_messages[1],
313            ChatMessage::Summary("Summary message 1".into())
314        );
315        assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
316        assert_eq!(
317            new_messages[3],
318            ChatMessage::Assistant(Some("Hello there again".to_string()), None)
319        );
320
321        context
322            .add_message(ChatMessage::Summary("Summary message 2".into()))
323            .await;
324
325        let new_messages = context.next_completion().await.unwrap();
326        assert_eq!(new_messages.len(), 2);
327
328        assert!(matches!(new_messages[0], ChatMessage::System(_)));
329        assert_eq!(
330            new_messages[1],
331            ChatMessage::Summary("Summary message 2".into())
332        );
333    }
334
335    #[tokio::test]
336    async fn test_redrive() {
337        let context = DefaultContext::default();
338
339        // Record initial chat messages
340        context
341            .add_messages(vec![
342                ChatMessage::System("System message".into()),
343                ChatMessage::User("Hello".into()),
344            ])
345            .await;
346
347        let messages = context.next_completion().await.unwrap();
348        assert_eq!(messages.len(), 2);
349        assert!(context.next_completion().await.is_none());
350        context.redrive().await;
351
352        let messages = context.next_completion().await.unwrap();
353        assert_eq!(messages.len(), 2);
354
355        context
356            .add_messages(vec![ChatMessage::User("Hey?".into())])
357            .await;
358
359        let messages = context.next_completion().await.unwrap();
360        assert_eq!(messages.len(), 3);
361        assert!(context.next_completion().await.is_none());
362        context.redrive().await;
363
364        // Add more messages
365        context
366            .add_messages(vec![ChatMessage::User("How are you?".into())])
367            .await;
368
369        let messages = context.next_completion().await.unwrap();
370        assert_eq!(messages.len(), 4);
371        assert!(context.next_completion().await.is_none());
372
373        // Redrive should remove the last set of messages
374        dbg!(&context);
375        context.redrive().await;
376        dbg!(&context);
377
378        // We just redrove with the same messages
379        let messages = context.next_completion().await.unwrap();
380        assert_eq!(messages.len(), 4);
381        assert!(context.next_completion().await.is_none());
382
383        // Add more messages
384        context
385            .add_messages(vec![
386                ChatMessage::User("How are you really?".into()),
387                ChatMessage::User("How are you really?".into()),
388            ])
389            .await;
390
391        // This should remove any additional messages
392        context.redrive().await;
393
394        // We just redrove with the same messages
395        let messages = context.next_completion().await.unwrap();
396        assert_eq!(messages.len(), 4);
397        assert!(context.next_completion().await.is_none());
398
399        // Redrive again
400        context.redrive().await;
401        let messages = context.next_completion().await.unwrap();
402        assert_eq!(messages.len(), 4);
403        assert!(context.next_completion().await.is_none());
404    }
405}