swiftide_agents/
default_context.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
//! Manages agent history and provides an interface for the external world
//!
//! This is the default for agents. It is fully async and shareable between agents.
//!
//! By default uses the `LocalExecutor` for tool execution.
//!
//! If chat messages include a `ChatMessage::Summary`, all previous messages are ignored except the
//! system prompt. This is useful for maintaining focus in long conversations or managing token
//! limits.
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use anyhow::Result;
use async_trait::async_trait;
use swiftide_core::chat_completion::ChatMessage;
use swiftide_core::{AgentContext, Command, CommandError, CommandOutput, ToolExecutor};
use tokio::sync::Mutex;

use crate::tools::local_executor::LocalExecutor;

// TODO: Remove unit as executor and implement a local executor instead
#[derive(Clone)]
pub struct DefaultContext {
    completion_history: Arc<Mutex<Vec<ChatMessage>>>,
    /// Index in the conversation history where the next completion will start
    completions_ptr: Arc<AtomicUsize>,

    /// Index in the conversation history where the current completion started
    /// Allows for retrieving only new messages since the last completion
    current_completions_ptr: Arc<AtomicUsize>,

    /// The executor used to run tools. I.e. local, remote, docker
    tool_executor: Arc<dyn ToolExecutor>,

    /// Stop if last message is from the assistant
    stop_on_assistant: bool,
}

impl Default for DefaultContext {
    fn default() -> Self {
        DefaultContext {
            completion_history: Arc::new(Mutex::new(Vec::new())),
            completions_ptr: Arc::new(AtomicUsize::new(0)),
            current_completions_ptr: Arc::new(AtomicUsize::new(0)),
            tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
            stop_on_assistant: true,
        }
    }
}

impl DefaultContext {
    /// Create a new context with a custom executor
    pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
        DefaultContext {
            tool_executor: executor.into(),
            ..Default::default()
        }
    }

    /// If set to true, the agent will stop if the last message is from the assistant (i.e. no new
    /// tool calls, summaries or user messages)
    pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
        self.stop_on_assistant = stop;
        self
    }
}
#[async_trait]
impl AgentContext for DefaultContext {
    /// Retrieve messages for the next completion
    async fn next_completion(&self) -> Option<Vec<ChatMessage>> {
        let current = self.completions_ptr.load(Ordering::SeqCst);

        let history = self.completion_history.lock().await;

        if history[current..].is_empty()
            || (self.stop_on_assistant
                && matches!(history.last(), Some(ChatMessage::Assistant(_, _))))
        {
            None
        } else {
            let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
            self.current_completions_ptr
                .store(previous, Ordering::SeqCst);

            Some(filter_messages_since_summary(history.clone()))
        }
    }

    /// Returns the messages the agent is currently completing on
    async fn current_new_messages(&self) -> Vec<ChatMessage> {
        let current = self.current_completions_ptr.load(Ordering::SeqCst);
        let end = self.completions_ptr.load(Ordering::SeqCst);

        let history = self.completion_history.lock().await;

        filter_messages_since_summary(history[current..end].to_vec())
    }

    /// Retrieve all messages in the conversation history
    async fn history(&self) -> Vec<ChatMessage> {
        self.completion_history.lock().await.clone()
    }

    /// Add multiple messages to the conversation history
    async fn add_messages(&self, messages: Vec<ChatMessage>) {
        for item in messages {
            self.add_message(item).await;
        }
    }

    /// Add a single message to the conversation history
    async fn add_message(&self, item: ChatMessage) {
        self.completion_history.lock().await.push(item);
    }

    /// Execute a command in the tool executor
    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
        self.tool_executor.exec_cmd(cmd).await
    }
}

fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
    let mut summary_found = false;
    let mut messages = messages
        .into_iter()
        .rev()
        .filter(|m| {
            if summary_found {
                return matches!(m, ChatMessage::System(_));
            }
            if let ChatMessage::Summary(_) = m {
                summary_found = true;
            }
            true
        })
        .collect::<Vec<_>>();

    messages.reverse();

    messages
}

#[cfg(test)]
mod tests {
    use crate::{assistant, tool_output, user};

    use super::*;
    use swiftide_core::chat_completion::{ChatMessage, ToolCall, ToolOutput};

    #[tokio::test]
    async fn test_iteration_tracking() {
        let mut context = DefaultContext::default();

        // Record initial chat messages
        context
            .add_messages(vec![
                ChatMessage::System("You are awesome".into()),
                ChatMessage::User("Hello".into()),
            ])
            .await;

        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 2);
        assert!(context.next_completion().await.is_none());

        context
            .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
            .await;

        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 4);
        assert!(context.next_completion().await.is_none());

        // If the last message is from the assistant, we should not get any more completions
        context.add_messages(vec![assistant!("I am fine")]).await;

        assert!(context.next_completion().await.is_none());

        context.with_stop_on_assistant(false);

        assert!(context.next_completion().await.is_some());
    }

    #[tokio::test]
    async fn test_should_complete_after_tool_call() {
        let context = DefaultContext::default();
        // Record initial chat messages
        context
            .add_messages(vec![
                ChatMessage::System("You are awesome".into()),
                ChatMessage::User("Hello".into()),
            ])
            .await;
        let messages = context.next_completion().await.unwrap();
        assert_eq!(messages.len(), 2);
        assert_eq!(context.current_new_messages().await.len(), 2);
        assert!(context.next_completion().await.is_none());

        context
            .add_messages(vec![
                assistant!("Hey?", ["test"]),
                tool_output!("test", "Hoi"),
            ])
            .await;

        let messages = context.next_completion().await.unwrap();
        assert_eq!(context.current_new_messages().await.len(), 2);
        assert_eq!(messages.len(), 4);

        assert!(context.next_completion().await.is_none());
    }

    #[tokio::test]
    async fn test_filters_messages_before_summary() {
        let messages = vec![
            ChatMessage::System("System message".into()),
            ChatMessage::User("Hello".into()),
            ChatMessage::Assistant(Some("Hello there".into()), None),
            ChatMessage::Summary("Summary message".into()),
            ChatMessage::User("This should be ignored".into()),
        ];
        let context = DefaultContext::default();
        // Record initial chat messages
        context.add_messages(messages).await;

        let new_messages = context.next_completion().await.unwrap();

        assert_eq!(new_messages.len(), 3);
        assert!(matches!(new_messages[0], ChatMessage::System(_)));
        assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
        assert!(matches!(new_messages[2], ChatMessage::User(_)));

        let current_new_messages = context.current_new_messages().await;
        assert_eq!(current_new_messages.len(), 3);
        assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
        assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
        assert!(matches!(current_new_messages[2], ChatMessage::User(_)));

        assert!(context.next_completion().await.is_none());
    }

    #[tokio::test]
    async fn test_filters_messages_before_summary_with_assistant_last() {
        let messages = vec![
            ChatMessage::System("System message".into()),
            ChatMessage::User("Hello".into()),
            ChatMessage::Assistant(Some("Hello there".into()), None),
        ];
        let mut context = DefaultContext::default();
        context.with_stop_on_assistant(false);
        // Record initial chat messages
        context.add_messages(messages).await;

        let new_messages = context.next_completion().await.unwrap();

        assert_eq!(new_messages.len(), 3);
        assert!(matches!(new_messages[0], ChatMessage::System(_)));
        assert!(matches!(new_messages[1], ChatMessage::User(_)));
        assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));

        context
            .add_message(ChatMessage::Summary("Summary message 1".into()))
            .await;

        let new_messages = context.next_completion().await.unwrap();
        dbg!(&new_messages);
        assert_eq!(new_messages.len(), 2);
        assert!(matches!(new_messages[0], ChatMessage::System(_)));
        assert_eq!(
            new_messages[1],
            ChatMessage::Summary("Summary message 1".into())
        );

        assert!(context.next_completion().await.is_none());

        let messages = vec![
            ChatMessage::User("Hello again".into()),
            ChatMessage::Assistant(Some("Hello there again".into()), None),
        ];

        context.add_messages(messages).await;

        let new_messages = context.next_completion().await.unwrap();

        assert!(matches!(new_messages[0], ChatMessage::System(_)));
        assert_eq!(
            new_messages[1],
            ChatMessage::Summary("Summary message 1".into())
        );
        assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
        assert_eq!(
            new_messages[3],
            ChatMessage::Assistant(Some("Hello there again".to_string()), None)
        );

        context
            .add_message(ChatMessage::Summary("Summary message 2".into()))
            .await;

        let new_messages = context.next_completion().await.unwrap();
        assert_eq!(new_messages.len(), 2);

        assert!(matches!(new_messages[0], ChatMessage::System(_)));
        assert_eq!(
            new_messages[1],
            ChatMessage::Summary("Summary message 2".into())
        );
    }
}