1use crate::request_headers;
2use crate::Key;
3use crate::Message;
4use crate::Provider;
5use async_stream::stream;
6use bytes::Bytes;
7use futures::Stream;
8use futures::StreamExt;
9use reqwest;
10use reqwest::Response;
11use serde::Deserialize;
12use serde::Serialize;
13use serde_json::Value;
14use std::error::Error;
15use std::pin::Pin;
16
17fn address(provider: &Provider) -> String {
18 let base_url = crate::openai_base_url(provider);
19 format!("{}/chat/completions", base_url)
20}
21
22async fn request_chat_completion(
23 provider: &Provider,
24 key: &Key,
25 model: &str,
26 stream: bool,
27 messages: &[Message],
28) -> Result<Response, Box<dyn Error + Send + Sync>> {
29 let address = address(provider);
30 let body = serde_json::json!({
31 "model": model,
32 "messages": messages,
33 "stream": stream,
34 });
35 let client = if provider == &Provider::Google {
36 reqwest::Client::builder().use_rustls_tls().build()?
40 } else {
41 reqwest::Client::new()
42 };
43 tracing::debug!("Requesting chat: {body}");
44 let resp = client
45 .post(address)
46 .headers(request_headers(key)?)
47 .json(&body)
48 .send()
49 .await?;
50 Ok(resp)
51}
52
53#[derive(Debug, Serialize, Deserialize)]
54pub struct Choice {
55 pub index: u64,
56 pub message: Message,
57 pub logprobs: Option<String>,
58 pub finish_reason: Option<String>,
59}
60
61#[derive(Debug, Serialize, Deserialize)]
62pub struct Usage {
63 pub prompt_tokens: u64,
64 pub completion_tokens: u64,
65 pub total_tokens: u64,
66}
67
68#[derive(Debug, Serialize, Deserialize)]
69pub struct ChatCompletion {
70 pub id: Option<String>,
71 pub object: String,
72 pub created: u64,
73 pub model: String,
74 pub system_fingerprint: Option<String>,
75 pub choices: Vec<Choice>,
76 pub service_tier: Option<String>,
77 pub usage: Usage,
78}
79
80#[derive(Debug, Serialize, Deserialize)]
81pub struct ChatCompletionError {
82 pub object: Option<String>,
83 pub message: String,
84}
85
86fn extract_error(body: &Value) -> String {
87 if let Some(error) = body.get("error") {
88 if let Some(message) = error.get("message") {
89 return message
90 .as_str()
91 .unwrap_or(body.to_string().as_str())
92 .to_string();
93 }
94 }
95 if let Some(message) = body.get("message") {
96 return message
97 .as_str()
98 .unwrap_or(body.to_string().as_str())
99 .to_string();
100 }
101 format!("Unknown error: {body}")
102}
103
104pub struct ChatCompletionResponse {
116 status: u16,
117 resp: Bytes,
118}
119
120impl ChatCompletionResponse {
121 pub fn bytes(&self) -> &Bytes {
122 &self.resp
123 }
124 pub fn raw_value(&self) -> Result<Value, Box<dyn Error + Send + Sync>> {
125 Ok(serde_json::from_slice::<Value>(&self.resp)?)
126 }
127 pub fn structured(&self) -> Result<ChatCompletion, Box<dyn Error + Send + Sync>> {
128 let json = self.raw_value()?;
129 let text = json.to_string();
130 if text.is_empty() {
131 return Err(
132 format!("Received empty response with status code: {}", self.status).into(),
133 );
134 }
135 let json = match serde_json::from_str::<ChatCompletion>(&text) {
136 Ok(json) => json,
137 Err(_e) => match serde_json::from_str::<Value>(&text) {
138 Ok(error) => return Err(extract_error(&error).into()),
139 Err(e) => {
140 return Err(format!("Error parsing response: {} in text: '{}'", e, text).into())
141 }
142 },
143 };
144 Ok(json)
145 }
146}
147
148pub async fn chat_completion(
149 provider: &Provider,
150 key: &Key,
151 model: &str,
152 messages: &[Message],
153) -> Result<ChatCompletionResponse, Box<dyn Error + Send + Sync>> {
154 let stream = false;
155 let resp = request_chat_completion(provider, key, model, stream, messages).await?;
156 let status = resp.status();
157 let chat_completion_response = ChatCompletionResponse {
158 status: status.into(),
159 resp: resp.bytes().await?,
160 };
161 Ok(chat_completion_response)
162}
163
164#[derive(Debug, Serialize, Deserialize)]
165pub struct Delta {
166 pub role: Option<String>,
167 pub content: Option<String>,
168}
169
170#[derive(Debug, Serialize, Deserialize)]
171pub struct ChunkChoice {
172 pub index: u64,
173 pub delta: Delta,
174 pub finish_reason: Option<String>,
175}
176
177#[derive(Debug, Serialize, Deserialize)]
178pub struct ChatCompletionChunk {
179 pub id: Option<String>,
180 pub object: String,
181 pub created: u64,
182 pub model: String,
183 pub system_fingerprint: Option<String>,
184 pub choices: Vec<ChunkChoice>,
185}
186
187fn process_line(line: &str) -> Option<ChatCompletionChunk> {
188 if line.is_empty() {
189 return None;
190 }
191
192 if let Some(json_str) = line.strip_prefix("data: ") {
193 if json_str == "[DONE]" {
194 return None;
195 }
196 match serde_json::from_str::<ChatCompletionChunk>(json_str) {
197 Ok(chunk) => Some(chunk),
198 Err(_) => None,
199 }
200 } else {
201 None
202 }
203}
204
205pub async fn stream_chat_completion(
206 provider: &Provider,
207 key: &Key,
208 model: &str,
209 messages: &[Message],
210) -> Result<Pin<Box<dyn Stream<Item = ChatCompletionChunk> + Send>>, Box<dyn Error + Send + Sync>> {
211 let resp = request_chat_completion(provider, key, model, true, messages).await?;
212
213 let stream = stream! {
214 let mut buffer = String::new();
215 let mut byte_stream = resp.bytes_stream();
216
217 while let Some(chunk) = byte_stream.next().await {
218 let chunk = match chunk {
219 Ok(c) => c,
220 Err(_) => break,
221 };
222
223 let mut current_text = String::from_utf8_lossy(&chunk).to_string();
224
225 if !buffer.is_empty() {
226 current_text = format!("{buffer}{current_text}");
227 buffer.clear();
228 }
229 let mut lines = current_text.split_inclusive('\n').peekable();
230
231 while let Some(line) = lines.next() {
232 let is_last_line = lines.peek().is_none() && !current_text.ends_with('\n');
233 if is_last_line {
234 buffer.push_str(line);
235 continue;
236 }
237 if let Some(chunk) = process_line(line) {
238 yield chunk;
239 }
240 }
241 }
242
243 if !buffer.is_empty() {
244 if let Some(chunk) = process_line(&buffer) {
245 yield chunk;
246 }
247 }
248 };
249
250 Ok(Box::pin(stream))
251}