swiftide_integrations/openai/simple_prompt.rs
1//! This module provides an implementation of the `SimplePrompt` trait for the `OpenAI` struct.
2//! It defines an asynchronous function to interact with the `OpenAI` API, allowing prompt
3//! processing and generating responses as part of the Swiftide system.
4use async_openai::types::ChatCompletionRequestUserMessageArgs;
5use async_trait::async_trait;
6use swiftide_core::{
7    SimplePrompt, chat_completion::errors::LanguageModelError, prompt::Prompt,
8    util::debug_long_utf8,
9};
10
11use crate::openai::openai_error_to_language_model_error;
12
13use super::GenericOpenAI;
14use anyhow::Result;
15
16/// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a
17/// response.
18#[async_trait]
19impl<C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug>
20    SimplePrompt for GenericOpenAI<C>
21{
22    /// Sends a prompt to the OpenAI API and returns the response content.
23    ///
24    /// # Parameters
25    /// - `prompt`: A string slice that holds the prompt to be sent to the OpenAI API.
26    ///
27    /// # Returns
28    /// - `Result<String>`: On success, returns the content of the response as a `String`. On
29    ///   failure, returns an error wrapped in a `Result`.
30    ///
31    /// # Errors
32    /// - Returns an error if the model is not set in the default options.
33    /// - Returns an error if the request to the OpenAI API fails.
34    /// - Returns an error if the response does not contain the expected content.
35    #[tracing::instrument(skip_all, err)]
36    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
37        // Retrieve the model from the default options, returning an error if not set.
38        let model = self
39            .default_options
40            .prompt_model
41            .as_ref()
42            .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
43
44        // Build the request to be sent to the OpenAI API.
45        let request = self
46            .chat_completion_request_defaults()
47            .model(model)
48            .messages(vec![
49                ChatCompletionRequestUserMessageArgs::default()
50                    .content(prompt.render()?)
51                    .build()
52                    .map_err(LanguageModelError::permanent)?
53                    .into(),
54            ])
55            .build()
56            .map_err(LanguageModelError::permanent)?;
57
58        // Log the request for debugging purposes.
59        tracing::debug!(
60            model = &model,
61            messages = debug_long_utf8(
62                serde_json::to_string_pretty(&request.messages.first())
63                    .map_err(LanguageModelError::permanent)?,
64                100
65            ),
66            "[SimplePrompt] Request to openai"
67        );
68
69        // Send the request to the OpenAI API and await the response.
70        let response = self
71            .client
72            .chat()
73            .create(request)
74            .await
75            .map_err(openai_error_to_language_model_error)?
76            .choices
77            .remove(0)
78            .message
79            .content
80            .take()
81            .ok_or_else(|| {
82                LanguageModelError::PermanentError("Expected content in response".into())
83            })?;
84
85        // Log the response for debugging purposes.
86        tracing::debug!(
87            response = debug_long_utf8(&response, 100),
88            "[SimplePrompt] Response from openai"
89        );
90
91        // Extract and return the content of the response, returning an error if not found.
92        Ok(response)
93    }
94}