swiftide_integrations/openai/
simple_prompt.rs

1//! This module provides an implementation of the `SimplePrompt` trait for the `OpenAI` struct.
2//! It defines an asynchronous function to interact with the `OpenAI` API, allowing prompt
3//! processing and generating responses as part of the Swiftide system.
4
5use async_openai::types::ChatCompletionRequestUserMessageArgs;
6use async_trait::async_trait;
7#[cfg(feature = "metrics")]
8use swiftide_core::metrics::emit_usage;
9use swiftide_core::{
10    SimplePrompt,
11    chat_completion::{Usage, errors::LanguageModelError},
12    prompt::Prompt,
13    util::debug_long_utf8,
14};
15
16use crate::openai::openai_error_to_language_model_error;
17
18use super::GenericOpenAI;
19use anyhow::Result;
20
21/// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a
22/// response.
23#[async_trait]
24impl<
25    C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
26> SimplePrompt for GenericOpenAI<C>
27{
28    /// Sends a prompt to the `OpenAI` API and returns the response content.
29    ///
30    /// # Parameters
31    /// - `prompt`: A string slice that holds the prompt to be sent to the `OpenAI` API.
32    ///
33    /// # Returns
34    /// - `Result<String>`: On success, returns the content of the response as a `String`. On
35    ///   failure, returns an error wrapped in a `Result`.
36    ///
37    /// # Errors
38    /// - Returns an error if the model is not set in the default options.
39    /// - Returns an error if the request to the `OpenAI` API fails.
40    /// - Returns an error if the response does not contain the expected content.
41    #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all, err))]
42    #[cfg_attr(
43        feature = "langfuse",
44        tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION"))
45    )]
46    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
47        // Retrieve the model from the default options, returning an error if not set.
48        let model = self
49            .default_options
50            .prompt_model
51            .as_ref()
52            .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
53
54        // Build the request to be sent to the OpenAI API.
55        let request = self
56            .chat_completion_request_defaults()
57            .model(model)
58            .messages(vec![
59                ChatCompletionRequestUserMessageArgs::default()
60                    .content(prompt.render()?)
61                    .build()
62                    .map_err(LanguageModelError::permanent)?
63                    .into(),
64            ])
65            .build()
66            .map_err(LanguageModelError::permanent)?;
67
68        // Log the request for debugging purposes.
69        tracing::trace!(
70            model = &model,
71            messages = debug_long_utf8(
72                serde_json::to_string_pretty(&request.messages.last())
73                    .map_err(LanguageModelError::permanent)?,
74                100
75            ),
76            "[SimplePrompt] Request to openai"
77        );
78
79        // Send the request to the OpenAI API and await the response.
80        let mut response = self
81            .client
82            .chat()
83            .create(request.clone())
84            .await
85            .map_err(openai_error_to_language_model_error)?;
86
87        if cfg!(feature = "langfuse") {
88            let usage = response.usage.clone().unwrap_or_default();
89            tracing::debug!(
90                langfuse.model = model,
91                langfuse.input = %serde_json::to_string_pretty(&request).unwrap_or_default(),
92                langfuse.output = %serde_json::to_string_pretty(&response).unwrap_or_default(),
93                langfuse.usage = %serde_json::to_string_pretty(&usage).unwrap_or_default(),
94            );
95        }
96
97        let message = response
98            .choices
99            .remove(0)
100            .message
101            .content
102            .take()
103            .ok_or_else(|| {
104                LanguageModelError::PermanentError("Expected content in response".into())
105            })?;
106
107        {
108            if let Some(usage) = response.usage.as_ref() {
109                if let Some(callback) = &self.on_usage {
110                    let usage = Usage {
111                        prompt_tokens: usage.prompt_tokens,
112                        completion_tokens: usage.completion_tokens,
113                        total_tokens: usage.total_tokens,
114                    };
115                    callback(&usage).await?;
116                }
117                #[cfg(feature = "metrics")]
118                emit_usage(
119                    model,
120                    usage.prompt_tokens.into(),
121                    usage.completion_tokens.into(),
122                    usage.total_tokens.into(),
123                    self.metric_metadata.as_ref(),
124                );
125            } else {
126                tracing::warn!("Metrics enabled but no usage data found in response");
127            }
128        }
129
130        // Emit Langfuse event with the response details.
131
132        // Extract and return the content of the response, returning an error if not found.
133        Ok(message)
134    }
135}