swiftide_integrations/openai/simple_prompt.rs
1//! This module provides an implementation of the `SimplePrompt` trait for the `OpenAI` struct.
2//! It defines an asynchronous function to interact with the `OpenAI` API, allowing prompt
3//! processing and generating responses as part of the Swiftide system.
4use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
5use async_trait::async_trait;
6use swiftide_core::{
7 chat_completion::errors::LanguageModelError, prompt::Prompt, util::debug_long_utf8,
8 SimplePrompt,
9};
10
11use crate::openai::openai_error_to_language_model_error;
12
13use super::GenericOpenAI;
14use anyhow::Result;
15
16/// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a
17/// response.
18#[async_trait]
19impl<C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug>
20 SimplePrompt for GenericOpenAI<C>
21{
22 /// Sends a prompt to the OpenAI API and returns the response content.
23 ///
24 /// # Parameters
25 /// - `prompt`: A string slice that holds the prompt to be sent to the OpenAI API.
26 ///
27 /// # Returns
28 /// - `Result<String>`: On success, returns the content of the response as a `String`. On
29 /// failure, returns an error wrapped in a `Result`.
30 ///
31 /// # Errors
32 /// - Returns an error if the model is not set in the default options.
33 /// - Returns an error if the request to the OpenAI API fails.
34 /// - Returns an error if the response does not contain the expected content.
35 #[tracing::instrument(skip_all, err)]
36 async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
37 // Retrieve the model from the default options, returning an error if not set.
38 let model = self
39 .default_options
40 .prompt_model
41 .as_ref()
42 .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
43
44 // Build the request to be sent to the OpenAI API.
45 let request = CreateChatCompletionRequestArgs::default()
46 .model(model)
47 .messages(vec![ChatCompletionRequestUserMessageArgs::default()
48 .content(prompt.render()?)
49 .build()
50 .map_err(LanguageModelError::permanent)?
51 .into()])
52 .build()
53 .map_err(LanguageModelError::permanent)?;
54
55 // Log the request for debugging purposes.
56 tracing::debug!(
57 model = &model,
58 messages = debug_long_utf8(
59 serde_json::to_string_pretty(&request.messages.first())
60 .map_err(LanguageModelError::permanent)?,
61 100
62 ),
63 "[SimplePrompt] Request to openai"
64 );
65
66 // Send the request to the OpenAI API and await the response.
67 let response = self
68 .client
69 .chat()
70 .create(request)
71 .await
72 .map_err(openai_error_to_language_model_error)?
73 .choices
74 .remove(0)
75 .message
76 .content
77 .take()
78 .ok_or_else(|| {
79 LanguageModelError::PermanentError("Expected content in response".into())
80 })?;
81
82 // Log the response for debugging purposes.
83 tracing::debug!(
84 response = debug_long_utf8(&response, 100),
85 "[SimplePrompt] Response from openai"
86 );
87
88 // Extract and return the content of the response, returning an error if not found.
89 Ok(response)
90 }
91}