swiftide_core/chat_completion/
traits.rs

1use async_trait::async_trait;
2use dyn_clone::DynClone;
3
4use crate::{AgentContext, CommandOutput};
5
6use super::{
7    chat_completion_request::ChatCompletionRequest,
8    chat_completion_response::ChatCompletionResponse,
9    errors::{ChatCompletionError, ToolError},
10    ToolOutput, ToolSpec,
11};
12
13#[async_trait]
14pub trait ChatCompletion: Send + Sync + DynClone {
15    async fn complete(
16        &self,
17        request: &ChatCompletionRequest,
18    ) -> Result<ChatCompletionResponse, ChatCompletionError>;
19}
20
21#[async_trait]
22impl ChatCompletion for Box<dyn ChatCompletion> {
23    async fn complete(
24        &self,
25        request: &ChatCompletionRequest,
26    ) -> Result<ChatCompletionResponse, ChatCompletionError> {
27        (**self).complete(request).await
28    }
29}
30
31#[async_trait]
32impl ChatCompletion for &dyn ChatCompletion {
33    async fn complete(
34        &self,
35        request: &ChatCompletionRequest,
36    ) -> Result<ChatCompletionResponse, ChatCompletionError> {
37        (**self).complete(request).await
38    }
39}
40
41#[async_trait]
42impl<T> ChatCompletion for &T
43where
44    T: ChatCompletion + Clone + 'static,
45{
46    async fn complete(
47        &self,
48        request: &ChatCompletionRequest,
49    ) -> Result<ChatCompletionResponse, ChatCompletionError> {
50        (**self).complete(request).await
51    }
52}
53
54impl<LLM> From<&LLM> for Box<dyn ChatCompletion>
55where
56    LLM: ChatCompletion + Clone + 'static,
57{
58    fn from(llm: &LLM) -> Self {
59        Box::new(llm.clone()) as Box<dyn ChatCompletion>
60    }
61}
62
63dyn_clone::clone_trait_object!(ChatCompletion);
64
65impl From<CommandOutput> for ToolOutput {
66    fn from(value: CommandOutput) -> Self {
67        ToolOutput::Text(value.output)
68    }
69}
70
71#[async_trait]
72pub trait Tool: Send + Sync + DynClone {
73    // tbd
74    async fn invoke(
75        &self,
76        agent_context: &dyn AgentContext,
77        raw_args: Option<&str>,
78    ) -> Result<ToolOutput, ToolError>;
79
80    fn name(&self) -> &'static str;
81
82    fn tool_spec(&self) -> ToolSpec;
83
84    fn boxed<'a>(self) -> Box<dyn Tool + 'a>
85    where
86        Self: Sized + 'a,
87    {
88        Box::new(self) as Box<dyn Tool>
89    }
90}
91
92#[async_trait]
93impl Tool for Box<dyn Tool> {
94    async fn invoke(
95        &self,
96        agent_context: &dyn AgentContext,
97        raw_args: Option<&str>,
98    ) -> Result<ToolOutput, ToolError> {
99        (**self).invoke(agent_context, raw_args).await
100    }
101    fn name(&self) -> &'static str {
102        (**self).name()
103    }
104    fn tool_spec(&self) -> ToolSpec {
105        (**self).tool_spec()
106    }
107}
108
109dyn_clone::clone_trait_object!(Tool);
110
111// impl<T> From<T> for Box<dyn Tool + '_>
112// where
113//     for<'b> T: Tool + 'b,
114// {
115//     fn from(value: T) -> Self {
116//         // dyn_clone::clone_box(&value)
117//         Box::new(value)
118//     }
119// }
120
121/// Tools are identified and unique by name
122/// These allow comparison and lookups
123impl PartialEq for Box<dyn Tool> {
124    fn eq(&self, other: &Self) -> bool {
125        self.name() == other.name()
126    }
127}
128impl Eq for Box<dyn Tool> {}
129impl std::hash::Hash for Box<dyn Tool> {
130    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
131        self.name().hash(state);
132    }
133}