1use anyhow::Result;
2use futures_util::StreamExt;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashSet;
7use std::pin::Pin;
8use std::time::Duration;
9use tokio_stream::Stream;
10use tracing::{debug, info};
11
12use crate::ai::model::Model;
13use crate::ai::{error::AiError, provider::AiProvider, types::*};
14
15#[derive(Clone)]
16pub struct OpenRouterProvider {
17 client: Client,
18 api_key: String,
19 base_url: String,
20}
21
22impl OpenRouterProvider {
23 pub fn new(api_key: String) -> Self {
24 let client = Client::builder()
25 .timeout(Duration::from_secs(300))
26 .build()
27 .expect("Failed to create HTTP client");
28
29 Self {
30 client,
31 api_key,
32 base_url: "https://openrouter.ai/api/v1".to_string(),
33 }
34 }
35
36 fn get_openrouter_model_id(&self, model: &Model) -> Result<String, AiError> {
37 let model_id = match model {
38 Model::ClaudeSonnet45 => "anthropic/claude-sonnet-4.5",
39 Model::ClaudeOpus46 => "anthropic/claude-opus-4.6",
40 Model::ClaudeOpus45 => "anthropic/claude-opus-4.5",
41 Model::ClaudeHaiku45 => "anthropic/claude-haiku-4.5",
42
43 Model::Gemini3ProPreview => "google/gemini-3-pro-preview",
44 Model::Gemini3FlashPreview => "google/gemini-3-flash-preview",
45
46 Model::Gpt52 => "openai/gpt-5.2",
47 Model::Gpt51CodexMax => "openai/gpt-5.1-codex-max",
48 Model::GptOss120b => "openai/gpt-oss-120b",
49
50 Model::KimiK25 => "moonshotai/kimi-k2.5",
51 Model::GLM47 => "z-ai/glm-4.7",
52 Model::MinimaxM21 => "minimax/minimax-m2.1",
53
54 Model::Grok41Fast => "x-ai/grok-4.1-fast",
55 Model::GrokCodeFast1 => "x-ai/grok-code-fast-1",
56
57 Model::Qwen3Coder => "qwen/qwen3-coder",
58 Model::OpenRouterAuto => "openrouter/auto",
59 _ => {
60 return Err(AiError::Terminal(anyhow::anyhow!(
61 "Model {} is not supported in OpenRouter",
62 model.name()
63 )));
64 }
65 };
66 Ok(model_id.to_string())
67 }
68
69 fn convert_to_openrouter_messages(
70 &self,
71 messages: &[Message],
72 system_prompt: &str,
73 model: Model,
74 ) -> Result<Vec<OpenRouterMessage>, AiError> {
75 let mut openrouter_messages = Vec::new();
76
77 if !system_prompt.trim().is_empty() {
78 let cache_control = if model.supports_prompt_caching() {
79 Some(CacheControl::ephemeral())
80 } else {
81 None
82 };
83
84 let content = MessageContent::Array(vec![ContentPart::Text {
85 text: system_prompt.to_string(),
86 cache_control,
87 }]);
88
89 openrouter_messages.push(OpenRouterMessage {
90 role: "system".to_string(),
91 content: Some(content),
92 name: None,
93 tool_call_id: None,
94 tool_calls: None,
95 reasoning_details: None,
96 });
97 }
98
99 for msg in messages.iter() {
100 openrouter_messages.extend(message_to_openrouter(msg)?);
101 }
102
103 if model.supports_prompt_caching() {
104 let user_indices: Vec<usize> = openrouter_messages
105 .iter()
106 .enumerate()
107 .filter(|(_, m)| m.role == "user")
108 .map(|(i, _)| i)
109 .collect();
110
111 for &idx in user_indices.iter().rev().take(2) {
112 apply_cache_control_to_message(&mut openrouter_messages, idx);
113 }
114 }
115
116 Ok(openrouter_messages)
117 }
118}
119
120#[async_trait::async_trait]
121impl AiProvider for OpenRouterProvider {
122 fn name(&self) -> &'static str {
123 "OpenRouter"
124 }
125
126 fn supported_models(&self) -> HashSet<Model> {
127 HashSet::from([
128 Model::ClaudeSonnet45,
129 Model::ClaudeOpus46,
130 Model::ClaudeOpus45,
131 Model::ClaudeHaiku45,
132 Model::Gemini3ProPreview,
133 Model::Gemini3FlashPreview,
134 Model::Gpt52,
135 Model::Gpt51CodexMax,
136 Model::GptOss120b,
137 Model::GLM47,
138 Model::MinimaxM21,
139 Model::Grok41Fast,
140 Model::GrokCodeFast1,
141 Model::KimiK25,
142 Model::Qwen3Coder,
143 Model::OpenRouterAuto,
144 ])
145 }
146
147 async fn converse(
148 &self,
149 request: ConversationRequest,
150 ) -> Result<ConversationResponse, AiError> {
151 let model_id = self.get_openrouter_model_id(&request.model.model)?;
152 let messages = self.convert_to_openrouter_messages(
153 &request.messages,
154 &request.system_prompt,
155 request.model.model,
156 )?;
157
158 debug!(?model_id, "Using OpenRouter API");
159
160 let openrouter_request = OpenRouterRequest {
161 model: model_id,
162 messages,
163 max_tokens: request.model.max_tokens,
164 temperature: request.model.temperature,
165 top_p: request.model.top_p,
166 stop: if !request.stop_sequences.is_empty() {
167 Some(request.stop_sequences.clone())
168 } else {
169 None
170 },
171 stream: Some(false),
172 tools: if !request.tools.is_empty() {
173 Some(convert_tools_to_openrouter(
174 &request.tools,
175 request.model.model,
176 ))
177 } else {
178 None
179 },
180 tool_choice: if !request.tools.is_empty() {
181 Some(ToolChoice::Simple("auto".to_string()))
182 } else {
183 None
184 },
185 reasoning: match request.model.reasoning_budget {
186 ReasoningBudget::Off => None,
187 _ => Some(ReasoningConfig {
188 effort: Some(match request.model.reasoning_budget {
189 ReasoningBudget::Low => ReasoningEffort::Low,
190 ReasoningBudget::Medium => ReasoningEffort::Medium,
191 ReasoningBudget::High => ReasoningEffort::High,
192 ReasoningBudget::Max => ReasoningEffort::XHigh,
193 ReasoningBudget::Off => unreachable!(),
194 }),
195 }),
196 },
197 usage: Some(UsageConfig { include: true }),
198 };
199
200 let request_json =
201 serde_json::to_string(&openrouter_request).expect("OpenRouterRequest should serialize");
202 info!(request_json = %request_json, "Full OpenRouter request");
203
204 let response = self
205 .client
206 .post(format!("{}/chat/completions", self.base_url))
207 .header("Authorization", format!("Bearer {}", self.api_key))
208 .header("Content-Type", "application/json")
209 .header("HTTP-Referer", "https://tycode.dev")
210 .header("X-Title", "TyCode")
211 .json(&openrouter_request)
212 .send()
213 .await
214 .map_err(|e| {
215 debug!(?e, "OpenRouter API call failed");
216 AiError::Retryable(anyhow::anyhow!("Network error: {}", e))
217 })?;
218
219 tracing::info!("Response: {response:?}");
220
221 let status = response.status();
222 let response_text = response
223 .text()
224 .await
225 .map_err(|e| AiError::Retryable(anyhow::anyhow!("Failed to read response: {}", e)))?;
226
227 if !status.is_success() {
228 debug!(?status, ?response_text, "OpenRouter API returned error");
229
230 let error_text_lower = response_text.to_lowercase();
231 let is_input_too_long = status.as_u16() == 413
232 || ["too long"]
233 .iter()
234 .any(|keyword| error_text_lower.contains(keyword));
235
236 if is_input_too_long {
237 return Err(AiError::InputTooLong(anyhow::anyhow!(
238 "OpenRouter API error {}: {}",
239 status,
240 response_text
241 )));
242 }
243
244 let is_transient = status.as_u16() == 429
245 && (error_text_lower.contains("provider returned error")
246 || error_text_lower.contains("rate-limited upstream"));
247
248 if is_transient {
249 return Err(AiError::Transient(anyhow::anyhow!(
250 "OpenRouter API error {}: {}",
251 status,
252 response_text
253 )));
254 }
255
256 return Err(AiError::Terminal(anyhow::anyhow!(
257 "OpenRouter API error {}: {}",
258 status,
259 response_text
260 )));
261 }
262
263 let openrouter_response: OpenRouterResponse = serde_json::from_str(&response_text)
264 .map_err(|e| {
265 AiError::Terminal(anyhow::anyhow!(
266 "Failed to parse OpenRouter response: {} - Response: {}",
267 e,
268 response_text
269 ))
270 })?;
271
272 let choice = openrouter_response
273 .choices
274 .into_iter()
275 .next()
276 .ok_or_else(|| AiError::Terminal(anyhow::anyhow!("No choices in response")))?;
277
278 let usage = if let Some(usage) = openrouter_response.usage {
279 TokenUsage {
280 input_tokens: usage.prompt_tokens,
281 output_tokens: usage.completion_tokens,
282 total_tokens: usage.total_tokens,
283 cached_prompt_tokens: usage.prompt_details.map(|d| d.cached_tokens),
284 reasoning_tokens: usage.completion_details.map(|d| d.reasoning_tokens),
285 cache_creation_input_tokens: None,
286 }
287 } else {
288 TokenUsage::empty()
289 };
290
291 let stop_reason = match choice.finish_reason.as_deref() {
292 Some("stop") => StopReason::EndTurn,
293 Some("length") => StopReason::MaxTokens,
294 Some("tool_calls") => StopReason::ToolUse,
295 Some("content_filter") => StopReason::EndTurn,
296 Some("error") => StopReason::EndTurn,
297 _ => StopReason::EndTurn,
298 };
299
300 let content = extract_content_from_response(&choice.message)?;
301
302 Ok(ConversationResponse {
303 content,
304 usage,
305 stop_reason,
306 })
307 }
308
309 async fn converse_stream(
310 &self,
311 request: ConversationRequest,
312 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, AiError>> + Send>>, AiError> {
313 let model_id = self.get_openrouter_model_id(&request.model.model)?;
314 let messages = self.convert_to_openrouter_messages(
315 &request.messages,
316 &request.system_prompt,
317 request.model.model,
318 )?;
319
320 let openrouter_request = OpenRouterRequest {
321 model: model_id,
322 messages,
323 max_tokens: request.model.max_tokens,
324 temperature: request.model.temperature,
325 top_p: request.model.top_p,
326 stop: if !request.stop_sequences.is_empty() {
327 Some(request.stop_sequences.clone())
328 } else {
329 None
330 },
331 stream: Some(true),
332 tools: if !request.tools.is_empty() {
333 Some(convert_tools_to_openrouter(
334 &request.tools,
335 request.model.model,
336 ))
337 } else {
338 None
339 },
340 tool_choice: if !request.tools.is_empty() {
341 Some(ToolChoice::Simple("auto".to_string()))
342 } else {
343 None
344 },
345 reasoning: match request.model.reasoning_budget {
346 ReasoningBudget::Off => None,
347 _ => Some(ReasoningConfig {
348 effort: Some(match request.model.reasoning_budget {
349 ReasoningBudget::Low => ReasoningEffort::Low,
350 ReasoningBudget::Medium => ReasoningEffort::Medium,
351 ReasoningBudget::High => ReasoningEffort::High,
352 ReasoningBudget::Max => ReasoningEffort::XHigh,
353 ReasoningBudget::Off => unreachable!(),
354 }),
355 }),
356 },
357 usage: Some(UsageConfig { include: true }),
358 };
359
360 let response = self
361 .client
362 .post(format!("{}/chat/completions", self.base_url))
363 .header("Authorization", format!("Bearer {}", self.api_key))
364 .header("Content-Type", "application/json")
365 .header("HTTP-Referer", "https://tycode.dev")
366 .header("X-Title", "TyCode")
367 .json(&openrouter_request)
368 .send()
369 .await
370 .map_err(|e| AiError::Retryable(anyhow::anyhow!("Network error: {}", e)))?;
371
372 let status = response.status();
373 if !status.is_success() {
374 let response_text = response.text().await.map_err(|e| {
375 AiError::Retryable(anyhow::anyhow!("Failed to read response: {}", e))
376 })?;
377 return Err(AiError::Terminal(anyhow::anyhow!(
378 "OpenRouter API error {}: {}",
379 status,
380 response_text
381 )));
382 }
383
384 let byte_stream = response.bytes_stream();
385
386 let stream = async_stream::stream! {
387 let mut state = OpenRouterStreamAccumulator::default();
388 let mut line_buffer = String::new();
389
390 futures_util::pin_mut!(byte_stream);
391
392 while let Some(chunk_result) = byte_stream.next().await {
393 let Ok(chunk) = chunk_result else {
394 yield Err(AiError::Retryable(anyhow::anyhow!("Stream read error")));
395 return;
396 };
397 line_buffer.push_str(&String::from_utf8_lossy(&chunk));
398 for event in state.process_line_buffer(&mut line_buffer) {
399 yield Ok(event);
400 }
401 }
402
403 yield Ok(StreamEvent::MessageComplete { response: state.into_response() });
404 };
405
406 Ok(Box::pin(stream))
407 }
408
409 fn get_cost(&self, model: &Model) -> Cost {
410 match model {
411 Model::ClaudeOpus46 => Cost::new(5.0, 25.0, 6.25, 0.5),
412 Model::ClaudeOpus45 => Cost::new(5.0, 25.0, 6.25, 0.5),
413 Model::ClaudeSonnet45 => Cost::new(3.0, 15.0, 3.75, 0.3),
414 Model::ClaudeHaiku45 => Cost::new(1.0, 5.0, 1.25, 0.1),
415 Model::Gemini3ProPreview => Cost::new(2.0, 12.0, 0.0, 0.0),
416 Model::Gemini3FlashPreview => Cost::new(0.5, 3.0, 0.0, 0.0),
417 Model::Gpt52 => Cost::new(1.75, 14.0, 0.0, 0.0),
418 Model::Gpt51CodexMax => Cost::new(1.25, 10.0, 0.0, 0.0),
419 Model::GptOss120b => Cost::new(0.1, 0.5, 0.0, 0.0),
420 Model::GLM47 => Cost::new(0.40, 1.50, 0.0, 0.0),
421 Model::MinimaxM21 => Cost::new(0.30, 1.20, 0.0, 0.0),
422 Model::Grok41Fast => Cost::new(0.20, 0.50, 0.0, 0.0),
423 Model::GrokCodeFast1 => Cost::new(0.2, 1.5, 0.0, 0.0),
424 Model::KimiK25 => Cost::new(0.50, 2.80, 0.0, 0.0),
425 Model::Qwen3Coder => Cost::new(0.35, 1.5, 0.0, 0.0),
426 Model::OpenRouterAuto => Cost::new(3.0, 15.0, 3.75, 0.3),
427 _ => Cost::new(0.0, 0.0, 0.0, 0.0),
428 }
429 }
430}
431
432#[derive(Debug, Serialize, Deserialize, Clone)]
433struct CacheControl {
434 r#type: String,
435}
436
437impl CacheControl {
438 fn ephemeral() -> Self {
439 Self {
440 r#type: "ephemeral".to_string(),
441 }
442 }
443}
444
445#[derive(Debug, Serialize, Deserialize, Clone)]
446#[serde(tag = "type")]
447enum ContentPart {
448 #[serde(rename = "text")]
449 Text {
450 text: String,
451 #[serde(skip_serializing_if = "Option::is_none")]
452 cache_control: Option<CacheControl>,
453 },
454 #[serde(rename = "image_url")]
455 ImageUrl { image_url: ImageUrlData },
456}
457
458#[derive(Debug, Serialize, Deserialize, Clone)]
459struct ImageUrlData {
460 url: String,
461}
462
463#[derive(Debug, Serialize, Deserialize, Clone)]
464#[serde(untagged)]
465enum MessageContent {
466 String(String),
467 Array(Vec<ContentPart>),
468}
469
470#[derive(Debug, Default, Serialize, Deserialize)]
471struct OpenRouterRequest {
472 pub model: String,
473 pub messages: Vec<OpenRouterMessage>,
474 #[serde(skip_serializing_if = "Option::is_none")]
475 pub max_tokens: Option<u32>,
476 #[serde(skip_serializing_if = "Option::is_none")]
477 pub temperature: Option<f32>,
478 #[serde(skip_serializing_if = "Option::is_none")]
479 pub top_p: Option<f32>,
480 #[serde(skip_serializing_if = "Option::is_none")]
481 pub stop: Option<Vec<String>>,
482 #[serde(skip_serializing_if = "Option::is_none")]
483 pub stream: Option<bool>,
484 #[serde(skip_serializing_if = "Option::is_none")]
485 pub tools: Option<Vec<OpenRouterTool>>,
486 #[serde(skip_serializing_if = "Option::is_none")]
487 pub tool_choice: Option<ToolChoice>,
488 #[serde(skip_serializing_if = "Option::is_none")]
489 pub reasoning: Option<ReasoningConfig>,
490 #[serde(skip_serializing_if = "Option::is_none")]
491 pub usage: Option<UsageConfig>,
492}
493
494#[derive(Debug, Default, Serialize, Deserialize)]
495struct OpenRouterMessage {
496 pub role: String,
497 #[serde(skip_serializing_if = "Option::is_none")]
498 pub content: Option<MessageContent>,
499 #[serde(skip_serializing_if = "Option::is_none")]
500 pub name: Option<String>,
501 #[serde(skip_serializing_if = "Option::is_none")]
502 pub tool_call_id: Option<String>,
503 #[serde(skip_serializing_if = "Option::is_none")]
504 pub tool_calls: Option<Vec<OpenRouterToolCall>>,
505 #[serde(skip_serializing_if = "Option::is_none")]
506 pub reasoning_details: Option<Vec<ReasoningDetail>>,
507}
508
509#[derive(Debug, Default, Serialize, Deserialize)]
510struct OpenRouterTool {
511 pub r#type: String,
512 pub function: FunctionObject,
513 #[serde(skip_serializing_if = "Option::is_none")]
514 pub cache_control: Option<CacheControl>,
515}
516
517#[derive(Debug, Default, Serialize, Deserialize)]
518struct FunctionObject {
519 pub name: String,
520 #[serde(skip_serializing_if = "Option::is_none")]
521 pub description: Option<String>,
522 #[serde(skip_serializing_if = "Option::is_none")]
523 pub parameters: Option<Value>,
524 #[serde(skip_serializing_if = "Option::is_none")]
525 pub strict: Option<bool>,
526}
527
528#[derive(Debug, Serialize, Deserialize)]
529#[serde(untagged)]
530enum ToolChoice {
531 Simple(String),
532 Function(ToolChoiceFunction),
533}
534
535#[derive(Debug, Serialize, Deserialize)]
536struct ToolChoiceFunction {
537 pub r#type: String,
538 pub function: FunctionName,
539}
540
541#[derive(Debug, Serialize, Deserialize)]
542struct FunctionName {
543 pub name: String,
544}
545
546#[derive(Debug, Serialize, Deserialize)]
547struct ReasoningConfig {
548 #[serde(skip_serializing_if = "Option::is_none")]
549 pub effort: Option<ReasoningEffort>,
550}
551
552#[derive(Debug, Serialize, Deserialize)]
553#[serde(rename_all = "lowercase")]
554enum ReasoningEffort {
555 Low,
556 Medium,
557 High,
558 XHigh,
559}
560
561#[derive(Debug, Serialize, Deserialize)]
562struct UsageConfig {
563 pub include: bool,
564}
565
566#[derive(Debug, Serialize, Deserialize)]
567struct OpenRouterResponse {
568 pub id: String,
569 pub choices: Vec<OpenRouterChoice>,
570 pub created: u64,
571 pub model: String,
572 pub object: String,
573 #[serde(skip_serializing_if = "Option::is_none")]
574 pub system_fingerprint: Option<String>,
575 #[serde(skip_serializing_if = "Option::is_none")]
576 pub usage: Option<OpenRouterUsage>,
577}
578
579#[derive(Debug, Serialize, Deserialize)]
580struct OpenRouterChoice {
581 #[serde(skip_serializing_if = "Option::is_none")]
582 pub finish_reason: Option<String>,
583 #[serde(skip_serializing_if = "Option::is_none")]
584 pub native_finish_reason: Option<String>,
585 pub message: OpenRouterMessageResponse,
586}
587
588#[derive(Debug, Serialize, Deserialize)]
589struct OpenRouterMessageResponse {
590 pub role: String,
591 #[serde(skip_serializing_if = "Option::is_none")]
592 pub content: Option<String>,
593 #[serde(skip_serializing_if = "Option::is_none")]
594 pub tool_calls: Option<Vec<OpenRouterToolCall>>,
595 #[serde(skip_serializing_if = "Option::is_none")]
596 pub reasoning_details: Option<Vec<ReasoningDetail>>,
597}
598
599#[derive(Debug, Serialize, Deserialize)]
600struct OpenRouterToolCall {
601 pub id: String,
602 pub r#type: String,
603 pub function: FunctionCall,
604}
605
606#[derive(Debug, Serialize, Deserialize)]
607struct FunctionCall {
608 pub name: String,
609 pub arguments: String,
610}
611
612#[derive(Debug, Serialize, Deserialize)]
613struct OpenRouterUsage {
614 #[serde(rename = "prompt_tokens")]
615 pub prompt_tokens: u32,
616 #[serde(rename = "completion_tokens")]
617 pub completion_tokens: u32,
618 #[serde(rename = "total_tokens")]
619 pub total_tokens: u32,
620 #[serde(rename = "prompt_tokens_details")]
621 pub prompt_details: Option<OpenRouterPromptDetails>,
622 #[serde(rename = "completion_tokens_details")]
623 pub completion_details: Option<OpenRouterCompletionDetails>,
624}
625
626#[derive(Debug, Serialize, Deserialize)]
627struct OpenRouterPromptDetails {
628 #[serde(rename = "cached_tokens")]
629 pub cached_tokens: u32,
630}
631
632#[derive(Debug, Serialize, Deserialize)]
633struct OpenRouterCompletionDetails {
634 #[serde(rename = "reasoning_tokens")]
635 pub reasoning_tokens: u32,
636}
637
638#[derive(Debug, Deserialize)]
639struct StreamChunk {
640 #[serde(default)]
641 choices: Vec<StreamChoice>,
642 #[serde(default)]
643 usage: Option<OpenRouterUsage>,
644}
645
646#[derive(Debug, Deserialize)]
647struct StreamChoice {
648 #[serde(default)]
649 delta: Option<StreamDelta>,
650 #[serde(default)]
651 finish_reason: Option<String>,
652}
653
654#[derive(Debug, Deserialize)]
655struct StreamDelta {
656 #[serde(default)]
657 content: Option<String>,
658 #[serde(default)]
659 reasoning: Option<String>,
660 #[serde(default)]
661 tool_calls: Option<Vec<StreamToolCallDelta>>,
662}
663
664#[derive(Debug, Deserialize)]
665struct StreamToolCallDelta {
666 index: usize,
667 #[serde(default)]
668 id: Option<String>,
669 #[serde(default)]
670 function: Option<StreamFunctionDelta>,
671}
672
673#[derive(Debug, Deserialize)]
674struct StreamFunctionDelta {
675 #[serde(default)]
676 name: Option<String>,
677 #[serde(default)]
678 arguments: Option<String>,
679}
680
681#[derive(Default)]
682struct OpenRouterStreamAccumulator {
683 accumulated_text: String,
684 accumulated_reasoning: String,
685 tool_calls: Vec<(String, String, String)>,
686 finish_reason: Option<String>,
687 usage: Option<OpenRouterUsage>,
688}
689
690impl OpenRouterStreamAccumulator {
691 fn process_line_buffer(&mut self, line_buffer: &mut String) -> Vec<StreamEvent> {
692 let mut events = Vec::new();
693 while let Some(newline_pos) = line_buffer.find('\n') {
694 let line = line_buffer[..newline_pos].trim().to_string();
695 line_buffer.drain(..=newline_pos);
696 events.extend(self.process_sse_line(&line));
697 }
698 events
699 }
700
701 fn process_sse_line(&mut self, line: &str) -> Vec<StreamEvent> {
702 if line.is_empty() || line.starts_with(':') {
703 return vec![];
704 }
705
706 let data = match line.strip_prefix("data: ") {
707 Some(d) => d.trim(),
708 None => return vec![],
709 };
710
711 if data == "[DONE]" {
712 return vec![];
713 }
714
715 let chunk: StreamChunk = match serde_json::from_str(data) {
716 Ok(c) => c,
717 Err(e) => {
718 tracing::warn!("Failed to parse SSE chunk: {e:?}");
719 return vec![];
720 }
721 };
722
723 if let Some(u) = chunk.usage {
724 self.usage = Some(u);
725 }
726
727 let Some(choice) = chunk.choices.into_iter().next() else {
728 return vec![];
729 };
730
731 if let Some(reason) = choice.finish_reason {
732 self.finish_reason = Some(reason);
733 }
734
735 let Some(delta) = choice.delta else {
736 return vec![];
737 };
738
739 self.process_delta(delta)
740 }
741
742 fn process_delta(&mut self, delta: StreamDelta) -> Vec<StreamEvent> {
743 let mut events = Vec::new();
744
745 if let Some(content) = delta.content {
746 if !content.is_empty() {
747 self.accumulated_text.push_str(&content);
748 events.push(StreamEvent::TextDelta { text: content });
749 }
750 }
751
752 if let Some(reasoning) = delta.reasoning {
753 if !reasoning.is_empty() {
754 self.accumulated_reasoning.push_str(&reasoning);
755 events.push(StreamEvent::ReasoningDelta { text: reasoning });
756 }
757 }
758
759 if let Some(tc_deltas) = delta.tool_calls {
760 self.accumulate_tool_calls(tc_deltas);
761 }
762
763 events
764 }
765
766 fn accumulate_tool_calls(&mut self, tc_deltas: Vec<StreamToolCallDelta>) {
767 for tc in tc_deltas {
768 let idx = tc.index;
769 while self.tool_calls.len() <= idx {
770 self.tool_calls
771 .push((String::new(), String::new(), String::new()));
772 }
773 if let Some(id) = tc.id {
774 self.tool_calls[idx].0 = id;
775 }
776 let Some(func) = tc.function else { continue };
777 if let Some(name) = func.name {
778 self.tool_calls[idx].1 = name;
779 }
780 if let Some(args) = func.arguments {
781 self.tool_calls[idx].2.push_str(&args);
782 }
783 }
784 }
785
786 fn into_response(self) -> ConversationResponse {
787 let mut content_blocks = Vec::new();
788
789 if !self.accumulated_text.trim().is_empty() {
790 content_blocks.push(ContentBlock::Text(self.accumulated_text.trim().to_string()));
791 }
792
793 if !self.accumulated_reasoning.trim().is_empty() {
794 content_blocks.push(ContentBlock::ReasoningContent(ReasoningData {
795 text: self.accumulated_reasoning.trim().to_string(),
796 signature: None,
797 blob: None,
798 raw_json: None,
799 }));
800 }
801
802 for (id, name, args_str) in &self.tool_calls {
803 if !name.is_empty() {
804 let arguments = serde_json::from_str(args_str).unwrap_or(Value::Null);
805 content_blocks.push(ContentBlock::ToolUse(ToolUseData {
806 id: id.clone(),
807 name: name.clone(),
808 arguments,
809 }));
810 }
811 }
812
813 let token_usage = match self.usage {
814 Some(u) => TokenUsage {
815 input_tokens: u.prompt_tokens,
816 output_tokens: u.completion_tokens,
817 total_tokens: u.total_tokens,
818 cached_prompt_tokens: u.prompt_details.map(|d| d.cached_tokens),
819 reasoning_tokens: u.completion_details.map(|d| d.reasoning_tokens),
820 cache_creation_input_tokens: None,
821 },
822 None => TokenUsage::empty(),
823 };
824
825 let stop_reason = match self.finish_reason.as_deref() {
826 Some("stop") => StopReason::EndTurn,
827 Some("length") => StopReason::MaxTokens,
828 Some("tool_calls") => StopReason::ToolUse,
829 _ => StopReason::EndTurn,
830 };
831
832 ConversationResponse {
833 content: Content::from(content_blocks),
834 usage: token_usage,
835 stop_reason,
836 }
837 }
838}
839
840#[derive(Debug, Serialize, Deserialize)]
841#[serde(tag = "type")]
842enum ReasoningDetail {
843 #[serde(rename = "reasoning.summary")]
844 Summary {
845 summary: String,
846 id: Option<String>,
847 format: Option<String>,
848 index: Option<u32>,
849 },
850 #[serde(rename = "reasoning.text")]
851 Text {
852 text: String,
853 signature: Option<String>,
854 id: Option<String>,
855 format: Option<String>,
856 index: Option<u32>,
857 },
858 #[serde(rename = "reasoning.encrypted")]
859 Encrypted {
860 data: String,
861 id: Option<String>,
862 format: Option<String>,
863 index: Option<u32>,
864 },
865}
866
867fn apply_cache_control_to_message(messages: &mut [OpenRouterMessage], idx: usize) {
868 let Some(msg) = messages.get_mut(idx) else {
869 return;
870 };
871 let Some(MessageContent::Array(parts)) = &mut msg.content else {
872 return;
873 };
874 let Some(last_text) = parts
875 .iter_mut()
876 .rev()
877 .find(|p| matches!(p, ContentPart::Text { .. }))
878 else {
879 return;
880 };
881 if let ContentPart::Text { cache_control, .. } = last_text {
882 *cache_control = Some(CacheControl::ephemeral());
883 }
884}
885
886fn create_message_content(content: String) -> MessageContent {
887 MessageContent::Array(vec![ContentPart::Text {
888 text: content,
889 cache_control: None,
890 }])
891}
892
893fn create_tool_result_message(tool_result: &ToolResultData) -> OpenRouterMessage {
894 OpenRouterMessage {
895 role: "tool".to_string(),
896 content: Some(MessageContent::Array(vec![ContentPart::Text {
897 text: tool_result.content.trim().to_string(),
898 cache_control: None,
899 }])),
900 name: None,
901 tool_call_id: Some(tool_result.tool_use_id.clone()),
902 tool_calls: None,
903 reasoning_details: None,
904 }
905}
906
907fn process_user_message(message: &Message) -> Result<Vec<OpenRouterMessage>, AiError> {
908 let mut results = vec![];
909
910 for tool_result in message.content.tool_results() {
911 results.push(create_tool_result_message(tool_result));
912 }
913
914 let text = extract_text_content(&message.content);
915 let images: Vec<&ImageData> = message.content.images();
916
917 if text.is_empty() && images.is_empty() {
918 return Ok(results);
919 }
920
921 let mut parts = Vec::new();
922 if !text.is_empty() {
923 parts.push(ContentPart::Text {
924 text,
925 cache_control: None,
926 });
927 }
928 for img in images {
929 parts.push(ContentPart::ImageUrl {
930 image_url: ImageUrlData {
931 url: format!("data:{};base64,{}", img.media_type, img.data),
932 },
933 });
934 }
935
936 results.push(OpenRouterMessage {
937 role: "user".to_string(),
938 content: Some(MessageContent::Array(parts)),
939 name: None,
940 tool_call_id: None,
941 tool_calls: None,
942 reasoning_details: None,
943 });
944 Ok(results)
945}
946
947fn message_to_openrouter(message: &Message) -> Result<Vec<OpenRouterMessage>, AiError> {
948 match message.role {
949 MessageRole::User => process_user_message(message),
950 MessageRole::Assistant => process_assistant_message(message),
951 }
952}
953
954fn convert_tool_use_to_openrouter(tool_use: &ToolUseData) -> Result<OpenRouterToolCall, AiError> {
955 let arguments = serde_json::to_string(&tool_use.arguments).map_err(|e| {
956 AiError::Terminal(anyhow::anyhow!("Failed to serialize tool arguments: {}", e))
957 })?;
958
959 Ok(OpenRouterToolCall {
960 id: tool_use.id.clone(),
961 r#type: "function".to_string(),
962 function: FunctionCall {
963 name: tool_use.name.clone(),
964 arguments,
965 },
966 })
967}
968
969fn process_assistant_message(message: &Message) -> Result<Vec<OpenRouterMessage>, AiError> {
970 let mut content_parts = Vec::new();
971 let mut reasoning_details: Option<Vec<ReasoningDetail>> = None;
972 let mut tool_calls = Vec::new();
973
974 for block in message.content.blocks() {
975 match block {
976 ContentBlock::Text(text) => {
977 if !text.trim().is_empty() {
978 content_parts.push(text.trim().to_string());
979 }
980 }
981 ContentBlock::ReasoningContent(reason) => {
982 if let Some(raw_json) = &reason.raw_json {
983 reasoning_details = serde_json::from_value(raw_json.clone()).ok();
984 } else {
985 tracing::warn!(?reason, "No raw json found in reasoning. This count happen if switching providers mid conversation");
986 }
987 }
988 ContentBlock::ToolUse(tool_use) => {
989 tool_calls.push(convert_tool_use_to_openrouter(tool_use)?);
990 }
991 ContentBlock::ToolResult(_) | ContentBlock::Image(_) => continue,
992 }
993 }
994
995 let content_text = if content_parts.is_empty() {
996 "<no response>".to_string()
997 } else {
998 content_parts.join("\n")
999 };
1000 let content = create_message_content(content_text);
1001
1002 let message = OpenRouterMessage {
1003 role: "assistant".to_string(),
1004 content: Some(content),
1005 name: None,
1006 tool_call_id: None,
1007 tool_calls: if tool_calls.is_empty() {
1008 None
1009 } else {
1010 Some(tool_calls)
1011 },
1012 reasoning_details,
1013 };
1014
1015 Ok(vec![message])
1016}
1017
1018fn extract_text_content(content: &Content) -> String {
1019 let mut text_parts = Vec::new();
1020
1021 for block in content.blocks() {
1022 match block {
1023 ContentBlock::Text(text) => {
1024 if !text.trim().is_empty() {
1025 text_parts.push(text.trim().to_string());
1026 }
1027 }
1028 ContentBlock::ReasoningContent(reasoning) => {
1029 if !reasoning.text.trim().is_empty() {
1030 text_parts.push(format!("[Reasoning: {}]", reasoning.text.trim()));
1031 }
1032 }
1033 ContentBlock::ToolUse(_) | ContentBlock::ToolResult(_) | ContentBlock::Image(_) => {
1034 continue;
1035 }
1036 }
1037 }
1038
1039 text_parts.join("\n")
1040}
1041
1042fn convert_tools_to_openrouter(tools: &[ToolDefinition], model: Model) -> Vec<OpenRouterTool> {
1043 let mut result: Vec<OpenRouterTool> = tools
1044 .iter()
1045 .map(|tool| OpenRouterTool {
1046 r#type: "function".to_string(),
1047 function: FunctionObject {
1048 name: tool.name.clone(),
1049 description: Some(tool.description.clone()),
1050 parameters: Some(tool.input_schema.clone()),
1051 strict: Some(true),
1052 },
1053 cache_control: None,
1054 })
1055 .collect();
1056
1057 if model.supports_prompt_caching() {
1058 if let Some(last) = result.last_mut() {
1059 last.cache_control = Some(CacheControl::ephemeral());
1060 }
1061 }
1062
1063 result
1064}
1065
1066fn extract_content_from_response(message: &OpenRouterMessageResponse) -> Result<Content, AiError> {
1067 let mut content_blocks = Vec::new();
1068
1069 if let Some(content) = &message.content {
1070 if !content.trim().is_empty() {
1071 content_blocks.push(ContentBlock::Text(content.trim().to_string()));
1072 }
1073 }
1074
1075 if let Some(tool_calls) = &message.tool_calls {
1076 for tool_call in tool_calls {
1077 if let Ok(arguments) = serde_json::from_str::<Value>(&tool_call.function.arguments) {
1078 let tool_use_data = ToolUseData {
1079 id: tool_call.id.clone(),
1080 name: tool_call.function.name.clone(),
1081 arguments,
1082 };
1083 content_blocks.push(ContentBlock::ToolUse(tool_use_data));
1084 } else {
1085 return Err(AiError::Terminal(anyhow::anyhow!(
1086 "Failed to parse tool call arguments: {}",
1087 tool_call.function.arguments
1088 )));
1089 }
1090 }
1091 }
1092
1093 if let Some(reasoning_details) = &message.reasoning_details {
1094 let raw_json = serde_json::to_value(&reasoning_details)?;
1095
1096 let mut text_parts = Vec::new();
1097 let mut signature: Option<String> = None;
1098
1099 for detail in reasoning_details {
1100 match detail {
1101 ReasoningDetail::Text {
1102 text,
1103 signature: sig,
1104 ..
1105 } => {
1106 text_parts.push(text.clone());
1107 if signature.is_none() {
1108 signature = sig.clone();
1109 }
1110 }
1111 ReasoningDetail::Summary { summary, .. } => {
1112 text_parts.push(summary.clone());
1113 }
1114 ReasoningDetail::Encrypted { .. } => {}
1115 }
1116 }
1117
1118 if !text_parts.is_empty() {
1119 content_blocks.push(ContentBlock::ReasoningContent(ReasoningData {
1120 text: text_parts.join("\n"),
1121 signature,
1122 blob: None,
1123 raw_json: Some(raw_json),
1124 }));
1125 }
1126 }
1127
1128 Ok(Content::from(content_blocks))
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133 use super::*;
1134 use crate::ai::tests::{
1135 test_hello_world, test_reasoning_conversation, test_reasoning_with_tools, test_tool_usage,
1136 };
1137
1138 async fn create_openrouter_provider() -> anyhow::Result<OpenRouterProvider> {
1139 let api_key = "";
1140 Ok(OpenRouterProvider::new(api_key.to_string()))
1141 }
1142
1143 #[tokio::test]
1144 #[ignore = "requires OpenRouter API key"]
1145 async fn test_openrouter_hello_world() {
1146 let provider = match create_openrouter_provider().await {
1147 Ok(provider) => provider,
1148 Err(e) => {
1149 debug!(?e, "Failed to create OpenRouter provider");
1150 panic!("Failed to create OpenRouter provider: {e:?}");
1151 }
1152 };
1153
1154 if let Err(e) = test_hello_world(provider).await {
1155 debug!(?e, "OpenRouter hello world test failed");
1156 panic!("OpenRouter hello world test failed: {e:?}");
1157 }
1158 }
1159
1160 #[tokio::test]
1161 #[ignore = "requires OpenRouter API key"]
1162 async fn test_openrouter_reasoning_conversation() {
1163 let provider = match create_openrouter_provider().await {
1164 Ok(provider) => provider,
1165 Err(e) => {
1166 debug!(?e, "Failed to create OpenRouter provider");
1167 panic!("Failed to create OpenRouter provider: {e:?}");
1168 }
1169 };
1170
1171 if let Err(e) = test_reasoning_conversation(provider).await {
1172 debug!(?e, "OpenRouter reasoning conversation test failed");
1173 panic!("OpenRouter reasoning conversation test failed: {e:?}");
1174 }
1175 }
1176
1177 #[tokio::test]
1178 #[ignore = "requires OpenRouter API key"]
1179 async fn test_openrouter_tool_usage() {
1180 let provider = match create_openrouter_provider().await {
1181 Ok(provider) => provider,
1182 Err(e) => {
1183 debug!(?e, "Failed to create OpenRouter provider");
1184 panic!("Failed to create OpenRouter provider: {e:?}");
1185 }
1186 };
1187
1188 if let Err(e) = test_tool_usage(provider).await {
1189 debug!(?e, "OpenRouter tool usage test failed");
1190 panic!("OpenRouter tool usage test failed: {e:?}");
1191 }
1192 }
1193
1194 #[tokio::test]
1195 #[ignore = "requires OpenRouter API key"]
1196 async fn test_openrouter_reasoning_with_tools() {
1197 let provider = match create_openrouter_provider().await {
1198 Ok(provider) => provider,
1199 Err(e) => {
1200 debug!(?e, "Failed to create OpenRouter provider");
1201 panic!("Failed to create OpenRouter provider: {e:?}");
1202 }
1203 };
1204
1205 if let Err(e) = test_reasoning_with_tools(provider).await {
1206 debug!(?e, "OpenRouter reasoning with tools test failed");
1207 panic!("OpenRouter reasoning with tools test failed: {e:?}");
1208 }
1209 }
1210}