swiftide_core/
agent_traits.rs

1use std::{
2    path::{Path, PathBuf},
3    sync::{Arc, Mutex},
4};
5
6use crate::{
7    chat_completion::{ChatMessage, ToolCall},
8    indexing::IndexingStream,
9};
10use anyhow::Result;
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15/// A ToolExecutor provides an interface for agents to interact with a system
16/// in an isolated context.
17///
18/// When starting up an agent, it's context expects an executor. For example,
19/// you might want your coding agent to work with a fresh, isolated set of files,
20/// separated from the rest of the system.
21///
22/// See `swiftide-docker-executor` for an executor that uses Docker. By default
23/// the executor is a local executor.
24///
25/// Additionally, the executor can be used stream files files for indexing.
26#[async_trait]
27pub trait ToolExecutor: Send + Sync {
28    /// Execute a command in the executor
29    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
30
31    /// Stream files from the executor
32    async fn stream_files(
33        &self,
34        path: &Path,
35        extensions: Option<Vec<String>>,
36    ) -> Result<IndexingStream>;
37}
38
39#[async_trait]
40impl<T: ToolExecutor> ToolExecutor for &T {
41    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
42        (*self).exec_cmd(cmd).await
43    }
44
45    async fn stream_files(
46        &self,
47        path: &Path,
48        extensions: Option<Vec<String>>,
49    ) -> Result<IndexingStream> {
50        (*self).stream_files(path, extensions).await
51    }
52}
53
54#[async_trait]
55impl ToolExecutor for Arc<dyn ToolExecutor> {
56    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
57        (**self).exec_cmd(cmd).await
58    }
59    async fn stream_files(
60        &self,
61        path: &Path,
62        extensions: Option<Vec<String>>,
63    ) -> Result<IndexingStream> {
64        (*self).stream_files(path, extensions).await
65    }
66}
67
68#[async_trait]
69impl ToolExecutor for Box<dyn ToolExecutor> {
70    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
71        (**self).exec_cmd(cmd).await
72    }
73    async fn stream_files(
74        &self,
75        path: &Path,
76        extensions: Option<Vec<String>>,
77    ) -> Result<IndexingStream> {
78        (*self).stream_files(path, extensions).await
79    }
80}
81
82#[async_trait]
83impl ToolExecutor for &dyn ToolExecutor {
84    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
85        (**self).exec_cmd(cmd).await
86    }
87    async fn stream_files(
88        &self,
89        path: &Path,
90        extensions: Option<Vec<String>>,
91    ) -> Result<IndexingStream> {
92        (*self).stream_files(path, extensions).await
93    }
94}
95
96#[derive(Debug, Error)]
97pub enum CommandError {
98    /// The executor itself failed
99    #[error("executor error: {0:#}")]
100    ExecutorError(#[from] anyhow::Error),
101
102    /// The command failed, i.e. failing tests with stderr. This error might be handled
103    #[error("command failed with NonZeroExit: {0}")]
104    NonZeroExit(CommandOutput),
105}
106
107impl From<std::io::Error> for CommandError {
108    fn from(err: std::io::Error) -> Self {
109        CommandError::NonZeroExit(err.to_string().into())
110    }
111}
112
113/// Commands that can be executed by the executor
114/// Conceptually, `Shell` allows any kind of input, and other commands enable more optimized
115/// implementations.
116///
117/// There is an ongoing consideration to make this an associated type on the executor
118///
119/// TODO: Should be able to borrow everything?
120#[non_exhaustive]
121#[derive(Debug, Clone)]
122pub enum Command {
123    Shell(String),
124    ReadFile(PathBuf),
125    WriteFile(PathBuf, String),
126}
127
128impl Command {
129    pub fn shell<S: Into<String>>(cmd: S) -> Self {
130        Command::Shell(cmd.into())
131    }
132
133    pub fn read_file<P: Into<PathBuf>>(path: P) -> Self {
134        Command::ReadFile(path.into())
135    }
136
137    pub fn write_file<P: Into<PathBuf>, S: Into<String>>(path: P, content: S) -> Self {
138        Command::WriteFile(path.into(), content.into())
139    }
140}
141
142/// Output from a `Command`
143#[derive(Debug, Clone)]
144pub struct CommandOutput {
145    pub output: String,
146    // status_code: i32,
147    // success: bool,
148}
149
150impl CommandOutput {
151    pub fn empty() -> Self {
152        CommandOutput {
153            output: String::new(),
154        }
155    }
156
157    pub fn new(output: impl Into<String>) -> Self {
158        CommandOutput {
159            output: output.into(),
160        }
161    }
162    pub fn is_empty(&self) -> bool {
163        self.output.is_empty()
164    }
165}
166
167impl std::fmt::Display for CommandOutput {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        self.output.fmt(f)
170    }
171}
172
173impl<T: Into<String>> From<T> for CommandOutput {
174    fn from(value: T) -> Self {
175        CommandOutput {
176            output: value.into(),
177        }
178    }
179}
180
181impl AsRef<str> for CommandOutput {
182    fn as_ref(&self) -> &str {
183        &self.output
184    }
185}
186
187/// Feedback that can be given on a tool, i.e. with a human in the loop
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub enum ToolFeedback {
190    Approved { payload: Option<serde_json::Value> },
191    Refused { payload: Option<serde_json::Value> },
192}
193
194impl ToolFeedback {
195    pub fn approved() -> Self {
196        ToolFeedback::Approved { payload: None }
197    }
198
199    pub fn refused() -> Self {
200        ToolFeedback::Refused { payload: None }
201    }
202
203    #[must_use]
204    pub fn with_payload(self, payload: serde_json::Value) -> Self {
205        match self {
206            ToolFeedback::Approved { .. } => ToolFeedback::Approved {
207                payload: Some(payload),
208            },
209            ToolFeedback::Refused { .. } => ToolFeedback::Refused {
210                payload: Some(payload),
211            },
212        }
213    }
214}
215
216/// Acts as the interface to the external world and manages messages for completion
217#[async_trait]
218pub trait AgentContext: Send + Sync {
219    /// List of all messages for this agent
220    ///
221    /// Used as main source for the next completion and expects all
222    /// messages to be returned if new messages are present.
223    ///
224    /// Once this method has been called, there should not be new messages
225    ///
226    /// TODO: Figure out a nice way to return a reference instead while still supporting i.e.
227    /// mutexes
228    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>>;
229
230    /// Lists only the new messages after calling `new_completion`
231    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>>;
232
233    /// Add messages for the next completion
234    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()>;
235
236    /// Add messages for the next completion
237    async fn add_message(&self, item: ChatMessage) -> Result<()>;
238
239    /// Execute a command if the context supports it
240    ///
241    /// Deprecated: use executor instead to access the executor directly
242    #[deprecated(note = "use executor instead")]
243    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
244
245    fn executor(&self) -> &Arc<dyn ToolExecutor>;
246
247    async fn history(&self) -> Result<Vec<ChatMessage>>;
248
249    /// Pops the last messages up until the last completion
250    ///
251    /// LLMs failing completion for various reasons is unfortunately a common occurrence
252    /// This gives a way to redrive the last completion in a generic way
253    async fn redrive(&self) -> Result<()>;
254
255    /// Tools that require feedback or approval (i.e. from a human) can use this to check if the
256    /// feedback is received
257    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback>;
258
259    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()>;
260}
261
262#[async_trait]
263impl AgentContext for Box<dyn AgentContext> {
264    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
265        (**self).next_completion().await
266    }
267
268    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
269        (**self).current_new_messages().await
270    }
271
272    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
273        (**self).add_messages(item).await
274    }
275
276    async fn add_message(&self, item: ChatMessage) -> Result<()> {
277        (**self).add_message(item).await
278    }
279
280    #[allow(deprecated)]
281    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
282        (**self).exec_cmd(cmd).await
283    }
284
285    fn executor(&self) -> &Arc<dyn ToolExecutor> {
286        (**self).executor()
287    }
288
289    async fn history(&self) -> Result<Vec<ChatMessage>> {
290        (**self).history().await
291    }
292
293    async fn redrive(&self) -> Result<()> {
294        (**self).redrive().await
295    }
296
297    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
298        (**self).has_received_feedback(tool_call).await
299    }
300
301    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
302        (**self).feedback_received(tool_call, feedback).await
303    }
304}
305
306#[async_trait]
307impl AgentContext for Arc<dyn AgentContext> {
308    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
309        (**self).next_completion().await
310    }
311
312    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
313        (**self).current_new_messages().await
314    }
315
316    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
317        (**self).add_messages(item).await
318    }
319
320    async fn add_message(&self, item: ChatMessage) -> Result<()> {
321        (**self).add_message(item).await
322    }
323
324    #[allow(deprecated)]
325    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
326        (**self).exec_cmd(cmd).await
327    }
328
329    fn executor(&self) -> &Arc<dyn ToolExecutor> {
330        (**self).executor()
331    }
332
333    async fn history(&self) -> Result<Vec<ChatMessage>> {
334        (**self).history().await
335    }
336
337    async fn redrive(&self) -> Result<()> {
338        (**self).redrive().await
339    }
340
341    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
342        (**self).has_received_feedback(tool_call).await
343    }
344
345    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
346        (**self).feedback_received(tool_call, feedback).await
347    }
348}
349
350#[async_trait]
351impl AgentContext for &dyn AgentContext {
352    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
353        (**self).next_completion().await
354    }
355
356    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
357        (**self).current_new_messages().await
358    }
359
360    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
361        (**self).add_messages(item).await
362    }
363
364    async fn add_message(&self, item: ChatMessage) -> Result<()> {
365        (**self).add_message(item).await
366    }
367
368    #[allow(deprecated)]
369    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
370        (**self).exec_cmd(cmd).await
371    }
372
373    fn executor(&self) -> &Arc<dyn ToolExecutor> {
374        (**self).executor()
375    }
376
377    async fn history(&self) -> Result<Vec<ChatMessage>> {
378        (**self).history().await
379    }
380
381    async fn redrive(&self) -> Result<()> {
382        (**self).redrive().await
383    }
384
385    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
386        (**self).has_received_feedback(tool_call).await
387    }
388
389    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
390        (**self).feedback_received(tool_call, feedback).await
391    }
392}
393
394/// Convenience implementation for empty agent context
395///
396/// Errors if tools attempt to execute commands
397#[async_trait]
398impl AgentContext for () {
399    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
400        Ok(None)
401    }
402
403    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
404        Ok(Vec::new())
405    }
406
407    async fn add_messages(&self, _item: Vec<ChatMessage>) -> Result<()> {
408        Ok(())
409    }
410
411    async fn add_message(&self, _item: ChatMessage) -> Result<()> {
412        Ok(())
413    }
414
415    async fn exec_cmd(&self, _cmd: &Command) -> Result<CommandOutput, CommandError> {
416        Err(CommandError::ExecutorError(anyhow::anyhow!(
417            "Empty agent context does not have a tool executor"
418        )))
419    }
420
421    fn executor(&self) -> &Arc<dyn ToolExecutor> {
422        unimplemented!("Empty agent context does not have a tool executor")
423    }
424
425    async fn history(&self) -> Result<Vec<ChatMessage>> {
426        Ok(Vec::new())
427    }
428
429    async fn redrive(&self) -> Result<()> {
430        Ok(())
431    }
432
433    async fn has_received_feedback(&self, _tool_call: &ToolCall) -> Option<ToolFeedback> {
434        Some(ToolFeedback::Approved { payload: None })
435    }
436
437    async fn feedback_received(
438        &self,
439        _tool_call: &ToolCall,
440        _feedback: &ToolFeedback,
441    ) -> Result<()> {
442        Ok(())
443    }
444}
445
446/// A backend for the agent context. A default is provided for Arc<Mutex<Vec<ChatMessage>>>
447///
448/// If you want to use for instance a database, implement this trait and pass it to the agent
449/// context when creating it.
450#[async_trait]
451pub trait MessageHistory: Send + Sync + std::fmt::Debug {
452    /// Returns the history of messages
453    async fn history(&self) -> Result<Vec<ChatMessage>>;
454
455    /// Add a message to the history
456    async fn push_owned(&self, item: ChatMessage) -> Result<()>;
457
458    /// Overwrite the history with the given items
459    async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()>;
460
461    /// Add a message to the history
462    async fn push(&self, item: &ChatMessage) -> Result<()> {
463        self.push_owned(item.clone()).await
464    }
465
466    /// Extend the history with the given items
467    async fn extend(&self, items: &[ChatMessage]) -> Result<()> {
468        self.extend_owned(items.to_vec()).await
469    }
470
471    /// Extend the history with the given items, taking ownership of them
472    async fn extend_owned(&self, items: Vec<ChatMessage>) -> Result<()> {
473        for item in items {
474            self.push_owned(item).await?;
475        }
476
477        Ok(())
478    }
479}
480
481#[async_trait]
482impl MessageHistory for Mutex<Vec<ChatMessage>> {
483    async fn history(&self) -> Result<Vec<ChatMessage>> {
484        Ok(self.lock().unwrap().clone())
485    }
486
487    async fn push_owned(&self, item: ChatMessage) -> Result<()> {
488        self.lock().unwrap().push(item);
489
490        Ok(())
491    }
492
493    async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()> {
494        let mut lock = self.lock().unwrap();
495        *lock = items;
496
497        Ok(())
498    }
499}