rusty_openai/openai_api/
completion.rs

1use crate::{error_handling::OpenAIResult, openai::OpenAI, setters};
2use serde::Serialize;
3use serde_json::Value;
4
5/// Represents the response format for chat completions
6#[derive(Debug, Serialize)]
7#[serde(untagged)]
8pub enum ResponseFormat {
9    /// Simple JSON response format
10    Json {
11        #[serde(rename = "type")]
12        format_type: String,
13    },
14    /// JSON Schema response format with validation
15    JsonSchema {
16        #[serde(rename = "type")]
17        format_type: String,
18        json_schema: JsonSchema,
19        strict: bool,
20    },
21}
22
23/// Represents a JSON Schema for structured outputs
24#[derive(Debug, Serialize)]
25pub struct JsonSchema {
26    /// Name of the schema
27    pub name: String,
28    /// The actual schema definition
29    pub schema: Value,
30}
31
32/// [`CompletionsApi`] struct to interact with the chat completions endpoint of the API.
33pub struct CompletionsApi<'a>(pub(crate) &'a OpenAI<'a>);
34
35/// Struct representing a request for chat completions.
36#[derive(Default, Serialize)]
37pub struct ChatCompletionRequest {
38    /// Model name to be used for the chat completion
39    model: String,
40
41    /// History of messages in the conversation
42    messages: Vec<Value>,
43
44    /// Maximum number of tokens to generate
45    #[serde(skip_serializing_if = "Option::is_none")]
46    max_tokens: Option<u64>,
47
48    /// Sampling temperature
49    #[serde(skip_serializing_if = "Option::is_none")]
50    temperature: Option<f64>,
51
52    /// Nucleus sampling parameter
53    #[serde(skip_serializing_if = "Option::is_none")]
54    top_p: Option<f64>,
55
56    /// Number of completions to generate for each prompt
57    #[serde(skip_serializing_if = "Option::is_none")]
58    n: Option<u64>,
59
60    /// Whether to stream back partial progress
61    #[serde(skip_serializing_if = "Option::is_none")]
62    stream: Option<bool>,
63
64    /// Sequence to stop generation
65    #[serde(skip_serializing_if = "Option::is_none")]
66    stop: Option<Vec<String>>,
67
68    /// Presence penalty to apply
69    #[serde(skip_serializing_if = "Option::is_none")]
70    presence_penalty: Option<f64>,
71
72    /// Frequency penalty to apply
73    #[serde(skip_serializing_if = "Option::is_none")]
74    frequency_penalty: Option<f64>,
75
76    /// Bias for logits
77    #[serde(skip_serializing_if = "Option::is_none")]
78    logit_bias: Option<Value>,
79
80    /// User ID
81    #[serde(skip_serializing_if = "Option::is_none")]
82    user: Option<String>,
83
84    /// Response format for structured outputs
85    #[serde(skip_serializing_if = "Option::is_none")]
86    response_format: Option<ResponseFormat>,
87}
88
89impl ChatCompletionRequest {
90    /// Create a new instance of [`ChatCompletionRequest`].
91    #[inline(always)]
92    pub fn new(model: String, messages: Vec<Value>) -> Self {
93        Self {
94            model,
95            messages,
96            ..Default::default()
97        }
98    }
99
100    /// Create a new instance with JSON response format
101    pub fn new_json(model: String, messages: Vec<Value>) -> Self {
102        Self {
103            model,
104            messages,
105            response_format: Some(ResponseFormat::Json {
106                format_type: "json".to_string(),
107            }),
108            ..Default::default()
109        }
110    }
111
112    /// Create a new instance with JSON Schema response format
113    pub fn new_json_schema(model: String, messages: Vec<Value>, schema_name: String, schema: Value) -> Self {
114        Self {
115            model,
116            messages,
117            response_format: Some(ResponseFormat::JsonSchema {
118                format_type: "json_schema".to_string(),
119                json_schema: JsonSchema {
120                    name: schema_name,
121                    schema,
122                },
123                strict: true,
124            }),
125            ..Default::default()
126        }
127    }
128
129    // Fluent setter methods to set each option on the request.
130
131    setters! {
132        max_tokens: u64,
133        temperature: f64,
134        top_p: f64,
135        n: u64,
136        stream: bool,
137        stop: Vec<String>,
138        presence_penalty: f64,
139        frequency_penalty: f64,
140        logit_bias: Value,
141        user: String,
142        response_format: ResponseFormat,
143    }
144}
145
146impl<'a> CompletionsApi<'a> {
147    /// Create a chat completion using the provided request parameters.
148    ///
149    /// # Arguments
150    ///
151    /// * `request` - A [`ChatCompletionRequest`] containing the parameters for the completion.
152    ///
153    /// # Returns
154    ///
155    /// A Result containing the JSON response as [`serde_json::Value`] on success, or an [`OpenAIError`][crate::error_handling::OpenAIError] on failure.
156    pub async fn create(&self, request: ChatCompletionRequest) -> OpenAIResult<Value> {
157        // Send a POST request to the chat completions endpoint with the request body.
158        self.0.post_json("/chat/completions", &request).await
159    }
160}