1use anyhow::{Context, Result};
7use futures_util::StreamExt;
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10use tokio::sync::mpsc;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum StreamChunk {
15 Text(String),
17 ThinkingStart,
19 ThinkingDelta(String),
21 ThinkingEnd { summary: Option<String> },
23 ToolCallStart {
25 index: usize,
26 id: String,
27 name: String,
28 },
29 ToolCallDelta { index: usize, arguments: String },
31 Done,
33 Error(String),
35}
36
37#[derive(Debug, Clone)]
39pub struct StreamRequest {
40 pub provider: String,
41 pub base_url: String,
42 pub api_key: Option<String>,
43 pub model: String,
44 pub messages: Vec<StreamMessage>,
45 pub tools: Vec<serde_json::Value>,
46 pub thinking_budget: Option<u32>,
48}
49
50#[derive(Debug, Clone)]
51pub struct StreamMessage {
52 pub role: String,
53 pub content: String,
54}
55
56pub async fn call_openai_streaming(
59 http: &reqwest::Client,
60 req: &StreamRequest,
61 tx: mpsc::Sender<StreamChunk>,
62) -> Result<()> {
63 let url = format!("{}/chat/completions", req.base_url.trim_end_matches('/'));
64
65 let messages: Vec<serde_json::Value> = req
66 .messages
67 .iter()
68 .map(|m| {
69 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&m.content) {
70 if parsed.is_object() && parsed.get("role").is_some() {
71 return parsed;
72 }
73 }
74 json!({ "role": m.role, "content": m.content })
75 })
76 .collect();
77
78 let mut body = json!({
79 "model": req.model,
80 "messages": messages,
81 "stream": true,
82 });
83
84 if !req.tools.is_empty() {
85 body["tools"] = json!(req.tools);
86 }
87
88 let mut builder = http.post(&url).json(&body);
89 if let Some(ref key) = req.api_key {
90 builder = builder.bearer_auth(key);
91 }
92
93 let resp = builder.send().await.context("HTTP request failed")?;
94
95 if !resp.status().is_success() {
96 let status = resp.status();
97 let text = resp.text().await.unwrap_or_default();
98 let _ = tx
99 .send(StreamChunk::Error(format!("{} — {}", status, text)))
100 .await;
101 return Ok(());
102 }
103
104 let mut stream = resp.bytes_stream();
106 let mut buffer = String::new();
107 let mut tool_calls: Vec<(String, String, String)> = Vec::new(); while let Some(chunk_result) = stream.next().await {
110 let chunk = chunk_result.context("Stream read error")?;
111 buffer.push_str(&String::from_utf8_lossy(&chunk));
112
113 while let Some(event_end) = buffer.find("\n\n") {
115 let event = buffer[..event_end].to_string();
116 buffer = buffer[event_end + 2..].to_string();
117
118 for line in event.lines() {
119 if let Some(data) = line.strip_prefix("data: ") {
120 if data == "[DONE]" {
121 let _ = tx.send(StreamChunk::Done).await;
122 return Ok(());
123 }
124
125 if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
126 if let Some(delta) = json["choices"][0]["delta"].as_object() {
128 if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
130 let _ = tx.send(StreamChunk::Text(content.to_string())).await;
131 }
132
133 if let Some(tc_array) =
135 delta.get("tool_calls").and_then(|t| t.as_array())
136 {
137 for tc in tc_array {
138 let index = tc["index"].as_u64().unwrap_or(0) as usize;
139
140 while tool_calls.len() <= index {
142 tool_calls.push((
143 String::new(),
144 String::new(),
145 String::new(),
146 ));
147 }
148
149 if let Some(id) = tc["id"].as_str() {
151 tool_calls[index].0 = id.to_string();
152 }
153 if let Some(func) = tc.get("function") {
154 if let Some(name) = func["name"].as_str() {
155 tool_calls[index].1 = name.to_string();
156 let _ = tx
157 .send(StreamChunk::ToolCallStart {
158 index,
159 id: tool_calls[index].0.clone(),
160 name: name.to_string(),
161 })
162 .await;
163 }
164 if let Some(args) = func["arguments"].as_str() {
165 tool_calls[index].2.push_str(args);
166 let _ = tx
167 .send(StreamChunk::ToolCallDelta {
168 index,
169 arguments: args.to_string(),
170 })
171 .await;
172 }
173 }
174 }
175 }
176 }
177
178 if let Some(finish) = json["choices"][0]["finish_reason"].as_str() {
180 if finish == "stop" || finish == "tool_calls" {
181 let _ = tx.send(StreamChunk::Done).await;
182 return Ok(());
183 }
184 }
185 }
186 }
187 }
188 }
189 }
190
191 let _ = tx.send(StreamChunk::Done).await;
192 Ok(())
193}
194
195pub async fn call_anthropic_streaming(
201 http: &reqwest::Client,
202 req: &StreamRequest,
203 tx: mpsc::Sender<StreamChunk>,
204) -> Result<()> {
205 let url = format!("{}/v1/messages", req.base_url.trim_end_matches('/'));
206
207 let system = req
208 .messages
209 .iter()
210 .filter(|m| m.role == "system")
211 .map(|m| m.content.as_str())
212 .collect::<Vec<_>>()
213 .join("\n\n");
214
215 let messages: Vec<serde_json::Value> = req
216 .messages
217 .iter()
218 .filter(|m| m.role != "system")
219 .map(|m| {
220 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&m.content) {
221 if parsed.is_array() {
222 return json!({ "role": m.role, "content": parsed });
223 }
224 }
225 json!({ "role": m.role, "content": m.content })
226 })
227 .collect();
228
229 let max_tokens = if req.thinking_budget.is_some() {
232 16384 } else {
234 4096
235 };
236
237 let mut body = json!({
238 "model": req.model,
239 "max_tokens": max_tokens,
240 "messages": messages,
241 "stream": true,
242 });
243
244 if !system.is_empty() {
245 body["system"] = serde_json::Value::String(system);
246 }
247 if !req.tools.is_empty() {
248 body["tools"] = json!(req.tools);
249 }
250
251 if let Some(budget) = req.thinking_budget {
253 body["thinking"] = json!({
254 "type": "enabled",
255 "budget_tokens": budget
256 });
257 }
258
259 let api_key = req.api_key.as_deref().unwrap_or("");
260 let resp = http
261 .post(&url)
262 .header("x-api-key", api_key)
263 .header("anthropic-version", "2023-06-01")
264 .json(&body)
265 .send()
266 .await
267 .context("HTTP request to Anthropic failed")?;
268
269 if !resp.status().is_success() {
270 let status = resp.status();
271 let text = resp.text().await.unwrap_or_default();
272 let _ = tx
273 .send(StreamChunk::Error(format!("{} — {}", status, text)))
274 .await;
275 return Ok(());
276 }
277
278 let mut stream = resp.bytes_stream();
280 let mut buffer = String::new();
281 let mut current_tool_index = 0;
282 let mut in_thinking_block = false;
283 let mut thinking_content = String::new();
284
285 while let Some(chunk_result) = stream.next().await {
286 let chunk = chunk_result.context("Stream read error")?;
287 buffer.push_str(&String::from_utf8_lossy(&chunk));
288
289 while let Some(event_end) = buffer.find("\n\n") {
290 let event = buffer[..event_end].to_string();
291 buffer = buffer[event_end + 2..].to_string();
292
293 let mut event_type = String::new();
294 let mut event_data = String::new();
295
296 for line in event.lines() {
297 if let Some(typ) = line.strip_prefix("event: ") {
298 event_type = typ.to_string();
299 } else if let Some(data) = line.strip_prefix("data: ") {
300 event_data = data.to_string();
301 }
302 }
303
304 if event_data.is_empty() {
305 continue;
306 }
307
308 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&event_data) {
309 match event_type.as_str() {
310 "content_block_start" => {
311 if let Some(block) = json.get("content_block") {
312 match block["type"].as_str() {
313 Some("thinking") => {
314 in_thinking_block = true;
316 thinking_content.clear();
317 let _ = tx.send(StreamChunk::ThinkingStart).await;
318 }
319 Some("tool_use") => {
320 let id = block["id"].as_str().unwrap_or("").to_string();
321 let name = block["name"].as_str().unwrap_or("").to_string();
322 current_tool_index =
323 json["index"].as_u64().unwrap_or(0) as usize;
324 let _ = tx
325 .send(StreamChunk::ToolCallStart {
326 index: current_tool_index,
327 id,
328 name,
329 })
330 .await;
331 }
332 Some("text") => {
333 }
335 _ => {}
336 }
337 }
338 }
339 "content_block_delta" => {
340 if let Some(delta) = json.get("delta") {
341 match delta["type"].as_str() {
342 Some("thinking_delta") => {
343 if let Some(thinking) = delta["thinking"].as_str() {
345 thinking_content.push_str(thinking);
346 let _ = tx
347 .send(StreamChunk::ThinkingDelta(thinking.to_string()))
348 .await;
349 }
350 }
351 Some("text_delta") => {
352 if let Some(text) = delta["text"].as_str() {
353 let _ = tx.send(StreamChunk::Text(text.to_string())).await;
354 }
355 }
356 Some("input_json_delta") => {
357 if let Some(partial) = delta["partial_json"].as_str() {
358 let _ = tx
359 .send(StreamChunk::ToolCallDelta {
360 index: current_tool_index,
361 arguments: partial.to_string(),
362 })
363 .await;
364 }
365 }
366 _ => {}
367 }
368 }
369 }
370 "content_block_stop" => {
371 if in_thinking_block {
373 in_thinking_block = false;
374 let summary = if thinking_content.len() > 100 {
377 let truncated = &thinking_content[..100];
378 if let Some(period_pos) = truncated.find(". ") {
379 Some(truncated[..=period_pos].to_string())
380 } else {
381 Some(format!("{}...", truncated))
382 }
383 } else if !thinking_content.is_empty() {
384 Some(thinking_content.clone())
385 } else {
386 None
387 };
388 let _ = tx.send(StreamChunk::ThinkingEnd { summary }).await;
389 }
390 }
391 "message_stop" => {
392 let _ = tx.send(StreamChunk::Done).await;
393 return Ok(());
394 }
395 "error" => {
396 let msg = json["error"]["message"].as_str().unwrap_or("Unknown error");
397 let _ = tx.send(StreamChunk::Error(msg.to_string())).await;
398 return Ok(());
399 }
400 _ => {}
401 }
402 }
403 }
404 }
405
406 let _ = tx.send(StreamChunk::Done).await;
407 Ok(())
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_stream_chunk_serialization() {
416 let chunk = StreamChunk::Text("hello".to_string());
417 let json = serde_json::to_string(&chunk).unwrap();
418 assert!(json.contains("Text"));
419 assert!(json.contains("hello"));
420 }
421
422 #[test]
423 fn test_thinking_chunk_serialization() {
424 let start = StreamChunk::ThinkingStart;
425 let json = serde_json::to_string(&start).unwrap();
426 assert!(json.contains("ThinkingStart"));
427
428 let delta = StreamChunk::ThinkingDelta("analyzing...".to_string());
429 let json = serde_json::to_string(&delta).unwrap();
430 assert!(json.contains("ThinkingDelta"));
431 assert!(json.contains("analyzing"));
432
433 let end = StreamChunk::ThinkingEnd {
434 summary: Some("Done thinking".to_string()),
435 };
436 let json = serde_json::to_string(&end).unwrap();
437 assert!(json.contains("ThinkingEnd"));
438 assert!(json.contains("Done thinking"));
439 }
440
441 #[test]
442 fn test_stream_request_creation() {
443 let req = StreamRequest {
444 provider: "openai".to_string(),
445 base_url: "https://api.openai.com".to_string(),
446 api_key: Some("test-key".to_string()),
447 model: "gpt-4".to_string(),
448 messages: vec![StreamMessage {
449 role: "user".to_string(),
450 content: "Hello".to_string(),
451 }],
452 tools: vec![],
453 thinking_budget: None,
454 };
455 assert_eq!(req.model, "gpt-4");
456 }
457
458 #[test]
459 fn test_stream_request_with_thinking() {
460 let req = StreamRequest {
461 provider: "anthropic".to_string(),
462 base_url: "https://api.anthropic.com".to_string(),
463 api_key: Some("test-key".to_string()),
464 model: "claude-sonnet-4-20250514".to_string(),
465 messages: vec![StreamMessage {
466 role: "user".to_string(),
467 content: "Think about this deeply".to_string(),
468 }],
469 tools: vec![],
470 thinking_budget: Some(10000),
471 };
472 assert_eq!(req.thinking_budget, Some(10000));
473 }
474}