swiftide_integrations/openai/
mod.rs1use derive_builder::Builder;
7use std::sync::Arc;
8
9mod chat_completion;
10mod embed;
11mod simple_prompt;
12
13pub use async_openai::config::AzureConfig;
15pub use async_openai::config::OpenAIConfig;
16
17#[cfg(feature = "tiktoken")]
18use crate::tiktoken::TikToken;
19#[cfg(feature = "tiktoken")]
20use anyhow::Result;
21#[cfg(feature = "tiktoken")]
22use swiftide_core::Estimatable;
23#[cfg(feature = "tiktoken")]
24use swiftide_core::EstimateTokens;
25
26pub type OpenAI = GenericOpenAI<OpenAIConfig>;
49pub type OpenAIBuilder = GenericOpenAIBuilder<OpenAIConfig>;
50
51#[derive(Debug, Builder, Clone)]
52#[builder(setter(into, strip_option))]
53pub struct GenericOpenAI<
55 C: async_openai::config::Config + Default = async_openai::config::OpenAIConfig,
56> {
57 #[builder(
60 default = "Arc::new(async_openai::Client::<C>::default())",
61 setter(custom)
62 )]
63 client: Arc<async_openai::Client<C>>,
64
65 #[builder(default)]
67 pub(crate) default_options: Options,
68
69 #[cfg(feature = "tiktoken")]
70 #[cfg_attr(feature = "tiktoken", builder( default = self.default_tiktoken()))]
71 pub(crate) tiktoken: TikToken,
72}
73
74#[derive(Debug, Clone, Builder)]
77#[builder(setter(into, strip_option))]
78pub struct Options {
79 #[builder(default)]
81 pub embed_model: Option<String>,
82 #[builder(default)]
84 pub prompt_model: Option<String>,
85
86 #[builder(default = Some(true))]
87 pub parallel_tool_calls: Option<bool>,
91}
92
93impl Default for Options {
94 fn default() -> Self {
95 Self {
96 embed_model: None,
97 prompt_model: None,
98 parallel_tool_calls: Some(true),
99 }
100 }
101}
102
103impl Options {
104 pub fn builder() -> OptionsBuilder {
106 OptionsBuilder::default()
107 }
108}
109
110impl OpenAI {
111 pub fn builder() -> OpenAIBuilder {
113 OpenAIBuilder::default()
114 }
115}
116
117impl<C: async_openai::config::Config + Default + Sync + Send + std::fmt::Debug>
118 GenericOpenAIBuilder<C>
119{
120 pub fn client(&mut self, client: async_openai::Client<C>) -> &mut Self {
128 self.client = Some(Arc::new(client));
129 self
130 }
131
132 pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
140 if let Some(options) = self.default_options.as_mut() {
141 options.embed_model = Some(model.into());
142 } else {
143 self.default_options = Some(Options {
144 embed_model: Some(model.into()),
145 ..Default::default()
146 });
147 }
148 self
149 }
150
151 pub fn parallel_tool_calls(&mut self, parallel_tool_calls: Option<bool>) -> &mut Self {
157 if let Some(options) = self.default_options.as_mut() {
158 options.parallel_tool_calls = parallel_tool_calls;
159 } else {
160 self.default_options = Some(Options {
161 parallel_tool_calls,
162 ..Default::default()
163 });
164 }
165 self
166 }
167
168 pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
176 if let Some(options) = self.default_options.as_mut() {
177 options.prompt_model = Some(model.into());
178 } else {
179 self.default_options = Some(Options {
180 prompt_model: Some(model.into()),
181 ..Default::default()
182 });
183 }
184 self
185 }
186}
187impl<C: async_openai::config::Config + Default> GenericOpenAIBuilder<C> {
188 #[cfg(feature = "tiktoken")]
189 fn default_tiktoken(&self) -> TikToken {
190 let model = self
191 .default_options
192 .as_ref()
193 .and_then(|o| o.prompt_model.as_deref())
194 .unwrap_or("gpt-4");
195
196 TikToken::try_from_model(model).expect("Failed to build default model; infallible")
197 }
198}
199
200impl<C: async_openai::config::Config + Default> GenericOpenAI<C> {
201 #[cfg(feature = "tiktoken")]
209 pub async fn estimate_tokens(&self, value: impl Estimatable) -> Result<usize> {
210 self.tiktoken.estimate(value).await
211 }
212}
213
214#[cfg(test)]
215mod test {
216 use super::*;
217
218 #[test]
220 fn test_default_embed_and_prompt_model() {
221 let openai: OpenAI = OpenAI::builder()
222 .default_embed_model("gpt-3")
223 .default_prompt_model("gpt-4")
224 .build()
225 .unwrap();
226 assert_eq!(
227 openai.default_options.embed_model,
228 Some("gpt-3".to_string())
229 );
230 assert_eq!(
231 openai.default_options.prompt_model,
232 Some("gpt-4".to_string())
233 );
234
235 let openai: OpenAI = OpenAI::builder()
236 .default_prompt_model("gpt-4")
237 .default_embed_model("gpt-3")
238 .build()
239 .unwrap();
240 assert_eq!(
241 openai.default_options.prompt_model,
242 Some("gpt-4".to_string())
243 );
244 assert_eq!(
245 openai.default_options.embed_model,
246 Some("gpt-3".to_string())
247 );
248 }
249}