1use anyhow::{Context, Result};
7use futures_util::StreamExt;
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10use tokio::sync::mpsc;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum StreamChunk {
15 Text(String),
17 ThinkingStart,
19 ThinkingDelta(String),
21 ThinkingEnd { summary: Option<String> },
23 ToolCallStart {
25 index: usize,
26 id: String,
27 name: String,
28 },
29 ToolCallDelta { index: usize, arguments: String },
31 Done,
33 Error(String),
35}
36
37#[derive(Debug, Clone)]
39pub struct StreamRequest {
40 pub provider: String,
41 pub base_url: String,
42 pub api_key: Option<String>,
43 pub model: String,
44 pub messages: Vec<StreamMessage>,
45 pub tools: Vec<serde_json::Value>,
46 pub thinking_budget: Option<u32>,
48}
49
50#[derive(Debug, Clone)]
51pub struct StreamMessage {
52 pub role: String,
53 pub content: String,
54}
55
56pub async fn call_openai_streaming(
59 http: &reqwest::Client,
60 req: &StreamRequest,
61 tx: mpsc::Sender<StreamChunk>,
62) -> Result<()> {
63 let url = format!("{}/chat/completions", req.base_url.trim_end_matches('/'));
64
65 let messages: Vec<serde_json::Value> = req
66 .messages
67 .iter()
68 .map(|m| {
69 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&m.content) {
70 if parsed.is_object() && parsed.get("role").is_some() {
71 return parsed;
72 }
73 }
74 json!({ "role": m.role, "content": m.content })
75 })
76 .collect();
77
78 let mut body = json!({
79 "model": req.model,
80 "messages": messages,
81 "stream": true,
82 });
83
84 if !req.tools.is_empty() {
85 body["tools"] = json!(req.tools);
86 }
87
88 let mut builder = http.post(&url).json(&body);
89 if let Some(ref key) = req.api_key {
90 builder = builder.bearer_auth(key);
91 }
92
93 let resp = builder.send().await.context("HTTP request failed")?;
94
95 if !resp.status().is_success() {
96 let status = resp.status();
97 let text = resp.text().await.unwrap_or_default();
98 let _ = tx.send(StreamChunk::Error(format!("{} — {}", status, text))).await;
99 return Ok(());
100 }
101
102 let mut stream = resp.bytes_stream();
104 let mut buffer = String::new();
105 let mut tool_calls: Vec<(String, String, String)> = Vec::new(); while let Some(chunk_result) = stream.next().await {
108 let chunk = chunk_result.context("Stream read error")?;
109 buffer.push_str(&String::from_utf8_lossy(&chunk));
110
111 while let Some(event_end) = buffer.find("\n\n") {
113 let event = buffer[..event_end].to_string();
114 buffer = buffer[event_end + 2..].to_string();
115
116 for line in event.lines() {
117 if let Some(data) = line.strip_prefix("data: ") {
118 if data == "[DONE]" {
119 let _ = tx.send(StreamChunk::Done).await;
120 return Ok(());
121 }
122
123 if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
124 if let Some(delta) = json["choices"][0]["delta"].as_object() {
126 if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
128 let _ = tx.send(StreamChunk::Text(content.to_string())).await;
129 }
130
131 if let Some(tc_array) = delta.get("tool_calls").and_then(|t| t.as_array()) {
133 for tc in tc_array {
134 let index = tc["index"].as_u64().unwrap_or(0) as usize;
135
136 while tool_calls.len() <= index {
138 tool_calls.push((String::new(), String::new(), String::new()));
139 }
140
141 if let Some(id) = tc["id"].as_str() {
143 tool_calls[index].0 = id.to_string();
144 }
145 if let Some(func) = tc.get("function") {
146 if let Some(name) = func["name"].as_str() {
147 tool_calls[index].1 = name.to_string();
148 let _ = tx.send(StreamChunk::ToolCallStart {
149 index,
150 id: tool_calls[index].0.clone(),
151 name: name.to_string(),
152 }).await;
153 }
154 if let Some(args) = func["arguments"].as_str() {
155 tool_calls[index].2.push_str(args);
156 let _ = tx.send(StreamChunk::ToolCallDelta {
157 index,
158 arguments: args.to_string(),
159 }).await;
160 }
161 }
162 }
163 }
164 }
165
166 if let Some(finish) = json["choices"][0]["finish_reason"].as_str() {
168 if finish == "stop" || finish == "tool_calls" {
169 let _ = tx.send(StreamChunk::Done).await;
170 return Ok(());
171 }
172 }
173 }
174 }
175 }
176 }
177 }
178
179 let _ = tx.send(StreamChunk::Done).await;
180 Ok(())
181}
182
183pub async fn call_anthropic_streaming(
189 http: &reqwest::Client,
190 req: &StreamRequest,
191 tx: mpsc::Sender<StreamChunk>,
192) -> Result<()> {
193 let url = format!("{}/v1/messages", req.base_url.trim_end_matches('/'));
194
195 let system = req
196 .messages
197 .iter()
198 .filter(|m| m.role == "system")
199 .map(|m| m.content.as_str())
200 .collect::<Vec<_>>()
201 .join("\n\n");
202
203 let messages: Vec<serde_json::Value> = req
204 .messages
205 .iter()
206 .filter(|m| m.role != "system")
207 .map(|m| {
208 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&m.content) {
209 if parsed.is_array() {
210 return json!({ "role": m.role, "content": parsed });
211 }
212 }
213 json!({ "role": m.role, "content": m.content })
214 })
215 .collect();
216
217 let max_tokens = if req.thinking_budget.is_some() {
220 16384 } else {
222 4096
223 };
224
225 let mut body = json!({
226 "model": req.model,
227 "max_tokens": max_tokens,
228 "messages": messages,
229 "stream": true,
230 });
231
232 if !system.is_empty() {
233 body["system"] = serde_json::Value::String(system);
234 }
235 if !req.tools.is_empty() {
236 body["tools"] = json!(req.tools);
237 }
238
239 if let Some(budget) = req.thinking_budget {
241 body["thinking"] = json!({
242 "type": "enabled",
243 "budget_tokens": budget
244 });
245 }
246
247 let api_key = req.api_key.as_deref().unwrap_or("");
248 let resp = http
249 .post(&url)
250 .header("x-api-key", api_key)
251 .header("anthropic-version", "2023-06-01")
252 .json(&body)
253 .send()
254 .await
255 .context("HTTP request to Anthropic failed")?;
256
257 if !resp.status().is_success() {
258 let status = resp.status();
259 let text = resp.text().await.unwrap_or_default();
260 let _ = tx.send(StreamChunk::Error(format!("{} — {}", status, text))).await;
261 return Ok(());
262 }
263
264 let mut stream = resp.bytes_stream();
266 let mut buffer = String::new();
267 let mut current_tool_index = 0;
268 let mut in_thinking_block = false;
269 let mut thinking_content = String::new();
270
271 while let Some(chunk_result) = stream.next().await {
272 let chunk = chunk_result.context("Stream read error")?;
273 buffer.push_str(&String::from_utf8_lossy(&chunk));
274
275 while let Some(event_end) = buffer.find("\n\n") {
276 let event = buffer[..event_end].to_string();
277 buffer = buffer[event_end + 2..].to_string();
278
279 let mut event_type = String::new();
280 let mut event_data = String::new();
281
282 for line in event.lines() {
283 if let Some(typ) = line.strip_prefix("event: ") {
284 event_type = typ.to_string();
285 } else if let Some(data) = line.strip_prefix("data: ") {
286 event_data = data.to_string();
287 }
288 }
289
290 if event_data.is_empty() {
291 continue;
292 }
293
294 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&event_data) {
295 match event_type.as_str() {
296 "content_block_start" => {
297 if let Some(block) = json.get("content_block") {
298 match block["type"].as_str() {
299 Some("thinking") => {
300 in_thinking_block = true;
302 thinking_content.clear();
303 let _ = tx.send(StreamChunk::ThinkingStart).await;
304 }
305 Some("tool_use") => {
306 let id = block["id"].as_str().unwrap_or("").to_string();
307 let name = block["name"].as_str().unwrap_or("").to_string();
308 current_tool_index = json["index"].as_u64().unwrap_or(0) as usize;
309 let _ = tx.send(StreamChunk::ToolCallStart {
310 index: current_tool_index,
311 id,
312 name,
313 }).await;
314 }
315 Some("text") => {
316 }
318 _ => {}
319 }
320 }
321 }
322 "content_block_delta" => {
323 if let Some(delta) = json.get("delta") {
324 match delta["type"].as_str() {
325 Some("thinking_delta") => {
326 if let Some(thinking) = delta["thinking"].as_str() {
328 thinking_content.push_str(thinking);
329 let _ = tx.send(StreamChunk::ThinkingDelta(thinking.to_string())).await;
330 }
331 }
332 Some("text_delta") => {
333 if let Some(text) = delta["text"].as_str() {
334 let _ = tx.send(StreamChunk::Text(text.to_string())).await;
335 }
336 }
337 Some("input_json_delta") => {
338 if let Some(partial) = delta["partial_json"].as_str() {
339 let _ = tx.send(StreamChunk::ToolCallDelta {
340 index: current_tool_index,
341 arguments: partial.to_string(),
342 }).await;
343 }
344 }
345 _ => {}
346 }
347 }
348 }
349 "content_block_stop" => {
350 if in_thinking_block {
352 in_thinking_block = false;
353 let summary = if thinking_content.len() > 100 {
356 let truncated = &thinking_content[..100];
357 if let Some(period_pos) = truncated.find(". ") {
358 Some(truncated[..=period_pos].to_string())
359 } else {
360 Some(format!("{}...", truncated))
361 }
362 } else if !thinking_content.is_empty() {
363 Some(thinking_content.clone())
364 } else {
365 None
366 };
367 let _ = tx.send(StreamChunk::ThinkingEnd { summary }).await;
368 }
369 }
370 "message_stop" => {
371 let _ = tx.send(StreamChunk::Done).await;
372 return Ok(());
373 }
374 "error" => {
375 let msg = json["error"]["message"]
376 .as_str()
377 .unwrap_or("Unknown error");
378 let _ = tx.send(StreamChunk::Error(msg.to_string())).await;
379 return Ok(());
380 }
381 _ => {}
382 }
383 }
384 }
385 }
386
387 let _ = tx.send(StreamChunk::Done).await;
388 Ok(())
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_stream_chunk_serialization() {
397 let chunk = StreamChunk::Text("hello".to_string());
398 let json = serde_json::to_string(&chunk).unwrap();
399 assert!(json.contains("Text"));
400 assert!(json.contains("hello"));
401 }
402
403 #[test]
404 fn test_thinking_chunk_serialization() {
405 let start = StreamChunk::ThinkingStart;
406 let json = serde_json::to_string(&start).unwrap();
407 assert!(json.contains("ThinkingStart"));
408
409 let delta = StreamChunk::ThinkingDelta("analyzing...".to_string());
410 let json = serde_json::to_string(&delta).unwrap();
411 assert!(json.contains("ThinkingDelta"));
412 assert!(json.contains("analyzing"));
413
414 let end = StreamChunk::ThinkingEnd { summary: Some("Done thinking".to_string()) };
415 let json = serde_json::to_string(&end).unwrap();
416 assert!(json.contains("ThinkingEnd"));
417 assert!(json.contains("Done thinking"));
418 }
419
420 #[test]
421 fn test_stream_request_creation() {
422 let req = StreamRequest {
423 provider: "openai".to_string(),
424 base_url: "https://api.openai.com".to_string(),
425 api_key: Some("test-key".to_string()),
426 model: "gpt-4".to_string(),
427 messages: vec![StreamMessage {
428 role: "user".to_string(),
429 content: "Hello".to_string(),
430 }],
431 tools: vec![],
432 thinking_budget: None,
433 };
434 assert_eq!(req.model, "gpt-4");
435 }
436
437 #[test]
438 fn test_stream_request_with_thinking() {
439 let req = StreamRequest {
440 provider: "anthropic".to_string(),
441 base_url: "https://api.anthropic.com".to_string(),
442 api_key: Some("test-key".to_string()),
443 model: "claude-sonnet-4-20250514".to_string(),
444 messages: vec![StreamMessage {
445 role: "user".to_string(),
446 content: "Think about this deeply".to_string(),
447 }],
448 tools: vec![],
449 thinking_budget: Some(10000),
450 };
451 assert_eq!(req.thinking_budget, Some(10000));
452 }
453}