swiftide_integrations/openai/
simple_prompt.rs

1//! This module provides an implementation of the `SimplePrompt` trait for the `OpenAI` struct.
2//! It defines an asynchronous function to interact with the `OpenAI` API, allowing prompt
3//! processing and generating responses as part of the Swiftide system.
4
5use async_openai::types::ChatCompletionRequestUserMessageArgs;
6use async_trait::async_trait;
7#[cfg(feature = "metrics")]
8use swiftide_core::metrics::emit_usage;
9use swiftide_core::{
10    SimplePrompt, chat_completion::errors::LanguageModelError, prompt::Prompt,
11    util::debug_long_utf8,
12};
13
14use crate::openai::openai_error_to_language_model_error;
15
16use super::GenericOpenAI;
17use anyhow::Result;
18
19/// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a
20/// response.
21#[async_trait]
22impl<
23    C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
24> SimplePrompt for GenericOpenAI<C>
25{
26    /// Sends a prompt to the OpenAI API and returns the response content.
27    ///
28    /// # Parameters
29    /// - `prompt`: A string slice that holds the prompt to be sent to the OpenAI API.
30    ///
31    /// # Returns
32    /// - `Result<String>`: On success, returns the content of the response as a `String`. On
33    ///   failure, returns an error wrapped in a `Result`.
34    ///
35    /// # Errors
36    /// - Returns an error if the model is not set in the default options.
37    /// - Returns an error if the request to the OpenAI API fails.
38    /// - Returns an error if the response does not contain the expected content.
39    #[tracing::instrument(skip_all, err)]
40    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
41        // Retrieve the model from the default options, returning an error if not set.
42        let model = self
43            .default_options
44            .prompt_model
45            .as_ref()
46            .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
47
48        // Build the request to be sent to the OpenAI API.
49        let request = self
50            .chat_completion_request_defaults()
51            .model(model)
52            .messages(vec![
53                ChatCompletionRequestUserMessageArgs::default()
54                    .content(prompt.render()?)
55                    .build()
56                    .map_err(LanguageModelError::permanent)?
57                    .into(),
58            ])
59            .build()
60            .map_err(LanguageModelError::permanent)?;
61
62        // Log the request for debugging purposes.
63        tracing::debug!(
64            model = &model,
65            messages = debug_long_utf8(
66                serde_json::to_string_pretty(&request.messages.first())
67                    .map_err(LanguageModelError::permanent)?,
68                100
69            ),
70            "[SimplePrompt] Request to openai"
71        );
72
73        // Send the request to the OpenAI API and await the response.
74        let mut response = self
75            .client
76            .chat()
77            .create(request)
78            .await
79            .map_err(openai_error_to_language_model_error)?;
80
81        let message = response
82            .choices
83            .remove(0)
84            .message
85            .content
86            .take()
87            .ok_or_else(|| {
88                LanguageModelError::PermanentError("Expected content in response".into())
89            })?;
90
91        #[cfg(feature = "metrics")]
92        {
93            if let Some(usage) = response.usage.as_ref() {
94                emit_usage(
95                    model,
96                    usage.prompt_tokens.into(),
97                    usage.completion_tokens.into(),
98                    usage.total_tokens.into(),
99                    self.metric_metadata.as_ref(),
100                );
101            } else {
102                tracing::warn!("Metrics enabled but no usage data found in response");
103            }
104        }
105
106        // Log the response for debugging purposes.
107        tracing::debug!(
108            message = debug_long_utf8(&message, 100),
109            "[SimplePrompt] Response from openai"
110        );
111
112        // Extract and return the content of the response, returning an error if not found.
113        Ok(message)
114    }
115}