swiftide_integrations/openai/
mod.rs

1//! This module provides integration with `OpenAI`'s API, enabling the use of language models and
2//! embeddings within the Swiftide project. It includes the `OpenAI` struct for managing API clients
3//! and default options for embedding and prompt models. The module is conditionally compiled based
4//! on the "openai" feature flag.
5
6use async_openai::error::OpenAIError;
7use async_openai::types::CreateChatCompletionRequestArgs;
8use async_openai::types::CreateEmbeddingRequestArgs;
9use async_openai::types::ReasoningEffort;
10use derive_builder::Builder;
11use std::sync::Arc;
12use swiftide_core::chat_completion::errors::LanguageModelError;
13
14mod chat_completion;
15mod embed;
16mod simple_prompt;
17
18// expose type aliases to simplify downstream use of the open ai builder invocations
19pub use async_openai::config::AzureConfig;
20pub use async_openai::config::OpenAIConfig;
21
22#[cfg(feature = "tiktoken")]
23use crate::tiktoken::TikToken;
24#[cfg(feature = "tiktoken")]
25use anyhow::Result;
26#[cfg(feature = "tiktoken")]
27use swiftide_core::Estimatable;
28#[cfg(feature = "tiktoken")]
29use swiftide_core::EstimateTokens;
30
31/// The `OpenAI` struct encapsulates an `OpenAI` client and default options for embedding and prompt
32/// models. It uses the `Builder` pattern for flexible and customizable instantiation.
33///
34/// # Example
35///
36/// ```no_run
37/// # use swiftide_integrations::openai::{OpenAI, Options};
38/// # use swiftide_integrations::openai::OpenAIConfig;
39///
40/// // Create an OpenAI client with default options. The client will use the OPENAI_API_KEY environment variable.
41/// let openai = OpenAI::builder()
42///     .default_embed_model("text-embedding-3-small")
43///     .default_prompt_model("gpt-4")
44///     .build().unwrap();
45///
46/// // Create an OpenAI client with a custom api key.
47/// let openai = OpenAI::builder()
48///     .default_embed_model("text-embedding-3-small")
49///     .default_prompt_model("gpt-4")
50///     .client(async_openai::Client::with_config(async_openai::config::OpenAIConfig::default().with_api_key("my-api-key")))
51///     .build().unwrap();
52///
53/// // Create an OpenAI client with custom options
54/// let openai = OpenAI::builder()
55///     .default_embed_model("text-embedding-3-small")
56///     .default_prompt_model("gpt-4")
57///     .default_options(
58///         Options::builder()
59///           .temperature(1.0)
60///           .parallel_tool_calls(false)
61///           .user("MyUserId")
62///     )
63///     .build().unwrap();
64/// ```
65pub type OpenAI = GenericOpenAI<OpenAIConfig>;
66pub type OpenAIBuilder = GenericOpenAIBuilder<OpenAIConfig>;
67
68#[derive(Debug, Builder, Clone)]
69#[builder(setter(into, strip_option))]
70/// Generic client for `OpenAI` APIs.
71pub struct GenericOpenAI<
72    C: async_openai::config::Config + Default = async_openai::config::OpenAIConfig,
73> {
74    /// The `OpenAI` client, wrapped in an `Arc` for thread-safe reference counting.
75    /// Defaults to a new instance of `async_openai::Client`.
76    #[builder(
77        default = "Arc::new(async_openai::Client::<C>::default())",
78        setter(custom)
79    )]
80    client: Arc<async_openai::Client<C>>,
81
82    /// Default options for embedding and prompt models.
83    #[builder(default, setter(custom))]
84    pub(crate) default_options: Options,
85
86    #[cfg(feature = "tiktoken")]
87    #[cfg_attr(feature = "tiktoken", builder(default))]
88    pub(crate) tiktoken: TikToken,
89
90    /// Convenience option to stream the full response. Defaults to true, because nobody has time
91    /// to reconstruct the delta. Disabling this will make the streamed content only return the
92    /// delta, for when performance matters. This only has effect when streaming is enabled.
93    #[builder(default = true)]
94    pub stream_full: bool,
95}
96
97/// The `Options` struct holds configuration options for the `OpenAI` client.
98/// It includes optional fields for specifying the embedding and prompt models.
99#[derive(Debug, Clone, Builder, Default)]
100#[builder(setter(strip_option))]
101pub struct Options {
102    /// The default embedding model to use, if specified.
103    #[builder(default, setter(into))]
104    pub embed_model: Option<String>,
105    /// The default prompt model to use, if specified.
106    #[builder(default, setter(into))]
107    pub prompt_model: Option<String>,
108
109    #[builder(default)]
110    /// Option to enable or disable parallel tool calls for completions.
111    ///
112    /// At this moment, o1 and o3-mini do not support it and should be set to `None`.
113    pub parallel_tool_calls: Option<bool>,
114
115    /// Maximum number of tokens to generate in the completion.
116    ///
117    /// By default, the limit is disabled
118    #[builder(default)]
119    pub max_completion_tokens: Option<u32>,
120
121    /// Temperature setting for the model.
122    #[builder(default)]
123    pub temperature: Option<f32>,
124
125    /// Reasoning effor for reasoning models.
126    #[builder(default, setter(into))]
127    pub reasoning_effort: Option<ReasoningEffort>,
128
129    /// This feature is in Beta. If specified, our system will make a best effort to sample
130    /// deterministically, such that repeated requests with the same seed and parameters should
131    /// return the same result. Determinism is not guaranteed, and you should refer to the
132    /// `system_fingerprint` response parameter to monitor changes in the backend.
133    #[builder(default)]
134    pub seed: Option<i64>,
135
136    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they
137    /// appear in the text so far, increasing the model’s likelihood to talk about new topics.
138    #[builder(default)]
139    pub presence_penalty: Option<f32>,
140
141    /// Developer-defined tags and values used for filtering completions in the dashboard.
142    #[builder(default, setter(into))]
143    pub metadata: Option<serde_json::Value>,
144
145    /// A unique identifier representing your end-user, which can help `OpenAI` to monitor and
146    /// detect abuse.
147    #[builder(default, setter(into))]
148    pub user: Option<String>,
149
150    #[builder(default)]
151    /// The number of dimensions the resulting output embeddings should have. Only supported in
152    /// text-embedding-3 and later models.
153    pub dimensions: Option<u32>,
154}
155
156impl Options {
157    /// Creates a new `OptionsBuilder` for constructing `Options` instances.
158    pub fn builder() -> OptionsBuilder {
159        OptionsBuilder::default()
160    }
161
162    /// Extends options with other options
163    pub fn merge(&mut self, other: &Options) {
164        if let Some(embed_model) = &other.embed_model {
165            self.embed_model = Some(embed_model.clone());
166        }
167        if let Some(prompt_model) = &other.prompt_model {
168            self.prompt_model = Some(prompt_model.clone());
169        }
170        if let Some(parallel_tool_calls) = other.parallel_tool_calls {
171            self.parallel_tool_calls = Some(parallel_tool_calls);
172        }
173        if let Some(max_completion_tokens) = other.max_completion_tokens {
174            self.max_completion_tokens = Some(max_completion_tokens);
175        }
176        if let Some(temperature) = other.temperature {
177            self.temperature = Some(temperature);
178        }
179        if let Some(reasoning_effort) = &other.reasoning_effort {
180            self.reasoning_effort = Some(reasoning_effort.clone());
181        }
182        if let Some(seed) = other.seed {
183            self.seed = Some(seed);
184        }
185        if let Some(presence_penalty) = other.presence_penalty {
186            self.presence_penalty = Some(presence_penalty);
187        }
188        if let Some(metadata) = &other.metadata {
189            self.metadata = Some(metadata.clone());
190        }
191        if let Some(user) = &other.user {
192            self.user = Some(user.clone());
193        }
194    }
195}
196
197impl From<OptionsBuilder> for Options {
198    fn from(value: OptionsBuilder) -> Self {
199        Self {
200            embed_model: value.embed_model.flatten(),
201            prompt_model: value.prompt_model.flatten(),
202            parallel_tool_calls: value.parallel_tool_calls.flatten(),
203            max_completion_tokens: value.max_completion_tokens.flatten(),
204            temperature: value.temperature.flatten(),
205            reasoning_effort: value.reasoning_effort.flatten(),
206            presence_penalty: value.presence_penalty.flatten(),
207            seed: value.seed.flatten(),
208            metadata: value.metadata.flatten(),
209            user: value.user.flatten(),
210            dimensions: value.dimensions.flatten(),
211        }
212    }
213}
214
215impl From<&mut OptionsBuilder> for Options {
216    fn from(value: &mut OptionsBuilder) -> Self {
217        let value = value.clone();
218        Self {
219            embed_model: value.embed_model.flatten(),
220            prompt_model: value.prompt_model.flatten(),
221            parallel_tool_calls: value.parallel_tool_calls.flatten(),
222            max_completion_tokens: value.max_completion_tokens.flatten(),
223            temperature: value.temperature.flatten(),
224            reasoning_effort: value.reasoning_effort.flatten(),
225            presence_penalty: value.presence_penalty.flatten(),
226            seed: value.seed.flatten(),
227            metadata: value.metadata.flatten(),
228            user: value.user.flatten(),
229            dimensions: value.dimensions.flatten(),
230        }
231    }
232}
233
234impl OpenAI {
235    /// Creates a new `OpenAIBuilder` for constructing `OpenAI` instances.
236    pub fn builder() -> OpenAIBuilder {
237        OpenAIBuilder::default()
238    }
239}
240
241impl<C: async_openai::config::Config + Default + Sync + Send + std::fmt::Debug>
242    GenericOpenAIBuilder<C>
243{
244    /// Sets the `OpenAI` client for the `OpenAI` instance.
245    ///
246    /// # Parameters
247    /// - `client`: The `OpenAI` client to set.
248    ///
249    /// # Returns
250    /// A mutable reference to the `OpenAIBuilder`.
251    pub fn client(&mut self, client: async_openai::Client<C>) -> &mut Self {
252        self.client = Some(Arc::new(client));
253        self
254    }
255
256    /// Sets the default embedding model for the `OpenAI` instance.
257    ///
258    /// # Parameters
259    /// - `model`: The embedding model to set.
260    ///
261    /// # Returns
262    /// A mutable reference to the `OpenAIBuilder`.
263    pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
264        if let Some(options) = self.default_options.as_mut() {
265            options.embed_model = Some(model.into());
266        } else {
267            self.default_options = Some(Options {
268                embed_model: Some(model.into()),
269                ..Default::default()
270            });
271        }
272        self
273    }
274
275    /// Enable or disable parallel tool calls for completions.
276    ///
277    /// Note that currently reasoning models do not support parallel tool calls
278    ///
279    /// Defaults to `true`
280    pub fn parallel_tool_calls(&mut self, parallel_tool_calls: Option<bool>) -> &mut Self {
281        if let Some(options) = self.default_options.as_mut() {
282            options.parallel_tool_calls = parallel_tool_calls;
283        } else {
284            self.default_options = Some(Options {
285                parallel_tool_calls,
286                ..Default::default()
287            });
288        }
289        self
290    }
291
292    /// Sets the default prompt model for the `OpenAI` instance.
293    ///
294    /// # Parameters
295    /// - `model`: The prompt model to set.
296    ///
297    /// # Returns
298    /// A mutable reference to the `OpenAIBuilder`.
299    pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
300        if let Some(options) = self.default_options.as_mut() {
301            options.prompt_model = Some(model.into());
302        } else {
303            self.default_options = Some(Options {
304                prompt_model: Some(model.into()),
305                ..Default::default()
306            });
307        }
308        self
309    }
310
311    /// Sets the default options to use for requests to the `OpenAI` API.
312    ///
313    /// Merges with any existing options
314    pub fn default_options(&mut self, options: impl Into<Options>) -> &mut Self {
315        if let Some(existing_options) = self.default_options.as_mut() {
316            existing_options.merge(&options.into());
317        } else {
318            self.default_options = Some(options.into());
319        }
320        self
321    }
322}
323
324impl<C: async_openai::config::Config + Default> GenericOpenAI<C> {
325    /// Estimates the number of tokens for implementors of the `Estimatable` trait.
326    ///
327    /// I.e. `String`, `ChatMessage` etc
328    ///
329    /// # Errors
330    ///
331    /// Errors if tokinization fails in any way
332    #[cfg(feature = "tiktoken")]
333    pub async fn estimate_tokens(&self, value: impl Estimatable) -> Result<usize> {
334        self.tiktoken.estimate(value).await
335    }
336
337    pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
338        self.default_options = Options {
339            prompt_model: Some(model.into()),
340            ..self.default_options.clone()
341        };
342        self
343    }
344
345    pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
346        self.default_options = Options {
347            embed_model: Some(model.into()),
348            ..self.default_options.clone()
349        };
350        self
351    }
352
353    /// Retrieve a reference to the inner `OpenAI` client.
354    pub fn client(&self) -> &Arc<async_openai::Client<C>> {
355        &self.client
356    }
357
358    /// Retrieve a reference to the default options for the `OpenAI` instance.
359    pub fn options(&self) -> &Options {
360        &self.default_options
361    }
362
363    /// Retrieve a mutable reference to the default options for the `OpenAI` instance.
364    pub fn options_mut(&mut self) -> &mut Options {
365        &mut self.default_options
366    }
367
368    fn chat_completion_request_defaults(&self) -> CreateChatCompletionRequestArgs {
369        let mut args = CreateChatCompletionRequestArgs::default();
370
371        let options = &self.default_options;
372
373        if let Some(parallel_tool_calls) = options.parallel_tool_calls {
374            args.parallel_tool_calls(parallel_tool_calls);
375        }
376
377        if let Some(max_tokens) = options.max_completion_tokens {
378            args.max_completion_tokens(max_tokens);
379        }
380
381        if let Some(temperature) = options.temperature {
382            args.temperature(temperature);
383        }
384
385        if let Some(reasoning_effort) = &options.reasoning_effort {
386            args.reasoning_effort(reasoning_effort.clone());
387        }
388
389        if let Some(seed) = options.seed {
390            args.seed(seed);
391        }
392
393        if let Some(presence_penalty) = options.presence_penalty {
394            args.presence_penalty(presence_penalty);
395        }
396
397        if let Some(metadata) = &options.metadata {
398            args.metadata(metadata.clone());
399        }
400
401        if let Some(user) = &options.user {
402            args.user(user.clone());
403        }
404
405        args
406    }
407
408    fn embed_request_defaults(&self) -> CreateEmbeddingRequestArgs {
409        let mut args = CreateEmbeddingRequestArgs::default();
410
411        let options = &self.default_options;
412
413        if let Some(user) = &options.user {
414            args.user(user.clone());
415        }
416
417        if let Some(dimensions) = options.dimensions {
418            args.dimensions(dimensions);
419        }
420
421        args
422    }
423}
424
425pub fn openai_error_to_language_model_error(e: OpenAIError) -> LanguageModelError {
426    match e {
427        OpenAIError::ApiError(api_error) => {
428            // If the response is an ApiError, it could be a context length exceeded error
429            if api_error.code == Some("context_length_exceeded".to_string()) {
430                LanguageModelError::context_length_exceeded(OpenAIError::ApiError(api_error))
431            } else {
432                LanguageModelError::permanent(OpenAIError::ApiError(api_error))
433            }
434        }
435        OpenAIError::Reqwest(e) => {
436            // async_openai passes any network errors as reqwest errors, so we just assume they are
437            // recoverable
438            LanguageModelError::transient(e)
439        }
440        OpenAIError::JSONDeserialize(_) => {
441            // OpenAI generated a non-json response, probably a temporary problem on their side
442            // (i.e. reverse proxy can't find an available backend)
443            LanguageModelError::transient(e)
444        }
445        OpenAIError::FileSaveError(_)
446        | OpenAIError::FileReadError(_)
447        | OpenAIError::StreamError(_)
448        | OpenAIError::InvalidArgument(_) => LanguageModelError::permanent(e),
449    }
450}
451
452#[cfg(test)]
453mod test {
454    use super::*;
455    use async_openai::error::{ApiError, OpenAIError};
456
457    /// test default embed model
458    #[test]
459    fn test_default_embed_and_prompt_model() {
460        let openai: OpenAI = OpenAI::builder()
461            .default_embed_model("gpt-3")
462            .default_prompt_model("gpt-4")
463            .build()
464            .unwrap();
465        assert_eq!(
466            openai.default_options.embed_model,
467            Some("gpt-3".to_string())
468        );
469        assert_eq!(
470            openai.default_options.prompt_model,
471            Some("gpt-4".to_string())
472        );
473
474        let openai: OpenAI = OpenAI::builder()
475            .default_prompt_model("gpt-4")
476            .default_embed_model("gpt-3")
477            .build()
478            .unwrap();
479        assert_eq!(
480            openai.default_options.prompt_model,
481            Some("gpt-4".to_string())
482        );
483        assert_eq!(
484            openai.default_options.embed_model,
485            Some("gpt-3".to_string())
486        );
487    }
488
489    #[test]
490    fn test_context_length_exceeded_error() {
491        // Create an API error with the context_length_exceeded code
492        let api_error = ApiError {
493            message: "This model's maximum context length is 8192 tokens".to_string(),
494            r#type: Some("invalid_request_error".to_string()),
495            param: Some("messages".to_string()),
496            code: Some("context_length_exceeded".to_string()),
497        };
498
499        let openai_error = OpenAIError::ApiError(api_error);
500        let result = openai_error_to_language_model_error(openai_error);
501
502        // Verify it's categorized as ContextLengthExceeded
503        match result {
504            LanguageModelError::ContextLengthExceeded(_) => {} // Expected
505            _ => panic!("Expected ContextLengthExceeded error, got {result:?}"),
506        }
507    }
508
509    #[test]
510    fn test_api_error_permanent() {
511        // Create a generic API error (not context length exceeded)
512        let api_error = ApiError {
513            message: "Invalid API key".to_string(),
514            r#type: Some("invalid_request_error".to_string()),
515            param: Some("api_key".to_string()),
516            code: Some("invalid_api_key".to_string()),
517        };
518
519        let openai_error = OpenAIError::ApiError(api_error);
520        let result = openai_error_to_language_model_error(openai_error);
521
522        // Verify it's categorized as PermanentError
523        match result {
524            LanguageModelError::PermanentError(_) => {} // Expected
525            _ => panic!("Expected PermanentError, got {result:?}"),
526        }
527    }
528
529    #[test]
530    fn test_file_save_error_is_permanent() {
531        // Create a file save error
532        let openai_error = OpenAIError::FileSaveError("Failed to save file".to_string());
533        let result = openai_error_to_language_model_error(openai_error);
534
535        // Verify it's categorized as PermanentError
536        match result {
537            LanguageModelError::PermanentError(_) => {} // Expected
538            _ => panic!("Expected PermanentError, got {result:?}"),
539        }
540    }
541
542    #[test]
543    fn test_file_read_error_is_permanent() {
544        // Create a file read error
545        let openai_error = OpenAIError::FileReadError("Failed to read file".to_string());
546        let result = openai_error_to_language_model_error(openai_error);
547
548        // Verify it's categorized as PermanentError
549        match result {
550            LanguageModelError::PermanentError(_) => {} // Expected
551            _ => panic!("Expected PermanentError, got {result:?}"),
552        }
553    }
554
555    #[test]
556    fn test_stream_error_is_permanent() {
557        // Create a stream error
558        let openai_error = OpenAIError::StreamError("Stream failed".to_string());
559        let result = openai_error_to_language_model_error(openai_error);
560
561        // Verify it's categorized as PermanentError
562        match result {
563            LanguageModelError::PermanentError(_) => {} // Expected
564            _ => panic!("Expected PermanentError, got {result:?}"),
565        }
566    }
567
568    #[test]
569    fn test_invalid_argument_is_permanent() {
570        // Create an invalid argument error
571        let openai_error = OpenAIError::InvalidArgument("Invalid argument".to_string());
572        let result = openai_error_to_language_model_error(openai_error);
573
574        // Verify it's categorized as PermanentError
575        match result {
576            LanguageModelError::PermanentError(_) => {} // Expected
577            _ => panic!("Expected PermanentError, got {result:?}"),
578        }
579    }
580}