transformrs/
chat.rs

1use crate::request_headers;
2use crate::Key;
3use crate::Message;
4use crate::Provider;
5use async_stream::stream;
6use bytes::Bytes;
7use futures::Stream;
8use futures::StreamExt;
9use reqwest;
10use reqwest::Response;
11use serde::Deserialize;
12use serde::Serialize;
13use serde_json::Value;
14use std::error::Error;
15use std::pin::Pin;
16
17fn address(provider: &Provider) -> String {
18    let base_url = crate::openai_base_url(provider);
19    format!("{}/chat/completions", base_url)
20}
21
22async fn request_chat_completion(
23    provider: &Provider,
24    key: &Key,
25    model: &str,
26    stream: bool,
27    messages: &[Message],
28) -> Result<Response, Box<dyn Error + Send + Sync>> {
29    let address = address(provider);
30    let body = serde_json::json!({
31        "model": model,
32        "messages": messages,
33        "stream": stream,
34    });
35    let client = if provider == &Provider::Google {
36        // Without this, the request will fail with 400 INVALID_ARGUMENT.
37        // According to the docs, a 400 error is returned when the request body
38        // is malformed.  Why rustls tls fixes this, I do not know.
39        reqwest::Client::builder().use_rustls_tls().build()?
40    } else {
41        reqwest::Client::new()
42    };
43    tracing::debug!("Requesting chat: {body}");
44    let resp = client
45        .post(address)
46        .headers(request_headers(key)?)
47        .json(&body)
48        .send()
49        .await?;
50    Ok(resp)
51}
52
53#[derive(Debug, Serialize, Deserialize)]
54pub struct Choice {
55    pub index: u64,
56    pub message: Message,
57    pub logprobs: Option<String>,
58    pub finish_reason: Option<String>,
59}
60
61#[derive(Debug, Serialize, Deserialize)]
62pub struct Usage {
63    pub prompt_tokens: u64,
64    pub completion_tokens: u64,
65    pub total_tokens: u64,
66}
67
68#[derive(Debug, Serialize, Deserialize)]
69pub struct ChatCompletion {
70    pub id: Option<String>,
71    pub object: String,
72    pub created: u64,
73    pub model: String,
74    pub system_fingerprint: Option<String>,
75    pub choices: Vec<Choice>,
76    pub service_tier: Option<String>,
77    pub usage: Usage,
78}
79
80#[derive(Debug, Serialize, Deserialize)]
81pub struct ChatCompletionError {
82    pub object: Option<String>,
83    pub message: String,
84}
85
86fn extract_error(body: &Value) -> String {
87    if let Some(error) = body.get("error") {
88        if let Some(message) = error.get("message") {
89            return message
90                .as_str()
91                .unwrap_or(body.to_string().as_str())
92                .to_string();
93        }
94    }
95    if let Some(message) = body.get("message") {
96        return message
97            .as_str()
98            .unwrap_or(body.to_string().as_str())
99            .to_string();
100    }
101    format!("Unknown error: {body}")
102}
103
104/// Response from the OpenAI API.
105///
106/// This is a wrapper around the `serde_json::Value` which can either be
107/// extracted as a structured object or left as a raw value. Allowing clients to
108/// extract the unstructured response is done to allow for access to fields that
109/// might be added in the future, handle edge cases, or custom processing.
110///
111/// You might think while reading, why not keep it simple and that's a good
112/// point.  DeepSeek R1 had a great observation about this: "API design is like
113/// walking the tightrope. The challenge is to build constraints that empower,
114/// not confine."
115pub struct ChatCompletionResponse {
116    status: u16,
117    resp: Bytes,
118}
119
120impl ChatCompletionResponse {
121    pub fn bytes(&self) -> &Bytes {
122        &self.resp
123    }
124    pub fn raw_value(&self) -> Result<Value, Box<dyn Error + Send + Sync>> {
125        Ok(serde_json::from_slice::<Value>(&self.resp)?)
126    }
127    pub fn structured(&self) -> Result<ChatCompletion, Box<dyn Error + Send + Sync>> {
128        let json = self.raw_value()?;
129        let text = json.to_string();
130        if text.is_empty() {
131            return Err(
132                format!("Received empty response with status code: {}", self.status).into(),
133            );
134        }
135        let json = match serde_json::from_str::<ChatCompletion>(&text) {
136            Ok(json) => json,
137            Err(_e) => match serde_json::from_str::<Value>(&text) {
138                Ok(error) => return Err(extract_error(&error).into()),
139                Err(e) => {
140                    return Err(format!("Error parsing response: {} in text: '{}'", e, text).into())
141                }
142            },
143        };
144        Ok(json)
145    }
146}
147
148pub async fn chat_completion(
149    provider: &Provider,
150    key: &Key,
151    model: &str,
152    messages: &[Message],
153) -> Result<ChatCompletionResponse, Box<dyn Error + Send + Sync>> {
154    let stream = false;
155    let resp = request_chat_completion(provider, key, model, stream, messages).await?;
156    let status = resp.status();
157    let chat_completion_response = ChatCompletionResponse {
158        status: status.into(),
159        resp: resp.bytes().await?,
160    };
161    Ok(chat_completion_response)
162}
163
164#[derive(Debug, Serialize, Deserialize)]
165pub struct Delta {
166    pub role: Option<String>,
167    pub content: Option<String>,
168}
169
170#[derive(Debug, Serialize, Deserialize)]
171pub struct ChunkChoice {
172    pub index: u64,
173    pub delta: Delta,
174    pub finish_reason: Option<String>,
175}
176
177#[derive(Debug, Serialize, Deserialize)]
178pub struct ChatCompletionChunk {
179    pub id: Option<String>,
180    pub object: String,
181    pub created: u64,
182    pub model: String,
183    pub system_fingerprint: Option<String>,
184    pub choices: Vec<ChunkChoice>,
185}
186
187fn process_line(line: &str) -> Option<ChatCompletionChunk> {
188    if line.is_empty() {
189        return None;
190    }
191
192    if let Some(json_str) = line.strip_prefix("data: ") {
193        if json_str == "[DONE]" {
194            return None;
195        }
196        match serde_json::from_str::<ChatCompletionChunk>(json_str) {
197            Ok(chunk) => Some(chunk),
198            Err(_) => None,
199        }
200    } else {
201        None
202    }
203}
204
205pub async fn stream_chat_completion(
206    provider: &Provider,
207    key: &Key,
208    model: &str,
209    messages: &[Message],
210) -> Result<Pin<Box<dyn Stream<Item = ChatCompletionChunk> + Send>>, Box<dyn Error + Send + Sync>> {
211    let resp = request_chat_completion(provider, key, model, true, messages).await?;
212
213    let stream = stream! {
214        let mut buffer = String::new();
215        let mut byte_stream = resp.bytes_stream();
216
217        while let Some(chunk) = byte_stream.next().await {
218            let chunk = match chunk {
219                Ok(c) => c,
220                Err(_) => break,
221            };
222
223            let mut current_text = String::from_utf8_lossy(&chunk).to_string();
224
225            if !buffer.is_empty() {
226                current_text = format!("{buffer}{current_text}");
227                buffer.clear();
228            }
229            let mut lines = current_text.split_inclusive('\n').peekable();
230
231            while let Some(line) = lines.next() {
232                let is_last_line = lines.peek().is_none() && !current_text.ends_with('\n');
233                if is_last_line {
234                    buffer.push_str(line);
235                    continue;
236                }
237                if let Some(chunk) = process_line(line) {
238                    yield chunk;
239                }
240            }
241        }
242
243        if !buffer.is_empty() {
244            if let Some(chunk) = process_line(&buffer) {
245                yield chunk;
246            }
247        }
248    };
249
250    Ok(Box::pin(stream))
251}