swiftide_integrations/openai/simple_prompt.rs
1//! This module provides an implementation of the `SimplePrompt` trait for the `OpenAI` struct.
2//! It defines an asynchronous function to interact with the `OpenAI` API, allowing prompt
3//! processing and generating responses as part of the Swiftide system.
4use async_openai::types::ChatCompletionRequestUserMessageArgs;
5use async_trait::async_trait;
6use swiftide_core::{
7 SimplePrompt, chat_completion::errors::LanguageModelError, prompt::Prompt,
8 util::debug_long_utf8,
9};
10
11use crate::openai::openai_error_to_language_model_error;
12
13use super::GenericOpenAI;
14use anyhow::Result;
15
16/// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a
17/// response.
18#[async_trait]
19impl<C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug>
20 SimplePrompt for GenericOpenAI<C>
21{
22 /// Sends a prompt to the OpenAI API and returns the response content.
23 ///
24 /// # Parameters
25 /// - `prompt`: A string slice that holds the prompt to be sent to the OpenAI API.
26 ///
27 /// # Returns
28 /// - `Result<String>`: On success, returns the content of the response as a `String`. On
29 /// failure, returns an error wrapped in a `Result`.
30 ///
31 /// # Errors
32 /// - Returns an error if the model is not set in the default options.
33 /// - Returns an error if the request to the OpenAI API fails.
34 /// - Returns an error if the response does not contain the expected content.
35 #[tracing::instrument(skip_all, err)]
36 async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
37 // Retrieve the model from the default options, returning an error if not set.
38 let model = self
39 .default_options
40 .prompt_model
41 .as_ref()
42 .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
43
44 // Build the request to be sent to the OpenAI API.
45 let request = self
46 .chat_completion_request_defaults()
47 .model(model)
48 .messages(vec![
49 ChatCompletionRequestUserMessageArgs::default()
50 .content(prompt.render()?)
51 .build()
52 .map_err(LanguageModelError::permanent)?
53 .into(),
54 ])
55 .build()
56 .map_err(LanguageModelError::permanent)?;
57
58 // Log the request for debugging purposes.
59 tracing::debug!(
60 model = &model,
61 messages = debug_long_utf8(
62 serde_json::to_string_pretty(&request.messages.first())
63 .map_err(LanguageModelError::permanent)?,
64 100
65 ),
66 "[SimplePrompt] Request to openai"
67 );
68
69 // Send the request to the OpenAI API and await the response.
70 let response = self
71 .client
72 .chat()
73 .create(request)
74 .await
75 .map_err(openai_error_to_language_model_error)?
76 .choices
77 .remove(0)
78 .message
79 .content
80 .take()
81 .ok_or_else(|| {
82 LanguageModelError::PermanentError("Expected content in response".into())
83 })?;
84
85 // Log the response for debugging purposes.
86 tracing::debug!(
87 response = debug_long_utf8(&response, 100),
88 "[SimplePrompt] Response from openai"
89 );
90
91 // Extract and return the content of the response, returning an error if not found.
92 Ok(response)
93 }
94}