swiftide_core/
agent_traits.rs

1use std::{
2    path::PathBuf,
3    sync::{Arc, Mutex},
4};
5
6use crate::chat_completion::{ChatMessage, ToolCall};
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use thiserror::Error;
11
12/// A tool executor that can be used within an `AgentContext`
13#[async_trait]
14pub trait ToolExecutor: Send + Sync {
15    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
16}
17
18#[async_trait]
19impl<T: ToolExecutor> ToolExecutor for &T {
20    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
21        (*self).exec_cmd(cmd).await
22    }
23}
24
25#[async_trait]
26impl ToolExecutor for Arc<dyn ToolExecutor> {
27    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
28        (**self).exec_cmd(cmd).await
29    }
30}
31
32#[async_trait]
33impl ToolExecutor for Box<dyn ToolExecutor> {
34    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
35        (**self).exec_cmd(cmd).await
36    }
37}
38
39#[async_trait]
40impl ToolExecutor for &dyn ToolExecutor {
41    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
42        (**self).exec_cmd(cmd).await
43    }
44}
45
46#[derive(Debug, Error)]
47pub enum CommandError {
48    /// The executor itself failed
49    #[error("executor error: {0:#}")]
50    ExecutorError(#[from] anyhow::Error),
51
52    /// The command failed, i.e. failing tests with stderr. This error might be handled
53    #[error("command failed with NonZeroExit: {0}")]
54    NonZeroExit(CommandOutput),
55}
56
57impl From<std::io::Error> for CommandError {
58    fn from(err: std::io::Error) -> Self {
59        CommandError::NonZeroExit(err.to_string().into())
60    }
61}
62
63/// Commands that can be executed by the executor
64/// Conceptually, `Shell` allows any kind of input, and other commands enable more optimized
65/// implementations.
66///
67/// There is an ongoing consideration to make this an associated type on the executor
68///
69/// TODO: Should be able to borrow everything?
70#[non_exhaustive]
71#[derive(Debug, Clone)]
72pub enum Command {
73    Shell(String),
74    ReadFile(PathBuf),
75    WriteFile(PathBuf, String),
76}
77
78impl Command {
79    pub fn shell<S: Into<String>>(cmd: S) -> Self {
80        Command::Shell(cmd.into())
81    }
82
83    pub fn read_file<P: Into<PathBuf>>(path: P) -> Self {
84        Command::ReadFile(path.into())
85    }
86
87    pub fn write_file<P: Into<PathBuf>, S: Into<String>>(path: P, content: S) -> Self {
88        Command::WriteFile(path.into(), content.into())
89    }
90}
91
92/// Output from a `Command`
93#[derive(Debug, Clone)]
94pub struct CommandOutput {
95    pub output: String,
96    // status_code: i32,
97    // success: bool,
98}
99
100impl CommandOutput {
101    pub fn empty() -> Self {
102        CommandOutput {
103            output: String::new(),
104        }
105    }
106
107    pub fn new(output: impl Into<String>) -> Self {
108        CommandOutput {
109            output: output.into(),
110        }
111    }
112    pub fn is_empty(&self) -> bool {
113        self.output.is_empty()
114    }
115}
116
117impl std::fmt::Display for CommandOutput {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        self.output.fmt(f)
120    }
121}
122
123impl<T: Into<String>> From<T> for CommandOutput {
124    fn from(value: T) -> Self {
125        CommandOutput {
126            output: value.into(),
127        }
128    }
129}
130
131impl AsRef<str> for CommandOutput {
132    fn as_ref(&self) -> &str {
133        &self.output
134    }
135}
136
137/// Feedback that can be given on a tool, i.e. with a human in the loop
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub enum ToolFeedback {
140    Approved { payload: Option<serde_json::Value> },
141    Refused { payload: Option<serde_json::Value> },
142}
143
144impl ToolFeedback {
145    pub fn approved() -> Self {
146        ToolFeedback::Approved { payload: None }
147    }
148
149    pub fn refused() -> Self {
150        ToolFeedback::Refused { payload: None }
151    }
152
153    #[must_use]
154    pub fn with_payload(self, payload: serde_json::Value) -> Self {
155        match self {
156            ToolFeedback::Approved { .. } => ToolFeedback::Approved {
157                payload: Some(payload),
158            },
159            ToolFeedback::Refused { .. } => ToolFeedback::Refused {
160                payload: Some(payload),
161            },
162        }
163    }
164}
165
166/// Acts as the interface to the external world and manages messages for completion
167#[async_trait]
168pub trait AgentContext: Send + Sync {
169    /// List of all messages for this agent
170    ///
171    /// Used as main source for the next completion and expects all
172    /// messages to be returned if new messages are present.
173    ///
174    /// Once this method has been called, there should not be new messages
175    ///
176    /// TODO: Figure out a nice way to return a reference instead while still supporting i.e.
177    /// mutexes
178    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>>;
179
180    /// Lists only the new messages after calling `new_completion`
181    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>>;
182
183    /// Add messages for the next completion
184    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()>;
185
186    /// Add messages for the next completion
187    async fn add_message(&self, item: ChatMessage) -> Result<()>;
188
189    /// Execute a command if the context supports it
190    ///
191    /// Deprecated: use executor instead to access the executor directly
192    #[deprecated(note = "use executor instead")]
193    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
194
195    fn executor(&self) -> &Arc<dyn ToolExecutor>;
196
197    async fn history(&self) -> Result<Vec<ChatMessage>>;
198
199    /// Pops the last messages up until the last completion
200    ///
201    /// LLMs failing completion for various reasons is unfortunately a common occurrence
202    /// This gives a way to redrive the last completion in a generic way
203    async fn redrive(&self) -> Result<()>;
204
205    /// Tools that require feedback or approval (i.e. from a human) can use this to check if the
206    /// feedback is received
207    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback>;
208
209    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()>;
210}
211
212#[async_trait]
213impl AgentContext for Box<dyn AgentContext> {
214    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
215        (**self).next_completion().await
216    }
217
218    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
219        (**self).current_new_messages().await
220    }
221
222    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
223        (**self).add_messages(item).await
224    }
225
226    async fn add_message(&self, item: ChatMessage) -> Result<()> {
227        (**self).add_message(item).await
228    }
229
230    #[allow(deprecated)]
231    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
232        (**self).exec_cmd(cmd).await
233    }
234
235    fn executor(&self) -> &Arc<dyn ToolExecutor> {
236        (**self).executor()
237    }
238
239    async fn history(&self) -> Result<Vec<ChatMessage>> {
240        (**self).history().await
241    }
242
243    async fn redrive(&self) -> Result<()> {
244        (**self).redrive().await
245    }
246
247    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
248        (**self).has_received_feedback(tool_call).await
249    }
250
251    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
252        (**self).feedback_received(tool_call, feedback).await
253    }
254}
255
256#[async_trait]
257impl AgentContext for Arc<dyn AgentContext> {
258    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
259        (**self).next_completion().await
260    }
261
262    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
263        (**self).current_new_messages().await
264    }
265
266    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
267        (**self).add_messages(item).await
268    }
269
270    async fn add_message(&self, item: ChatMessage) -> Result<()> {
271        (**self).add_message(item).await
272    }
273
274    #[allow(deprecated)]
275    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
276        (**self).exec_cmd(cmd).await
277    }
278
279    fn executor(&self) -> &Arc<dyn ToolExecutor> {
280        (**self).executor()
281    }
282
283    async fn history(&self) -> Result<Vec<ChatMessage>> {
284        (**self).history().await
285    }
286
287    async fn redrive(&self) -> Result<()> {
288        (**self).redrive().await
289    }
290
291    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
292        (**self).has_received_feedback(tool_call).await
293    }
294
295    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
296        (**self).feedback_received(tool_call, feedback).await
297    }
298}
299
300#[async_trait]
301impl AgentContext for &dyn AgentContext {
302    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
303        (**self).next_completion().await
304    }
305
306    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
307        (**self).current_new_messages().await
308    }
309
310    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
311        (**self).add_messages(item).await
312    }
313
314    async fn add_message(&self, item: ChatMessage) -> Result<()> {
315        (**self).add_message(item).await
316    }
317
318    #[allow(deprecated)]
319    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
320        (**self).exec_cmd(cmd).await
321    }
322
323    fn executor(&self) -> &Arc<dyn ToolExecutor> {
324        (**self).executor()
325    }
326
327    async fn history(&self) -> Result<Vec<ChatMessage>> {
328        (**self).history().await
329    }
330
331    async fn redrive(&self) -> Result<()> {
332        (**self).redrive().await
333    }
334
335    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
336        (**self).has_received_feedback(tool_call).await
337    }
338
339    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
340        (**self).feedback_received(tool_call, feedback).await
341    }
342}
343
344/// Convenience implementation for empty agent context
345///
346/// Errors if tools attempt to execute commands
347#[async_trait]
348impl AgentContext for () {
349    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
350        Ok(None)
351    }
352
353    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
354        Ok(Vec::new())
355    }
356
357    async fn add_messages(&self, _item: Vec<ChatMessage>) -> Result<()> {
358        Ok(())
359    }
360
361    async fn add_message(&self, _item: ChatMessage) -> Result<()> {
362        Ok(())
363    }
364
365    async fn exec_cmd(&self, _cmd: &Command) -> Result<CommandOutput, CommandError> {
366        Err(CommandError::ExecutorError(anyhow::anyhow!(
367            "Empty agent context does not have a tool executor"
368        )))
369    }
370
371    fn executor(&self) -> &Arc<dyn ToolExecutor> {
372        unimplemented!("Empty agent context does not have a tool executor")
373    }
374
375    async fn history(&self) -> Result<Vec<ChatMessage>> {
376        Ok(Vec::new())
377    }
378
379    async fn redrive(&self) -> Result<()> {
380        Ok(())
381    }
382
383    async fn has_received_feedback(&self, _tool_call: &ToolCall) -> Option<ToolFeedback> {
384        Some(ToolFeedback::Approved { payload: None })
385    }
386
387    async fn feedback_received(
388        &self,
389        _tool_call: &ToolCall,
390        _feedback: &ToolFeedback,
391    ) -> Result<()> {
392        Ok(())
393    }
394}
395
396/// A backend for the agent context. A default is provided for Arc<Mutex<Vec<ChatMessage>>>
397///
398/// If you want to use for instance a database, implement this trait and pass it to the agent
399/// context when creating it.
400#[async_trait]
401pub trait MessageHistory: Send + Sync + std::fmt::Debug {
402    /// Returns the history of messages
403    async fn history(&self) -> Result<Vec<ChatMessage>>;
404
405    /// Add a message to the history
406    async fn push_owned(&self, item: ChatMessage) -> Result<()>;
407
408    /// Overwrite the history with the given items
409    async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()>;
410
411    /// Add a message to the history
412    async fn push(&self, item: &ChatMessage) -> Result<()> {
413        self.push_owned(item.clone()).await
414    }
415
416    /// Extend the history with the given items
417    async fn extend(&self, items: &[ChatMessage]) -> Result<()> {
418        self.extend_owned(items.to_vec()).await
419    }
420
421    /// Extend the history with the given items, taking ownership of them
422    async fn extend_owned(&self, items: Vec<ChatMessage>) -> Result<()> {
423        for item in items {
424            self.push_owned(item).await?;
425        }
426
427        Ok(())
428    }
429}
430
431#[async_trait]
432impl MessageHistory for Mutex<Vec<ChatMessage>> {
433    async fn history(&self) -> Result<Vec<ChatMessage>> {
434        Ok(self.lock().unwrap().clone())
435    }
436
437    async fn push_owned(&self, item: ChatMessage) -> Result<()> {
438        self.lock().unwrap().push(item);
439
440        Ok(())
441    }
442
443    async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()> {
444        let mut lock = self.lock().unwrap();
445        *lock = items;
446
447        Ok(())
448    }
449}