swiftide_integrations/openai/
simple_prompt.rs

1//! This module provides an implementation of the `SimplePrompt` trait for the `OpenAI` struct.
2//! It defines an asynchronous function to interact with the `OpenAI` API, allowing prompt
3//! processing and generating responses as part of the Swiftide system.
4
5use async_openai::types::ChatCompletionRequestUserMessageArgs;
6use async_trait::async_trait;
7#[cfg(feature = "metrics")]
8use swiftide_core::metrics::emit_usage;
9use swiftide_core::{
10    SimplePrompt,
11    chat_completion::{Usage, errors::LanguageModelError},
12    prompt::Prompt,
13    util::debug_long_utf8,
14};
15
16use crate::openai::openai_error_to_language_model_error;
17
18use super::GenericOpenAI;
19use anyhow::Result;
20
21/// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a
22/// response.
23#[async_trait]
24impl<
25    C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
26> SimplePrompt for GenericOpenAI<C>
27{
28    /// Sends a prompt to the `OpenAI` API and returns the response content.
29    ///
30    /// # Parameters
31    /// - `prompt`: A string slice that holds the prompt to be sent to the `OpenAI` API.
32    ///
33    /// # Returns
34    /// - `Result<String>`: On success, returns the content of the response as a `String`. On
35    ///   failure, returns an error wrapped in a `Result`.
36    ///
37    /// # Errors
38    /// - Returns an error if the model is not set in the default options.
39    /// - Returns an error if the request to the `OpenAI` API fails.
40    /// - Returns an error if the response does not contain the expected content.
41    #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all, err))]
42    #[cfg_attr(
43        feature = "langfuse",
44        tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION"))
45    )]
46    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
47        // Retrieve the model from the default options, returning an error if not set.
48        let model = self
49            .default_options
50            .prompt_model
51            .as_ref()
52            .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
53
54        // Build the request to be sent to the OpenAI API.
55        let request = self
56            .chat_completion_request_defaults()
57            .model(model)
58            .messages(vec![
59                ChatCompletionRequestUserMessageArgs::default()
60                    .content(prompt.render()?)
61                    .build()
62                    .map_err(LanguageModelError::permanent)?
63                    .into(),
64            ])
65            .build()
66            .map_err(LanguageModelError::permanent)?;
67
68        // Log the request for debugging purposes.
69        tracing::trace!(
70            model = &model,
71            messages = debug_long_utf8(
72                serde_json::to_string_pretty(&request.messages.last())
73                    .map_err(LanguageModelError::permanent)?,
74                100
75            ),
76            "[SimplePrompt] Request to openai"
77        );
78
79        // Send the request to the OpenAI API and await the response.
80        let mut response = self
81            .client
82            .chat()
83            .create(request.clone())
84            .await
85            .map_err(openai_error_to_language_model_error)?;
86
87        let message = response
88            .choices
89            .remove(0)
90            .message
91            .content
92            .take()
93            .ok_or_else(|| {
94                LanguageModelError::PermanentError("Expected content in response".into())
95            })?;
96
97        {
98            if let Some(usage) = response.usage.as_ref() {
99                let usage = Usage {
100                    prompt_tokens: usage.prompt_tokens,
101                    completion_tokens: usage.completion_tokens,
102                    total_tokens: usage.total_tokens,
103                };
104                if let Some(callback) = &self.on_usage {
105                    callback(&usage).await?;
106                }
107                if cfg!(feature = "langfuse") {
108                    tracing::debug!(
109                        langfuse.model = model,
110                        langfuse.input = %serde_json::to_string_pretty(&request).unwrap_or_default(),
111                        langfuse.output = %serde_json::to_string_pretty(&response).unwrap_or_default(),
112                        langfuse.usage = %serde_json::to_string_pretty(&usage).unwrap_or_default(),
113                    );
114                }
115                #[cfg(feature = "metrics")]
116                emit_usage(
117                    model,
118                    usage.prompt_tokens.into(),
119                    usage.completion_tokens.into(),
120                    usage.total_tokens.into(),
121                    self.metric_metadata.as_ref(),
122                );
123            } else {
124                tracing::warn!("Metrics enabled but no usage data found in response");
125            }
126        }
127
128        Ok(message)
129    }
130}