Skip to main content

rustyclaw_core/
streaming.rs

1//! Streaming provider support for OpenAI-compatible and Anthropic APIs.
2//!
3//! This module adds SSE (Server-Sent Events) streaming to provider calls,
4//! allowing real-time token delivery to the TUI.
5
6use anyhow::{Context, Result};
7use futures_util::StreamExt;
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10use tokio::sync::mpsc;
11
12/// A streaming chunk from the model
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum StreamChunk {
15    /// Text content delta
16    Text(String),
17    /// Extended thinking started (Anthropic)
18    ThinkingStart,
19    /// Extended thinking content delta (Anthropic)
20    ThinkingDelta(String),
21    /// Extended thinking finished, includes summary if provided
22    ThinkingEnd { summary: Option<String> },
23    /// Tool call started
24    ToolCallStart {
25        index: usize,
26        id: String,
27        name: String,
28    },
29    /// Tool call arguments delta
30    ToolCallDelta { index: usize, arguments: String },
31    /// Stream finished
32    Done,
33    /// Error occurred
34    Error(String),
35}
36
37/// Request parameters for streaming calls
38#[derive(Debug, Clone)]
39pub struct StreamRequest {
40    pub provider: String,
41    pub base_url: String,
42    pub api_key: Option<String>,
43    pub model: String,
44    pub messages: Vec<StreamMessage>,
45    pub tools: Vec<serde_json::Value>,
46    /// Budget tokens for extended thinking (Anthropic only)
47    pub thinking_budget: Option<u32>,
48}
49
50#[derive(Debug, Clone)]
51pub struct StreamMessage {
52    pub role: String,
53    pub content: String,
54}
55
56/// Call OpenAI-compatible endpoint with streaming.
57/// Sends chunks to the provided channel.
58pub async fn call_openai_streaming(
59    http: &reqwest::Client,
60    req: &StreamRequest,
61    tx: mpsc::Sender<StreamChunk>,
62) -> Result<()> {
63    let url = format!("{}/chat/completions", req.base_url.trim_end_matches('/'));
64
65    let messages: Vec<serde_json::Value> = req
66        .messages
67        .iter()
68        .map(|m| {
69            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&m.content) {
70                if parsed.is_object() && parsed.get("role").is_some() {
71                    return parsed;
72                }
73            }
74            json!({ "role": m.role, "content": m.content })
75        })
76        .collect();
77
78    let mut body = json!({
79        "model": req.model,
80        "messages": messages,
81        "stream": true,
82    });
83
84    if !req.tools.is_empty() {
85        body["tools"] = json!(req.tools);
86    }
87
88    let mut builder = http.post(&url).json(&body);
89    if let Some(ref key) = req.api_key {
90        builder = builder.bearer_auth(key);
91    }
92
93    let resp = builder.send().await.context("HTTP request failed")?;
94
95    if !resp.status().is_success() {
96        let status = resp.status();
97        let text = resp.text().await.unwrap_or_default();
98        let _ = tx.send(StreamChunk::Error(format!("{} — {}", status, text))).await;
99        return Ok(());
100    }
101
102    // Parse SSE stream
103    let mut stream = resp.bytes_stream();
104    let mut buffer = String::new();
105    let mut tool_calls: Vec<(String, String, String)> = Vec::new(); // (id, name, args)
106
107    while let Some(chunk_result) = stream.next().await {
108        let chunk = chunk_result.context("Stream read error")?;
109        buffer.push_str(&String::from_utf8_lossy(&chunk));
110
111        // Process complete SSE events
112        while let Some(event_end) = buffer.find("\n\n") {
113            let event = buffer[..event_end].to_string();
114            buffer = buffer[event_end + 2..].to_string();
115
116            for line in event.lines() {
117                if let Some(data) = line.strip_prefix("data: ") {
118                    if data == "[DONE]" {
119                        let _ = tx.send(StreamChunk::Done).await;
120                        return Ok(());
121                    }
122
123                    if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
124                        // Extract content delta
125                        if let Some(delta) = json["choices"][0]["delta"].as_object() {
126                            // Text content
127                            if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
128                                let _ = tx.send(StreamChunk::Text(content.to_string())).await;
129                            }
130
131                            // Tool calls
132                            if let Some(tc_array) = delta.get("tool_calls").and_then(|t| t.as_array()) {
133                                for tc in tc_array {
134                                    let index = tc["index"].as_u64().unwrap_or(0) as usize;
135                                    
136                                    // Ensure tool_calls vec is big enough
137                                    while tool_calls.len() <= index {
138                                        tool_calls.push((String::new(), String::new(), String::new()));
139                                    }
140
141                                    // Tool call start
142                                    if let Some(id) = tc["id"].as_str() {
143                                        tool_calls[index].0 = id.to_string();
144                                    }
145                                    if let Some(func) = tc.get("function") {
146                                        if let Some(name) = func["name"].as_str() {
147                                            tool_calls[index].1 = name.to_string();
148                                            let _ = tx.send(StreamChunk::ToolCallStart {
149                                                index,
150                                                id: tool_calls[index].0.clone(),
151                                                name: name.to_string(),
152                                            }).await;
153                                        }
154                                        if let Some(args) = func["arguments"].as_str() {
155                                            tool_calls[index].2.push_str(args);
156                                            let _ = tx.send(StreamChunk::ToolCallDelta {
157                                                index,
158                                                arguments: args.to_string(),
159                                            }).await;
160                                        }
161                                    }
162                                }
163                            }
164                        }
165
166                        // Check for finish reason
167                        if let Some(finish) = json["choices"][0]["finish_reason"].as_str() {
168                            if finish == "stop" || finish == "tool_calls" {
169                                let _ = tx.send(StreamChunk::Done).await;
170                                return Ok(());
171                            }
172                        }
173                    }
174                }
175            }
176        }
177    }
178
179    let _ = tx.send(StreamChunk::Done).await;
180    Ok(())
181}
182
183/// Call Anthropic endpoint with streaming.
184///
185/// Supports extended thinking via the `thinking_budget` field in the request.
186/// When thinking is enabled, sends `ThinkingStart`, `ThinkingDelta`, and
187/// `ThinkingEnd` chunks so the TUI can display a thinking indicator.
188pub async fn call_anthropic_streaming(
189    http: &reqwest::Client,
190    req: &StreamRequest,
191    tx: mpsc::Sender<StreamChunk>,
192) -> Result<()> {
193    let url = format!("{}/v1/messages", req.base_url.trim_end_matches('/'));
194
195    let system = req
196        .messages
197        .iter()
198        .filter(|m| m.role == "system")
199        .map(|m| m.content.as_str())
200        .collect::<Vec<_>>()
201        .join("\n\n");
202
203    let messages: Vec<serde_json::Value> = req
204        .messages
205        .iter()
206        .filter(|m| m.role != "system")
207        .map(|m| {
208            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&m.content) {
209                if parsed.is_array() {
210                    return json!({ "role": m.role, "content": parsed });
211                }
212            }
213            json!({ "role": m.role, "content": m.content })
214        })
215        .collect();
216
217    // Determine max_tokens based on whether thinking is enabled
218    // Extended thinking requires higher max_tokens to accommodate thinking + response
219    let max_tokens = if req.thinking_budget.is_some() {
220        16384 // Allow room for thinking + response
221    } else {
222        4096
223    };
224
225    let mut body = json!({
226        "model": req.model,
227        "max_tokens": max_tokens,
228        "messages": messages,
229        "stream": true,
230    });
231
232    if !system.is_empty() {
233        body["system"] = serde_json::Value::String(system);
234    }
235    if !req.tools.is_empty() {
236        body["tools"] = json!(req.tools);
237    }
238
239    // Add thinking configuration if budget is specified
240    if let Some(budget) = req.thinking_budget {
241        body["thinking"] = json!({
242            "type": "enabled",
243            "budget_tokens": budget
244        });
245    }
246
247    let api_key = req.api_key.as_deref().unwrap_or("");
248    let resp = http
249        .post(&url)
250        .header("x-api-key", api_key)
251        .header("anthropic-version", "2023-06-01")
252        .json(&body)
253        .send()
254        .await
255        .context("HTTP request to Anthropic failed")?;
256
257    if !resp.status().is_success() {
258        let status = resp.status();
259        let text = resp.text().await.unwrap_or_default();
260        let _ = tx.send(StreamChunk::Error(format!("{} — {}", status, text))).await;
261        return Ok(());
262    }
263
264    // Parse Anthropic SSE stream
265    let mut stream = resp.bytes_stream();
266    let mut buffer = String::new();
267    let mut current_tool_index = 0;
268    let mut in_thinking_block = false;
269    let mut thinking_content = String::new();
270
271    while let Some(chunk_result) = stream.next().await {
272        let chunk = chunk_result.context("Stream read error")?;
273        buffer.push_str(&String::from_utf8_lossy(&chunk));
274
275        while let Some(event_end) = buffer.find("\n\n") {
276            let event = buffer[..event_end].to_string();
277            buffer = buffer[event_end + 2..].to_string();
278
279            let mut event_type = String::new();
280            let mut event_data = String::new();
281
282            for line in event.lines() {
283                if let Some(typ) = line.strip_prefix("event: ") {
284                    event_type = typ.to_string();
285                } else if let Some(data) = line.strip_prefix("data: ") {
286                    event_data = data.to_string();
287                }
288            }
289
290            if event_data.is_empty() {
291                continue;
292            }
293
294            if let Ok(json) = serde_json::from_str::<serde_json::Value>(&event_data) {
295                match event_type.as_str() {
296                    "content_block_start" => {
297                        if let Some(block) = json.get("content_block") {
298                            match block["type"].as_str() {
299                                Some("thinking") => {
300                                    // Extended thinking block started
301                                    in_thinking_block = true;
302                                    thinking_content.clear();
303                                    let _ = tx.send(StreamChunk::ThinkingStart).await;
304                                }
305                                Some("tool_use") => {
306                                    let id = block["id"].as_str().unwrap_or("").to_string();
307                                    let name = block["name"].as_str().unwrap_or("").to_string();
308                                    current_tool_index = json["index"].as_u64().unwrap_or(0) as usize;
309                                    let _ = tx.send(StreamChunk::ToolCallStart {
310                                        index: current_tool_index,
311                                        id,
312                                        name,
313                                    }).await;
314                                }
315                                Some("text") => {
316                                    // Regular text block - nothing special to do on start
317                                }
318                                _ => {}
319                            }
320                        }
321                    }
322                    "content_block_delta" => {
323                        if let Some(delta) = json.get("delta") {
324                            match delta["type"].as_str() {
325                                Some("thinking_delta") => {
326                                    // Extended thinking content streaming
327                                    if let Some(thinking) = delta["thinking"].as_str() {
328                                        thinking_content.push_str(thinking);
329                                        let _ = tx.send(StreamChunk::ThinkingDelta(thinking.to_string())).await;
330                                    }
331                                }
332                                Some("text_delta") => {
333                                    if let Some(text) = delta["text"].as_str() {
334                                        let _ = tx.send(StreamChunk::Text(text.to_string())).await;
335                                    }
336                                }
337                                Some("input_json_delta") => {
338                                    if let Some(partial) = delta["partial_json"].as_str() {
339                                        let _ = tx.send(StreamChunk::ToolCallDelta {
340                                            index: current_tool_index,
341                                            arguments: partial.to_string(),
342                                        }).await;
343                                    }
344                                }
345                                _ => {}
346                            }
347                        }
348                    }
349                    "content_block_stop" => {
350                        // A content block finished
351                        if in_thinking_block {
352                            in_thinking_block = false;
353                            // Generate a brief summary from the thinking content
354                            // (first ~100 chars or first sentence, whichever is shorter)
355                            let summary = if thinking_content.len() > 100 {
356                                let truncated = &thinking_content[..100];
357                                if let Some(period_pos) = truncated.find(". ") {
358                                    Some(truncated[..=period_pos].to_string())
359                                } else {
360                                    Some(format!("{}...", truncated))
361                                }
362                            } else if !thinking_content.is_empty() {
363                                Some(thinking_content.clone())
364                            } else {
365                                None
366                            };
367                            let _ = tx.send(StreamChunk::ThinkingEnd { summary }).await;
368                        }
369                    }
370                    "message_stop" => {
371                        let _ = tx.send(StreamChunk::Done).await;
372                        return Ok(());
373                    }
374                    "error" => {
375                        let msg = json["error"]["message"]
376                            .as_str()
377                            .unwrap_or("Unknown error");
378                        let _ = tx.send(StreamChunk::Error(msg.to_string())).await;
379                        return Ok(());
380                    }
381                    _ => {}
382                }
383            }
384        }
385    }
386
387    let _ = tx.send(StreamChunk::Done).await;
388    Ok(())
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_stream_chunk_serialization() {
397        let chunk = StreamChunk::Text("hello".to_string());
398        let json = serde_json::to_string(&chunk).unwrap();
399        assert!(json.contains("Text"));
400        assert!(json.contains("hello"));
401    }
402
403    #[test]
404    fn test_thinking_chunk_serialization() {
405        let start = StreamChunk::ThinkingStart;
406        let json = serde_json::to_string(&start).unwrap();
407        assert!(json.contains("ThinkingStart"));
408
409        let delta = StreamChunk::ThinkingDelta("analyzing...".to_string());
410        let json = serde_json::to_string(&delta).unwrap();
411        assert!(json.contains("ThinkingDelta"));
412        assert!(json.contains("analyzing"));
413
414        let end = StreamChunk::ThinkingEnd { summary: Some("Done thinking".to_string()) };
415        let json = serde_json::to_string(&end).unwrap();
416        assert!(json.contains("ThinkingEnd"));
417        assert!(json.contains("Done thinking"));
418    }
419
420    #[test]
421    fn test_stream_request_creation() {
422        let req = StreamRequest {
423            provider: "openai".to_string(),
424            base_url: "https://api.openai.com".to_string(),
425            api_key: Some("test-key".to_string()),
426            model: "gpt-4".to_string(),
427            messages: vec![StreamMessage {
428                role: "user".to_string(),
429                content: "Hello".to_string(),
430            }],
431            tools: vec![],
432            thinking_budget: None,
433        };
434        assert_eq!(req.model, "gpt-4");
435    }
436
437    #[test]
438    fn test_stream_request_with_thinking() {
439        let req = StreamRequest {
440            provider: "anthropic".to_string(),
441            base_url: "https://api.anthropic.com".to_string(),
442            api_key: Some("test-key".to_string()),
443            model: "claude-sonnet-4-20250514".to_string(),
444            messages: vec![StreamMessage {
445                role: "user".to_string(),
446                content: "Think about this deeply".to_string(),
447            }],
448            tools: vec![],
449            thinking_budget: Some(10000),
450        };
451        assert_eq!(req.thinking_budget, Some(10000));
452    }
453}