swiftide_agents/tasks/
impls.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use swiftide_core::{
5 ChatCompletion, Command, CommandError, CommandOutput, SimplePrompt, ToolExecutor,
6 chat_completion::{ChatCompletionRequest, ChatCompletionResponse, errors::LanguageModelError},
7 prompt::Prompt,
8};
9use tokio::sync::Mutex;
10
11use crate::{Agent, errors::AgentError};
12
13use super::node::{NodeArg, NodeId, TaskNode};
14
15#[derive(Clone, Debug)]
17pub struct TaskAgent(Arc<Mutex<Agent>>);
18
19impl From<Agent> for TaskAgent {
20 fn from(agent: Agent) -> Self {
21 TaskAgent(Arc::new(Mutex::new(agent)))
22 }
23}
24
25#[async_trait]
29impl TaskNode for TaskAgent {
30 type Input = Prompt;
31
32 type Output = ();
33
34 type Error = AgentError;
35
36 async fn evaluate(
37 &self,
38 _node_id: &NodeId<
39 dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
40 >,
41 input: &Self::Input,
42 ) -> Result<Self::Output, Self::Error> {
43 self.0.lock().await.query(input.clone()).await
44 }
45}
46
47#[async_trait]
48impl TaskNode for Box<dyn SimplePrompt> {
49 type Input = Prompt;
50
51 type Output = String;
52
53 type Error = LanguageModelError;
54
55 async fn evaluate(
56 &self,
57 _node_id: &NodeId<
58 dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
59 >,
60 input: &Self::Input,
61 ) -> Result<Self::Output, Self::Error> {
62 self.prompt(input.clone()).await
64 }
65}
66
67#[async_trait]
68impl TaskNode for Arc<dyn SimplePrompt> {
69 type Input = Prompt;
70
71 type Output = String;
72
73 type Error = LanguageModelError;
74
75 async fn evaluate(
76 &self,
77 _node_id: &NodeId<
78 dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
79 >,
80 input: &Self::Input,
81 ) -> Result<Self::Output, Self::Error> {
82 self.prompt(input.clone()).await
84 }
85}
86
87#[async_trait]
88impl TaskNode for Box<dyn ChatCompletion> {
89 type Input = ChatCompletionRequest;
90
91 type Output = ChatCompletionResponse;
92
93 type Error = LanguageModelError;
94
95 async fn evaluate(
96 &self,
97 _node_id: &NodeId<
98 dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
99 >,
100 input: &Self::Input,
101 ) -> Result<Self::Output, Self::Error> {
102 self.complete(input).await
103 }
104}
105
106#[async_trait]
107impl TaskNode for Arc<dyn ChatCompletion> {
108 type Input = ChatCompletionRequest;
109
110 type Output = ChatCompletionResponse;
111
112 type Error = LanguageModelError;
113
114 async fn evaluate(
115 &self,
116 _node_id: &NodeId<
117 dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
118 >,
119 input: &Self::Input,
120 ) -> Result<Self::Output, Self::Error> {
121 self.complete(input).await
122 }
123}
124
125#[async_trait]
126impl TaskNode for Box<dyn ToolExecutor> {
127 type Input = Command;
128
129 type Output = CommandOutput;
130
131 type Error = CommandError;
132
133 async fn evaluate(
134 &self,
135 _node_id: &NodeId<
136 dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
137 >,
138 input: &Self::Input,
139 ) -> Result<Self::Output, Self::Error> {
140 self.exec_cmd(input).await
141 }
142}
143
144#[async_trait]
145impl TaskNode for Arc<dyn ToolExecutor> {
146 type Input = Command;
147
148 type Output = CommandOutput;
149
150 type Error = CommandError;
151
152 async fn evaluate(
153 &self,
154 _node_id: &NodeId<
155 dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
156 >,
157 input: &Self::Input,
158 ) -> Result<Self::Output, Self::Error> {
159 self.exec_cmd(input).await
160 }
161}
162
163#[async_trait]
165impl<I: NodeArg, O: NodeArg, E: std::error::Error + Send + Sync + 'static> TaskNode
166 for fn(&I) -> Result<O, E>
167{
168 type Input = I;
169
170 type Output = O;
171
172 type Error = E;
173
174 async fn evaluate(
175 &self,
176 _node_id: &NodeId<
177 dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
178 >,
179 input: &Self::Input,
180 ) -> Result<Self::Output, Self::Error> {
181 (self)(input)
182 }
183}