swiftide_core/chat_completion/
traits.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use dyn_clone::DynClone;
4use futures_util::Stream;
5use std::{borrow::Cow, pin::Pin, sync::Arc};
6
7use crate::AgentContext;
8
9use super::{
10    ToolCall, ToolOutput, ToolSpec,
11    chat_completion_request::ChatCompletionRequest,
12    chat_completion_response::ChatCompletionResponse,
13    errors::{LanguageModelError, ToolError},
14};
15
16pub type ChatCompletionStream =
17    Pin<Box<dyn Stream<Item = Result<ChatCompletionResponse, LanguageModelError>> + Send>>;
18#[async_trait]
19pub trait ChatCompletion: Send + Sync + DynClone {
20    async fn complete(
21        &self,
22        request: &ChatCompletionRequest,
23    ) -> Result<ChatCompletionResponse, LanguageModelError>;
24
25    /// Stream the completion response. If it's not supported, it will return a single
26    /// response
27    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
28        Box::pin(tokio_stream::iter(vec![self.complete(request).await]))
29    }
30}
31
32#[async_trait]
33impl ChatCompletion for Box<dyn ChatCompletion> {
34    async fn complete(
35        &self,
36        request: &ChatCompletionRequest,
37    ) -> Result<ChatCompletionResponse, LanguageModelError> {
38        (**self).complete(request).await
39    }
40
41    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
42        (**self).complete_stream(request).await
43    }
44}
45
46#[async_trait]
47impl ChatCompletion for &dyn ChatCompletion {
48    async fn complete(
49        &self,
50        request: &ChatCompletionRequest,
51    ) -> Result<ChatCompletionResponse, LanguageModelError> {
52        (**self).complete(request).await
53    }
54
55    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
56        (**self).complete_stream(request).await
57    }
58}
59
60#[async_trait]
61impl<T> ChatCompletion for &T
62where
63    T: ChatCompletion + Clone + 'static,
64{
65    async fn complete(
66        &self,
67        request: &ChatCompletionRequest,
68    ) -> Result<ChatCompletionResponse, LanguageModelError> {
69        (**self).complete(request).await
70    }
71
72    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
73        (**self).complete_stream(request).await
74    }
75}
76
77impl<LLM> From<&LLM> for Box<dyn ChatCompletion>
78where
79    LLM: ChatCompletion + Clone + 'static,
80{
81    fn from(llm: &LLM) -> Self {
82        Box::new(llm.clone()) as Box<dyn ChatCompletion>
83    }
84}
85
86dyn_clone::clone_trait_object!(ChatCompletion);
87
88/// The `Tool` trait is the main interface for chat completion and agent tools.
89///
90/// `swiftide-macros` provides a set of macros to generate implementations of this trait. If you
91/// need more control over the implementation, you can implement the trait manually.
92///
93/// The `ToolSpec` is what will end up with the LLM. A builder is provided. The `name` is expected
94/// to be unique, and is used to identify the tool. It should be the same as the name in the
95/// `ToolSpec`.
96#[async_trait]
97pub trait Tool: Send + Sync + DynClone {
98    // tbd
99    async fn invoke(
100        &self,
101        agent_context: &dyn AgentContext,
102        tool_call: &ToolCall,
103    ) -> Result<ToolOutput, ToolError>;
104
105    fn name(&self) -> Cow<'_, str>;
106
107    fn tool_spec(&self) -> ToolSpec;
108
109    fn boxed<'a>(self) -> Box<dyn Tool + 'a>
110    where
111        Self: Sized + 'a,
112    {
113        Box::new(self) as Box<dyn Tool>
114    }
115}
116
117/// A toolbox is a collection of tools
118///
119/// It can be a list, an mcp client, or anything else we can think of.
120///
121/// This allows agents to not know their tools when they are created, and to get them at runtime.
122///
123/// It also allows for tools to be dynamically loaded and unloaded, etc.
124#[async_trait]
125pub trait ToolBox: Send + Sync + DynClone {
126    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>>;
127
128    fn name(&self) -> Cow<'_, str> {
129        Cow::Borrowed("Unnamed ToolBox")
130    }
131
132    fn boxed<'a>(self) -> Box<dyn ToolBox + 'a>
133    where
134        Self: Sized + 'a,
135    {
136        Box::new(self) as Box<dyn ToolBox>
137    }
138}
139
140#[async_trait]
141impl ToolBox for Vec<Box<dyn Tool>> {
142    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
143        Ok(self.clone())
144    }
145}
146
147#[async_trait]
148impl ToolBox for Box<dyn ToolBox> {
149    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
150        (**self).available_tools().await
151    }
152}
153
154#[async_trait]
155impl ToolBox for Arc<dyn ToolBox> {
156    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
157        (**self).available_tools().await
158    }
159}
160
161#[async_trait]
162impl ToolBox for &dyn ToolBox {
163    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
164        (**self).available_tools().await
165    }
166}
167
168#[async_trait]
169impl ToolBox for &[Box<dyn Tool>] {
170    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
171        Ok(self.to_vec())
172    }
173}
174
175#[async_trait]
176impl ToolBox for [Box<dyn Tool>] {
177    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
178        Ok(self.to_vec())
179    }
180}
181
182dyn_clone::clone_trait_object!(ToolBox);
183
184#[async_trait]
185impl Tool for Box<dyn Tool> {
186    async fn invoke(
187        &self,
188        agent_context: &dyn AgentContext,
189        tool_call: &ToolCall,
190    ) -> Result<ToolOutput, ToolError> {
191        (**self).invoke(agent_context, tool_call).await
192    }
193    fn name(&self) -> Cow<'_, str> {
194        (**self).name()
195    }
196    fn tool_spec(&self) -> ToolSpec {
197        (**self).tool_spec()
198    }
199}
200
201dyn_clone::clone_trait_object!(Tool);
202
203/// Tools are identified and unique by name
204/// These allow comparison and lookups
205impl PartialEq for Box<dyn Tool> {
206    fn eq(&self, other: &Self) -> bool {
207        self.name() == other.name()
208    }
209}
210impl Eq for Box<dyn Tool> {}
211impl std::hash::Hash for Box<dyn Tool> {
212    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
213        self.name().hash(state);
214    }
215}