swiftide_integrations/openai/
mod.rs1use async_openai::error::OpenAIError;
7use async_openai::types::CreateChatCompletionRequestArgs;
8use async_openai::types::CreateEmbeddingRequestArgs;
9use async_openai::types::ReasoningEffort;
10use derive_builder::Builder;
11use std::sync::Arc;
12use swiftide_core::chat_completion::errors::LanguageModelError;
13
14mod chat_completion;
15mod embed;
16mod simple_prompt;
17
18pub use async_openai::config::AzureConfig;
20pub use async_openai::config::OpenAIConfig;
21
22#[cfg(feature = "tiktoken")]
23use crate::tiktoken::TikToken;
24#[cfg(feature = "tiktoken")]
25use anyhow::Result;
26#[cfg(feature = "tiktoken")]
27use swiftide_core::Estimatable;
28#[cfg(feature = "tiktoken")]
29use swiftide_core::EstimateTokens;
30
31pub type OpenAI = GenericOpenAI<OpenAIConfig>;
66pub type OpenAIBuilder = GenericOpenAIBuilder<OpenAIConfig>;
67
68#[derive(Debug, Builder, Clone)]
69#[builder(setter(into, strip_option))]
70pub struct GenericOpenAI<
72 C: async_openai::config::Config + Default = async_openai::config::OpenAIConfig,
73> {
74 #[builder(
77 default = "Arc::new(async_openai::Client::<C>::default())",
78 setter(custom)
79 )]
80 client: Arc<async_openai::Client<C>>,
81
82 #[builder(default, setter(custom))]
84 pub(crate) default_options: Options,
85
86 #[cfg(feature = "tiktoken")]
87 #[cfg_attr(feature = "tiktoken", builder(default))]
88 pub(crate) tiktoken: TikToken,
89
90 #[builder(default = true)]
94 pub stream_full: bool,
95}
96
97#[derive(Debug, Clone, Builder, Default)]
100#[builder(setter(strip_option))]
101pub struct Options {
102 #[builder(default, setter(into))]
104 pub embed_model: Option<String>,
105 #[builder(default, setter(into))]
107 pub prompt_model: Option<String>,
108
109 #[builder(default)]
110 pub parallel_tool_calls: Option<bool>,
114
115 #[builder(default)]
119 pub max_completion_tokens: Option<u32>,
120
121 #[builder(default)]
123 pub temperature: Option<f32>,
124
125 #[builder(default, setter(into))]
127 pub reasoning_effort: Option<ReasoningEffort>,
128
129 #[builder(default)]
134 pub seed: Option<i64>,
135
136 #[builder(default)]
139 pub presence_penalty: Option<f32>,
140
141 #[builder(default, setter(into))]
143 pub metadata: Option<serde_json::Value>,
144
145 #[builder(default, setter(into))]
148 pub user: Option<String>,
149
150 #[builder(default)]
151 pub dimensions: Option<u32>,
154}
155
156impl Options {
157 pub fn builder() -> OptionsBuilder {
159 OptionsBuilder::default()
160 }
161
162 pub fn merge(&mut self, other: &Options) {
164 if let Some(embed_model) = &other.embed_model {
165 self.embed_model = Some(embed_model.clone());
166 }
167 if let Some(prompt_model) = &other.prompt_model {
168 self.prompt_model = Some(prompt_model.clone());
169 }
170 if let Some(parallel_tool_calls) = other.parallel_tool_calls {
171 self.parallel_tool_calls = Some(parallel_tool_calls);
172 }
173 if let Some(max_completion_tokens) = other.max_completion_tokens {
174 self.max_completion_tokens = Some(max_completion_tokens);
175 }
176 if let Some(temperature) = other.temperature {
177 self.temperature = Some(temperature);
178 }
179 if let Some(reasoning_effort) = &other.reasoning_effort {
180 self.reasoning_effort = Some(reasoning_effort.clone());
181 }
182 if let Some(seed) = other.seed {
183 self.seed = Some(seed);
184 }
185 if let Some(presence_penalty) = other.presence_penalty {
186 self.presence_penalty = Some(presence_penalty);
187 }
188 if let Some(metadata) = &other.metadata {
189 self.metadata = Some(metadata.clone());
190 }
191 if let Some(user) = &other.user {
192 self.user = Some(user.clone());
193 }
194 }
195}
196
197impl From<OptionsBuilder> for Options {
198 fn from(value: OptionsBuilder) -> Self {
199 Self {
200 embed_model: value.embed_model.flatten(),
201 prompt_model: value.prompt_model.flatten(),
202 parallel_tool_calls: value.parallel_tool_calls.flatten(),
203 max_completion_tokens: value.max_completion_tokens.flatten(),
204 temperature: value.temperature.flatten(),
205 reasoning_effort: value.reasoning_effort.flatten(),
206 presence_penalty: value.presence_penalty.flatten(),
207 seed: value.seed.flatten(),
208 metadata: value.metadata.flatten(),
209 user: value.user.flatten(),
210 dimensions: value.dimensions.flatten(),
211 }
212 }
213}
214
215impl From<&mut OptionsBuilder> for Options {
216 fn from(value: &mut OptionsBuilder) -> Self {
217 let value = value.clone();
218 Self {
219 embed_model: value.embed_model.flatten(),
220 prompt_model: value.prompt_model.flatten(),
221 parallel_tool_calls: value.parallel_tool_calls.flatten(),
222 max_completion_tokens: value.max_completion_tokens.flatten(),
223 temperature: value.temperature.flatten(),
224 reasoning_effort: value.reasoning_effort.flatten(),
225 presence_penalty: value.presence_penalty.flatten(),
226 seed: value.seed.flatten(),
227 metadata: value.metadata.flatten(),
228 user: value.user.flatten(),
229 dimensions: value.dimensions.flatten(),
230 }
231 }
232}
233
234impl OpenAI {
235 pub fn builder() -> OpenAIBuilder {
237 OpenAIBuilder::default()
238 }
239}
240
241impl<C: async_openai::config::Config + Default + Sync + Send + std::fmt::Debug>
242 GenericOpenAIBuilder<C>
243{
244 pub fn client(&mut self, client: async_openai::Client<C>) -> &mut Self {
252 self.client = Some(Arc::new(client));
253 self
254 }
255
256 pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
264 if let Some(options) = self.default_options.as_mut() {
265 options.embed_model = Some(model.into());
266 } else {
267 self.default_options = Some(Options {
268 embed_model: Some(model.into()),
269 ..Default::default()
270 });
271 }
272 self
273 }
274
275 pub fn parallel_tool_calls(&mut self, parallel_tool_calls: Option<bool>) -> &mut Self {
281 if let Some(options) = self.default_options.as_mut() {
282 options.parallel_tool_calls = parallel_tool_calls;
283 } else {
284 self.default_options = Some(Options {
285 parallel_tool_calls,
286 ..Default::default()
287 });
288 }
289 self
290 }
291
292 pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
300 if let Some(options) = self.default_options.as_mut() {
301 options.prompt_model = Some(model.into());
302 } else {
303 self.default_options = Some(Options {
304 prompt_model: Some(model.into()),
305 ..Default::default()
306 });
307 }
308 self
309 }
310
311 pub fn default_options(&mut self, options: impl Into<Options>) -> &mut Self {
315 if let Some(existing_options) = self.default_options.as_mut() {
316 existing_options.merge(&options.into());
317 } else {
318 self.default_options = Some(options.into());
319 }
320 self
321 }
322}
323
324impl<C: async_openai::config::Config + Default> GenericOpenAI<C> {
325 #[cfg(feature = "tiktoken")]
333 pub async fn estimate_tokens(&self, value: impl Estimatable) -> Result<usize> {
334 self.tiktoken.estimate(value).await
335 }
336
337 pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
338 self.default_options = Options {
339 prompt_model: Some(model.into()),
340 ..self.default_options.clone()
341 };
342 self
343 }
344
345 pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
346 self.default_options = Options {
347 embed_model: Some(model.into()),
348 ..self.default_options.clone()
349 };
350 self
351 }
352
353 pub fn client(&self) -> &Arc<async_openai::Client<C>> {
355 &self.client
356 }
357
358 pub fn options(&self) -> &Options {
360 &self.default_options
361 }
362
363 pub fn options_mut(&mut self) -> &mut Options {
365 &mut self.default_options
366 }
367
368 fn chat_completion_request_defaults(&self) -> CreateChatCompletionRequestArgs {
369 let mut args = CreateChatCompletionRequestArgs::default();
370
371 let options = &self.default_options;
372
373 if let Some(parallel_tool_calls) = options.parallel_tool_calls {
374 args.parallel_tool_calls(parallel_tool_calls);
375 }
376
377 if let Some(max_tokens) = options.max_completion_tokens {
378 args.max_completion_tokens(max_tokens);
379 }
380
381 if let Some(temperature) = options.temperature {
382 args.temperature(temperature);
383 }
384
385 if let Some(reasoning_effort) = &options.reasoning_effort {
386 args.reasoning_effort(reasoning_effort.clone());
387 }
388
389 if let Some(seed) = options.seed {
390 args.seed(seed);
391 }
392
393 if let Some(presence_penalty) = options.presence_penalty {
394 args.presence_penalty(presence_penalty);
395 }
396
397 if let Some(metadata) = &options.metadata {
398 args.metadata(metadata.clone());
399 }
400
401 if let Some(user) = &options.user {
402 args.user(user.clone());
403 }
404
405 args
406 }
407
408 fn embed_request_defaults(&self) -> CreateEmbeddingRequestArgs {
409 let mut args = CreateEmbeddingRequestArgs::default();
410
411 let options = &self.default_options;
412
413 if let Some(user) = &options.user {
414 args.user(user.clone());
415 }
416
417 if let Some(dimensions) = options.dimensions {
418 args.dimensions(dimensions);
419 }
420
421 args
422 }
423}
424
425pub fn openai_error_to_language_model_error(e: OpenAIError) -> LanguageModelError {
426 match e {
427 OpenAIError::ApiError(api_error) => {
428 if api_error.code == Some("context_length_exceeded".to_string()) {
430 LanguageModelError::context_length_exceeded(OpenAIError::ApiError(api_error))
431 } else {
432 LanguageModelError::permanent(OpenAIError::ApiError(api_error))
433 }
434 }
435 OpenAIError::Reqwest(e) => {
436 LanguageModelError::transient(e)
439 }
440 OpenAIError::JSONDeserialize(_) => {
441 LanguageModelError::transient(e)
444 }
445 OpenAIError::FileSaveError(_)
446 | OpenAIError::FileReadError(_)
447 | OpenAIError::StreamError(_)
448 | OpenAIError::InvalidArgument(_) => LanguageModelError::permanent(e),
449 }
450}
451
452#[cfg(test)]
453mod test {
454 use super::*;
455 use async_openai::error::{ApiError, OpenAIError};
456
457 #[test]
459 fn test_default_embed_and_prompt_model() {
460 let openai: OpenAI = OpenAI::builder()
461 .default_embed_model("gpt-3")
462 .default_prompt_model("gpt-4")
463 .build()
464 .unwrap();
465 assert_eq!(
466 openai.default_options.embed_model,
467 Some("gpt-3".to_string())
468 );
469 assert_eq!(
470 openai.default_options.prompt_model,
471 Some("gpt-4".to_string())
472 );
473
474 let openai: OpenAI = OpenAI::builder()
475 .default_prompt_model("gpt-4")
476 .default_embed_model("gpt-3")
477 .build()
478 .unwrap();
479 assert_eq!(
480 openai.default_options.prompt_model,
481 Some("gpt-4".to_string())
482 );
483 assert_eq!(
484 openai.default_options.embed_model,
485 Some("gpt-3".to_string())
486 );
487 }
488
489 #[test]
490 fn test_context_length_exceeded_error() {
491 let api_error = ApiError {
493 message: "This model's maximum context length is 8192 tokens".to_string(),
494 r#type: Some("invalid_request_error".to_string()),
495 param: Some("messages".to_string()),
496 code: Some("context_length_exceeded".to_string()),
497 };
498
499 let openai_error = OpenAIError::ApiError(api_error);
500 let result = openai_error_to_language_model_error(openai_error);
501
502 match result {
504 LanguageModelError::ContextLengthExceeded(_) => {} _ => panic!("Expected ContextLengthExceeded error, got {result:?}"),
506 }
507 }
508
509 #[test]
510 fn test_api_error_permanent() {
511 let api_error = ApiError {
513 message: "Invalid API key".to_string(),
514 r#type: Some("invalid_request_error".to_string()),
515 param: Some("api_key".to_string()),
516 code: Some("invalid_api_key".to_string()),
517 };
518
519 let openai_error = OpenAIError::ApiError(api_error);
520 let result = openai_error_to_language_model_error(openai_error);
521
522 match result {
524 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
526 }
527 }
528
529 #[test]
530 fn test_file_save_error_is_permanent() {
531 let openai_error = OpenAIError::FileSaveError("Failed to save file".to_string());
533 let result = openai_error_to_language_model_error(openai_error);
534
535 match result {
537 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
539 }
540 }
541
542 #[test]
543 fn test_file_read_error_is_permanent() {
544 let openai_error = OpenAIError::FileReadError("Failed to read file".to_string());
546 let result = openai_error_to_language_model_error(openai_error);
547
548 match result {
550 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
552 }
553 }
554
555 #[test]
556 fn test_stream_error_is_permanent() {
557 let openai_error = OpenAIError::StreamError("Stream failed".to_string());
559 let result = openai_error_to_language_model_error(openai_error);
560
561 match result {
563 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
565 }
566 }
567
568 #[test]
569 fn test_invalid_argument_is_permanent() {
570 let openai_error = OpenAIError::InvalidArgument("Invalid argument".to_string());
572 let result = openai_error_to_language_model_error(openai_error);
573
574 match result {
576 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
578 }
579 }
580}