swiftide_integrations/openai/
mod.rs1use async_openai::error::OpenAIError;
7use async_openai::types::CreateChatCompletionRequestArgs;
8use async_openai::types::CreateEmbeddingRequestArgs;
9use async_openai::types::ReasoningEffort;
10use derive_builder::Builder;
11use std::sync::Arc;
12use swiftide_core::chat_completion::errors::LanguageModelError;
13
14mod chat_completion;
15mod embed;
16mod simple_prompt;
17
18pub use async_openai::config::AzureConfig;
20pub use async_openai::config::OpenAIConfig;
21
22#[cfg(feature = "tiktoken")]
23use crate::tiktoken::TikToken;
24#[cfg(feature = "tiktoken")]
25use anyhow::Result;
26#[cfg(feature = "tiktoken")]
27use swiftide_core::Estimatable;
28#[cfg(feature = "tiktoken")]
29use swiftide_core::EstimateTokens;
30
31pub type OpenAI = GenericOpenAI<OpenAIConfig>;
66pub type OpenAIBuilder = GenericOpenAIBuilder<OpenAIConfig>;
67
68#[derive(Debug, Builder, Clone)]
69#[builder(setter(into, strip_option))]
70pub struct GenericOpenAI<
72 C: async_openai::config::Config + Default = async_openai::config::OpenAIConfig,
73> {
74 #[builder(
77 default = "Arc::new(async_openai::Client::<C>::default())",
78 setter(custom)
79 )]
80 client: Arc<async_openai::Client<C>>,
81
82 #[builder(default, setter(custom))]
84 pub(crate) default_options: Options,
85
86 #[cfg(feature = "tiktoken")]
87 #[cfg_attr(feature = "tiktoken", builder(default))]
88 pub(crate) tiktoken: TikToken,
89
90 #[builder(default = true)]
94 pub stream_full: bool,
95
96 #[cfg(feature = "metrics")]
97 #[builder(default)]
98 metric_metadata: Option<std::collections::HashMap<String, String>>,
100}
101
102#[derive(Debug, Clone, Builder, Default)]
105#[builder(setter(strip_option))]
106pub struct Options {
107 #[builder(default, setter(into))]
109 pub embed_model: Option<String>,
110 #[builder(default, setter(into))]
112 pub prompt_model: Option<String>,
113
114 #[builder(default)]
115 pub parallel_tool_calls: Option<bool>,
119
120 #[builder(default)]
124 pub max_completion_tokens: Option<u32>,
125
126 #[builder(default)]
128 pub temperature: Option<f32>,
129
130 #[builder(default, setter(into))]
132 pub reasoning_effort: Option<ReasoningEffort>,
133
134 #[builder(default)]
139 pub seed: Option<i64>,
140
141 #[builder(default)]
144 pub presence_penalty: Option<f32>,
145
146 #[builder(default, setter(into))]
148 pub metadata: Option<serde_json::Value>,
149
150 #[builder(default, setter(into))]
153 pub user: Option<String>,
154
155 #[builder(default)]
156 pub dimensions: Option<u32>,
159}
160
161impl Options {
162 pub fn builder() -> OptionsBuilder {
164 OptionsBuilder::default()
165 }
166
167 pub fn merge(&mut self, other: &Options) {
169 if let Some(embed_model) = &other.embed_model {
170 self.embed_model = Some(embed_model.clone());
171 }
172 if let Some(prompt_model) = &other.prompt_model {
173 self.prompt_model = Some(prompt_model.clone());
174 }
175 if let Some(parallel_tool_calls) = other.parallel_tool_calls {
176 self.parallel_tool_calls = Some(parallel_tool_calls);
177 }
178 if let Some(max_completion_tokens) = other.max_completion_tokens {
179 self.max_completion_tokens = Some(max_completion_tokens);
180 }
181 if let Some(temperature) = other.temperature {
182 self.temperature = Some(temperature);
183 }
184 if let Some(reasoning_effort) = &other.reasoning_effort {
185 self.reasoning_effort = Some(reasoning_effort.clone());
186 }
187 if let Some(seed) = other.seed {
188 self.seed = Some(seed);
189 }
190 if let Some(presence_penalty) = other.presence_penalty {
191 self.presence_penalty = Some(presence_penalty);
192 }
193 if let Some(metadata) = &other.metadata {
194 self.metadata = Some(metadata.clone());
195 }
196 if let Some(user) = &other.user {
197 self.user = Some(user.clone());
198 }
199 }
200}
201
202impl From<OptionsBuilder> for Options {
203 fn from(value: OptionsBuilder) -> Self {
204 Self {
205 embed_model: value.embed_model.flatten(),
206 prompt_model: value.prompt_model.flatten(),
207 parallel_tool_calls: value.parallel_tool_calls.flatten(),
208 max_completion_tokens: value.max_completion_tokens.flatten(),
209 temperature: value.temperature.flatten(),
210 reasoning_effort: value.reasoning_effort.flatten(),
211 presence_penalty: value.presence_penalty.flatten(),
212 seed: value.seed.flatten(),
213 metadata: value.metadata.flatten(),
214 user: value.user.flatten(),
215 dimensions: value.dimensions.flatten(),
216 }
217 }
218}
219
220impl From<&mut OptionsBuilder> for Options {
221 fn from(value: &mut OptionsBuilder) -> Self {
222 let value = value.clone();
223 Self {
224 embed_model: value.embed_model.flatten(),
225 prompt_model: value.prompt_model.flatten(),
226 parallel_tool_calls: value.parallel_tool_calls.flatten(),
227 max_completion_tokens: value.max_completion_tokens.flatten(),
228 temperature: value.temperature.flatten(),
229 reasoning_effort: value.reasoning_effort.flatten(),
230 presence_penalty: value.presence_penalty.flatten(),
231 seed: value.seed.flatten(),
232 metadata: value.metadata.flatten(),
233 user: value.user.flatten(),
234 dimensions: value.dimensions.flatten(),
235 }
236 }
237}
238
239impl OpenAI {
240 pub fn builder() -> OpenAIBuilder {
242 OpenAIBuilder::default()
243 }
244}
245
246impl<C: async_openai::config::Config + Default + Sync + Send + std::fmt::Debug>
247 GenericOpenAIBuilder<C>
248{
249 pub fn client(&mut self, client: async_openai::Client<C>) -> &mut Self {
257 self.client = Some(Arc::new(client));
258 self
259 }
260
261 pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
269 if let Some(options) = self.default_options.as_mut() {
270 options.embed_model = Some(model.into());
271 } else {
272 self.default_options = Some(Options {
273 embed_model: Some(model.into()),
274 ..Default::default()
275 });
276 }
277 self
278 }
279
280 pub fn parallel_tool_calls(&mut self, parallel_tool_calls: Option<bool>) -> &mut Self {
286 if let Some(options) = self.default_options.as_mut() {
287 options.parallel_tool_calls = parallel_tool_calls;
288 } else {
289 self.default_options = Some(Options {
290 parallel_tool_calls,
291 ..Default::default()
292 });
293 }
294 self
295 }
296
297 pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
305 if let Some(options) = self.default_options.as_mut() {
306 options.prompt_model = Some(model.into());
307 } else {
308 self.default_options = Some(Options {
309 prompt_model: Some(model.into()),
310 ..Default::default()
311 });
312 }
313 self
314 }
315
316 pub fn default_options(&mut self, options: impl Into<Options>) -> &mut Self {
320 if let Some(existing_options) = self.default_options.as_mut() {
321 existing_options.merge(&options.into());
322 } else {
323 self.default_options = Some(options.into());
324 }
325 self
326 }
327}
328
329impl<C: async_openai::config::Config + Default> GenericOpenAI<C> {
330 #[cfg(feature = "tiktoken")]
338 pub async fn estimate_tokens(&self, value: impl Estimatable) -> Result<usize> {
339 self.tiktoken.estimate(value).await
340 }
341
342 pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
343 self.default_options = Options {
344 prompt_model: Some(model.into()),
345 ..self.default_options.clone()
346 };
347 self
348 }
349
350 pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
351 self.default_options = Options {
352 embed_model: Some(model.into()),
353 ..self.default_options.clone()
354 };
355 self
356 }
357
358 pub fn client(&self) -> &Arc<async_openai::Client<C>> {
360 &self.client
361 }
362
363 pub fn options(&self) -> &Options {
365 &self.default_options
366 }
367
368 pub fn options_mut(&mut self) -> &mut Options {
370 &mut self.default_options
371 }
372
373 fn chat_completion_request_defaults(&self) -> CreateChatCompletionRequestArgs {
374 let mut args = CreateChatCompletionRequestArgs::default();
375
376 let options = &self.default_options;
377
378 if let Some(parallel_tool_calls) = options.parallel_tool_calls {
379 args.parallel_tool_calls(parallel_tool_calls);
380 }
381
382 if let Some(max_tokens) = options.max_completion_tokens {
383 args.max_completion_tokens(max_tokens);
384 }
385
386 if let Some(temperature) = options.temperature {
387 args.temperature(temperature);
388 }
389
390 if let Some(reasoning_effort) = &options.reasoning_effort {
391 args.reasoning_effort(reasoning_effort.clone());
392 }
393
394 if let Some(seed) = options.seed {
395 args.seed(seed);
396 }
397
398 if let Some(presence_penalty) = options.presence_penalty {
399 args.presence_penalty(presence_penalty);
400 }
401
402 if let Some(metadata) = &options.metadata {
403 args.metadata(metadata.clone());
404 }
405
406 if let Some(user) = &options.user {
407 args.user(user.clone());
408 }
409
410 args
411 }
412
413 fn embed_request_defaults(&self) -> CreateEmbeddingRequestArgs {
414 let mut args = CreateEmbeddingRequestArgs::default();
415
416 let options = &self.default_options;
417
418 if let Some(user) = &options.user {
419 args.user(user.clone());
420 }
421
422 if let Some(dimensions) = options.dimensions {
423 args.dimensions(dimensions);
424 }
425
426 args
427 }
428}
429
430pub fn openai_error_to_language_model_error(e: OpenAIError) -> LanguageModelError {
431 match e {
432 OpenAIError::ApiError(api_error) => {
433 if api_error.code == Some("context_length_exceeded".to_string()) {
435 LanguageModelError::context_length_exceeded(OpenAIError::ApiError(api_error))
436 } else {
437 LanguageModelError::permanent(OpenAIError::ApiError(api_error))
438 }
439 }
440 OpenAIError::Reqwest(e) => {
441 LanguageModelError::transient(e)
444 }
445 OpenAIError::JSONDeserialize(_) => {
446 LanguageModelError::transient(e)
449 }
450 OpenAIError::FileSaveError(_)
451 | OpenAIError::FileReadError(_)
452 | OpenAIError::StreamError(_)
453 | OpenAIError::InvalidArgument(_) => LanguageModelError::permanent(e),
454 }
455}
456
457#[cfg(test)]
458mod test {
459 use super::*;
460 use async_openai::error::{ApiError, OpenAIError};
461
462 #[test]
464 fn test_default_embed_and_prompt_model() {
465 let openai: OpenAI = OpenAI::builder()
466 .default_embed_model("gpt-3")
467 .default_prompt_model("gpt-4")
468 .build()
469 .unwrap();
470 assert_eq!(
471 openai.default_options.embed_model,
472 Some("gpt-3".to_string())
473 );
474 assert_eq!(
475 openai.default_options.prompt_model,
476 Some("gpt-4".to_string())
477 );
478
479 let openai: OpenAI = OpenAI::builder()
480 .default_prompt_model("gpt-4")
481 .default_embed_model("gpt-3")
482 .build()
483 .unwrap();
484 assert_eq!(
485 openai.default_options.prompt_model,
486 Some("gpt-4".to_string())
487 );
488 assert_eq!(
489 openai.default_options.embed_model,
490 Some("gpt-3".to_string())
491 );
492 }
493
494 #[test]
495 fn test_context_length_exceeded_error() {
496 let api_error = ApiError {
498 message: "This model's maximum context length is 8192 tokens".to_string(),
499 r#type: Some("invalid_request_error".to_string()),
500 param: Some("messages".to_string()),
501 code: Some("context_length_exceeded".to_string()),
502 };
503
504 let openai_error = OpenAIError::ApiError(api_error);
505 let result = openai_error_to_language_model_error(openai_error);
506
507 match result {
509 LanguageModelError::ContextLengthExceeded(_) => {} _ => panic!("Expected ContextLengthExceeded error, got {result:?}"),
511 }
512 }
513
514 #[test]
515 fn test_api_error_permanent() {
516 let api_error = ApiError {
518 message: "Invalid API key".to_string(),
519 r#type: Some("invalid_request_error".to_string()),
520 param: Some("api_key".to_string()),
521 code: Some("invalid_api_key".to_string()),
522 };
523
524 let openai_error = OpenAIError::ApiError(api_error);
525 let result = openai_error_to_language_model_error(openai_error);
526
527 match result {
529 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
531 }
532 }
533
534 #[test]
535 fn test_file_save_error_is_permanent() {
536 let openai_error = OpenAIError::FileSaveError("Failed to save file".to_string());
538 let result = openai_error_to_language_model_error(openai_error);
539
540 match result {
542 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
544 }
545 }
546
547 #[test]
548 fn test_file_read_error_is_permanent() {
549 let openai_error = OpenAIError::FileReadError("Failed to read file".to_string());
551 let result = openai_error_to_language_model_error(openai_error);
552
553 match result {
555 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
557 }
558 }
559
560 #[test]
561 fn test_stream_error_is_permanent() {
562 let openai_error = OpenAIError::StreamError("Stream failed".to_string());
564 let result = openai_error_to_language_model_error(openai_error);
565
566 match result {
568 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
570 }
571 }
572
573 #[test]
574 fn test_invalid_argument_is_permanent() {
575 let openai_error = OpenAIError::InvalidArgument("Invalid argument".to_string());
577 let result = openai_error_to_language_model_error(openai_error);
578
579 match result {
581 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
583 }
584 }
585}