swiftide_integrations/openai/
mod.rs

1//! This module provides integration with `OpenAI`'s API, enabling the use of language models and
2//! embeddings within the Swiftide project. It includes the `OpenAI` struct for managing API clients
3//! and default options for embedding and prompt models. The module is conditionally compiled based
4//! on the "openai" feature flag.
5
6use async_openai::error::OpenAIError;
7use async_openai::types::CreateChatCompletionRequestArgs;
8use async_openai::types::CreateEmbeddingRequestArgs;
9use async_openai::types::ReasoningEffort;
10use derive_builder::Builder;
11use std::sync::Arc;
12use swiftide_core::chat_completion::errors::LanguageModelError;
13
14mod chat_completion;
15mod embed;
16mod simple_prompt;
17
18// expose type aliases to simplify downstream use of the open ai builder invocations
19pub use async_openai::config::AzureConfig;
20pub use async_openai::config::OpenAIConfig;
21
22#[cfg(feature = "tiktoken")]
23use crate::tiktoken::TikToken;
24#[cfg(feature = "tiktoken")]
25use anyhow::Result;
26#[cfg(feature = "tiktoken")]
27use swiftide_core::Estimatable;
28#[cfg(feature = "tiktoken")]
29use swiftide_core::EstimateTokens;
30
31/// The `OpenAI` struct encapsulates an `OpenAI` client and default options for embedding and prompt
32/// models. It uses the `Builder` pattern for flexible and customizable instantiation.
33///
34/// # Example
35///
36/// ```no_run
37/// # use swiftide_integrations::openai::{OpenAI, Options};
38/// # use swiftide_integrations::openai::OpenAIConfig;
39///
40/// // Create an OpenAI client with default options. The client will use the OPENAI_API_KEY environment variable.
41/// let openai = OpenAI::builder()
42///     .default_embed_model("text-embedding-3-small")
43///     .default_prompt_model("gpt-4")
44///     .build().unwrap();
45///
46/// // Create an OpenAI client with a custom api key.
47/// let openai = OpenAI::builder()
48///     .default_embed_model("text-embedding-3-small")
49///     .default_prompt_model("gpt-4")
50///     .client(async_openai::Client::with_config(async_openai::config::OpenAIConfig::default().with_api_key("my-api-key")))
51///     .build().unwrap();
52///
53/// // Create an OpenAI client with custom options
54/// let openai = OpenAI::builder()
55///     .default_embed_model("text-embedding-3-small")
56///     .default_prompt_model("gpt-4")
57///     .default_options(
58///         Options::builder()
59///           .temperature(1.0)
60///           .parallel_tool_calls(false)
61///           .user("MyUserId")
62///     )
63///     .build().unwrap();
64/// ```
65pub type OpenAI = GenericOpenAI<OpenAIConfig>;
66pub type OpenAIBuilder = GenericOpenAIBuilder<OpenAIConfig>;
67
68#[derive(Debug, Builder, Clone)]
69#[builder(setter(into, strip_option))]
70/// Generic client for `OpenAI` APIs.
71pub struct GenericOpenAI<
72    C: async_openai::config::Config + Default = async_openai::config::OpenAIConfig,
73> {
74    /// The `OpenAI` client, wrapped in an `Arc` for thread-safe reference counting.
75    /// Defaults to a new instance of `async_openai::Client`.
76    #[builder(
77        default = "Arc::new(async_openai::Client::<C>::default())",
78        setter(custom)
79    )]
80    client: Arc<async_openai::Client<C>>,
81
82    /// Default options for embedding and prompt models.
83    #[builder(default, setter(custom))]
84    pub(crate) default_options: Options,
85
86    #[cfg(feature = "tiktoken")]
87    #[cfg_attr(feature = "tiktoken", builder(default))]
88    pub(crate) tiktoken: TikToken,
89
90    /// Convenience option to stream the full response. Defaults to true, because nobody has time
91    /// to reconstruct the delta. Disabling this will make the streamed content only return the
92    /// delta, for when performance matters. This only has effect when streaming is enabled.
93    #[builder(default = true)]
94    pub stream_full: bool,
95
96    #[cfg(feature = "metrics")]
97    #[builder(default)]
98    /// Optional metadata to attach to metrics emitted by this client.
99    metric_metadata: Option<std::collections::HashMap<String, String>>,
100}
101
102/// The `Options` struct holds configuration options for the `OpenAI` client.
103/// It includes optional fields for specifying the embedding and prompt models.
104#[derive(Debug, Clone, Builder, Default)]
105#[builder(setter(strip_option))]
106pub struct Options {
107    /// The default embedding model to use, if specified.
108    #[builder(default, setter(into))]
109    pub embed_model: Option<String>,
110    /// The default prompt model to use, if specified.
111    #[builder(default, setter(into))]
112    pub prompt_model: Option<String>,
113
114    #[builder(default)]
115    /// Option to enable or disable parallel tool calls for completions.
116    ///
117    /// At this moment, o1 and o3-mini do not support it and should be set to `None`.
118    pub parallel_tool_calls: Option<bool>,
119
120    /// Maximum number of tokens to generate in the completion.
121    ///
122    /// By default, the limit is disabled
123    #[builder(default)]
124    pub max_completion_tokens: Option<u32>,
125
126    /// Temperature setting for the model.
127    #[builder(default)]
128    pub temperature: Option<f32>,
129
130    /// Reasoning effor for reasoning models.
131    #[builder(default, setter(into))]
132    pub reasoning_effort: Option<ReasoningEffort>,
133
134    /// This feature is in Beta. If specified, our system will make a best effort to sample
135    /// deterministically, such that repeated requests with the same seed and parameters should
136    /// return the same result. Determinism is not guaranteed, and you should refer to the
137    /// `system_fingerprint` response parameter to monitor changes in the backend.
138    #[builder(default)]
139    pub seed: Option<i64>,
140
141    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they
142    /// appear in the text so far, increasing the model’s likelihood to talk about new topics.
143    #[builder(default)]
144    pub presence_penalty: Option<f32>,
145
146    /// Developer-defined tags and values used for filtering completions in the dashboard.
147    #[builder(default, setter(into))]
148    pub metadata: Option<serde_json::Value>,
149
150    /// A unique identifier representing your end-user, which can help `OpenAI` to monitor and
151    /// detect abuse.
152    #[builder(default, setter(into))]
153    pub user: Option<String>,
154
155    #[builder(default)]
156    /// The number of dimensions the resulting output embeddings should have. Only supported in
157    /// text-embedding-3 and later models.
158    pub dimensions: Option<u32>,
159}
160
161impl Options {
162    /// Creates a new `OptionsBuilder` for constructing `Options` instances.
163    pub fn builder() -> OptionsBuilder {
164        OptionsBuilder::default()
165    }
166
167    /// Extends options with other options
168    pub fn merge(&mut self, other: &Options) {
169        if let Some(embed_model) = &other.embed_model {
170            self.embed_model = Some(embed_model.clone());
171        }
172        if let Some(prompt_model) = &other.prompt_model {
173            self.prompt_model = Some(prompt_model.clone());
174        }
175        if let Some(parallel_tool_calls) = other.parallel_tool_calls {
176            self.parallel_tool_calls = Some(parallel_tool_calls);
177        }
178        if let Some(max_completion_tokens) = other.max_completion_tokens {
179            self.max_completion_tokens = Some(max_completion_tokens);
180        }
181        if let Some(temperature) = other.temperature {
182            self.temperature = Some(temperature);
183        }
184        if let Some(reasoning_effort) = &other.reasoning_effort {
185            self.reasoning_effort = Some(reasoning_effort.clone());
186        }
187        if let Some(seed) = other.seed {
188            self.seed = Some(seed);
189        }
190        if let Some(presence_penalty) = other.presence_penalty {
191            self.presence_penalty = Some(presence_penalty);
192        }
193        if let Some(metadata) = &other.metadata {
194            self.metadata = Some(metadata.clone());
195        }
196        if let Some(user) = &other.user {
197            self.user = Some(user.clone());
198        }
199    }
200}
201
202impl From<OptionsBuilder> for Options {
203    fn from(value: OptionsBuilder) -> Self {
204        Self {
205            embed_model: value.embed_model.flatten(),
206            prompt_model: value.prompt_model.flatten(),
207            parallel_tool_calls: value.parallel_tool_calls.flatten(),
208            max_completion_tokens: value.max_completion_tokens.flatten(),
209            temperature: value.temperature.flatten(),
210            reasoning_effort: value.reasoning_effort.flatten(),
211            presence_penalty: value.presence_penalty.flatten(),
212            seed: value.seed.flatten(),
213            metadata: value.metadata.flatten(),
214            user: value.user.flatten(),
215            dimensions: value.dimensions.flatten(),
216        }
217    }
218}
219
220impl From<&mut OptionsBuilder> for Options {
221    fn from(value: &mut OptionsBuilder) -> Self {
222        let value = value.clone();
223        Self {
224            embed_model: value.embed_model.flatten(),
225            prompt_model: value.prompt_model.flatten(),
226            parallel_tool_calls: value.parallel_tool_calls.flatten(),
227            max_completion_tokens: value.max_completion_tokens.flatten(),
228            temperature: value.temperature.flatten(),
229            reasoning_effort: value.reasoning_effort.flatten(),
230            presence_penalty: value.presence_penalty.flatten(),
231            seed: value.seed.flatten(),
232            metadata: value.metadata.flatten(),
233            user: value.user.flatten(),
234            dimensions: value.dimensions.flatten(),
235        }
236    }
237}
238
239impl OpenAI {
240    /// Creates a new `OpenAIBuilder` for constructing `OpenAI` instances.
241    pub fn builder() -> OpenAIBuilder {
242        OpenAIBuilder::default()
243    }
244}
245
246impl<C: async_openai::config::Config + Default + Sync + Send + std::fmt::Debug>
247    GenericOpenAIBuilder<C>
248{
249    /// Sets the `OpenAI` client for the `OpenAI` instance.
250    ///
251    /// # Parameters
252    /// - `client`: The `OpenAI` client to set.
253    ///
254    /// # Returns
255    /// A mutable reference to the `OpenAIBuilder`.
256    pub fn client(&mut self, client: async_openai::Client<C>) -> &mut Self {
257        self.client = Some(Arc::new(client));
258        self
259    }
260
261    /// Sets the default embedding model for the `OpenAI` instance.
262    ///
263    /// # Parameters
264    /// - `model`: The embedding model to set.
265    ///
266    /// # Returns
267    /// A mutable reference to the `OpenAIBuilder`.
268    pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
269        if let Some(options) = self.default_options.as_mut() {
270            options.embed_model = Some(model.into());
271        } else {
272            self.default_options = Some(Options {
273                embed_model: Some(model.into()),
274                ..Default::default()
275            });
276        }
277        self
278    }
279
280    /// Enable or disable parallel tool calls for completions.
281    ///
282    /// Note that currently reasoning models do not support parallel tool calls
283    ///
284    /// Defaults to `true`
285    pub fn parallel_tool_calls(&mut self, parallel_tool_calls: Option<bool>) -> &mut Self {
286        if let Some(options) = self.default_options.as_mut() {
287            options.parallel_tool_calls = parallel_tool_calls;
288        } else {
289            self.default_options = Some(Options {
290                parallel_tool_calls,
291                ..Default::default()
292            });
293        }
294        self
295    }
296
297    /// Sets the default prompt model for the `OpenAI` instance.
298    ///
299    /// # Parameters
300    /// - `model`: The prompt model to set.
301    ///
302    /// # Returns
303    /// A mutable reference to the `OpenAIBuilder`.
304    pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
305        if let Some(options) = self.default_options.as_mut() {
306            options.prompt_model = Some(model.into());
307        } else {
308            self.default_options = Some(Options {
309                prompt_model: Some(model.into()),
310                ..Default::default()
311            });
312        }
313        self
314    }
315
316    /// Sets the default options to use for requests to the `OpenAI` API.
317    ///
318    /// Merges with any existing options
319    pub fn default_options(&mut self, options: impl Into<Options>) -> &mut Self {
320        if let Some(existing_options) = self.default_options.as_mut() {
321            existing_options.merge(&options.into());
322        } else {
323            self.default_options = Some(options.into());
324        }
325        self
326    }
327}
328
329impl<C: async_openai::config::Config + Default> GenericOpenAI<C> {
330    /// Estimates the number of tokens for implementors of the `Estimatable` trait.
331    ///
332    /// I.e. `String`, `ChatMessage` etc
333    ///
334    /// # Errors
335    ///
336    /// Errors if tokinization fails in any way
337    #[cfg(feature = "tiktoken")]
338    pub async fn estimate_tokens(&self, value: impl Estimatable) -> Result<usize> {
339        self.tiktoken.estimate(value).await
340    }
341
342    pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
343        self.default_options = Options {
344            prompt_model: Some(model.into()),
345            ..self.default_options.clone()
346        };
347        self
348    }
349
350    pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
351        self.default_options = Options {
352            embed_model: Some(model.into()),
353            ..self.default_options.clone()
354        };
355        self
356    }
357
358    /// Retrieve a reference to the inner `OpenAI` client.
359    pub fn client(&self) -> &Arc<async_openai::Client<C>> {
360        &self.client
361    }
362
363    /// Retrieve a reference to the default options for the `OpenAI` instance.
364    pub fn options(&self) -> &Options {
365        &self.default_options
366    }
367
368    /// Retrieve a mutable reference to the default options for the `OpenAI` instance.
369    pub fn options_mut(&mut self) -> &mut Options {
370        &mut self.default_options
371    }
372
373    fn chat_completion_request_defaults(&self) -> CreateChatCompletionRequestArgs {
374        let mut args = CreateChatCompletionRequestArgs::default();
375
376        let options = &self.default_options;
377
378        if let Some(parallel_tool_calls) = options.parallel_tool_calls {
379            args.parallel_tool_calls(parallel_tool_calls);
380        }
381
382        if let Some(max_tokens) = options.max_completion_tokens {
383            args.max_completion_tokens(max_tokens);
384        }
385
386        if let Some(temperature) = options.temperature {
387            args.temperature(temperature);
388        }
389
390        if let Some(reasoning_effort) = &options.reasoning_effort {
391            args.reasoning_effort(reasoning_effort.clone());
392        }
393
394        if let Some(seed) = options.seed {
395            args.seed(seed);
396        }
397
398        if let Some(presence_penalty) = options.presence_penalty {
399            args.presence_penalty(presence_penalty);
400        }
401
402        if let Some(metadata) = &options.metadata {
403            args.metadata(metadata.clone());
404        }
405
406        if let Some(user) = &options.user {
407            args.user(user.clone());
408        }
409
410        args
411    }
412
413    fn embed_request_defaults(&self) -> CreateEmbeddingRequestArgs {
414        let mut args = CreateEmbeddingRequestArgs::default();
415
416        let options = &self.default_options;
417
418        if let Some(user) = &options.user {
419            args.user(user.clone());
420        }
421
422        if let Some(dimensions) = options.dimensions {
423            args.dimensions(dimensions);
424        }
425
426        args
427    }
428}
429
430pub fn openai_error_to_language_model_error(e: OpenAIError) -> LanguageModelError {
431    match e {
432        OpenAIError::ApiError(api_error) => {
433            // If the response is an ApiError, it could be a context length exceeded error
434            if api_error.code == Some("context_length_exceeded".to_string()) {
435                LanguageModelError::context_length_exceeded(OpenAIError::ApiError(api_error))
436            } else {
437                LanguageModelError::permanent(OpenAIError::ApiError(api_error))
438            }
439        }
440        OpenAIError::Reqwest(e) => {
441            // async_openai passes any network errors as reqwest errors, so we just assume they are
442            // recoverable
443            LanguageModelError::transient(e)
444        }
445        OpenAIError::JSONDeserialize(_) => {
446            // OpenAI generated a non-json response, probably a temporary problem on their side
447            // (i.e. reverse proxy can't find an available backend)
448            LanguageModelError::transient(e)
449        }
450        OpenAIError::FileSaveError(_)
451        | OpenAIError::FileReadError(_)
452        | OpenAIError::StreamError(_)
453        | OpenAIError::InvalidArgument(_) => LanguageModelError::permanent(e),
454    }
455}
456
457#[cfg(test)]
458mod test {
459    use super::*;
460    use async_openai::error::{ApiError, OpenAIError};
461
462    /// test default embed model
463    #[test]
464    fn test_default_embed_and_prompt_model() {
465        let openai: OpenAI = OpenAI::builder()
466            .default_embed_model("gpt-3")
467            .default_prompt_model("gpt-4")
468            .build()
469            .unwrap();
470        assert_eq!(
471            openai.default_options.embed_model,
472            Some("gpt-3".to_string())
473        );
474        assert_eq!(
475            openai.default_options.prompt_model,
476            Some("gpt-4".to_string())
477        );
478
479        let openai: OpenAI = OpenAI::builder()
480            .default_prompt_model("gpt-4")
481            .default_embed_model("gpt-3")
482            .build()
483            .unwrap();
484        assert_eq!(
485            openai.default_options.prompt_model,
486            Some("gpt-4".to_string())
487        );
488        assert_eq!(
489            openai.default_options.embed_model,
490            Some("gpt-3".to_string())
491        );
492    }
493
494    #[test]
495    fn test_context_length_exceeded_error() {
496        // Create an API error with the context_length_exceeded code
497        let api_error = ApiError {
498            message: "This model's maximum context length is 8192 tokens".to_string(),
499            r#type: Some("invalid_request_error".to_string()),
500            param: Some("messages".to_string()),
501            code: Some("context_length_exceeded".to_string()),
502        };
503
504        let openai_error = OpenAIError::ApiError(api_error);
505        let result = openai_error_to_language_model_error(openai_error);
506
507        // Verify it's categorized as ContextLengthExceeded
508        match result {
509            LanguageModelError::ContextLengthExceeded(_) => {} // Expected
510            _ => panic!("Expected ContextLengthExceeded error, got {result:?}"),
511        }
512    }
513
514    #[test]
515    fn test_api_error_permanent() {
516        // Create a generic API error (not context length exceeded)
517        let api_error = ApiError {
518            message: "Invalid API key".to_string(),
519            r#type: Some("invalid_request_error".to_string()),
520            param: Some("api_key".to_string()),
521            code: Some("invalid_api_key".to_string()),
522        };
523
524        let openai_error = OpenAIError::ApiError(api_error);
525        let result = openai_error_to_language_model_error(openai_error);
526
527        // Verify it's categorized as PermanentError
528        match result {
529            LanguageModelError::PermanentError(_) => {} // Expected
530            _ => panic!("Expected PermanentError, got {result:?}"),
531        }
532    }
533
534    #[test]
535    fn test_file_save_error_is_permanent() {
536        // Create a file save error
537        let openai_error = OpenAIError::FileSaveError("Failed to save file".to_string());
538        let result = openai_error_to_language_model_error(openai_error);
539
540        // Verify it's categorized as PermanentError
541        match result {
542            LanguageModelError::PermanentError(_) => {} // Expected
543            _ => panic!("Expected PermanentError, got {result:?}"),
544        }
545    }
546
547    #[test]
548    fn test_file_read_error_is_permanent() {
549        // Create a file read error
550        let openai_error = OpenAIError::FileReadError("Failed to read file".to_string());
551        let result = openai_error_to_language_model_error(openai_error);
552
553        // Verify it's categorized as PermanentError
554        match result {
555            LanguageModelError::PermanentError(_) => {} // Expected
556            _ => panic!("Expected PermanentError, got {result:?}"),
557        }
558    }
559
560    #[test]
561    fn test_stream_error_is_permanent() {
562        // Create a stream error
563        let openai_error = OpenAIError::StreamError("Stream failed".to_string());
564        let result = openai_error_to_language_model_error(openai_error);
565
566        // Verify it's categorized as PermanentError
567        match result {
568            LanguageModelError::PermanentError(_) => {} // Expected
569            _ => panic!("Expected PermanentError, got {result:?}"),
570        }
571    }
572
573    #[test]
574    fn test_invalid_argument_is_permanent() {
575        // Create an invalid argument error
576        let openai_error = OpenAIError::InvalidArgument("Invalid argument".to_string());
577        let result = openai_error_to_language_model_error(openai_error);
578
579        // Verify it's categorized as PermanentError
580        match result {
581            LanguageModelError::PermanentError(_) => {} // Expected
582            _ => panic!("Expected PermanentError, got {result:?}"),
583        }
584    }
585}