swiftide_integrations/openai/
simple_prompt.rs

1//! This module provides an implementation of the `SimplePrompt` trait for the `OpenAI` struct.
2//! It defines an asynchronous function to interact with the `OpenAI` API, allowing prompt
3//! processing and generating responses as part of the Swiftide system.
4use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
5use async_trait::async_trait;
6use swiftide_core::{
7    chat_completion::errors::LanguageModelError, prompt::Prompt, util::debug_long_utf8,
8    SimplePrompt,
9};
10
11use crate::openai::openai_error_to_language_model_error;
12
13use super::GenericOpenAI;
14use anyhow::Result;
15
16/// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a
17/// response.
18#[async_trait]
19impl<C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug>
20    SimplePrompt for GenericOpenAI<C>
21{
22    /// Sends a prompt to the OpenAI API and returns the response content.
23    ///
24    /// # Parameters
25    /// - `prompt`: A string slice that holds the prompt to be sent to the OpenAI API.
26    ///
27    /// # Returns
28    /// - `Result<String>`: On success, returns the content of the response as a `String`. On
29    ///   failure, returns an error wrapped in a `Result`.
30    ///
31    /// # Errors
32    /// - Returns an error if the model is not set in the default options.
33    /// - Returns an error if the request to the OpenAI API fails.
34    /// - Returns an error if the response does not contain the expected content.
35    #[tracing::instrument(skip_all, err)]
36    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
37        // Retrieve the model from the default options, returning an error if not set.
38        let model = self
39            .default_options
40            .prompt_model
41            .as_ref()
42            .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
43
44        // Build the request to be sent to the OpenAI API.
45        let request = CreateChatCompletionRequestArgs::default()
46            .model(model)
47            .messages(vec![ChatCompletionRequestUserMessageArgs::default()
48                .content(prompt.render()?)
49                .build()
50                .map_err(LanguageModelError::permanent)?
51                .into()])
52            .build()
53            .map_err(LanguageModelError::permanent)?;
54
55        // Log the request for debugging purposes.
56        tracing::debug!(
57            model = &model,
58            messages = debug_long_utf8(
59                serde_json::to_string_pretty(&request.messages.first())
60                    .map_err(LanguageModelError::permanent)?,
61                100
62            ),
63            "[SimplePrompt] Request to openai"
64        );
65
66        // Send the request to the OpenAI API and await the response.
67        let response = self
68            .client
69            .chat()
70            .create(request)
71            .await
72            .map_err(openai_error_to_language_model_error)?
73            .choices
74            .remove(0)
75            .message
76            .content
77            .take()
78            .ok_or_else(|| {
79                LanguageModelError::PermanentError("Expected content in response".into())
80            })?;
81
82        // Log the response for debugging purposes.
83        tracing::debug!(
84            response = debug_long_utf8(&response, 100),
85            "[SimplePrompt] Response from openai"
86        );
87
88        // Extract and return the content of the response, returning an error if not found.
89        Ok(response)
90    }
91}