1use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12
13use anyhow::Result;
14use async_trait::async_trait;
15use swiftide_core::chat_completion::ChatMessage;
16use swiftide_core::{AgentContext, Command, CommandError, CommandOutput, ToolExecutor};
17use tokio::sync::Mutex;
18
19use crate::tools::local_executor::LocalExecutor;
20
21#[derive(Clone)]
23pub struct DefaultContext {
24 completion_history: Arc<Mutex<Vec<ChatMessage>>>,
25 completions_ptr: Arc<AtomicUsize>,
27
28 current_completions_ptr: Arc<AtomicUsize>,
31
32 tool_executor: Arc<dyn ToolExecutor>,
34
35 stop_on_assistant: bool,
37}
38
39impl Default for DefaultContext {
40 fn default() -> Self {
41 DefaultContext {
42 completion_history: Arc::new(Mutex::new(Vec::new())),
43 completions_ptr: Arc::new(AtomicUsize::new(0)),
44 current_completions_ptr: Arc::new(AtomicUsize::new(0)),
45 tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
46 stop_on_assistant: true,
47 }
48 }
49}
50
51impl std::fmt::Debug for DefaultContext {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("DefaultContext")
54 .field("completion_history", &self.completion_history)
55 .field("completions_ptr", &self.completions_ptr)
56 .field("current_completions_ptr", &self.current_completions_ptr)
57 .field("tool_executor", &"Arc<dyn ToolExecutor>")
58 .field("stop_on_assistant", &self.stop_on_assistant)
59 .finish()
60 }
61}
62
63impl DefaultContext {
64 pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
66 DefaultContext {
67 tool_executor: executor.into(),
68 ..Default::default()
69 }
70 }
71
72 pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
75 self.stop_on_assistant = stop;
76 self
77 }
78}
79#[async_trait]
80impl AgentContext for DefaultContext {
81 async fn next_completion(&self) -> Option<Vec<ChatMessage>> {
83 let history = self.completion_history.lock().await;
84
85 let current = self.completions_ptr.load(Ordering::SeqCst);
86
87 if history[current..].is_empty()
88 || (self.stop_on_assistant
89 && matches!(history.last(), Some(ChatMessage::Assistant(_, _))))
90 {
91 None
92 } else {
93 let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
94 self.current_completions_ptr
95 .store(previous, Ordering::SeqCst);
96
97 Some(filter_messages_since_summary(history.clone()))
98 }
99 }
100
101 async fn current_new_messages(&self) -> Vec<ChatMessage> {
103 let current = self.current_completions_ptr.load(Ordering::SeqCst);
104 let end = self.completions_ptr.load(Ordering::SeqCst);
105
106 let history = self.completion_history.lock().await;
107
108 filter_messages_since_summary(history[current..end].to_vec())
109 }
110
111 async fn history(&self) -> Vec<ChatMessage> {
113 self.completion_history.lock().await.clone()
114 }
115
116 async fn add_messages(&self, messages: Vec<ChatMessage>) {
118 for item in messages {
119 self.add_message(item).await;
120 }
121 }
122
123 async fn add_message(&self, item: ChatMessage) {
125 self.completion_history.lock().await.push(item);
126 }
127
128 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
130 self.tool_executor.exec_cmd(cmd).await
131 }
132
133 async fn redrive(&self) {
138 let mut history = self.completion_history.lock().await;
139 let previous = self.current_completions_ptr.load(Ordering::SeqCst);
140 let redrive_ptr = self.completions_ptr.swap(previous, Ordering::SeqCst);
141
142 history.truncate(redrive_ptr);
144 }
145}
146
147fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
148 let mut summary_found = false;
149 let mut messages = messages
150 .into_iter()
151 .rev()
152 .filter(|m| {
153 if summary_found {
154 return matches!(m, ChatMessage::System(_));
155 }
156 if let ChatMessage::Summary(_) = m {
157 summary_found = true;
158 }
159 true
160 })
161 .collect::<Vec<_>>();
162
163 messages.reverse();
164
165 messages
166}
167
168#[cfg(test)]
169mod tests {
170 use crate::{assistant, tool_output, user};
171
172 use super::*;
173 use swiftide_core::chat_completion::{ChatMessage, ToolCall, ToolOutput};
174
175 #[tokio::test]
176 async fn test_iteration_tracking() {
177 let mut context = DefaultContext::default();
178
179 context
181 .add_messages(vec![
182 ChatMessage::System("You are awesome".into()),
183 ChatMessage::User("Hello".into()),
184 ])
185 .await;
186
187 let messages = context.next_completion().await.unwrap();
188 assert_eq!(messages.len(), 2);
189 assert!(context.next_completion().await.is_none());
190
191 context
192 .add_messages(vec![assistant!("Hey?"), user!("How are you?")])
193 .await;
194
195 let messages = context.next_completion().await.unwrap();
196 assert_eq!(messages.len(), 4);
197 assert!(context.next_completion().await.is_none());
198
199 context.add_messages(vec![assistant!("I am fine")]).await;
201
202 assert!(context.next_completion().await.is_none());
203
204 context.with_stop_on_assistant(false);
205
206 assert!(context.next_completion().await.is_some());
207 }
208
209 #[tokio::test]
210 async fn test_should_complete_after_tool_call() {
211 let context = DefaultContext::default();
212 context
214 .add_messages(vec![
215 ChatMessage::System("You are awesome".into()),
216 ChatMessage::User("Hello".into()),
217 ])
218 .await;
219 let messages = context.next_completion().await.unwrap();
220 assert_eq!(messages.len(), 2);
221 assert_eq!(context.current_new_messages().await.len(), 2);
222 assert!(context.next_completion().await.is_none());
223
224 context
225 .add_messages(vec![
226 assistant!("Hey?", ["test"]),
227 tool_output!("test", "Hoi"),
228 ])
229 .await;
230
231 let messages = context.next_completion().await.unwrap();
232 assert_eq!(context.current_new_messages().await.len(), 2);
233 assert_eq!(messages.len(), 4);
234
235 assert!(context.next_completion().await.is_none());
236 }
237
238 #[tokio::test]
239 async fn test_filters_messages_before_summary() {
240 let messages = vec![
241 ChatMessage::System("System message".into()),
242 ChatMessage::User("Hello".into()),
243 ChatMessage::Assistant(Some("Hello there".into()), None),
244 ChatMessage::Summary("Summary message".into()),
245 ChatMessage::User("This should be ignored".into()),
246 ];
247 let context = DefaultContext::default();
248 context.add_messages(messages).await;
250
251 let new_messages = context.next_completion().await.unwrap();
252
253 assert_eq!(new_messages.len(), 3);
254 assert!(matches!(new_messages[0], ChatMessage::System(_)));
255 assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
256 assert!(matches!(new_messages[2], ChatMessage::User(_)));
257
258 let current_new_messages = context.current_new_messages().await;
259 assert_eq!(current_new_messages.len(), 3);
260 assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
261 assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
262 assert!(matches!(current_new_messages[2], ChatMessage::User(_)));
263
264 assert!(context.next_completion().await.is_none());
265 }
266
267 #[tokio::test]
268 async fn test_filters_messages_before_summary_with_assistant_last() {
269 let messages = vec![
270 ChatMessage::System("System message".into()),
271 ChatMessage::User("Hello".into()),
272 ChatMessage::Assistant(Some("Hello there".into()), None),
273 ];
274 let mut context = DefaultContext::default();
275 context.with_stop_on_assistant(false);
276 context.add_messages(messages).await;
278
279 let new_messages = context.next_completion().await.unwrap();
280
281 assert_eq!(new_messages.len(), 3);
282 assert!(matches!(new_messages[0], ChatMessage::System(_)));
283 assert!(matches!(new_messages[1], ChatMessage::User(_)));
284 assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));
285
286 context
287 .add_message(ChatMessage::Summary("Summary message 1".into()))
288 .await;
289
290 let new_messages = context.next_completion().await.unwrap();
291 dbg!(&new_messages);
292 assert_eq!(new_messages.len(), 2);
293 assert!(matches!(new_messages[0], ChatMessage::System(_)));
294 assert_eq!(
295 new_messages[1],
296 ChatMessage::Summary("Summary message 1".into())
297 );
298
299 assert!(context.next_completion().await.is_none());
300
301 let messages = vec![
302 ChatMessage::User("Hello again".into()),
303 ChatMessage::Assistant(Some("Hello there again".into()), None),
304 ];
305
306 context.add_messages(messages).await;
307
308 let new_messages = context.next_completion().await.unwrap();
309
310 assert!(matches!(new_messages[0], ChatMessage::System(_)));
311 assert_eq!(
312 new_messages[1],
313 ChatMessage::Summary("Summary message 1".into())
314 );
315 assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
316 assert_eq!(
317 new_messages[3],
318 ChatMessage::Assistant(Some("Hello there again".to_string()), None)
319 );
320
321 context
322 .add_message(ChatMessage::Summary("Summary message 2".into()))
323 .await;
324
325 let new_messages = context.next_completion().await.unwrap();
326 assert_eq!(new_messages.len(), 2);
327
328 assert!(matches!(new_messages[0], ChatMessage::System(_)));
329 assert_eq!(
330 new_messages[1],
331 ChatMessage::Summary("Summary message 2".into())
332 );
333 }
334
335 #[tokio::test]
336 async fn test_redrive() {
337 let context = DefaultContext::default();
338
339 context
341 .add_messages(vec![
342 ChatMessage::System("System message".into()),
343 ChatMessage::User("Hello".into()),
344 ])
345 .await;
346
347 let messages = context.next_completion().await.unwrap();
348 assert_eq!(messages.len(), 2);
349 assert!(context.next_completion().await.is_none());
350 context.redrive().await;
351
352 let messages = context.next_completion().await.unwrap();
353 assert_eq!(messages.len(), 2);
354
355 context
356 .add_messages(vec![ChatMessage::User("Hey?".into())])
357 .await;
358
359 let messages = context.next_completion().await.unwrap();
360 assert_eq!(messages.len(), 3);
361 assert!(context.next_completion().await.is_none());
362 context.redrive().await;
363
364 context
366 .add_messages(vec![ChatMessage::User("How are you?".into())])
367 .await;
368
369 let messages = context.next_completion().await.unwrap();
370 assert_eq!(messages.len(), 4);
371 assert!(context.next_completion().await.is_none());
372
373 dbg!(&context);
375 context.redrive().await;
376 dbg!(&context);
377
378 let messages = context.next_completion().await.unwrap();
380 assert_eq!(messages.len(), 4);
381 assert!(context.next_completion().await.is_none());
382
383 context
385 .add_messages(vec![
386 ChatMessage::User("How are you really?".into()),
387 ChatMessage::User("How are you really?".into()),
388 ])
389 .await;
390
391 context.redrive().await;
393
394 let messages = context.next_completion().await.unwrap();
396 assert_eq!(messages.len(), 4);
397 assert!(context.next_completion().await.is_none());
398
399 context.redrive().await;
401 let messages = context.next_completion().await.unwrap();
402 assert_eq!(messages.len(), 4);
403 assert!(context.next_completion().await.is_none());
404 }
405}