swiftide_integrations/openai/
mod.rs1use anyhow::Context as _;
7use async_openai::error::{OpenAIError, StreamError};
8use async_openai::types::CreateChatCompletionRequestArgs;
9use async_openai::types::CreateEmbeddingRequestArgs;
10use async_openai::types::ReasoningEffort;
11use derive_builder::Builder;
12use reqwest::StatusCode;
13use reqwest_eventsource::Error as EventSourceError;
14use serde_json::Value;
15use std::pin::Pin;
16use std::sync::Arc;
17use swiftide_core::chat_completion::Usage;
18use swiftide_core::chat_completion::errors::LanguageModelError;
19
20mod chat_completion;
21mod embed;
22mod responses_api;
23mod simple_prompt;
24mod structured_prompt;
25
26pub use async_openai::config::AzureConfig;
28pub use async_openai::config::OpenAIConfig;
29
30#[cfg(feature = "tiktoken")]
31use crate::tiktoken::TikToken;
32#[cfg(feature = "tiktoken")]
33use anyhow::Result;
34#[cfg(feature = "tiktoken")]
35use swiftide_core::Estimatable;
36#[cfg(feature = "tiktoken")]
37use swiftide_core::EstimateTokens;
38
39pub type OpenAI = GenericOpenAI<OpenAIConfig>;
74pub type OpenAIBuilder = GenericOpenAIBuilder<OpenAIConfig>;
75
76#[derive(Builder, Clone)]
77#[builder(setter(into, strip_option))]
78pub struct GenericOpenAI<
80 C: async_openai::config::Config + Default = async_openai::config::OpenAIConfig,
81> {
82 #[builder(
85 default = "Arc::new(async_openai::Client::<C>::default())",
86 setter(custom)
87 )]
88 client: Arc<async_openai::Client<C>>,
89
90 #[builder(default, setter(custom))]
92 pub(crate) default_options: Options,
93
94 #[cfg(feature = "tiktoken")]
95 #[cfg_attr(feature = "tiktoken", builder(default))]
96 pub(crate) tiktoken: TikToken,
97
98 #[builder(default = true)]
102 pub stream_full: bool,
103
104 #[cfg(feature = "metrics")]
105 #[builder(default)]
106 metric_metadata: Option<std::collections::HashMap<String, String>>,
108
109 #[builder(default)]
111 pub(crate) use_responses_api: bool,
112
113 #[builder(default, setter(custom))]
115 #[allow(clippy::type_complexity)]
116 on_usage: Option<
117 Arc<
118 dyn for<'a> Fn(
119 &'a Usage,
120 ) -> Pin<
121 Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>,
122 > + Send
123 + Sync,
124 >,
125 >,
126}
127
128impl<C: async_openai::config::Config + Default + std::fmt::Debug> std::fmt::Debug
129 for GenericOpenAI<C>
130{
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 f.debug_struct("GenericOpenAI")
133 .field("client", &self.client)
134 .field("default_options", &self.default_options)
135 .field("stream_full", &self.stream_full)
136 .field("use_responses_api", &self.use_responses_api)
137 .finish_non_exhaustive()
138 }
139}
140
141#[derive(Debug, Clone, Builder, Default)]
144#[builder(setter(strip_option))]
145pub struct Options {
146 #[builder(default, setter(into))]
148 pub embed_model: Option<String>,
149 #[builder(default, setter(into))]
151 pub prompt_model: Option<String>,
152
153 #[builder(default)]
154 pub parallel_tool_calls: Option<bool>,
158
159 #[builder(default)]
163 pub max_completion_tokens: Option<u32>,
164
165 #[builder(default)]
167 pub temperature: Option<f32>,
168
169 #[builder(default, setter(into))]
171 pub reasoning_effort: Option<ReasoningEffort>,
172
173 #[builder(default)]
178 pub seed: Option<i64>,
179
180 #[builder(default)]
183 pub presence_penalty: Option<f32>,
184
185 #[builder(default, setter(into))]
187 pub metadata: Option<serde_json::Value>,
188
189 #[builder(default, setter(into))]
192 pub user: Option<String>,
193
194 #[builder(default)]
195 pub dimensions: Option<u32>,
198}
199
200impl Options {
201 pub fn builder() -> OptionsBuilder {
203 OptionsBuilder::default()
204 }
205
206 pub fn merge(&mut self, other: &Options) {
208 if let Some(embed_model) = &other.embed_model {
209 self.embed_model = Some(embed_model.clone());
210 }
211 if let Some(prompt_model) = &other.prompt_model {
212 self.prompt_model = Some(prompt_model.clone());
213 }
214 if let Some(parallel_tool_calls) = other.parallel_tool_calls {
215 self.parallel_tool_calls = Some(parallel_tool_calls);
216 }
217 if let Some(max_completion_tokens) = other.max_completion_tokens {
218 self.max_completion_tokens = Some(max_completion_tokens);
219 }
220 if let Some(temperature) = other.temperature {
221 self.temperature = Some(temperature);
222 }
223 if let Some(reasoning_effort) = &other.reasoning_effort {
224 self.reasoning_effort = Some(reasoning_effort.clone());
225 }
226 if let Some(seed) = other.seed {
227 self.seed = Some(seed);
228 }
229 if let Some(presence_penalty) = other.presence_penalty {
230 self.presence_penalty = Some(presence_penalty);
231 }
232 if let Some(metadata) = &other.metadata {
233 self.metadata = Some(metadata.clone());
234 }
235 if let Some(user) = &other.user {
236 self.user = Some(user.clone());
237 }
238 }
239}
240
241impl From<OptionsBuilder> for Options {
242 fn from(value: OptionsBuilder) -> Self {
243 Self {
244 embed_model: value.embed_model.flatten(),
245 prompt_model: value.prompt_model.flatten(),
246 parallel_tool_calls: value.parallel_tool_calls.flatten(),
247 max_completion_tokens: value.max_completion_tokens.flatten(),
248 temperature: value.temperature.flatten(),
249 reasoning_effort: value.reasoning_effort.flatten(),
250 presence_penalty: value.presence_penalty.flatten(),
251 seed: value.seed.flatten(),
252 metadata: value.metadata.flatten(),
253 user: value.user.flatten(),
254 dimensions: value.dimensions.flatten(),
255 }
256 }
257}
258
259impl From<&mut OptionsBuilder> for Options {
260 fn from(value: &mut OptionsBuilder) -> Self {
261 let value = value.clone();
262 Self {
263 embed_model: value.embed_model.flatten(),
264 prompt_model: value.prompt_model.flatten(),
265 parallel_tool_calls: value.parallel_tool_calls.flatten(),
266 max_completion_tokens: value.max_completion_tokens.flatten(),
267 temperature: value.temperature.flatten(),
268 reasoning_effort: value.reasoning_effort.flatten(),
269 presence_penalty: value.presence_penalty.flatten(),
270 seed: value.seed.flatten(),
271 metadata: value.metadata.flatten(),
272 user: value.user.flatten(),
273 dimensions: value.dimensions.flatten(),
274 }
275 }
276}
277
278pub(crate) fn ensure_tool_schema_additional_properties_false(
279 parameters: &mut Value,
280) -> anyhow::Result<()> {
281 let object = parameters
282 .as_object_mut()
283 .context("tool schema must be a JSON object")?;
284
285 object.insert("additionalProperties".to_string(), Value::Bool(false));
286
287 Ok(())
288}
289
290pub(crate) fn ensure_tool_schema_required_matches_properties(
291 parameters: &mut Value,
292) -> anyhow::Result<()> {
293 let object = parameters
294 .as_object_mut()
295 .context("tool schema must be a JSON object")?;
296
297 let property_names: Vec<String> = if let Some(Value::Object(map)) = object.get("properties") {
298 map.keys().cloned().collect()
299 } else {
300 object
301 .entry("required".to_string())
302 .or_insert_with(|| Value::Array(Vec::new()));
303 return Ok(());
304 };
305
306 let required_entry = object
307 .entry("required".to_string())
308 .or_insert_with(|| Value::Array(Vec::new()));
309
310 let required_array = required_entry
311 .as_array_mut()
312 .context("tool schema 'required' must be an array")?;
313
314 for name in property_names {
315 let name_ref = name.as_str();
316 let already_present = required_array
317 .iter()
318 .any(|value| value.as_str().is_some_and(|s| s == name_ref));
319
320 if !already_present {
321 required_array.push(Value::String(name));
322 }
323 }
324
325 Ok(())
326}
327
328impl OpenAI {
329 pub fn builder() -> OpenAIBuilder {
331 OpenAIBuilder::default()
332 }
333}
334
335impl<C: async_openai::config::Config + Default + Sync + Send + std::fmt::Debug>
336 GenericOpenAIBuilder<C>
337{
338 pub fn on_usage<F>(&mut self, func: F) -> &mut Self
340 where
341 F: Fn(&Usage) -> anyhow::Result<()> + Send + Sync + 'static,
342 {
343 let func = Arc::new(func);
344 self.on_usage = Some(Some(Arc::new(move |usage: &Usage| {
345 let func = func.clone();
346 Box::pin(async move { func(usage) })
347 })));
348
349 self
350 }
351
352 pub fn on_usage_async<F>(&mut self, func: F) -> &mut Self
355 where
356 F: for<'a> Fn(
357 &'a Usage,
358 )
359 -> Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>>
360 + Send
361 + Sync
362 + 'static,
363 {
364 let func = Arc::new(func);
365 self.on_usage = Some(Some(Arc::new(move |usage: &Usage| {
366 let func = func.clone();
367 Box::pin(async move { func(usage).await })
368 })));
369
370 self
371 }
372 pub fn client(&mut self, client: async_openai::Client<C>) -> &mut Self {
380 self.client = Some(Arc::new(client));
381 self
382 }
383
384 pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
392 if let Some(options) = self.default_options.as_mut() {
393 options.embed_model = Some(model.into());
394 } else {
395 self.default_options = Some(Options {
396 embed_model: Some(model.into()),
397 ..Default::default()
398 });
399 }
400 self
401 }
402
403 pub fn for_end_user(&mut self, user: impl Into<String>) -> &mut Self {
405 if let Some(options) = self.default_options.as_mut() {
406 options.user = Some(user.into());
407 } else {
408 self.default_options = Some(Options {
409 user: Some(user.into()),
410 ..Default::default()
411 });
412 }
413 self
414 }
415
416 pub fn parallel_tool_calls(&mut self, parallel_tool_calls: Option<bool>) -> &mut Self {
422 if let Some(options) = self.default_options.as_mut() {
423 options.parallel_tool_calls = parallel_tool_calls;
424 } else {
425 self.default_options = Some(Options {
426 parallel_tool_calls,
427 ..Default::default()
428 });
429 }
430 self
431 }
432
433 pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
441 if let Some(options) = self.default_options.as_mut() {
442 options.prompt_model = Some(model.into());
443 } else {
444 self.default_options = Some(Options {
445 prompt_model: Some(model.into()),
446 ..Default::default()
447 });
448 }
449 self
450 }
451
452 pub fn default_options(&mut self, options: impl Into<Options>) -> &mut Self {
456 if let Some(existing_options) = self.default_options.as_mut() {
457 existing_options.merge(&options.into());
458 } else {
459 self.default_options = Some(options.into());
460 }
461 self
462 }
463}
464
465impl<C: async_openai::config::Config + Default> GenericOpenAI<C> {
466 #[cfg(feature = "tiktoken")]
474 pub async fn estimate_tokens(&self, value: impl Estimatable) -> Result<usize> {
475 self.tiktoken.estimate(value).await
476 }
477
478 pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
479 self.default_options = Options {
480 prompt_model: Some(model.into()),
481 ..self.default_options.clone()
482 };
483 self
484 }
485
486 pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
487 self.default_options = Options {
488 embed_model: Some(model.into()),
489 ..self.default_options.clone()
490 };
491 self
492 }
493
494 pub fn client(&self) -> &Arc<async_openai::Client<C>> {
496 &self.client
497 }
498
499 pub fn options(&self) -> &Options {
501 &self.default_options
502 }
503
504 pub fn options_mut(&mut self) -> &mut Options {
506 &mut self.default_options
507 }
508
509 pub fn is_responses_api_enabled(&self) -> bool {
511 self.use_responses_api
512 }
513
514 fn chat_completion_request_defaults(&self) -> CreateChatCompletionRequestArgs {
515 let mut args = CreateChatCompletionRequestArgs::default();
516
517 let options = &self.default_options;
518
519 if let Some(parallel_tool_calls) = options.parallel_tool_calls {
520 args.parallel_tool_calls(parallel_tool_calls);
521 }
522
523 if let Some(max_tokens) = options.max_completion_tokens {
524 args.max_completion_tokens(max_tokens);
525 }
526
527 if let Some(temperature) = options.temperature {
528 args.temperature(temperature);
529 }
530
531 if let Some(reasoning_effort) = &options.reasoning_effort {
532 args.reasoning_effort(reasoning_effort.clone());
533 }
534
535 if let Some(seed) = options.seed {
536 args.seed(seed);
537 }
538
539 if let Some(presence_penalty) = options.presence_penalty {
540 args.presence_penalty(presence_penalty);
541 }
542
543 if let Some(metadata) = &options.metadata {
544 args.metadata(metadata.clone());
545 }
546
547 if let Some(user) = &options.user {
548 args.user(user.clone());
549 }
550
551 args
552 }
553
554 fn embed_request_defaults(&self) -> CreateEmbeddingRequestArgs {
555 let mut args = CreateEmbeddingRequestArgs::default();
556
557 let options = &self.default_options;
558
559 if let Some(user) = &options.user {
560 args.user(user.clone());
561 }
562
563 if let Some(dimensions) = options.dimensions {
564 args.dimensions(dimensions);
565 }
566
567 args
568 }
569}
570
571pub fn openai_error_to_language_model_error(e: OpenAIError) -> LanguageModelError {
572 match e {
573 OpenAIError::ApiError(api_error) => {
574 if api_error.code == Some("context_length_exceeded".to_string()) {
576 LanguageModelError::context_length_exceeded(OpenAIError::ApiError(api_error))
577 } else {
578 LanguageModelError::permanent(OpenAIError::ApiError(api_error))
579 }
580 }
581 OpenAIError::Reqwest(e) => {
582 LanguageModelError::transient(e)
585 }
586 OpenAIError::JSONDeserialize(_, _) => {
587 LanguageModelError::transient(e)
590 }
591 OpenAIError::StreamError(stream_error) => {
592 if is_rate_limited_stream_error(&stream_error) {
596 LanguageModelError::transient(OpenAIError::StreamError(stream_error))
597 } else {
598 LanguageModelError::permanent(OpenAIError::StreamError(stream_error))
599 }
600 }
601 OpenAIError::FileSaveError(_)
602 | OpenAIError::FileReadError(_)
603 | OpenAIError::InvalidArgument(_) => LanguageModelError::permanent(e),
604 }
605}
606
607fn is_rate_limited_stream_error(error: &StreamError) -> bool {
608 match error {
609 StreamError::ReqwestEventSource(inner) => match inner {
610 EventSourceError::InvalidStatusCode(status, _) => {
611 *status == StatusCode::TOO_MANY_REQUESTS
612 }
613 EventSourceError::Transport(source) => {
614 source.status() == Some(StatusCode::TOO_MANY_REQUESTS)
615 }
616 _ => false,
617 },
618 StreamError::UnknownEvent(_) => false,
619 }
620}
621
622#[cfg(test)]
623mod test {
624 use super::*;
625 use async_openai::error::{ApiError, OpenAIError, StreamError};
626 use eventsource_stream::Event;
627
628 #[test]
630 fn test_default_embed_and_prompt_model() {
631 let openai: OpenAI = OpenAI::builder()
632 .default_embed_model("gpt-3")
633 .default_prompt_model("gpt-4")
634 .build()
635 .unwrap();
636 assert_eq!(
637 openai.default_options.embed_model,
638 Some("gpt-3".to_string())
639 );
640 assert_eq!(
641 openai.default_options.prompt_model,
642 Some("gpt-4".to_string())
643 );
644
645 let openai: OpenAI = OpenAI::builder()
646 .default_prompt_model("gpt-4")
647 .default_embed_model("gpt-3")
648 .build()
649 .unwrap();
650 assert_eq!(
651 openai.default_options.prompt_model,
652 Some("gpt-4".to_string())
653 );
654 assert_eq!(
655 openai.default_options.embed_model,
656 Some("gpt-3".to_string())
657 );
658 }
659
660 #[test]
661 fn test_use_responses_api_flag() {
662 let openai: OpenAI = OpenAI::builder().use_responses_api(true).build().unwrap();
663
664 assert!(openai.is_responses_api_enabled());
665 }
666
667 #[test]
668 fn test_context_length_exceeded_error() {
669 let api_error = ApiError {
671 message: "This model's maximum context length is 8192 tokens".to_string(),
672 r#type: Some("invalid_request_error".to_string()),
673 param: Some("messages".to_string()),
674 code: Some("context_length_exceeded".to_string()),
675 };
676
677 let openai_error = OpenAIError::ApiError(api_error);
678 let result = openai_error_to_language_model_error(openai_error);
679
680 match result {
682 LanguageModelError::ContextLengthExceeded(_) => {} _ => panic!("Expected ContextLengthExceeded error, got {result:?}"),
684 }
685 }
686
687 #[test]
688 fn test_api_error_permanent() {
689 let api_error = ApiError {
691 message: "Invalid API key".to_string(),
692 r#type: Some("invalid_request_error".to_string()),
693 param: Some("api_key".to_string()),
694 code: Some("invalid_api_key".to_string()),
695 };
696
697 let openai_error = OpenAIError::ApiError(api_error);
698 let result = openai_error_to_language_model_error(openai_error);
699
700 match result {
702 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
704 }
705 }
706
707 #[test]
708 fn test_file_save_error_is_permanent() {
709 let openai_error = OpenAIError::FileSaveError("Failed to save file".to_string());
711 let result = openai_error_to_language_model_error(openai_error);
712
713 match result {
715 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
717 }
718 }
719
720 #[test]
721 fn test_file_read_error_is_permanent() {
722 let openai_error = OpenAIError::FileReadError("Failed to read file".to_string());
724 let result = openai_error_to_language_model_error(openai_error);
725
726 match result {
728 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
730 }
731 }
732
733 #[test]
734 fn test_stream_error_is_permanent() {
735 let openai_error = OpenAIError::StreamError(StreamError::UnknownEvent(Event::default()));
737 let result = openai_error_to_language_model_error(openai_error);
738
739 match result {
741 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
743 }
744 }
745
746 #[test]
747 fn test_invalid_argument_is_permanent() {
748 let openai_error = OpenAIError::InvalidArgument("Invalid argument".to_string());
750 let result = openai_error_to_language_model_error(openai_error);
751
752 match result {
754 LanguageModelError::PermanentError(_) => {} _ => panic!("Expected PermanentError, got {result:?}"),
756 }
757 }
758}