1use std::collections::{BTreeMap, VecDeque};
2use std::time::Duration;
3
4use serde::Deserialize;
5use serde_json::{json, Value};
6
7use crate::error::ApiError;
8use crate::types::{
9 ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
10 InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
11 MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
12 ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
13};
14
15use super::{Provider, ProviderFuture};
16
17pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
18pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
19pub const DEFAULT_GEMINI_BASE_URL: &str =
20 "https://generativelanguage.googleapis.com/v1beta/openai";
21pub const DEFAULT_OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1";
22const REQUEST_ID_HEADER: &str = "request-id";
23const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
24const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
25const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
26const DEFAULT_MAX_RETRIES: u32 = 2;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct OpenAiCompatConfig {
30 pub provider_name: &'static str,
31 pub api_key_env: &'static str,
32 pub base_url_env: &'static str,
33 pub default_base_url: &'static str,
34}
35
36const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
37const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
38const GEMINI_ENV_VARS: &[&str] = &["GEMINI_API_KEY"];
39const OPENROUTER_ENV_VARS: &[&str] = &["OPENROUTER_API_KEY"];
40
41impl OpenAiCompatConfig {
42 #[must_use]
43 pub const fn xai() -> Self {
44 Self {
45 provider_name: "xAI",
46 api_key_env: "XAI_API_KEY",
47 base_url_env: "XAI_BASE_URL",
48 default_base_url: DEFAULT_XAI_BASE_URL,
49 }
50 }
51
52 #[must_use]
53 pub const fn openai() -> Self {
54 Self {
55 provider_name: "OpenAI",
56 api_key_env: "OPENAI_API_KEY",
57 base_url_env: "OPENAI_BASE_URL",
58 default_base_url: DEFAULT_OPENAI_BASE_URL,
59 }
60 }
61
62 #[must_use]
63 pub const fn gemini() -> Self {
64 Self {
65 provider_name: "Gemini",
66 api_key_env: "GEMINI_API_KEY",
67 base_url_env: "GEMINI_BASE_URL",
68 default_base_url: DEFAULT_GEMINI_BASE_URL,
69 }
70 }
71
72 #[must_use]
73 pub const fn openrouter() -> Self {
74 Self {
75 provider_name: "OpenRouter",
76 api_key_env: "OPENROUTER_API_KEY",
77 base_url_env: "OPENROUTER_BASE_URL",
78 default_base_url: DEFAULT_OPENROUTER_BASE_URL,
79 }
80 }
81
82 #[must_use]
83 pub fn credential_env_vars(self) -> &'static [&'static str] {
84 match self.provider_name {
85 "xAI" => XAI_ENV_VARS,
86 "OpenAI" => OPENAI_ENV_VARS,
87 "Gemini" => GEMINI_ENV_VARS,
88 "OpenRouter" => OPENROUTER_ENV_VARS,
89 _ => &[],
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
95pub struct OpenAiCompatClient {
96 http: reqwest::Client,
97 api_key: String,
98 base_url: String,
99 max_retries: u32,
100 initial_backoff: Duration,
101 max_backoff: Duration,
102}
103
104impl OpenAiCompatClient {
105 #[must_use]
106 pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
107 Self {
108 http: reqwest::Client::new(),
109 api_key: api_key.into(),
110 base_url: read_base_url(config),
111 max_retries: DEFAULT_MAX_RETRIES,
112 initial_backoff: DEFAULT_INITIAL_BACKOFF,
113 max_backoff: DEFAULT_MAX_BACKOFF,
114 }
115 }
116
117 pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
118 let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
119 return Err(ApiError::missing_credentials(
120 config.provider_name,
121 config.credential_env_vars(),
122 ));
123 };
124 Ok(Self::new(api_key, config))
125 }
126
127 #[must_use]
128 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
129 self.base_url = base_url.into();
130 self
131 }
132
133 #[must_use]
134 pub fn with_retry_policy(
135 mut self,
136 max_retries: u32,
137 initial_backoff: Duration,
138 max_backoff: Duration,
139 ) -> Self {
140 self.max_retries = max_retries;
141 self.initial_backoff = initial_backoff;
142 self.max_backoff = max_backoff;
143 self
144 }
145
146 pub async fn send_message(
147 &self,
148 request: &MessageRequest,
149 ) -> Result<MessageResponse, ApiError> {
150 let request = MessageRequest {
151 stream: false,
152 ..request.clone()
153 };
154 let response = self.send_with_retry(&request).await?;
155 let request_id = request_id_from_headers(response.headers());
156 let payload = response.json::<ChatCompletionResponse>().await?;
157 let mut normalized = normalize_response(&request.model, payload)?;
158 if normalized.request_id.is_none() {
159 normalized.request_id = request_id;
160 }
161 Ok(normalized)
162 }
163
164 pub async fn stream_message(
165 &self,
166 request: &MessageRequest,
167 ) -> Result<MessageStream, ApiError> {
168 let response = self
169 .send_with_retry(&request.clone().with_streaming())
170 .await?;
171 Ok(MessageStream {
172 request_id: request_id_from_headers(response.headers()),
173 response,
174 parser: OpenAiSseParser::new(),
175 pending: VecDeque::new(),
176 done: false,
177 state: StreamState::new(request.model.clone()),
178 })
179 }
180
181 async fn send_with_retry(
182 &self,
183 request: &MessageRequest,
184 ) -> Result<reqwest::Response, ApiError> {
185 let mut attempts = 0;
186
187 let last_error = loop {
188 attempts += 1;
189 let retryable_error = match self.send_raw_request(request).await {
190 Ok(response) => match expect_success(response).await {
191 Ok(response) => return Ok(response),
192 Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
193 Err(error) => return Err(error),
194 },
195 Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
196 Err(error) => return Err(error),
197 };
198
199 if attempts > self.max_retries {
200 break retryable_error;
201 }
202
203 tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
204 };
205
206 Err(ApiError::RetriesExhausted {
207 attempts,
208 last_error: Box::new(last_error),
209 })
210 }
211
212 async fn send_raw_request(
213 &self,
214 request: &MessageRequest,
215 ) -> Result<reqwest::Response, ApiError> {
216 let request_url = chat_completions_endpoint(&self.base_url);
217 self.http
218 .post(&request_url)
219 .header("content-type", "application/json")
220 .bearer_auth(&self.api_key)
221 .json(&build_chat_completion_request(request))
222 .send()
223 .await
224 .map_err(ApiError::from)
225 }
226
227 fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
228 let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
229 return Err(ApiError::BackoffOverflow {
230 attempt,
231 base_delay: self.initial_backoff,
232 });
233 };
234 Ok(self
235 .initial_backoff
236 .checked_mul(multiplier)
237 .map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
238 }
239}
240
241impl Provider for OpenAiCompatClient {
242 type Stream = MessageStream;
243
244 fn send_message<'a>(
245 &'a self,
246 request: &'a MessageRequest,
247 ) -> ProviderFuture<'a, MessageResponse> {
248 Box::pin(async move { self.send_message(request).await })
249 }
250
251 fn stream_message<'a>(
252 &'a self,
253 request: &'a MessageRequest,
254 ) -> ProviderFuture<'a, Self::Stream> {
255 Box::pin(async move { self.stream_message(request).await })
256 }
257}
258
259#[derive(Debug)]
260pub struct MessageStream {
261 request_id: Option<String>,
262 response: reqwest::Response,
263 parser: OpenAiSseParser,
264 pending: VecDeque<StreamEvent>,
265 done: bool,
266 state: StreamState,
267}
268
269impl MessageStream {
270 #[must_use]
271 pub fn request_id(&self) -> Option<&str> {
272 self.request_id.as_deref()
273 }
274
275 pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
276 loop {
277 if let Some(event) = self.pending.pop_front() {
278 return Ok(Some(event));
279 }
280
281 if self.done {
282 self.pending.extend(self.state.finish()?);
283 if let Some(event) = self.pending.pop_front() {
284 return Ok(Some(event));
285 }
286 return Ok(None);
287 }
288
289 match self.response.chunk().await? {
290 Some(chunk) => {
291 for parsed in self.parser.push(&chunk)? {
292 self.pending.extend(self.state.ingest_chunk(parsed)?);
293 }
294 }
295 None => {
296 self.done = true;
297 }
298 }
299 }
300 }
301}
302
303#[derive(Debug, Default)]
304struct OpenAiSseParser {
305 buffer: Vec<u8>,
306}
307
308impl OpenAiSseParser {
309 fn new() -> Self {
310 Self::default()
311 }
312
313 fn push(&mut self, chunk: &[u8]) -> Result<Vec<ChatCompletionChunk>, ApiError> {
314 self.buffer.extend_from_slice(chunk);
315 let mut events = Vec::new();
316
317 while let Some(frame) = next_sse_frame(&mut self.buffer) {
318 if let Some(event) = parse_sse_frame(&frame)? {
319 events.push(event);
320 }
321 }
322
323 Ok(events)
324 }
325}
326
327#[allow(clippy::struct_excessive_bools)]
328#[derive(Debug)]
329struct StreamState {
330 model: String,
331 message_started: bool,
332 text_started: bool,
333 text_finished: bool,
334 finished: bool,
335 stop_reason: Option<String>,
336 usage: Option<Usage>,
337 tool_calls: BTreeMap<u32, ToolCallState>,
338}
339
340impl StreamState {
341 fn new(model: String) -> Self {
342 Self {
343 model,
344 message_started: false,
345 text_started: false,
346 text_finished: false,
347 finished: false,
348 stop_reason: None,
349 usage: None,
350 tool_calls: BTreeMap::new(),
351 }
352 }
353
354 fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result<Vec<StreamEvent>, ApiError> {
355 let mut events = Vec::new();
356 if !self.message_started {
357 self.message_started = true;
358 events.push(StreamEvent::MessageStart(MessageStartEvent {
359 message: MessageResponse {
360 id: chunk.id.clone(),
361 kind: "message".to_string(),
362 role: "assistant".to_string(),
363 content: Vec::new(),
364 model: chunk.model.clone().unwrap_or_else(|| self.model.clone()),
365 stop_reason: None,
366 stop_sequence: None,
367 usage: Usage {
368 input_tokens: 0,
369 cache_creation_input_tokens: 0,
370 cache_read_input_tokens: 0,
371 output_tokens: 0,
372 },
373 request_id: None,
374 },
375 }));
376 }
377
378 if let Some(usage) = chunk.usage {
379 self.usage = Some(Usage {
380 input_tokens: usage.prompt_tokens,
381 cache_creation_input_tokens: 0,
382 cache_read_input_tokens: 0,
383 output_tokens: usage.completion_tokens,
384 });
385 }
386
387 for choice in chunk.choices {
388 if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
389 if !self.text_started {
390 self.text_started = true;
391 events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
392 index: 0,
393 content_block: OutputContentBlock::Text {
394 text: String::new(),
395 },
396 }));
397 }
398 events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
399 index: 0,
400 delta: ContentBlockDelta::TextDelta { text: content },
401 }));
402 }
403
404 for tool_call in choice.delta.tool_calls {
405 let state = self.tool_calls.entry(tool_call.index).or_default();
406 state.apply(tool_call);
407 let block_index = state.block_index();
408 if !state.started {
409 if let Some(start_event) = state.start_event()? {
410 state.started = true;
411 events.push(StreamEvent::ContentBlockStart(start_event));
412 } else {
413 continue;
414 }
415 }
416 if let Some(delta_event) = state.delta_event() {
417 events.push(StreamEvent::ContentBlockDelta(delta_event));
418 }
419 if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped {
420 state.stopped = true;
421 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
422 index: block_index,
423 }));
424 }
425 }
426
427 if let Some(finish_reason) = choice.finish_reason {
428 self.stop_reason = Some(normalize_finish_reason(&finish_reason));
429 if finish_reason == "tool_calls" {
430 for state in self.tool_calls.values_mut() {
431 if state.started && !state.stopped {
432 state.stopped = true;
433 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
434 index: state.block_index(),
435 }));
436 }
437 }
438 }
439 }
440 }
441
442 Ok(events)
443 }
444
445 fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
446 if self.finished {
447 return Ok(Vec::new());
448 }
449 self.finished = true;
450
451 let mut events = Vec::new();
452 if self.text_started && !self.text_finished {
453 self.text_finished = true;
454 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
455 index: 0,
456 }));
457 }
458
459 for state in self.tool_calls.values_mut() {
460 if !state.started {
461 if let Some(start_event) = state.start_event()? {
462 state.started = true;
463 events.push(StreamEvent::ContentBlockStart(start_event));
464 if let Some(delta_event) = state.delta_event() {
465 events.push(StreamEvent::ContentBlockDelta(delta_event));
466 }
467 }
468 }
469 if state.started && !state.stopped {
470 state.stopped = true;
471 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
472 index: state.block_index(),
473 }));
474 }
475 }
476
477 if self.message_started {
478 events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
479 delta: MessageDelta {
480 stop_reason: Some(
481 self.stop_reason
482 .clone()
483 .unwrap_or_else(|| "end_turn".to_string()),
484 ),
485 stop_sequence: None,
486 },
487 usage: self.usage.clone().unwrap_or(Usage {
488 input_tokens: 0,
489 cache_creation_input_tokens: 0,
490 cache_read_input_tokens: 0,
491 output_tokens: 0,
492 }),
493 }));
494 events.push(StreamEvent::MessageStop(MessageStopEvent {}));
495 }
496 Ok(events)
497 }
498}
499
500#[derive(Debug, Default)]
501struct ToolCallState {
502 openai_index: u32,
503 id: Option<String>,
504 name: Option<String>,
505 arguments: String,
506 emitted_len: usize,
507 started: bool,
508 stopped: bool,
509}
510
511impl ToolCallState {
512 fn apply(&mut self, tool_call: DeltaToolCall) {
513 self.openai_index = tool_call.index;
514 if let Some(id) = tool_call.id {
515 self.id = Some(id);
516 }
517 if let Some(name) = tool_call.function.name {
518 self.name = Some(name);
519 }
520 if let Some(arguments) = tool_call.function.arguments {
521 self.arguments.push_str(&arguments);
522 }
523 }
524
525 const fn block_index(&self) -> u32 {
526 self.openai_index + 1
527 }
528
529 #[allow(clippy::unnecessary_wraps)]
530 fn start_event(&self) -> Result<Option<ContentBlockStartEvent>, ApiError> {
531 let Some(name) = self.name.clone() else {
532 return Ok(None);
533 };
534 let id = self
535 .id
536 .clone()
537 .unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
538 Ok(Some(ContentBlockStartEvent {
539 index: self.block_index(),
540 content_block: OutputContentBlock::ToolUse {
541 id,
542 name,
543 input: json!({}),
544 },
545 }))
546 }
547
548 fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
549 if self.emitted_len >= self.arguments.len() {
550 return None;
551 }
552 let delta = self.arguments[self.emitted_len..].to_string();
553 self.emitted_len = self.arguments.len();
554 Some(ContentBlockDeltaEvent {
555 index: self.block_index(),
556 delta: ContentBlockDelta::InputJsonDelta {
557 partial_json: delta,
558 },
559 })
560 }
561}
562
563#[derive(Debug, Deserialize)]
564struct ChatCompletionResponse {
565 id: String,
566 model: String,
567 choices: Vec<ChatChoice>,
568 #[serde(default)]
569 usage: Option<OpenAiUsage>,
570}
571
572#[derive(Debug, Deserialize)]
573struct ChatChoice {
574 message: ChatMessage,
575 #[serde(default)]
576 finish_reason: Option<String>,
577}
578
579#[derive(Debug, Deserialize)]
580struct ChatMessage {
581 role: String,
582 #[serde(default)]
583 content: Option<String>,
584 #[serde(default)]
585 tool_calls: Vec<ResponseToolCall>,
586}
587
588#[derive(Debug, Deserialize)]
589struct ResponseToolCall {
590 id: String,
591 function: ResponseToolFunction,
592}
593
594#[derive(Debug, Deserialize)]
595struct ResponseToolFunction {
596 name: String,
597 arguments: String,
598}
599
600#[derive(Debug, Deserialize)]
601struct OpenAiUsage {
602 #[serde(default)]
603 prompt_tokens: u32,
604 #[serde(default)]
605 completion_tokens: u32,
606}
607
608#[derive(Debug, Deserialize)]
609struct ChatCompletionChunk {
610 id: String,
611 #[serde(default)]
612 model: Option<String>,
613 #[serde(default)]
614 choices: Vec<ChunkChoice>,
615 #[serde(default)]
616 usage: Option<OpenAiUsage>,
617}
618
619#[derive(Debug, Deserialize)]
620struct ChunkChoice {
621 delta: ChunkDelta,
622 #[serde(default)]
623 finish_reason: Option<String>,
624}
625
626#[derive(Debug, Default, Deserialize)]
627struct ChunkDelta {
628 #[serde(default)]
629 content: Option<String>,
630 #[serde(default)]
631 tool_calls: Vec<DeltaToolCall>,
632}
633
634#[derive(Debug, Deserialize)]
635struct DeltaToolCall {
636 #[serde(default)]
637 index: u32,
638 #[serde(default)]
639 id: Option<String>,
640 #[serde(default)]
641 function: DeltaFunction,
642}
643
644#[derive(Debug, Default, Deserialize)]
645struct DeltaFunction {
646 #[serde(default)]
647 name: Option<String>,
648 #[serde(default)]
649 arguments: Option<String>,
650}
651
652#[derive(Debug, Deserialize)]
653struct ErrorEnvelope {
654 error: ErrorBody,
655}
656
657#[derive(Debug, Deserialize)]
658struct ErrorBody {
659 #[serde(rename = "type")]
660 error_type: Option<String>,
661 message: Option<String>,
662}
663
664fn build_chat_completion_request(request: &MessageRequest) -> Value {
665 let mut messages = Vec::new();
666 if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
667 messages.push(json!({
668 "role": "system",
669 "content": system,
670 }));
671 }
672 for message in &request.messages {
673 messages.extend(translate_message(message));
674 }
675
676 let mut payload = json!({
677 "model": request.model,
678 "max_tokens": request.max_tokens,
679 "messages": messages,
680 "stream": request.stream,
681 });
682
683 if let Some(tools) = &request.tools {
684 payload["tools"] =
685 Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
686 }
687 if let Some(tool_choice) = &request.tool_choice {
688 payload["tool_choice"] = openai_tool_choice(tool_choice);
689 }
690
691 payload
692}
693
694fn translate_message(message: &InputMessage) -> Vec<Value> {
695 match message.role.as_str() {
696 "assistant" => {
697 let mut text = String::new();
698 let mut tool_calls = Vec::new();
699 for block in &message.content {
700 match block {
701 InputContentBlock::Text { text: value } => text.push_str(value),
702 InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
703 "id": id,
704 "type": "function",
705 "function": {
706 "name": name,
707 "arguments": input.to_string(),
708 }
709 })),
710 InputContentBlock::ToolResult { .. } => {}
711 }
712 }
713 if text.is_empty() && tool_calls.is_empty() {
714 Vec::new()
715 } else {
716 vec![json!({
717 "role": "assistant",
718 "content": (!text.is_empty()).then_some(text),
719 "tool_calls": tool_calls,
720 })]
721 }
722 }
723 _ => message
724 .content
725 .iter()
726 .filter_map(|block| match block {
727 InputContentBlock::Text { text } => Some(json!({
728 "role": "user",
729 "content": text,
730 })),
731 InputContentBlock::ToolResult {
732 tool_use_id,
733 content,
734 is_error,
735 } => Some(json!({
736 "role": "tool",
737 "tool_call_id": tool_use_id,
738 "content": flatten_tool_result_content(content),
739 "is_error": is_error,
740 })),
741 InputContentBlock::ToolUse { .. } => None,
742 })
743 .collect(),
744 }
745}
746
747fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
748 content
749 .iter()
750 .map(|block| match block {
751 ToolResultContentBlock::Text { text } => text.clone(),
752 ToolResultContentBlock::Json { value } => value.to_string(),
753 })
754 .collect::<Vec<_>>()
755 .join("\n")
756}
757
758fn openai_tool_definition(tool: &ToolDefinition) -> Value {
759 json!({
760 "type": "function",
761 "function": {
762 "name": tool.name,
763 "description": tool.description,
764 "parameters": tool.input_schema,
765 }
766 })
767}
768
769fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
770 match tool_choice {
771 ToolChoice::Auto => Value::String("auto".to_string()),
772 ToolChoice::Any => Value::String("required".to_string()),
773 ToolChoice::Tool { name } => json!({
774 "type": "function",
775 "function": { "name": name },
776 }),
777 }
778}
779
780fn normalize_response(
781 model: &str,
782 response: ChatCompletionResponse,
783) -> Result<MessageResponse, ApiError> {
784 let choice = response
785 .choices
786 .into_iter()
787 .next()
788 .ok_or(ApiError::InvalidSseFrame(
789 "chat completion response missing choices",
790 ))?;
791 let mut content = Vec::new();
792 if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) {
793 content.push(OutputContentBlock::Text { text });
794 }
795 for tool_call in choice.message.tool_calls {
796 content.push(OutputContentBlock::ToolUse {
797 id: tool_call.id,
798 name: tool_call.function.name,
799 input: parse_tool_arguments(&tool_call.function.arguments),
800 });
801 }
802
803 Ok(MessageResponse {
804 id: response.id,
805 kind: "message".to_string(),
806 role: choice.message.role,
807 content,
808 model: response.model.if_empty_then(model.to_string()),
809 stop_reason: choice
810 .finish_reason
811 .map(|value| normalize_finish_reason(&value)),
812 stop_sequence: None,
813 usage: Usage {
814 input_tokens: response
815 .usage
816 .as_ref()
817 .map_or(0, |usage| usage.prompt_tokens),
818 cache_creation_input_tokens: 0,
819 cache_read_input_tokens: 0,
820 output_tokens: response
821 .usage
822 .as_ref()
823 .map_or(0, |usage| usage.completion_tokens),
824 },
825 request_id: None,
826 })
827}
828
829fn parse_tool_arguments(arguments: &str) -> Value {
830 serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
831}
832
833fn next_sse_frame(buffer: &mut Vec<u8>) -> Option<String> {
834 let separator = buffer
835 .windows(2)
836 .position(|window| window == b"\n\n")
837 .map(|position| (position, 2))
838 .or_else(|| {
839 buffer
840 .windows(4)
841 .position(|window| window == b"\r\n\r\n")
842 .map(|position| (position, 4))
843 })?;
844
845 let (position, separator_len) = separator;
846 let frame = buffer.drain(..position + separator_len).collect::<Vec<_>>();
847 let frame_len = frame.len().saturating_sub(separator_len);
848 Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
849}
850
851fn parse_sse_frame(frame: &str) -> Result<Option<ChatCompletionChunk>, ApiError> {
852 let trimmed = frame.trim();
853 if trimmed.is_empty() {
854 return Ok(None);
855 }
856
857 let mut data_lines = Vec::new();
858 for line in trimmed.lines() {
859 if line.starts_with(':') {
860 continue;
861 }
862 if let Some(data) = line.strip_prefix("data:") {
863 data_lines.push(data.trim_start());
864 }
865 }
866 if data_lines.is_empty() {
867 return Ok(None);
868 }
869 let payload = data_lines.join("\n");
870 if payload == "[DONE]" {
871 return Ok(None);
872 }
873 serde_json::from_str(&payload)
874 .map(Some)
875 .map_err(ApiError::from)
876}
877
878fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
879 match std::env::var(key) {
880 Ok(value) if !value.is_empty() => Ok(Some(value)),
881 Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
882 Err(error) => Err(ApiError::from(error)),
883 }
884}
885
886#[must_use]
887pub fn has_api_key(key: &str) -> bool {
888 read_env_non_empty(key)
889 .ok()
890 .and_then(std::convert::identity)
891 .is_some()
892}
893
894#[must_use]
895pub fn read_base_url(config: OpenAiCompatConfig) -> String {
896 std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
897}
898
899fn chat_completions_endpoint(base_url: &str) -> String {
900 let trimmed = base_url.trim_end_matches('/');
901 if trimmed.ends_with("/chat/completions") {
902 trimmed.to_string()
903 } else {
904 format!("{trimmed}/chat/completions")
905 }
906}
907
908fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
909 headers
910 .get(REQUEST_ID_HEADER)
911 .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
912 .and_then(|value| value.to_str().ok())
913 .map(ToOwned::to_owned)
914}
915
916async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
917 let status = response.status();
918 if status.is_success() {
919 return Ok(response);
920 }
921
922 let body = response.text().await.unwrap_or_default();
923 let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
924 let retryable = is_retryable_status(status);
925
926 Err(ApiError::Api {
927 status,
928 error_type: parsed_error
929 .as_ref()
930 .and_then(|error| error.error.error_type.clone()),
931 message: parsed_error
932 .as_ref()
933 .and_then(|error| error.error.message.clone()),
934 body,
935 retryable,
936 })
937}
938
939const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
940 matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
941}
942
943fn normalize_finish_reason(value: &str) -> String {
944 match value {
945 "stop" => "end_turn",
946 "tool_calls" => "tool_use",
947 other => other,
948 }
949 .to_string()
950}
951
952trait StringExt {
953 fn if_empty_then(self, fallback: String) -> String;
954}
955
956impl StringExt for String {
957 fn if_empty_then(self, fallback: String) -> String {
958 if self.is_empty() {
959 fallback
960 } else {
961 self
962 }
963 }
964}
965
966#[cfg(test)]
967mod tests {
968 use super::{
969 build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason,
970 openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
971 };
972 use crate::error::ApiError;
973 use crate::types::{
974 InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition,
975 ToolResultContentBlock,
976 };
977 use serde_json::json;
978 use std::sync::{Mutex, OnceLock};
979
980 #[test]
981 fn request_translation_uses_openai_compatible_shape() {
982 let payload = build_chat_completion_request(&MessageRequest {
983 model: "grok-3".to_string(),
984 max_tokens: 64,
985 messages: vec![InputMessage {
986 role: "user".to_string(),
987 content: vec![
988 InputContentBlock::Text {
989 text: "hello".to_string(),
990 },
991 InputContentBlock::ToolResult {
992 tool_use_id: "tool_1".to_string(),
993 content: vec![ToolResultContentBlock::Json {
994 value: json!({"ok": true}),
995 }],
996 is_error: false,
997 },
998 ],
999 }],
1000 system: Some("be helpful".to_string()),
1001 tools: Some(vec![ToolDefinition {
1002 name: "weather".to_string(),
1003 description: Some("Get weather".to_string()),
1004 input_schema: json!({"type": "object"}),
1005 }]),
1006 tool_choice: Some(ToolChoice::Auto),
1007 stream: false,
1008 });
1009
1010 assert_eq!(payload["messages"][0]["role"], json!("system"));
1011 assert_eq!(payload["messages"][1]["role"], json!("user"));
1012 assert_eq!(payload["messages"][2]["role"], json!("tool"));
1013 assert_eq!(payload["tools"][0]["type"], json!("function"));
1014 assert_eq!(payload["tool_choice"], json!("auto"));
1015 }
1016
1017 #[test]
1018 fn tool_choice_translation_supports_required_function() {
1019 assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));
1020 assert_eq!(
1021 openai_tool_choice(&ToolChoice::Tool {
1022 name: "weather".to_string(),
1023 }),
1024 json!({"type": "function", "function": {"name": "weather"}})
1025 );
1026 }
1027
1028 #[test]
1029 fn parses_tool_arguments_fallback() {
1030 assert_eq!(
1031 parse_tool_arguments("{\"city\":\"Paris\"}"),
1032 json!({"city": "Paris"})
1033 );
1034 assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"}));
1035 }
1036
1037 #[test]
1038 fn missing_xai_api_key_is_provider_specific() {
1039 let _lock = env_lock();
1040 std::env::remove_var("XAI_API_KEY");
1041 let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai())
1042 .expect_err("missing key should error");
1043 assert!(matches!(
1044 error,
1045 ApiError::MissingCredentials {
1046 provider: "xAI",
1047 ..
1048 }
1049 ));
1050 }
1051
1052 #[test]
1053 fn endpoint_builder_accepts_base_urls_and_full_endpoints() {
1054 assert_eq!(
1055 chat_completions_endpoint("https://api.x.ai/v1"),
1056 "https://api.x.ai/v1/chat/completions"
1057 );
1058 assert_eq!(
1059 chat_completions_endpoint("https://api.x.ai/v1/"),
1060 "https://api.x.ai/v1/chat/completions"
1061 );
1062 assert_eq!(
1063 chat_completions_endpoint("https://api.x.ai/v1/chat/completions"),
1064 "https://api.x.ai/v1/chat/completions"
1065 );
1066 }
1067
1068 fn env_lock() -> std::sync::MutexGuard<'static, ()> {
1069 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
1070 LOCK.get_or_init(|| Mutex::new(()))
1071 .lock()
1072 .expect("env lock")
1073 }
1074
1075 #[test]
1076 fn normalizes_stop_reasons() {
1077 assert_eq!(normalize_finish_reason("stop"), "end_turn");
1078 assert_eq!(normalize_finish_reason("tool_calls"), "tool_use");
1079 }
1080}