use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use swiftide_core::chat_completion::ChatMessage;
use swiftide_core::{AgentContext, Command, CommandError, CommandOutput, ToolExecutor};
use tokio::sync::Mutex;
use crate::tools::local_executor::LocalExecutor;
#[derive(Clone)]
pub struct DefaultContext {
completion_history: Arc<Mutex<Vec<ChatMessage>>>,
completions_ptr: Arc<AtomicUsize>,
current_completions_ptr: Arc<AtomicUsize>,
tool_executor: Arc<dyn ToolExecutor>,
stop_on_assistant: bool,
}
impl Default for DefaultContext {
fn default() -> Self {
DefaultContext {
completion_history: Arc::new(Mutex::new(Vec::new())),
completions_ptr: Arc::new(AtomicUsize::new(0)),
current_completions_ptr: Arc::new(AtomicUsize::new(0)),
tool_executor: Arc::new(LocalExecutor::default()) as Arc<dyn ToolExecutor>,
stop_on_assistant: true,
}
}
}
impl DefaultContext {
pub fn from_executor<T: Into<Arc<dyn ToolExecutor>>>(executor: T) -> DefaultContext {
DefaultContext {
tool_executor: executor.into(),
..Default::default()
}
}
pub fn with_stop_on_assistant(&mut self, stop: bool) -> &mut Self {
self.stop_on_assistant = stop;
self
}
}
#[async_trait]
impl AgentContext for DefaultContext {
async fn next_completion(&self) -> Option<Vec<ChatMessage>> {
let current = self.completions_ptr.load(Ordering::SeqCst);
let history = self.completion_history.lock().await;
if history[current..].is_empty()
|| (self.stop_on_assistant
&& matches!(history.last(), Some(ChatMessage::Assistant(_, _))))
{
None
} else {
let previous = self.completions_ptr.swap(history.len(), Ordering::SeqCst);
self.current_completions_ptr
.store(previous, Ordering::SeqCst);
Some(filter_messages_since_summary(history.clone()))
}
}
async fn current_new_messages(&self) -> Vec<ChatMessage> {
let current = self.current_completions_ptr.load(Ordering::SeqCst);
let end = self.completions_ptr.load(Ordering::SeqCst);
let history = self.completion_history.lock().await;
filter_messages_since_summary(history[current..end].to_vec())
}
async fn history(&self) -> Vec<ChatMessage> {
self.completion_history.lock().await.clone()
}
async fn add_messages(&self, messages: Vec<ChatMessage>) {
for item in messages {
self.add_message(item).await;
}
}
async fn add_message(&self, item: ChatMessage) {
self.completion_history.lock().await.push(item);
}
async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
self.tool_executor.exec_cmd(cmd).await
}
}
fn filter_messages_since_summary(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
let mut summary_found = false;
let mut messages = messages
.into_iter()
.rev()
.filter(|m| {
if summary_found {
return matches!(m, ChatMessage::System(_));
}
if let ChatMessage::Summary(_) = m {
summary_found = true;
}
true
})
.collect::<Vec<_>>();
messages.reverse();
messages
}
#[cfg(test)]
mod tests {
use crate::{assistant, tool_output, user};
use super::*;
use swiftide_core::chat_completion::{ChatMessage, ToolCall, ToolOutput};
#[tokio::test]
async fn test_iteration_tracking() {
let mut context = DefaultContext::default();
context
.add_messages(vec![
ChatMessage::System("You are awesome".into()),
ChatMessage::User("Hello".into()),
])
.await;
let messages = context.next_completion().await.unwrap();
assert_eq!(messages.len(), 2);
assert!(context.next_completion().await.is_none());
context
.add_messages(vec![assistant!("Hey?"), user!("How are you?")])
.await;
let messages = context.next_completion().await.unwrap();
assert_eq!(messages.len(), 4);
assert!(context.next_completion().await.is_none());
context.add_messages(vec![assistant!("I am fine")]).await;
assert!(context.next_completion().await.is_none());
context.with_stop_on_assistant(false);
assert!(context.next_completion().await.is_some());
}
#[tokio::test]
async fn test_should_complete_after_tool_call() {
let context = DefaultContext::default();
context
.add_messages(vec![
ChatMessage::System("You are awesome".into()),
ChatMessage::User("Hello".into()),
])
.await;
let messages = context.next_completion().await.unwrap();
assert_eq!(messages.len(), 2);
assert_eq!(context.current_new_messages().await.len(), 2);
assert!(context.next_completion().await.is_none());
context
.add_messages(vec![
assistant!("Hey?", ["test"]),
tool_output!("test", "Hoi"),
])
.await;
let messages = context.next_completion().await.unwrap();
assert_eq!(context.current_new_messages().await.len(), 2);
assert_eq!(messages.len(), 4);
assert!(context.next_completion().await.is_none());
}
#[tokio::test]
async fn test_filters_messages_before_summary() {
let messages = vec![
ChatMessage::System("System message".into()),
ChatMessage::User("Hello".into()),
ChatMessage::Assistant(Some("Hello there".into()), None),
ChatMessage::Summary("Summary message".into()),
ChatMessage::User("This should be ignored".into()),
];
let context = DefaultContext::default();
context.add_messages(messages).await;
let new_messages = context.next_completion().await.unwrap();
assert_eq!(new_messages.len(), 3);
assert!(matches!(new_messages[0], ChatMessage::System(_)));
assert!(matches!(new_messages[1], ChatMessage::Summary(_)));
assert!(matches!(new_messages[2], ChatMessage::User(_)));
let current_new_messages = context.current_new_messages().await;
assert_eq!(current_new_messages.len(), 3);
assert!(matches!(current_new_messages[0], ChatMessage::System(_)));
assert!(matches!(current_new_messages[1], ChatMessage::Summary(_)));
assert!(matches!(current_new_messages[2], ChatMessage::User(_)));
assert!(context.next_completion().await.is_none());
}
#[tokio::test]
async fn test_filters_messages_before_summary_with_assistant_last() {
let messages = vec![
ChatMessage::System("System message".into()),
ChatMessage::User("Hello".into()),
ChatMessage::Assistant(Some("Hello there".into()), None),
];
let mut context = DefaultContext::default();
context.with_stop_on_assistant(false);
context.add_messages(messages).await;
let new_messages = context.next_completion().await.unwrap();
assert_eq!(new_messages.len(), 3);
assert!(matches!(new_messages[0], ChatMessage::System(_)));
assert!(matches!(new_messages[1], ChatMessage::User(_)));
assert!(matches!(new_messages[2], ChatMessage::Assistant(_, _)));
context
.add_message(ChatMessage::Summary("Summary message 1".into()))
.await;
let new_messages = context.next_completion().await.unwrap();
dbg!(&new_messages);
assert_eq!(new_messages.len(), 2);
assert!(matches!(new_messages[0], ChatMessage::System(_)));
assert_eq!(
new_messages[1],
ChatMessage::Summary("Summary message 1".into())
);
assert!(context.next_completion().await.is_none());
let messages = vec![
ChatMessage::User("Hello again".into()),
ChatMessage::Assistant(Some("Hello there again".into()), None),
];
context.add_messages(messages).await;
let new_messages = context.next_completion().await.unwrap();
assert!(matches!(new_messages[0], ChatMessage::System(_)));
assert_eq!(
new_messages[1],
ChatMessage::Summary("Summary message 1".into())
);
assert_eq!(new_messages[2], ChatMessage::User("Hello again".into()));
assert_eq!(
new_messages[3],
ChatMessage::Assistant(Some("Hello there again".to_string()), None)
);
context
.add_message(ChatMessage::Summary("Summary message 2".into()))
.await;
let new_messages = context.next_completion().await.unwrap();
assert_eq!(new_messages.len(), 2);
assert!(matches!(new_messages[0], ChatMessage::System(_)));
assert_eq!(
new_messages[1],
ChatMessage::Summary("Summary message 2".into())
);
}
}