swiftide_integrations/openai/
mod.rs

1//! This module provides integration with `OpenAI`'s API, enabling the use of language models and
2//! embeddings within the Swiftide project. It includes the `OpenAI` struct for managing API clients
3//! and default options for embedding and prompt models. The module is conditionally compiled based
4//! on the "openai" feature flag.
5
6use derive_builder::Builder;
7use std::sync::Arc;
8
9mod chat_completion;
10mod embed;
11mod simple_prompt;
12
13// expose type aliases to simplify downstream use of the open ai builder invocations
14pub use async_openai::config::AzureConfig;
15pub use async_openai::config::OpenAIConfig;
16
17#[cfg(feature = "tiktoken")]
18use crate::tiktoken::TikToken;
19#[cfg(feature = "tiktoken")]
20use anyhow::Result;
21#[cfg(feature = "tiktoken")]
22use swiftide_core::Estimatable;
23#[cfg(feature = "tiktoken")]
24use swiftide_core::EstimateTokens;
25
26/// The `OpenAI` struct encapsulates an `OpenAI` client and default options for embedding and prompt
27/// models. It uses the `Builder` pattern for flexible and customizable instantiation.
28///
29/// # Example
30///
31/// ```no_run
32/// # use swiftide_integrations::openai::OpenAI;
33/// # use swiftide_integrations::openai::OpenAIConfig;
34///
35/// // Create an OpenAI client with default options. The client will use the OPENAI_API_KEY environment variable.
36/// let openai = OpenAI::builder()
37///     .default_embed_model("text-embedding-3-small")
38///     .default_prompt_model("gpt-4")
39///     .build().unwrap();
40///
41/// // Create an OpenAI client with a custom api key.
42/// let openai = OpenAI::builder()
43///     .default_embed_model("text-embedding-3-small")
44///     .default_prompt_model("gpt-4")
45///     .client(async_openai::Client::with_config(async_openai::config::OpenAIConfig::default().with_api_key("my-api-key")))
46///     .build().unwrap();
47/// ```
48pub type OpenAI = GenericOpenAI<OpenAIConfig>;
49pub type OpenAIBuilder = GenericOpenAIBuilder<OpenAIConfig>;
50
51#[derive(Debug, Builder, Clone)]
52#[builder(setter(into, strip_option))]
53/// Generic client for `OpenAI` APIs.
54pub struct GenericOpenAI<
55    C: async_openai::config::Config + Default = async_openai::config::OpenAIConfig,
56> {
57    /// The `OpenAI` client, wrapped in an `Arc` for thread-safe reference counting.
58    /// Defaults to a new instance of `async_openai::Client`.
59    #[builder(
60        default = "Arc::new(async_openai::Client::<C>::default())",
61        setter(custom)
62    )]
63    client: Arc<async_openai::Client<C>>,
64
65    /// Default options for embedding and prompt models.
66    #[builder(default)]
67    pub(crate) default_options: Options,
68
69    #[cfg(feature = "tiktoken")]
70    #[cfg_attr(feature = "tiktoken", builder( default = self.default_tiktoken()))]
71    pub(crate) tiktoken: TikToken,
72}
73
74/// The `Options` struct holds configuration options for the `OpenAI` client.
75/// It includes optional fields for specifying the embedding and prompt models.
76#[derive(Debug, Clone, Builder)]
77#[builder(setter(into, strip_option))]
78pub struct Options {
79    /// The default embedding model to use, if specified.
80    #[builder(default)]
81    pub embed_model: Option<String>,
82    /// The default prompt model to use, if specified.
83    #[builder(default)]
84    pub prompt_model: Option<String>,
85
86    #[builder(default = Some(true))]
87    /// Option to enable or disable parallel tool calls for completions.
88    ///
89    /// At this moment, o1 and o3-mini do not support it and should be set to `None`.
90    pub parallel_tool_calls: Option<bool>,
91}
92
93impl Default for Options {
94    fn default() -> Self {
95        Self {
96            embed_model: None,
97            prompt_model: None,
98            parallel_tool_calls: Some(true),
99        }
100    }
101}
102
103impl Options {
104    /// Creates a new `OptionsBuilder` for constructing `Options` instances.
105    pub fn builder() -> OptionsBuilder {
106        OptionsBuilder::default()
107    }
108}
109
110impl OpenAI {
111    /// Creates a new `OpenAIBuilder` for constructing `OpenAI` instances.
112    pub fn builder() -> OpenAIBuilder {
113        OpenAIBuilder::default()
114    }
115}
116
117impl<C: async_openai::config::Config + Default + Sync + Send + std::fmt::Debug>
118    GenericOpenAIBuilder<C>
119{
120    /// Sets the `OpenAI` client for the `OpenAI` instance.
121    ///
122    /// # Parameters
123    /// - `client`: The `OpenAI` client to set.
124    ///
125    /// # Returns
126    /// A mutable reference to the `OpenAIBuilder`.
127    pub fn client(&mut self, client: async_openai::Client<C>) -> &mut Self {
128        self.client = Some(Arc::new(client));
129        self
130    }
131
132    /// Sets the default embedding model for the `OpenAI` instance.
133    ///
134    /// # Parameters
135    /// - `model`: The embedding model to set.
136    ///
137    /// # Returns
138    /// A mutable reference to the `OpenAIBuilder`.
139    pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
140        if let Some(options) = self.default_options.as_mut() {
141            options.embed_model = Some(model.into());
142        } else {
143            self.default_options = Some(Options {
144                embed_model: Some(model.into()),
145                ..Default::default()
146            });
147        }
148        self
149    }
150
151    /// Enable or disable parallel tool calls for completions.
152    ///
153    /// Note that currently reasoning models do not support parallel tool calls
154    ///
155    /// Defaults to `true`
156    pub fn parallel_tool_calls(&mut self, parallel_tool_calls: Option<bool>) -> &mut Self {
157        if let Some(options) = self.default_options.as_mut() {
158            options.parallel_tool_calls = parallel_tool_calls;
159        } else {
160            self.default_options = Some(Options {
161                parallel_tool_calls,
162                ..Default::default()
163            });
164        }
165        self
166    }
167
168    /// Sets the default prompt model for the `OpenAI` instance.
169    ///
170    /// # Parameters
171    /// - `model`: The prompt model to set.
172    ///
173    /// # Returns
174    /// A mutable reference to the `OpenAIBuilder`.
175    pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
176        if let Some(options) = self.default_options.as_mut() {
177            options.prompt_model = Some(model.into());
178        } else {
179            self.default_options = Some(Options {
180                prompt_model: Some(model.into()),
181                ..Default::default()
182            });
183        }
184        self
185    }
186}
187impl<C: async_openai::config::Config + Default> GenericOpenAIBuilder<C> {
188    #[cfg(feature = "tiktoken")]
189    fn default_tiktoken(&self) -> TikToken {
190        let model = self
191            .default_options
192            .as_ref()
193            .and_then(|o| o.prompt_model.as_deref())
194            .unwrap_or("gpt-4");
195
196        TikToken::try_from_model(model).expect("Failed to build default model; infallible")
197    }
198}
199
200impl<C: async_openai::config::Config + Default> GenericOpenAI<C> {
201    /// Estimates the number of tokens for implementors of the `Estimatable` trait.
202    ///
203    /// I.e. `String`, `ChatMessage` etc
204    ///
205    /// # Errors
206    ///
207    /// Errors if tokinization fails in any way
208    #[cfg(feature = "tiktoken")]
209    pub async fn estimate_tokens(&self, value: impl Estimatable) -> Result<usize> {
210        self.tiktoken.estimate(value).await
211    }
212}
213
214#[cfg(test)]
215mod test {
216    use super::*;
217
218    /// test default embed model
219    #[test]
220    fn test_default_embed_and_prompt_model() {
221        let openai: OpenAI = OpenAI::builder()
222            .default_embed_model("gpt-3")
223            .default_prompt_model("gpt-4")
224            .build()
225            .unwrap();
226        assert_eq!(
227            openai.default_options.embed_model,
228            Some("gpt-3".to_string())
229        );
230        assert_eq!(
231            openai.default_options.prompt_model,
232            Some("gpt-4".to_string())
233        );
234
235        let openai: OpenAI = OpenAI::builder()
236            .default_prompt_model("gpt-4")
237            .default_embed_model("gpt-3")
238            .build()
239            .unwrap();
240        assert_eq!(
241            openai.default_options.prompt_model,
242            Some("gpt-4".to_string())
243        );
244        assert_eq!(
245            openai.default_options.embed_model,
246            Some("gpt-3".to_string())
247        );
248    }
249}