swiftide_integrations/openai/
mod.rs1use async_openai::error::OpenAIError;
7use async_openai::types::CreateChatCompletionRequestArgs;
8use async_openai::types::CreateEmbeddingRequestArgs;
9use async_openai::types::ReasoningEffort;
10use derive_builder::Builder;
11use std::pin::Pin;
12use std::sync::Arc;
13use swiftide_core::chat_completion::Usage;
14use swiftide_core::chat_completion::errors::LanguageModelError;
15
16mod chat_completion;
17mod embed;
18mod simple_prompt;
19mod structured_prompt;
20
21pub use async_openai::config::AzureConfig;
23pub use async_openai::config::OpenAIConfig;
24
25#[cfg(feature = "tiktoken")]
26use crate::tiktoken::TikToken;
27#[cfg(feature = "tiktoken")]
28use anyhow::Result;
29#[cfg(feature = "tiktoken")]
30use swiftide_core::Estimatable;
31#[cfg(feature = "tiktoken")]
32use swiftide_core::EstimateTokens;
33
34pub type OpenAI = GenericOpenAI<OpenAIConfig>;
69pub type OpenAIBuilder = GenericOpenAIBuilder<OpenAIConfig>;
70
71#[derive(Builder, Clone)]
72#[builder(setter(into, strip_option))]
73pub struct GenericOpenAI<
75 C: async_openai::config::Config + Default = async_openai::config::OpenAIConfig,
76> {
77 #[builder(
80 default = "Arc::new(async_openai::Client::<C>::default())",
81 setter(custom)
82 )]
83 client: Arc<async_openai::Client<C>>,
84
85 #[builder(default, setter(custom))]
87 pub(crate) default_options: Options,
88
89 #[cfg(feature = "tiktoken")]
90 #[cfg_attr(feature = "tiktoken", builder(default))]
91 pub(crate) tiktoken: TikToken,
92
93 #[builder(default = true)]
97 pub stream_full: bool,
98
99 #[cfg(feature = "metrics")]
100 #[builder(default)]
101 metric_metadata: Option<std::collections::HashMap<String, String>>,
103
104 #[builder(default, setter(custom))]
106 #[allow(clippy::type_complexity)]
107 on_usage: Option<
108 Arc<
109 dyn for<'a> Fn(
110 &'a Usage,
111 ) -> Pin<
112 Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>,
113 > + Send
114 + Sync,
115 >,
116 >,
117}
118
119impl<C: async_openai::config::Config + Default + std::fmt::Debug> std::fmt::Debug
120 for GenericOpenAI<C>
121{
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("GenericOpenAI")
124 .field("client", &self.client)
125 .field("default_options", &self.default_options)
126 .field("stream_full", &self.stream_full)
127 .finish_non_exhaustive()
128 }
129}
130
131#[derive(Debug, Clone, Builder, Default)]
134#[builder(setter(strip_option))]
135pub struct Options {
136 #[builder(default, setter(into))]
138 pub embed_model: Option<String>,
139 #[builder(default, setter(into))]
141 pub prompt_model: Option<String>,
142
143 #[builder(default)]
144 pub parallel_tool_calls: Option<bool>,
148
149 #[builder(default)]
153 pub max_completion_tokens: Option<u32>,
154
155 #[builder(default)]
157 pub temperature: Option<f32>,
158
159 #[builder(default, setter(into))]
161 pub reasoning_effort: Option<ReasoningEffort>,
162
163 #[builder(default)]
168 pub seed: Option<i64>,
169
170 #[builder(default)]
173 pub presence_penalty: Option<f32>,
174
175 #[builder(default, setter(into))]
177 pub metadata: Option<serde_json::Value>,
178
179 #[builder(default, setter(into))]
182 pub user: Option<String>,
183
184 #[builder(default)]
185 pub dimensions: Option<u32>,
188}
189
190impl Options {
191 pub fn builder() -> OptionsBuilder {
193 OptionsBuilder::default()
194 }
195
196 pub fn merge(&mut self, other: &Options) {
198 if let Some(embed_model) = &other.embed_model {
199 self.embed_model = Some(embed_model.clone());
200 }
201 if let Some(prompt_model) = &other.prompt_model {
202 self.prompt_model = Some(prompt_model.clone());
203 }
204 if let Some(parallel_tool_calls) = other.parallel_tool_calls {
205 self.parallel_tool_calls = Some(parallel_tool_calls);
206 }
207 if let Some(max_completion_tokens) = other.max_completion_tokens {
208 self.max_completion_tokens = Some(max_completion_tokens);
209 }
210 if let Some(temperature) = other.temperature {
211 self.temperature = Some(temperature);
212 }
213 if let Some(reasoning_effort) = &other.reasoning_effort {
214 self.reasoning_effort = Some(reasoning_effort.clone());
215 }
216 if let Some(seed) = other.seed {
217 self.seed = Some(seed);
218 }
219 if let Some(presence_penalty) = other.presence_penalty {
220 self.presence_penalty = Some(presence_penalty);
221 }
222 if let Some(metadata) = &other.metadata {
223 self.metadata = Some(metadata.clone());
224 }
225 if let Some(user) = &other.user {
226 self.user = Some(user.clone());
227 }
228 }
229}
230
231impl From<OptionsBuilder> for Options {
232 fn from(value: OptionsBuilder) -> Self {
233 Self {
234 embed_model: value.embed_model.flatten(),
235 prompt_model: value.prompt_model.flatten(),
236 parallel_tool_calls: value.parallel_tool_calls.flatten(),
237 max_completion_tokens: value.max_completion_tokens.flatten(),
238 temperature: value.temperature.flatten(),
239 reasoning_effort: value.reasoning_effort.flatten(),
240 presence_penalty: value.presence_penalty.flatten(),
241 seed: value.seed.flatten(),
242 metadata: value.metadata.flatten(),
243 user: value.user.flatten(),
244 dimensions: value.dimensions.flatten(),
245 }
246 }
247}
248
249impl From<&mut OptionsBuilder> for Options {
250 fn from(value: &mut OptionsBuilder) -> Self {
251 let value = value.clone();
252 Self {
253 embed_model: value.embed_model.flatten(),
254 prompt_model: value.prompt_model.flatten(),
255 parallel_tool_calls: value.parallel_tool_calls.flatten(),
256 max_completion_tokens: value.max_completion_tokens.flatten(),
257 temperature: value.temperature.flatten(),
258 reasoning_effort: value.reasoning_effort.flatten(),
259 presence_penalty: value.presence_penalty.flatten(),
260 seed: value.seed.flatten(),
261 metadata: value.metadata.flatten(),
262 user: value.user.flatten(),
263 dimensions: value.dimensions.flatten(),
264 }
265 }
266}
267
268impl OpenAI {
269 pub fn builder() -> OpenAIBuilder {
271 OpenAIBuilder::default()
272 }
273}
274
275impl<C: async_openai::config::Config + Default + Sync + Send + std::fmt::Debug>
276 GenericOpenAIBuilder<C>
277{
278 pub fn on_usage<F>(&mut self, func: F) -> &mut Self
280 where
281 F: Fn(&Usage) -> anyhow::Result<()> + Send + Sync + 'static,
282 {
283 let func = Arc::new(func);
284 self.on_usage = Some(Some(Arc::new(move |usage: &Usage| {
285 let func = func.clone();
286 Box::pin(async move { func(usage) })
287 })));
288
289 self
290 }
291
292 pub fn on_usage_async<F>(&mut self, func: F) -> &mut Self
295 where
296 F: for<'a> Fn(
297 &'a Usage,
298 )
299 -> Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>>
300 + Send
301 + Sync
302 + 'static,
303 {
304 let func = Arc::new(func);
305 self.on_usage = Some(Some(Arc::new(move |usage: &Usage| {
306 let func = func.clone();
307 Box::pin(async move { func(usage).await })
308 })));
309
310 self
311 }
312 pub fn client(&mut self, client: async_openai::Client<C>) -> &mut Self {
320 self.client = Some(Arc::new(client));
321 self
322 }
323
324 pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
332 if let Some(options) = self.default_options.as_mut() {
333 options.embed_model = Some(model.into());
334 } else {
335 self.default_options = Some(Options {
336 embed_model: Some(model.into()),
337 ..Default::default()
338 });
339 }
340 self
341 }
342
343 pub fn for_end_user(&mut self, user: impl Into<String>) -> &mut Self {
345 if let Some(options) = self.default_options.as_mut() {
346 options.user = Some(user.into());
347 } else {
348 self.default_options = Some(Options {
349 user: Some(user.into()),
350 ..Default::default()
351 });
352 }
353 self
354 }
355
356 pub fn parallel_tool_calls(&mut self, parallel_tool_calls: Option<bool>) -> &mut Self {
362 if let Some(options) = self.default_options.as_mut() {
363 options.parallel_tool_calls = parallel_tool_calls;
364 } else {
365 self.default_options = Some(Options {
366 parallel_tool_calls,
367 ..Default::default()
368 });
369 }
370 self
371 }
372
373 pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
381 if let Some(options) = self.default_options.as_mut() {
382 options.prompt_model = Some(model.into());
383 } else {
384 self.default_options = Some(Options {
385 prompt_model: Some(model.into()),
386 ..Default::default()
387 });
388 }
389 self
390 }
391
392 pub fn default_options(&mut self, options: impl Into<Options>) -> &mut Self {
396 if let Some(existing_options) = self.default_options.as_mut() {
397 existing_options.merge(&options.into());
398 } else {
399 self.default_options = Some(options.into());
400 }
401 self
402 }
403}
404
405impl<C: async_openai::config::Config + Default> GenericOpenAI<C> {
406 #[cfg(feature = "tiktoken")]
414 pub async fn estimate_tokens(&self, value: impl Estimatable) -> Result<usize> {
415 self.tiktoken.estimate(value).await
416 }
417
418 pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
419 self.default_options = Options {
420 prompt_model: Some(model.into()),
421 ..self.default_options.clone()
422 };
423 self
424 }
425
426 pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
427 self.default_options = Options {
428 embed_model: Some(model.into()),
429 ..self.default_options.clone()
430 };
431 self
432 }
433
434 pub fn client(&self) -> &Arc<async_openai::Client<C>> {
436 &self.client
437 }
438
439 pub fn options(&self) -> &Options {
441 &self.default_options
442 }
443
444 pub fn options_mut(&mut self) -> &mut Options {
446 &mut self.default_options
447 }
448
449 fn chat_completion_request_defaults(&self) -> CreateChatCompletionRequestArgs {
450 let mut args = CreateChatCompletionRequestArgs::default();
451
452 let options = &self.default_options;
453
454 if let Some(parallel_tool_calls) = options.parallel_tool_calls {
455 args.parallel_tool_calls(parallel_tool_calls);
456 }
457
458 if let Some(max_tokens) = options.max_completion_tokens {
459 args.max_completion_tokens(max_tokens);
460 }
461
462 if let Some(temperature) = options.temperature {
463 args.temperature(temperature);
464 }
465
466 if let Some(reasoning_effort) = &options.reasoning_effort {
467 args.reasoning_effort(reasoning_effort.clone());
468 }
469
470 if let Some(seed) = options.seed {
471 args.seed(seed);
472 }
473
474 if let Some(presence_penalty) = options.presence_penalty {
475 args.presence_penalty(presence_penalty);
476 }
477
478 if let Some(metadata) = &options.metadata {
479 args.metadata(metadata.clone());
480 }
481
482 if let Some(user) = &options.user {
483 args.user(user.clone());
484 }
485
486 args
487 }
488
489 fn embed_request_defaults(&self) -> CreateEmbeddingRequestArgs {
490 let mut args = CreateEmbeddingRequestArgs::default();
491
492 let options = &self.default_options;
493
494 if let Some(user) = &options.user {
495 args.user(user.clone());
496 }
497
498 if let Some(dimensions) = options.dimensions {
499 args.dimensions(dimensions);
500 }
501
502 args
503 }
504}
505
506pub fn openai_error_to_language_model_error(e: OpenAIError) -> LanguageModelError {
507 match e {
508 OpenAIError::ApiError(api_error) => {
509 if api_error.code == Some("context_length_exceeded".to_string()) {
511 LanguageModelError::context_length_exceeded(OpenAIError::ApiError(api_error))
512 } else {
513 LanguageModelError::permanent(OpenAIError::ApiError(api_error))
514 }
515 }
516 OpenAIError::Reqwest(e) => {
517 LanguageModelError::transient(e)
520 }
521 OpenAIError::JSONDeserialize(_) => {
522 LanguageModelError::transient(e)
525 }
526 OpenAIError::StreamError(e) => {
527 if e.contains("Too Many Requests") {
531 LanguageModelError::transient(e)
532 } else {
533 LanguageModelError::permanent(e)
534 }
535 }
536 OpenAIError::FileSaveError(_)
537 | OpenAIError::FileReadError(_)
538 | OpenAIError::InvalidArgument(_) => LanguageModelError::permanent(e),
539 }
540}
541
542#[cfg(test)]
543mod test {
544 use super::*;
545 use async_openai::error::{ApiError, OpenAIError};
546
547 #[test]
549 fn test_default_embed_and_prompt_model() {
550 let openai: OpenAI = OpenAI::builder()
551 .default_embed_model("gpt-3")
552 .default_prompt_model("gpt-4")
553 .build()
554 .unwrap();
555 assert_eq!(
556 openai.default_options.embed_model,
557 Some("gpt-3".to_string())
558 );
559 assert_eq!(
560 openai.default_options.prompt_model,
561 Some("gpt-4".to_string())
562 );
563
564 let openai: OpenAI = OpenAI::builder()
565 .default_prompt_model("gpt-4")
566 .default_embed_model("gpt-3")
567 .build()
568 .unwrap();
569 assert_eq!(
570 openai.default_options.prompt_model,
571 Some("gpt-4".to_string())
572 );
573 assert_eq!(
574 openai.default_options.embed_model,
575 Some("gpt-3".to_string())
576 );
577 }
578
579 #[test]
580 fn test_context_length_exceeded_error() {
581 let api_error = ApiError {
583 message: "This model's maximum context length is 8192 tokens".to_string(),
584 r#type: Some("invalid_request_error".to_string()),
585 param: Some("messages".to_string()),
586 code: Some("context_length_exceeded".to_string()),
587 };
588
589 let openai_error = OpenAIError::ApiError(api_error);
590 let result = openai_error_to_language_model_error(openai_error);
591
592 match result {
594 LanguageModelError::ContextLengthExceeded(_) => {} _ => panic!("Expected ContextLengthExceeded error, got {result:?}"),
596 }
597 }
598
599 #[test]
600 fn test_api_error_permanent() {
601 let api_error = ApiError {
603 message: "Invalid API key".to_string(),
604 r#type: Some("invalid_request_error".to_string()),
605 param: Some("api_key".to_string()),
606 code: Some("invalid_api_key".to_string()),
607 };
608
609 let openai_error = OpenAIError::ApiError(api_error);
610 let result = openai_error_to_language_model_error(openai_error);
611
612 match result {
614 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
616 }
617 }
618
619 #[test]
620 fn test_file_save_error_is_permanent() {
621 let openai_error = OpenAIError::FileSaveError("Failed to save file".to_string());
623 let result = openai_error_to_language_model_error(openai_error);
624
625 match result {
627 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
629 }
630 }
631
632 #[test]
633 fn test_file_read_error_is_permanent() {
634 let openai_error = OpenAIError::FileReadError("Failed to read file".to_string());
636 let result = openai_error_to_language_model_error(openai_error);
637
638 match result {
640 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
642 }
643 }
644
645 #[test]
646 fn test_stream_error_is_permanent() {
647 let openai_error = OpenAIError::StreamError("Stream failed".to_string());
649 let result = openai_error_to_language_model_error(openai_error);
650
651 match result {
653 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
655 }
656 }
657
658 #[test]
659 fn test_invalid_argument_is_permanent() {
660 let openai_error = OpenAIError::InvalidArgument("Invalid argument".to_string());
662 let result = openai_error_to_language_model_error(openai_error);
663
664 match result {
666 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
668 }
669 }
670}