swiftide_agents/
default_context.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
//! Manages agent history and provides an interface for the external world
//!
//! This is the default for agents. It is fully async and shareable between agents.
//!
//! By default uses the `LocalExecutor` for tool execution.
//!
//! If chat messages include a `ChatMessage::Summary`, all previous messages are ignored except the
//! system prompt. This is useful for maintaining focus in long conversations or managing token
//! limits.
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use anyhow::Result;
use async_trait::async_trait;
use swiftide_core::chat_completion::ChatMessage;
use swiftide_core::{AgentContext, Command, CommandError, CommandOutput, ToolExecutor};
use tokio::sync::Mutex;

use crate::tools::local_executor::LocalExecutor;

// TODO: Remove unit as executor and implement a local executor instead
#[derive(Clone)]
pub struct DefaultContext {
    completion_history: Arc<Mutex<Vec<ChatMessage>>>,
    /// Index in the conversation history where the next completion will start
    completions_ptr: Arc<AtomicUsize>,

    /// Index in the conversation history where the current completion started
    /// Allows for retrieving only new messages since the last completion
    current_completions_ptr: Arc<AtomicUsize>,

    /// The executor used to run tools. I.e. local, remote, docker
    tool_executor: Arc<dyn ToolExecutor>,

    /// Stop if last message is from the assistant
    stop_on_assistant: bool,
}

impl Default for DefaultContext {
    fn default() -> Self {
        DefaultContext {
            completion_history: Arc::new(Mutex::new(Vec::new())),
            completions_ptr: Arc::new(AtomicUsize::new(0)),
            current_completions_ptr: Arc::new(AtomicUsize::new(0)),
            tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
            stop_on_assistant: true,
        }
    }
}

impl std::fmt::Debug for DefaultContext {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("DefaultContext")
            .field("completion_history", &self.completion_history)
            .field("completions_ptr", &self.completions_ptr)
            .field("current_completions_ptr", &self.current_completions_ptr)
            .field("tool_executor", &"Arc<dyn ToolExecutor>")
            .field("stop_on_assistant", &self.stop_on_assistant)
            .finish()
    }
}

impl DefaultContext {
    /// Create a new context with a custom executor
    pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
        DefaultContext {
            tool_executor: executor.into(),
            ..Default::default()
        }
    }

    /// If set to true, the agent will stop if the last message is from the assistant (i.e. no new
    /// tool calls, summaries or user messages)
    pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
        self.stop_on_assistant = stop;
        self
    }
}
#[async_trait]
impl AgentContext for DefaultContext {
    /// Retrieve messages for the next completion
    async fn next_completion(&self) -> Option<Vec<ChatMessage>> {
        let history = self.completion_history.lock().await;

        let current = self.completions_ptr.load(Ordering::SeqCst);

        if history[current..].is_empty()
            || (self.stop_on_assistant
                && matches!(history.last(), Some(ChatMessage::Assistant(_, _))))
        {
            None
        } else {
            let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
            self.current_completions_ptr
                .store(previous, Ordering::SeqCst);

            Some(filter_messages_since_summary(history.clone()))
        }
    }

    /// Returns the messages the agent is currently completing on
    async fn current_new_messages(&self) -> Vec<ChatMessage> {
        let current = self.current_completions_ptr.load(Ordering::SeqCst);
        let end = self.completions_ptr.load(Ordering::SeqCst);

        let history = self.completion_history.lock().await;

        filter_messages_since_summary(history[current..end].to_vec())
    }

    /// Retrieve all messages in the conversation history
    async fn history(&self) -> Vec<ChatMessage> {
        self.completion_history.lock().await.clone()
    }

    /// Add multiple messages to the conversation history
    async fn add_messages(&self, messages: Vec<ChatMessage>) {
        for item in messages {
            self.add_message(item).await;
        }
    }

    /// Add a single message to the conversation history
    async fn add_message(&self, item: ChatMessage) {
        self.completion_history.lock().await.push(item);
    }

    /// Execute a command in the tool executor
    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
        self.tool_executor.exec_cmd(cmd).await
    }

    /// Pops the last messages up until the previous completion
    ///
    /// LLMs failing completion for various reasons is unfortunately a common occurrence
    /// This gives a way to redrive the last completion in a generic way
    async fn redrive(&self) {
        let mut history = self.completion_history.lock().await;
        let previous = self.current_completions_ptr.load(Ordering::SeqCst);
        let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst);

        // delete everything after the last completion
        history.truncate(redrive_ptr);
    }
}

fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
    let mut summary_found = false;
    let mut messages = messages
        .into_iter()
        .rev()
        .filter(|m| {
            if summary_found {
                return matches!(m, ChatMessage::System(_));
            }
            if let ChatMessage::Summary(_) = m {
                summary_found = true;
            }
            true
        })
        .collect::<Vec<_>>();

    messages.reverse();

    messages
}

#[cfg(test)]
mod tests {
    use crate::{assistant, tool_output, user};

    use super::*;
    use swiftide_core::chat_completion::{ChatMessage, ToolCall, ToolOutput};

    #[tokio::test]
    async fn test_iteration_tracking() {
        let mut context = DefaultContext::default();

        // Record initial chat messages
        context
            .add_messages(vec![
                ChatMessage::System("You are awesome".into()),
                ChatMessage::User("Hello".into()),
            ])
            .await;

        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 2);
        assert!(context.next_completion().await.is_none());

        context
            .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
            .await;

        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 4);
        assert!(context.next_completion().await.is_none());

        // If the last message is from the assistant, we should not get any more completions
        context.add_messages(vec![assistant!("I am fine")]).await;

        assert!(context.next_completion().await.is_none());

        context.with_stop_on_assistant(false);

        assert!(context.next_completion().await.is_some());
    }

    #[tokio::test]
    async fn test_should_complete_after_tool_call() {
        let context = DefaultContext::default();
        // Record initial chat messages
        context
            .add_messages(vec![
                ChatMessage::System("You are awesome".into()),
                ChatMessage::User("Hello".into()),
            ])
            .await;
        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 2);
        assert_eq!(context.current_new_messages().await.len(), 2);
        assert!(context.next_completion().await.is_none());

        context
            .add_messages(vec![
                assistant!("Hey?", ["test"]),
                tool_output!("test", "Hoi"),
            ])
            .await;

        let messages = context.next_completion().await.unwrap();
        assert_eq!(context.current_new_messages().await.len(), 2);
        assert_eq!(messages.len(), 4);

        assert!(context.next_completion().await.is_none());
    }

    #[tokio::test]
    async fn test_filters_messages_before_summary() {
        let messages = vec![
            ChatMessage::System("System message".into()),
            ChatMessage::User("Hello".into()),
            ChatMessage::Assistant(Some("Hello there".into()), None),
            ChatMessage::Summary("Summary message".into()),
            ChatMessage::User("This should be ignored".into()),
        ];
        let context = DefaultContext::default();
        // Record initial chat messages
        context.add_messages(messages).await;

        let new_messages = context.next_completion().await.unwrap();

        assert_eq!(new_messages.len(), 3);
        assert!(matches!(new_messages[0], ChatMessage::System(_)));
        assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
        assert!(matches!(new_messages[2], ChatMessage::User(_)));

        let current_new_messages = context.current_new_messages().await;
        assert_eq!(current_new_messages.len(), 3);
        assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
        assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
        assert!(matches!(current_new_messages[2], ChatMessage::User(_)));

        assert!(context.next_completion().await.is_none());
    }

    #[tokio::test]
    async fn test_filters_messages_before_summary_with_assistant_last() {
        let messages = vec![
            ChatMessage::System("System message".into()),
            ChatMessage::User("Hello".into()),
            ChatMessage::Assistant(Some("Hello there".into()), None),
        ];
        let mut context = DefaultContext::default();
        context.with_stop_on_assistant(false);
        // Record initial chat messages
        context.add_messages(messages).await;

        let new_messages = context.next_completion().await.unwrap();

        assert_eq!(new_messages.len(), 3);
        assert!(matches!(new_messages[0], ChatMessage::System(_)));
        assert!(matches!(new_messages[1], ChatMessage::User(_)));
        assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));

        context
            .add_message(ChatMessage::Summary("Summary message 1".into()))
            .await;

        let new_messages = context.next_completion().await.unwrap();
        dbg!(&new_messages);
        assert_eq!(new_messages.len(), 2);
        assert!(matches!(new_messages[0], ChatMessage::System(_)));
        assert_eq!(
            new_messages[1],
            ChatMessage::Summary("Summary message 1".into())
        );

        assert!(context.next_completion().await.is_none());

        let messages = vec![
            ChatMessage::User("Hello again".into()),
            ChatMessage::Assistant(Some("Hello there again".into()), None),
        ];

        context.add_messages(messages).await;

        let new_messages = context.next_completion().await.unwrap();

        assert!(matches!(new_messages[0], ChatMessage::System(_)));
        assert_eq!(
            new_messages[1],
            ChatMessage::Summary("Summary message 1".into())
        );
        assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
        assert_eq!(
            new_messages[3],
            ChatMessage::Assistant(Some("Hello there again".to_string()), None)
        );

        context
            .add_message(ChatMessage::Summary("Summary message 2".into()))
            .await;

        let new_messages = context.next_completion().await.unwrap();
        assert_eq!(new_messages.len(), 2);

        assert!(matches!(new_messages[0], ChatMessage::System(_)));
        assert_eq!(
            new_messages[1],
            ChatMessage::Summary("Summary message 2".into())
        );
    }

    #[tokio::test]
    async fn test_redrive() {
        let context = DefaultContext::default();

        // Record initial chat messages
        context
            .add_messages(vec![
                ChatMessage::System("System message".into()),
                ChatMessage::User("Hello".into()),
            ])
            .await;

        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 2);
        assert!(context.next_completion().await.is_none());
        context.redrive().await;

        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 2);

        context
            .add_messages(vec![ChatMessage::User("Hey?".into())])
            .await;

        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 3);
        assert!(context.next_completion().await.is_none());
        context.redrive().await;

        // Add more messages
        context
            .add_messages(vec![ChatMessage::User("How are you?".into())])
            .await;

        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 4);
        assert!(context.next_completion().await.is_none());

        // Redrive should remove the last set of messages
        dbg!(&context);
        context.redrive().await;
        dbg!(&context);

        // We just redrove with the same messages
        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 4);
        assert!(context.next_completion().await.is_none());

        // Add more messages
        context
            .add_messages(vec![
                ChatMessage::User("How are you really?".into()),
                ChatMessage::User("How are you really?".into()),
            ])
            .await;

        // This should remove any additional messages
        context.redrive().await;

        // We just redrove with the same messages
        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 4);
        assert!(context.next_completion().await.is_none());

        // Redrive again
        context.redrive().await;
        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 4);
        assert!(context.next_completion().await.is_none());
    }
}