swiftide_integrations/openai/
mod.rs

1//! This module provides integration with `OpenAI`'s API, enabling the use of language models and
2//! embeddings within the Swiftide project. It includes the `OpenAI` struct for managing API clients
3//! and default options for embedding and prompt models. The module is conditionally compiled based
4//! on the "openai" feature flag.
5
6use async_openai::error::OpenAIError;
7use async_openai::types::CreateChatCompletionRequestArgs;
8use async_openai::types::CreateEmbeddingRequestArgs;
9use async_openai::types::ReasoningEffort;
10use derive_builder::Builder;
11use std::pin::Pin;
12use std::sync::Arc;
13use swiftide_core::chat_completion::Usage;
14use swiftide_core::chat_completion::errors::LanguageModelError;
15
16mod chat_completion;
17mod embed;
18mod simple_prompt;
19mod structured_prompt;
20
21// expose type aliases to simplify downstream use of the open ai builder invocations
22pub use async_openai::config::AzureConfig;
23pub use async_openai::config::OpenAIConfig;
24
25#[cfg(feature = "tiktoken")]
26use crate::tiktoken::TikToken;
27#[cfg(feature = "tiktoken")]
28use anyhow::Result;
29#[cfg(feature = "tiktoken")]
30use swiftide_core::Estimatable;
31#[cfg(feature = "tiktoken")]
32use swiftide_core::EstimateTokens;
33
34/// The `OpenAI` struct encapsulates an `OpenAI` client and default options for embedding and prompt
35/// models. It uses the `Builder` pattern for flexible and customizable instantiation.
36///
37/// # Example
38///
39/// ```no_run
40/// # use swiftide_integrations::openai::{OpenAI, Options};
41/// # use swiftide_integrations::openai::OpenAIConfig;
42///
43/// // Create an OpenAI client with default options. The client will use the OPENAI_API_KEY environment variable.
44/// let openai = OpenAI::builder()
45///     .default_embed_model("text-embedding-3-small")
46///     .default_prompt_model("gpt-4")
47///     .build().unwrap();
48///
49/// // Create an OpenAI client with a custom api key.
50/// let openai = OpenAI::builder()
51///     .default_embed_model("text-embedding-3-small")
52///     .default_prompt_model("gpt-4")
53///     .client(async_openai::Client::with_config(async_openai::config::OpenAIConfig::default().with_api_key("my-api-key")))
54///     .build().unwrap();
55///
56/// // Create an OpenAI client with custom options
57/// let openai = OpenAI::builder()
58///     .default_embed_model("text-embedding-3-small")
59///     .default_prompt_model("gpt-4")
60///     .default_options(
61///         Options::builder()
62///           .temperature(1.0)
63///           .parallel_tool_calls(false)
64///           .user("MyUserId")
65///     )
66///     .build().unwrap();
67/// ```
68pub type OpenAI = GenericOpenAI<OpenAIConfig>;
69pub type OpenAIBuilder = GenericOpenAIBuilder<OpenAIConfig>;
70
71#[derive(Builder, Clone)]
72#[builder(setter(into, strip_option))]
73/// Generic client for `OpenAI` APIs.
74pub struct GenericOpenAI<
75    C: async_openai::config::Config + Default = async_openai::config::OpenAIConfig,
76> {
77    /// The `OpenAI` client, wrapped in an `Arc` for thread-safe reference counting.
78    /// Defaults to a new instance of `async_openai::Client`.
79    #[builder(
80        default = "Arc::new(async_openai::Client::<C>::default())",
81        setter(custom)
82    )]
83    client: Arc<async_openai::Client<C>>,
84
85    /// Default options for embedding and prompt models.
86    #[builder(default, setter(custom))]
87    pub(crate) default_options: Options,
88
89    #[cfg(feature = "tiktoken")]
90    #[cfg_attr(feature = "tiktoken", builder(default))]
91    pub(crate) tiktoken: TikToken,
92
93    /// Convenience option to stream the full response. Defaults to true, because nobody has time
94    /// to reconstruct the delta. Disabling this will make the streamed content only return the
95    /// delta, for when performance matters. This only has effect when streaming is enabled.
96    #[builder(default = true)]
97    pub stream_full: bool,
98
99    #[cfg(feature = "metrics")]
100    #[builder(default)]
101    /// Optional metadata to attach to metrics emitted by this client.
102    metric_metadata: Option<std::collections::HashMap<String, String>>,
103
104    /// A callback function that is called when usage information is available.
105    #[builder(default, setter(custom))]
106    #[allow(clippy::type_complexity)]
107    on_usage: Option<
108        Arc<
109            dyn for<'a> Fn(
110                    &'a Usage,
111                ) -> Pin<
112                    Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>,
113                > + Send
114                + Sync,
115        >,
116    >,
117}
118
119impl<C: async_openai::config::Config + Default + std::fmt::Debug> std::fmt::Debug
120    for GenericOpenAI<C>
121{
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("GenericOpenAI")
124            .field("client", &self.client)
125            .field("default_options", &self.default_options)
126            .field("stream_full", &self.stream_full)
127            .finish_non_exhaustive()
128    }
129}
130
131/// The `Options` struct holds configuration options for the `OpenAI` client.
132/// It includes optional fields for specifying the embedding and prompt models.
133#[derive(Debug, Clone, Builder, Default)]
134#[builder(setter(strip_option))]
135pub struct Options {
136    /// The default embedding model to use, if specified.
137    #[builder(default, setter(into))]
138    pub embed_model: Option<String>,
139    /// The default prompt model to use, if specified.
140    #[builder(default, setter(into))]
141    pub prompt_model: Option<String>,
142
143    #[builder(default)]
144    /// Option to enable or disable parallel tool calls for completions.
145    ///
146    /// At this moment, o1 and o3-mini do not support it and should be set to `None`.
147    pub parallel_tool_calls: Option<bool>,
148
149    /// Maximum number of tokens to generate in the completion.
150    ///
151    /// By default, the limit is disabled
152    #[builder(default)]
153    pub max_completion_tokens: Option<u32>,
154
155    /// Temperature setting for the model.
156    #[builder(default)]
157    pub temperature: Option<f32>,
158
159    /// Reasoning effor for reasoning models.
160    #[builder(default, setter(into))]
161    pub reasoning_effort: Option<ReasoningEffort>,
162
163    /// This feature is in Beta. If specified, our system will make a best effort to sample
164    /// deterministically, such that repeated requests with the same seed and parameters should
165    /// return the same result. Determinism is not guaranteed, and you should refer to the
166    /// `system_fingerprint` response parameter to monitor changes in the backend.
167    #[builder(default)]
168    pub seed: Option<i64>,
169
170    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they
171    /// appear in the text so far, increasing the model’s likelihood to talk about new topics.
172    #[builder(default)]
173    pub presence_penalty: Option<f32>,
174
175    /// Developer-defined tags and values used for filtering completions in the dashboard.
176    #[builder(default, setter(into))]
177    pub metadata: Option<serde_json::Value>,
178
179    /// A unique identifier representing your end-user, which can help `OpenAI` to monitor and
180    /// detect abuse.
181    #[builder(default, setter(into))]
182    pub user: Option<String>,
183
184    #[builder(default)]
185    /// The number of dimensions the resulting output embeddings should have. Only supported in
186    /// text-embedding-3 and later models.
187    pub dimensions: Option<u32>,
188}
189
190impl Options {
191    /// Creates a new `OptionsBuilder` for constructing `Options` instances.
192    pub fn builder() -> OptionsBuilder {
193        OptionsBuilder::default()
194    }
195
196    /// Extends options with other options
197    pub fn merge(&mut self, other: &Options) {
198        if let Some(embed_model) = &other.embed_model {
199            self.embed_model = Some(embed_model.clone());
200        }
201        if let Some(prompt_model) = &other.prompt_model {
202            self.prompt_model = Some(prompt_model.clone());
203        }
204        if let Some(parallel_tool_calls) = other.parallel_tool_calls {
205            self.parallel_tool_calls = Some(parallel_tool_calls);
206        }
207        if let Some(max_completion_tokens) = other.max_completion_tokens {
208            self.max_completion_tokens = Some(max_completion_tokens);
209        }
210        if let Some(temperature) = other.temperature {
211            self.temperature = Some(temperature);
212        }
213        if let Some(reasoning_effort) = &other.reasoning_effort {
214            self.reasoning_effort = Some(reasoning_effort.clone());
215        }
216        if let Some(seed) = other.seed {
217            self.seed = Some(seed);
218        }
219        if let Some(presence_penalty) = other.presence_penalty {
220            self.presence_penalty = Some(presence_penalty);
221        }
222        if let Some(metadata) = &other.metadata {
223            self.metadata = Some(metadata.clone());
224        }
225        if let Some(user) = &other.user {
226            self.user = Some(user.clone());
227        }
228    }
229}
230
231impl From<OptionsBuilder> for Options {
232    fn from(value: OptionsBuilder) -> Self {
233        Self {
234            embed_model: value.embed_model.flatten(),
235            prompt_model: value.prompt_model.flatten(),
236            parallel_tool_calls: value.parallel_tool_calls.flatten(),
237            max_completion_tokens: value.max_completion_tokens.flatten(),
238            temperature: value.temperature.flatten(),
239            reasoning_effort: value.reasoning_effort.flatten(),
240            presence_penalty: value.presence_penalty.flatten(),
241            seed: value.seed.flatten(),
242            metadata: value.metadata.flatten(),
243            user: value.user.flatten(),
244            dimensions: value.dimensions.flatten(),
245        }
246    }
247}
248
249impl From<&mut OptionsBuilder> for Options {
250    fn from(value: &mut OptionsBuilder) -> Self {
251        let value = value.clone();
252        Self {
253            embed_model: value.embed_model.flatten(),
254            prompt_model: value.prompt_model.flatten(),
255            parallel_tool_calls: value.parallel_tool_calls.flatten(),
256            max_completion_tokens: value.max_completion_tokens.flatten(),
257            temperature: value.temperature.flatten(),
258            reasoning_effort: value.reasoning_effort.flatten(),
259            presence_penalty: value.presence_penalty.flatten(),
260            seed: value.seed.flatten(),
261            metadata: value.metadata.flatten(),
262            user: value.user.flatten(),
263            dimensions: value.dimensions.flatten(),
264        }
265    }
266}
267
268impl OpenAI {
269    /// Creates a new `OpenAIBuilder` for constructing `OpenAI` instances.
270    pub fn builder() -> OpenAIBuilder {
271        OpenAIBuilder::default()
272    }
273}
274
275impl<C: async_openai::config::Config + Default + Sync + Send + std::fmt::Debug>
276    GenericOpenAIBuilder<C>
277{
278    /// Adds a callback function that will be called when usage information is available.
279    pub fn on_usage<F>(&mut self, func: F) -> &mut Self
280    where
281        F: Fn(&Usage) -> anyhow::Result<()> + Send + Sync + 'static,
282    {
283        let func = Arc::new(func);
284        self.on_usage = Some(Some(Arc::new(move |usage: &Usage| {
285            let func = func.clone();
286            Box::pin(async move { func(usage) })
287        })));
288
289        self
290    }
291
292    /// Adds an asynchronous callback function that will be called when usage information is
293    /// available.
294    pub fn on_usage_async<F>(&mut self, func: F) -> &mut Self
295    where
296        F: for<'a> Fn(
297                &'a Usage,
298            )
299                -> Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>>
300            + Send
301            + Sync
302            + 'static,
303    {
304        let func = Arc::new(func);
305        self.on_usage = Some(Some(Arc::new(move |usage: &Usage| {
306            let func = func.clone();
307            Box::pin(async move { func(usage).await })
308        })));
309
310        self
311    }
312    /// Sets the `OpenAI` client for the `OpenAI` instance.
313    ///
314    /// # Parameters
315    /// - `client`: The `OpenAI` client to set.
316    ///
317    /// # Returns
318    /// A mutable reference to the `OpenAIBuilder`.
319    pub fn client(&mut self, client: async_openai::Client<C>) -> &mut Self {
320        self.client = Some(Arc::new(client));
321        self
322    }
323
324    /// Sets the default embedding model for the `OpenAI` instance.
325    ///
326    /// # Parameters
327    /// - `model`: The embedding model to set.
328    ///
329    /// # Returns
330    /// A mutable reference to the `OpenAIBuilder`.
331    pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
332        if let Some(options) = self.default_options.as_mut() {
333            options.embed_model = Some(model.into());
334        } else {
335            self.default_options = Some(Options {
336                embed_model: Some(model.into()),
337                ..Default::default()
338            });
339        }
340        self
341    }
342
343    /// Sets the `user` field used by `OpenAI` to monitor and detect usage and abuse.
344    pub fn for_end_user(&mut self, user: impl Into<String>) -> &mut Self {
345        if let Some(options) = self.default_options.as_mut() {
346            options.user = Some(user.into());
347        } else {
348            self.default_options = Some(Options {
349                user: Some(user.into()),
350                ..Default::default()
351            });
352        }
353        self
354    }
355
356    /// Enable or disable parallel tool calls for completions.
357    ///
358    /// Note that currently reasoning models do not support parallel tool calls
359    ///
360    /// Defaults to `true`
361    pub fn parallel_tool_calls(&mut self, parallel_tool_calls: Option<bool>) -> &mut Self {
362        if let Some(options) = self.default_options.as_mut() {
363            options.parallel_tool_calls = parallel_tool_calls;
364        } else {
365            self.default_options = Some(Options {
366                parallel_tool_calls,
367                ..Default::default()
368            });
369        }
370        self
371    }
372
373    /// Sets the default prompt model for the `OpenAI` instance.
374    ///
375    /// # Parameters
376    /// - `model`: The prompt model to set.
377    ///
378    /// # Returns
379    /// A mutable reference to the `OpenAIBuilder`.
380    pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
381        if let Some(options) = self.default_options.as_mut() {
382            options.prompt_model = Some(model.into());
383        } else {
384            self.default_options = Some(Options {
385                prompt_model: Some(model.into()),
386                ..Default::default()
387            });
388        }
389        self
390    }
391
392    /// Sets the default options to use for requests to the `OpenAI` API.
393    ///
394    /// Merges with any existing options
395    pub fn default_options(&mut self, options: impl Into<Options>) -> &mut Self {
396        if let Some(existing_options) = self.default_options.as_mut() {
397            existing_options.merge(&options.into());
398        } else {
399            self.default_options = Some(options.into());
400        }
401        self
402    }
403}
404
405impl<C: async_openai::config::Config + Default> GenericOpenAI<C> {
406    /// Estimates the number of tokens for implementors of the `Estimatable` trait.
407    ///
408    /// I.e. `String`, `ChatMessage` etc
409    ///
410    /// # Errors
411    ///
412    /// Errors if tokinization fails in any way
413    #[cfg(feature = "tiktoken")]
414    pub async fn estimate_tokens(&self, value: impl Estimatable) -> Result<usize> {
415        self.tiktoken.estimate(value).await
416    }
417
418    pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
419        self.default_options = Options {
420            prompt_model: Some(model.into()),
421            ..self.default_options.clone()
422        };
423        self
424    }
425
426    pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
427        self.default_options = Options {
428            embed_model: Some(model.into()),
429            ..self.default_options.clone()
430        };
431        self
432    }
433
434    /// Retrieve a reference to the inner `OpenAI` client.
435    pub fn client(&self) -> &Arc<async_openai::Client<C>> {
436        &self.client
437    }
438
439    /// Retrieve a reference to the default options for the `OpenAI` instance.
440    pub fn options(&self) -> &Options {
441        &self.default_options
442    }
443
444    /// Retrieve a mutable reference to the default options for the `OpenAI` instance.
445    pub fn options_mut(&mut self) -> &mut Options {
446        &mut self.default_options
447    }
448
449    fn chat_completion_request_defaults(&self) -> CreateChatCompletionRequestArgs {
450        let mut args = CreateChatCompletionRequestArgs::default();
451
452        let options = &self.default_options;
453
454        if let Some(parallel_tool_calls) = options.parallel_tool_calls {
455            args.parallel_tool_calls(parallel_tool_calls);
456        }
457
458        if let Some(max_tokens) = options.max_completion_tokens {
459            args.max_completion_tokens(max_tokens);
460        }
461
462        if let Some(temperature) = options.temperature {
463            args.temperature(temperature);
464        }
465
466        if let Some(reasoning_effort) = &options.reasoning_effort {
467            args.reasoning_effort(reasoning_effort.clone());
468        }
469
470        if let Some(seed) = options.seed {
471            args.seed(seed);
472        }
473
474        if let Some(presence_penalty) = options.presence_penalty {
475            args.presence_penalty(presence_penalty);
476        }
477
478        if let Some(metadata) = &options.metadata {
479            args.metadata(metadata.clone());
480        }
481
482        if let Some(user) = &options.user {
483            args.user(user.clone());
484        }
485
486        args
487    }
488
489    fn embed_request_defaults(&self) -> CreateEmbeddingRequestArgs {
490        let mut args = CreateEmbeddingRequestArgs::default();
491
492        let options = &self.default_options;
493
494        if let Some(user) = &options.user {
495            args.user(user.clone());
496        }
497
498        if let Some(dimensions) = options.dimensions {
499            args.dimensions(dimensions);
500        }
501
502        args
503    }
504}
505
506pub fn openai_error_to_language_model_error(e: OpenAIError) -> LanguageModelError {
507    match e {
508        OpenAIError::ApiError(api_error) => {
509            // If the response is an ApiError, it could be a context length exceeded error
510            if api_error.code == Some("context_length_exceeded".to_string()) {
511                LanguageModelError::context_length_exceeded(OpenAIError::ApiError(api_error))
512            } else {
513                LanguageModelError::permanent(OpenAIError::ApiError(api_error))
514            }
515        }
516        OpenAIError::Reqwest(e) => {
517            // async_openai passes any network errors as reqwest errors, so we just assume they are
518            // recoverable
519            LanguageModelError::transient(e)
520        }
521        OpenAIError::JSONDeserialize(_) => {
522            // OpenAI generated a non-json response, probably a temporary problem on their side
523            // (i.e. reverse proxy can't find an available backend)
524            LanguageModelError::transient(e)
525        }
526        OpenAIError::StreamError(e) => {
527            // Note that this will _retry_ the stream. We have to assume that the stream just
528            // started if a 429 happens. For future readers, internally clients streaming crate
529            // (eventsource), has a backoff mechanism as well
530            if e.contains("Too Many Requests") {
531                LanguageModelError::transient(e)
532            } else {
533                LanguageModelError::permanent(e)
534            }
535        }
536        OpenAIError::FileSaveError(_)
537        | OpenAIError::FileReadError(_)
538        | OpenAIError::InvalidArgument(_) => LanguageModelError::permanent(e),
539    }
540}
541
542#[cfg(test)]
543mod test {
544    use super::*;
545    use async_openai::error::{ApiError, OpenAIError};
546
547    /// test default embed model
548    #[test]
549    fn test_default_embed_and_prompt_model() {
550        let openai: OpenAI = OpenAI::builder()
551            .default_embed_model("gpt-3")
552            .default_prompt_model("gpt-4")
553            .build()
554            .unwrap();
555        assert_eq!(
556            openai.default_options.embed_model,
557            Some("gpt-3".to_string())
558        );
559        assert_eq!(
560            openai.default_options.prompt_model,
561            Some("gpt-4".to_string())
562        );
563
564        let openai: OpenAI = OpenAI::builder()
565            .default_prompt_model("gpt-4")
566            .default_embed_model("gpt-3")
567            .build()
568            .unwrap();
569        assert_eq!(
570            openai.default_options.prompt_model,
571            Some("gpt-4".to_string())
572        );
573        assert_eq!(
574            openai.default_options.embed_model,
575            Some("gpt-3".to_string())
576        );
577    }
578
579    #[test]
580    fn test_context_length_exceeded_error() {
581        // Create an API error with the context_length_exceeded code
582        let api_error = ApiError {
583            message: "This model's maximum context length is 8192 tokens".to_string(),
584            r#type: Some("invalid_request_error".to_string()),
585            param: Some("messages".to_string()),
586            code: Some("context_length_exceeded".to_string()),
587        };
588
589        let openai_error = OpenAIError::ApiError(api_error);
590        let result = openai_error_to_language_model_error(openai_error);
591
592        // Verify it's categorized as ContextLengthExceeded
593        match result {
594            LanguageModelError::ContextLengthExceeded(_) => {} // Expected
595            _ => panic!("Expected ContextLengthExceeded error, got {result:?}"),
596        }
597    }
598
599    #[test]
600    fn test_api_error_permanent() {
601        // Create a generic API error (not context length exceeded)
602        let api_error = ApiError {
603            message: "Invalid API key".to_string(),
604            r#type: Some("invalid_request_error".to_string()),
605            param: Some("api_key".to_string()),
606            code: Some("invalid_api_key".to_string()),
607        };
608
609        let openai_error = OpenAIError::ApiError(api_error);
610        let result = openai_error_to_language_model_error(openai_error);
611
612        // Verify it's categorized as PermanentError
613        match result {
614            LanguageModelError::PermanentError(_) => {} // Expected
615            _ => panic!("Expected PermanentError, got {result:?}"),
616        }
617    }
618
619    #[test]
620    fn test_file_save_error_is_permanent() {
621        // Create a file save error
622        let openai_error = OpenAIError::FileSaveError("Failed to save file".to_string());
623        let result = openai_error_to_language_model_error(openai_error);
624
625        // Verify it's categorized as PermanentError
626        match result {
627            LanguageModelError::PermanentError(_) => {} // Expected
628            _ => panic!("Expected PermanentError, got {result:?}"),
629        }
630    }
631
632    #[test]
633    fn test_file_read_error_is_permanent() {
634        // Create a file read error
635        let openai_error = OpenAIError::FileReadError("Failed to read file".to_string());
636        let result = openai_error_to_language_model_error(openai_error);
637
638        // Verify it's categorized as PermanentError
639        match result {
640            LanguageModelError::PermanentError(_) => {} // Expected
641            _ => panic!("Expected PermanentError, got {result:?}"),
642        }
643    }
644
645    #[test]
646    fn test_stream_error_is_permanent() {
647        // Create a stream error
648        let openai_error = OpenAIError::StreamError("Stream failed".to_string());
649        let result = openai_error_to_language_model_error(openai_error);
650
651        // Verify it's categorized as PermanentError
652        match result {
653            LanguageModelError::PermanentError(_) => {} // Expected
654            _ => panic!("Expected PermanentError, got {result:?}"),
655        }
656    }
657
658    #[test]
659    fn test_invalid_argument_is_permanent() {
660        // Create an invalid argument error
661        let openai_error = OpenAIError::InvalidArgument("Invalid argument".to_string());
662        let result = openai_error_to_language_model_error(openai_error);
663
664        // Verify it's categorized as PermanentError
665        match result {
666            LanguageModelError::PermanentError(_) => {} // Expected
667            _ => panic!("Expected PermanentError, got {result:?}"),
668        }
669    }
670}