swiftide_agents/tasks/
impls.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use swiftide_core::{
5    ChatCompletion, Command, CommandError, CommandOutput, SimplePrompt, ToolExecutor,
6    chat_completion::{ChatCompletionRequest, ChatCompletionResponse, errors::LanguageModelError},
7    prompt::Prompt,
8};
9use tokio::sync::Mutex;
10
11use crate::{Agent, errors::AgentError};
12
13use super::node::{NodeArg, NodeId, TaskNode};
14
15// TODO: Consider removing this and providing docs instead
16#[derive(Clone, Debug)]
17pub struct TaskAgent(Arc<Mutex<Agent>>);
18
19impl From<Agent> for TaskAgent {
20    fn from(agent: Agent) -> Self {
21        TaskAgent(Arc::new(Mutex::new(agent)))
22    }
23}
24
25/// A 'default' implementation for an agent where there is no output
26///
27/// TODO: Make this nicer :))
28#[async_trait]
29impl TaskNode for TaskAgent {
30    type Input = Prompt;
31
32    type Output = ();
33
34    type Error = AgentError;
35
36    async fn evaluate(
37        &self,
38        _node_id: &NodeId<
39            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
40        >,
41        input: &Self::Input,
42    ) -> Result<Self::Output, Self::Error> {
43        self.0.lock().await.query(input.clone()).await
44    }
45}
46
47#[async_trait]
48impl TaskNode for Box<dyn SimplePrompt> {
49    type Input = Prompt;
50
51    type Output = String;
52
53    type Error = LanguageModelError;
54
55    async fn evaluate(
56        &self,
57        _node_id: &NodeId<
58            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
59        >,
60        input: &Self::Input,
61    ) -> Result<Self::Output, Self::Error> {
62        // TODO: Prompt should be borrowed
63        self.prompt(input.clone()).await
64    }
65}
66
67#[async_trait]
68impl TaskNode for Arc<dyn SimplePrompt> {
69    type Input = Prompt;
70
71    type Output = String;
72
73    type Error = LanguageModelError;
74
75    async fn evaluate(
76        &self,
77        _node_id: &NodeId<
78            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
79        >,
80        input: &Self::Input,
81    ) -> Result<Self::Output, Self::Error> {
82        // TODO: Prompt should be borrowed
83        self.prompt(input.clone()).await
84    }
85}
86
87#[async_trait]
88impl TaskNode for Box<dyn ChatCompletion> {
89    type Input = ChatCompletionRequest;
90
91    type Output = ChatCompletionResponse;
92
93    type Error = LanguageModelError;
94
95    async fn evaluate(
96        &self,
97        _node_id: &NodeId<
98            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
99        >,
100        input: &Self::Input,
101    ) -> Result<Self::Output, Self::Error> {
102        self.complete(input).await
103    }
104}
105
106#[async_trait]
107impl TaskNode for Arc<dyn ChatCompletion> {
108    type Input = ChatCompletionRequest;
109
110    type Output = ChatCompletionResponse;
111
112    type Error = LanguageModelError;
113
114    async fn evaluate(
115        &self,
116        _node_id: &NodeId<
117            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
118        >,
119        input: &Self::Input,
120    ) -> Result<Self::Output, Self::Error> {
121        self.complete(input).await
122    }
123}
124
125#[async_trait]
126impl TaskNode for Box<dyn ToolExecutor> {
127    type Input = Command;
128
129    type Output = CommandOutput;
130
131    type Error = CommandError;
132
133    async fn evaluate(
134        &self,
135        _node_id: &NodeId<
136            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
137        >,
138        input: &Self::Input,
139    ) -> Result<Self::Output, Self::Error> {
140        self.exec_cmd(input).await
141    }
142}
143
144#[async_trait]
145impl TaskNode for Arc<dyn ToolExecutor> {
146    type Input = Command;
147
148    type Output = CommandOutput;
149
150    type Error = CommandError;
151
152    async fn evaluate(
153        &self,
154        _node_id: &NodeId<
155            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
156        >,
157        input: &Self::Input,
158    ) -> Result<Self::Output, Self::Error> {
159        self.exec_cmd(input).await
160    }
161}
162
163// Note: This only works for function pointers, not closures.
164#[async_trait]
165impl<I: NodeArg, O: NodeArg, E: std::error::Error + Send + Sync + 'static> TaskNode
166    for fn(&I) -> Result<O, E>
167{
168    type Input = I;
169
170    type Output = O;
171
172    type Error = E;
173
174    async fn evaluate(
175        &self,
176        _node_id: &NodeId<
177            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
178        >,
179        input: &Self::Input,
180    ) -> Result<Self::Output, Self::Error> {
181        (self)(input)
182    }
183}