Skip to main content

wraith_api/providers/
openai_compat.rs

1use std::collections::{BTreeMap, VecDeque};
2use std::time::Duration;
3
4use serde::Deserialize;
5use serde_json::{json, Value};
6
7use crate::error::ApiError;
8use crate::types::{
9    ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
10    InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
11    MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
12    ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
13};
14
15use super::{Provider, ProviderFuture};
16
17pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
18pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
19pub const DEFAULT_GEMINI_BASE_URL: &str =
20    "https://generativelanguage.googleapis.com/v1beta/openai";
21pub const DEFAULT_OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1";
22const REQUEST_ID_HEADER: &str = "request-id";
23const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
24const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
25const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
26const DEFAULT_MAX_RETRIES: u32 = 2;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct OpenAiCompatConfig {
30    pub provider_name: &'static str,
31    pub api_key_env: &'static str,
32    pub base_url_env: &'static str,
33    pub default_base_url: &'static str,
34}
35
36const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
37const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
38const GEMINI_ENV_VARS: &[&str] = &["GEMINI_API_KEY"];
39const OPENROUTER_ENV_VARS: &[&str] = &["OPENROUTER_API_KEY"];
40
41impl OpenAiCompatConfig {
42    #[must_use]
43    pub const fn xai() -> Self {
44        Self {
45            provider_name: "xAI",
46            api_key_env: "XAI_API_KEY",
47            base_url_env: "XAI_BASE_URL",
48            default_base_url: DEFAULT_XAI_BASE_URL,
49        }
50    }
51
52    #[must_use]
53    pub const fn openai() -> Self {
54        Self {
55            provider_name: "OpenAI",
56            api_key_env: "OPENAI_API_KEY",
57            base_url_env: "OPENAI_BASE_URL",
58            default_base_url: DEFAULT_OPENAI_BASE_URL,
59        }
60    }
61
62    #[must_use]
63    pub const fn gemini() -> Self {
64        Self {
65            provider_name: "Gemini",
66            api_key_env: "GEMINI_API_KEY",
67            base_url_env: "GEMINI_BASE_URL",
68            default_base_url: DEFAULT_GEMINI_BASE_URL,
69        }
70    }
71
72    #[must_use]
73    pub const fn openrouter() -> Self {
74        Self {
75            provider_name: "OpenRouter",
76            api_key_env: "OPENROUTER_API_KEY",
77            base_url_env: "OPENROUTER_BASE_URL",
78            default_base_url: DEFAULT_OPENROUTER_BASE_URL,
79        }
80    }
81
82    #[must_use]
83    pub fn credential_env_vars(self) -> &'static [&'static str] {
84        match self.provider_name {
85            "xAI" => XAI_ENV_VARS,
86            "OpenAI" => OPENAI_ENV_VARS,
87            "Gemini" => GEMINI_ENV_VARS,
88            "OpenRouter" => OPENROUTER_ENV_VARS,
89            _ => &[],
90        }
91    }
92}
93
94#[derive(Debug, Clone)]
95pub struct OpenAiCompatClient {
96    http: reqwest::Client,
97    api_key: String,
98    base_url: String,
99    max_retries: u32,
100    initial_backoff: Duration,
101    max_backoff: Duration,
102}
103
104impl OpenAiCompatClient {
105    #[must_use]
106    pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
107        Self {
108            http: reqwest::Client::new(),
109            api_key: api_key.into(),
110            base_url: read_base_url(config),
111            max_retries: DEFAULT_MAX_RETRIES,
112            initial_backoff: DEFAULT_INITIAL_BACKOFF,
113            max_backoff: DEFAULT_MAX_BACKOFF,
114        }
115    }
116
117    pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
118        let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
119            return Err(ApiError::missing_credentials(
120                config.provider_name,
121                config.credential_env_vars(),
122            ));
123        };
124        Ok(Self::new(api_key, config))
125    }
126
127    #[must_use]
128    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
129        self.base_url = base_url.into();
130        self
131    }
132
133    #[must_use]
134    pub fn with_retry_policy(
135        mut self,
136        max_retries: u32,
137        initial_backoff: Duration,
138        max_backoff: Duration,
139    ) -> Self {
140        self.max_retries = max_retries;
141        self.initial_backoff = initial_backoff;
142        self.max_backoff = max_backoff;
143        self
144    }
145
146    pub async fn send_message(
147        &self,
148        request: &MessageRequest,
149    ) -> Result<MessageResponse, ApiError> {
150        let request = MessageRequest {
151            stream: false,
152            ..request.clone()
153        };
154        let response = self.send_with_retry(&request).await?;
155        let request_id = request_id_from_headers(response.headers());
156        let payload = response.json::<ChatCompletionResponse>().await?;
157        let mut normalized = normalize_response(&request.model, payload)?;
158        if normalized.request_id.is_none() {
159            normalized.request_id = request_id;
160        }
161        Ok(normalized)
162    }
163
164    pub async fn stream_message(
165        &self,
166        request: &MessageRequest,
167    ) -> Result<MessageStream, ApiError> {
168        let response = self
169            .send_with_retry(&request.clone().with_streaming())
170            .await?;
171        Ok(MessageStream {
172            request_id: request_id_from_headers(response.headers()),
173            response,
174            parser: OpenAiSseParser::new(),
175            pending: VecDeque::new(),
176            done: false,
177            state: StreamState::new(request.model.clone()),
178        })
179    }
180
181    async fn send_with_retry(
182        &self,
183        request: &MessageRequest,
184    ) -> Result<reqwest::Response, ApiError> {
185        let mut attempts = 0;
186
187        let last_error = loop {
188            attempts += 1;
189            let retryable_error = match self.send_raw_request(request).await {
190                Ok(response) => match expect_success(response).await {
191                    Ok(response) => return Ok(response),
192                    Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
193                    Err(error) => return Err(error),
194                },
195                Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
196                Err(error) => return Err(error),
197            };
198
199            if attempts > self.max_retries {
200                break retryable_error;
201            }
202
203            tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
204        };
205
206        Err(ApiError::RetriesExhausted {
207            attempts,
208            last_error: Box::new(last_error),
209        })
210    }
211
212    async fn send_raw_request(
213        &self,
214        request: &MessageRequest,
215    ) -> Result<reqwest::Response, ApiError> {
216        let request_url = chat_completions_endpoint(&self.base_url);
217        self.http
218            .post(&request_url)
219            .header("content-type", "application/json")
220            .bearer_auth(&self.api_key)
221            .json(&build_chat_completion_request(request))
222            .send()
223            .await
224            .map_err(ApiError::from)
225    }
226
227    fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
228        let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
229            return Err(ApiError::BackoffOverflow {
230                attempt,
231                base_delay: self.initial_backoff,
232            });
233        };
234        Ok(self
235            .initial_backoff
236            .checked_mul(multiplier)
237            .map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
238    }
239}
240
241impl Provider for OpenAiCompatClient {
242    type Stream = MessageStream;
243
244    fn send_message<'a>(
245        &'a self,
246        request: &'a MessageRequest,
247    ) -> ProviderFuture<'a, MessageResponse> {
248        Box::pin(async move { self.send_message(request).await })
249    }
250
251    fn stream_message<'a>(
252        &'a self,
253        request: &'a MessageRequest,
254    ) -> ProviderFuture<'a, Self::Stream> {
255        Box::pin(async move { self.stream_message(request).await })
256    }
257}
258
259#[derive(Debug)]
260pub struct MessageStream {
261    request_id: Option<String>,
262    response: reqwest::Response,
263    parser: OpenAiSseParser,
264    pending: VecDeque<StreamEvent>,
265    done: bool,
266    state: StreamState,
267}
268
269impl MessageStream {
270    #[must_use]
271    pub fn request_id(&self) -> Option<&str> {
272        self.request_id.as_deref()
273    }
274
275    pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
276        loop {
277            if let Some(event) = self.pending.pop_front() {
278                return Ok(Some(event));
279            }
280
281            if self.done {
282                self.pending.extend(self.state.finish()?);
283                if let Some(event) = self.pending.pop_front() {
284                    return Ok(Some(event));
285                }
286                return Ok(None);
287            }
288
289            match self.response.chunk().await? {
290                Some(chunk) => {
291                    for parsed in self.parser.push(&chunk)? {
292                        self.pending.extend(self.state.ingest_chunk(parsed)?);
293                    }
294                }
295                None => {
296                    self.done = true;
297                }
298            }
299        }
300    }
301}
302
303#[derive(Debug, Default)]
304struct OpenAiSseParser {
305    buffer: Vec<u8>,
306}
307
308impl OpenAiSseParser {
309    fn new() -> Self {
310        Self::default()
311    }
312
313    fn push(&mut self, chunk: &[u8]) -> Result<Vec<ChatCompletionChunk>, ApiError> {
314        self.buffer.extend_from_slice(chunk);
315        let mut events = Vec::new();
316
317        while let Some(frame) = next_sse_frame(&mut self.buffer) {
318            if let Some(event) = parse_sse_frame(&frame)? {
319                events.push(event);
320            }
321        }
322
323        Ok(events)
324    }
325}
326
327#[allow(clippy::struct_excessive_bools)]
328#[derive(Debug)]
329struct StreamState {
330    model: String,
331    message_started: bool,
332    text_started: bool,
333    text_finished: bool,
334    finished: bool,
335    stop_reason: Option<String>,
336    usage: Option<Usage>,
337    tool_calls: BTreeMap<u32, ToolCallState>,
338}
339
340impl StreamState {
341    fn new(model: String) -> Self {
342        Self {
343            model,
344            message_started: false,
345            text_started: false,
346            text_finished: false,
347            finished: false,
348            stop_reason: None,
349            usage: None,
350            tool_calls: BTreeMap::new(),
351        }
352    }
353
354    fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result<Vec<StreamEvent>, ApiError> {
355        let mut events = Vec::new();
356        if !self.message_started {
357            self.message_started = true;
358            events.push(StreamEvent::MessageStart(MessageStartEvent {
359                message: MessageResponse {
360                    id: chunk.id.clone(),
361                    kind: "message".to_string(),
362                    role: "assistant".to_string(),
363                    content: Vec::new(),
364                    model: chunk.model.clone().unwrap_or_else(|| self.model.clone()),
365                    stop_reason: None,
366                    stop_sequence: None,
367                    usage: Usage {
368                        input_tokens: 0,
369                        cache_creation_input_tokens: 0,
370                        cache_read_input_tokens: 0,
371                        output_tokens: 0,
372                    },
373                    request_id: None,
374                },
375            }));
376        }
377
378        if let Some(usage) = chunk.usage {
379            self.usage = Some(Usage {
380                input_tokens: usage.prompt_tokens,
381                cache_creation_input_tokens: 0,
382                cache_read_input_tokens: 0,
383                output_tokens: usage.completion_tokens,
384            });
385        }
386
387        for choice in chunk.choices {
388            if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
389                if !self.text_started {
390                    self.text_started = true;
391                    events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
392                        index: 0,
393                        content_block: OutputContentBlock::Text {
394                            text: String::new(),
395                        },
396                    }));
397                }
398                events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
399                    index: 0,
400                    delta: ContentBlockDelta::TextDelta { text: content },
401                }));
402            }
403
404            for tool_call in choice.delta.tool_calls {
405                let state = self.tool_calls.entry(tool_call.index).or_default();
406                state.apply(tool_call);
407                let block_index = state.block_index();
408                if !state.started {
409                    if let Some(start_event) = state.start_event()? {
410                        state.started = true;
411                        events.push(StreamEvent::ContentBlockStart(start_event));
412                    } else {
413                        continue;
414                    }
415                }
416                if let Some(delta_event) = state.delta_event() {
417                    events.push(StreamEvent::ContentBlockDelta(delta_event));
418                }
419                if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped {
420                    state.stopped = true;
421                    events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
422                        index: block_index,
423                    }));
424                }
425            }
426
427            if let Some(finish_reason) = choice.finish_reason {
428                self.stop_reason = Some(normalize_finish_reason(&finish_reason));
429                if finish_reason == "tool_calls" {
430                    for state in self.tool_calls.values_mut() {
431                        if state.started && !state.stopped {
432                            state.stopped = true;
433                            events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
434                                index: state.block_index(),
435                            }));
436                        }
437                    }
438                }
439            }
440        }
441
442        Ok(events)
443    }
444
445    fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
446        if self.finished {
447            return Ok(Vec::new());
448        }
449        self.finished = true;
450
451        let mut events = Vec::new();
452        if self.text_started && !self.text_finished {
453            self.text_finished = true;
454            events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
455                index: 0,
456            }));
457        }
458
459        for state in self.tool_calls.values_mut() {
460            if !state.started {
461                if let Some(start_event) = state.start_event()? {
462                    state.started = true;
463                    events.push(StreamEvent::ContentBlockStart(start_event));
464                    if let Some(delta_event) = state.delta_event() {
465                        events.push(StreamEvent::ContentBlockDelta(delta_event));
466                    }
467                }
468            }
469            if state.started && !state.stopped {
470                state.stopped = true;
471                events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
472                    index: state.block_index(),
473                }));
474            }
475        }
476
477        if self.message_started {
478            events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
479                delta: MessageDelta {
480                    stop_reason: Some(
481                        self.stop_reason
482                            .clone()
483                            .unwrap_or_else(|| "end_turn".to_string()),
484                    ),
485                    stop_sequence: None,
486                },
487                usage: self.usage.clone().unwrap_or(Usage {
488                    input_tokens: 0,
489                    cache_creation_input_tokens: 0,
490                    cache_read_input_tokens: 0,
491                    output_tokens: 0,
492                }),
493            }));
494            events.push(StreamEvent::MessageStop(MessageStopEvent {}));
495        }
496        Ok(events)
497    }
498}
499
500#[derive(Debug, Default)]
501struct ToolCallState {
502    openai_index: u32,
503    id: Option<String>,
504    name: Option<String>,
505    arguments: String,
506    emitted_len: usize,
507    started: bool,
508    stopped: bool,
509}
510
511impl ToolCallState {
512    fn apply(&mut self, tool_call: DeltaToolCall) {
513        self.openai_index = tool_call.index;
514        if let Some(id) = tool_call.id {
515            self.id = Some(id);
516        }
517        if let Some(name) = tool_call.function.name {
518            self.name = Some(name);
519        }
520        if let Some(arguments) = tool_call.function.arguments {
521            self.arguments.push_str(&arguments);
522        }
523    }
524
525    const fn block_index(&self) -> u32 {
526        self.openai_index + 1
527    }
528
529    #[allow(clippy::unnecessary_wraps)]
530    fn start_event(&self) -> Result<Option<ContentBlockStartEvent>, ApiError> {
531        let Some(name) = self.name.clone() else {
532            return Ok(None);
533        };
534        let id = self
535            .id
536            .clone()
537            .unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
538        Ok(Some(ContentBlockStartEvent {
539            index: self.block_index(),
540            content_block: OutputContentBlock::ToolUse {
541                id,
542                name,
543                input: json!({}),
544            },
545        }))
546    }
547
548    fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
549        if self.emitted_len >= self.arguments.len() {
550            return None;
551        }
552        let delta = self.arguments[self.emitted_len..].to_string();
553        self.emitted_len = self.arguments.len();
554        Some(ContentBlockDeltaEvent {
555            index: self.block_index(),
556            delta: ContentBlockDelta::InputJsonDelta {
557                partial_json: delta,
558            },
559        })
560    }
561}
562
563#[derive(Debug, Deserialize)]
564struct ChatCompletionResponse {
565    id: String,
566    model: String,
567    choices: Vec<ChatChoice>,
568    #[serde(default)]
569    usage: Option<OpenAiUsage>,
570}
571
572#[derive(Debug, Deserialize)]
573struct ChatChoice {
574    message: ChatMessage,
575    #[serde(default)]
576    finish_reason: Option<String>,
577}
578
579#[derive(Debug, Deserialize)]
580struct ChatMessage {
581    role: String,
582    #[serde(default)]
583    content: Option<String>,
584    #[serde(default)]
585    tool_calls: Vec<ResponseToolCall>,
586}
587
588#[derive(Debug, Deserialize)]
589struct ResponseToolCall {
590    id: String,
591    function: ResponseToolFunction,
592}
593
594#[derive(Debug, Deserialize)]
595struct ResponseToolFunction {
596    name: String,
597    arguments: String,
598}
599
600#[derive(Debug, Deserialize)]
601struct OpenAiUsage {
602    #[serde(default)]
603    prompt_tokens: u32,
604    #[serde(default)]
605    completion_tokens: u32,
606}
607
608#[derive(Debug, Deserialize)]
609struct ChatCompletionChunk {
610    id: String,
611    #[serde(default)]
612    model: Option<String>,
613    #[serde(default)]
614    choices: Vec<ChunkChoice>,
615    #[serde(default)]
616    usage: Option<OpenAiUsage>,
617}
618
619#[derive(Debug, Deserialize)]
620struct ChunkChoice {
621    delta: ChunkDelta,
622    #[serde(default)]
623    finish_reason: Option<String>,
624}
625
626#[derive(Debug, Default, Deserialize)]
627struct ChunkDelta {
628    #[serde(default)]
629    content: Option<String>,
630    #[serde(default)]
631    tool_calls: Vec<DeltaToolCall>,
632}
633
634#[derive(Debug, Deserialize)]
635struct DeltaToolCall {
636    #[serde(default)]
637    index: u32,
638    #[serde(default)]
639    id: Option<String>,
640    #[serde(default)]
641    function: DeltaFunction,
642}
643
644#[derive(Debug, Default, Deserialize)]
645struct DeltaFunction {
646    #[serde(default)]
647    name: Option<String>,
648    #[serde(default)]
649    arguments: Option<String>,
650}
651
652#[derive(Debug, Deserialize)]
653struct ErrorEnvelope {
654    error: ErrorBody,
655}
656
657#[derive(Debug, Deserialize)]
658struct ErrorBody {
659    #[serde(rename = "type")]
660    error_type: Option<String>,
661    message: Option<String>,
662}
663
664fn build_chat_completion_request(request: &MessageRequest) -> Value {
665    let mut messages = Vec::new();
666    if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
667        messages.push(json!({
668            "role": "system",
669            "content": system,
670        }));
671    }
672    for message in &request.messages {
673        messages.extend(translate_message(message));
674    }
675
676    let mut payload = json!({
677        "model": request.model,
678        "max_tokens": request.max_tokens,
679        "messages": messages,
680        "stream": request.stream,
681    });
682
683    if let Some(tools) = &request.tools {
684        payload["tools"] =
685            Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
686    }
687    if let Some(tool_choice) = &request.tool_choice {
688        payload["tool_choice"] = openai_tool_choice(tool_choice);
689    }
690
691    payload
692}
693
694fn translate_message(message: &InputMessage) -> Vec<Value> {
695    match message.role.as_str() {
696        "assistant" => {
697            let mut text = String::new();
698            let mut tool_calls = Vec::new();
699            for block in &message.content {
700                match block {
701                    InputContentBlock::Text { text: value } => text.push_str(value),
702                    InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
703                        "id": id,
704                        "type": "function",
705                        "function": {
706                            "name": name,
707                            "arguments": input.to_string(),
708                        }
709                    })),
710                    InputContentBlock::ToolResult { .. } => {}
711                }
712            }
713            if text.is_empty() && tool_calls.is_empty() {
714                Vec::new()
715            } else {
716                vec![json!({
717                    "role": "assistant",
718                    "content": (!text.is_empty()).then_some(text),
719                    "tool_calls": tool_calls,
720                })]
721            }
722        }
723        _ => message
724            .content
725            .iter()
726            .filter_map(|block| match block {
727                InputContentBlock::Text { text } => Some(json!({
728                    "role": "user",
729                    "content": text,
730                })),
731                InputContentBlock::ToolResult {
732                    tool_use_id,
733                    content,
734                    is_error,
735                } => Some(json!({
736                    "role": "tool",
737                    "tool_call_id": tool_use_id,
738                    "content": flatten_tool_result_content(content),
739                    "is_error": is_error,
740                })),
741                InputContentBlock::ToolUse { .. } => None,
742            })
743            .collect(),
744    }
745}
746
747fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
748    content
749        .iter()
750        .map(|block| match block {
751            ToolResultContentBlock::Text { text } => text.clone(),
752            ToolResultContentBlock::Json { value } => value.to_string(),
753        })
754        .collect::<Vec<_>>()
755        .join("\n")
756}
757
758fn openai_tool_definition(tool: &ToolDefinition) -> Value {
759    json!({
760        "type": "function",
761        "function": {
762            "name": tool.name,
763            "description": tool.description,
764            "parameters": tool.input_schema,
765        }
766    })
767}
768
769fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
770    match tool_choice {
771        ToolChoice::Auto => Value::String("auto".to_string()),
772        ToolChoice::Any => Value::String("required".to_string()),
773        ToolChoice::Tool { name } => json!({
774            "type": "function",
775            "function": { "name": name },
776        }),
777    }
778}
779
780fn normalize_response(
781    model: &str,
782    response: ChatCompletionResponse,
783) -> Result<MessageResponse, ApiError> {
784    let choice = response
785        .choices
786        .into_iter()
787        .next()
788        .ok_or(ApiError::InvalidSseFrame(
789            "chat completion response missing choices",
790        ))?;
791    let mut content = Vec::new();
792    if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) {
793        content.push(OutputContentBlock::Text { text });
794    }
795    for tool_call in choice.message.tool_calls {
796        content.push(OutputContentBlock::ToolUse {
797            id: tool_call.id,
798            name: tool_call.function.name,
799            input: parse_tool_arguments(&tool_call.function.arguments),
800        });
801    }
802
803    Ok(MessageResponse {
804        id: response.id,
805        kind: "message".to_string(),
806        role: choice.message.role,
807        content,
808        model: response.model.if_empty_then(model.to_string()),
809        stop_reason: choice
810            .finish_reason
811            .map(|value| normalize_finish_reason(&value)),
812        stop_sequence: None,
813        usage: Usage {
814            input_tokens: response
815                .usage
816                .as_ref()
817                .map_or(0, |usage| usage.prompt_tokens),
818            cache_creation_input_tokens: 0,
819            cache_read_input_tokens: 0,
820            output_tokens: response
821                .usage
822                .as_ref()
823                .map_or(0, |usage| usage.completion_tokens),
824        },
825        request_id: None,
826    })
827}
828
829fn parse_tool_arguments(arguments: &str) -> Value {
830    serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
831}
832
833fn next_sse_frame(buffer: &mut Vec<u8>) -> Option<String> {
834    let separator = buffer
835        .windows(2)
836        .position(|window| window == b"\n\n")
837        .map(|position| (position, 2))
838        .or_else(|| {
839            buffer
840                .windows(4)
841                .position(|window| window == b"\r\n\r\n")
842                .map(|position| (position, 4))
843        })?;
844
845    let (position, separator_len) = separator;
846    let frame = buffer.drain(..position + separator_len).collect::<Vec<_>>();
847    let frame_len = frame.len().saturating_sub(separator_len);
848    Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
849}
850
851fn parse_sse_frame(frame: &str) -> Result<Option<ChatCompletionChunk>, ApiError> {
852    let trimmed = frame.trim();
853    if trimmed.is_empty() {
854        return Ok(None);
855    }
856
857    let mut data_lines = Vec::new();
858    for line in trimmed.lines() {
859        if line.starts_with(':') {
860            continue;
861        }
862        if let Some(data) = line.strip_prefix("data:") {
863            data_lines.push(data.trim_start());
864        }
865    }
866    if data_lines.is_empty() {
867        return Ok(None);
868    }
869    let payload = data_lines.join("\n");
870    if payload == "[DONE]" {
871        return Ok(None);
872    }
873    serde_json::from_str(&payload)
874        .map(Some)
875        .map_err(ApiError::from)
876}
877
878fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
879    match std::env::var(key) {
880        Ok(value) if !value.is_empty() => Ok(Some(value)),
881        Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
882        Err(error) => Err(ApiError::from(error)),
883    }
884}
885
886#[must_use]
887pub fn has_api_key(key: &str) -> bool {
888    read_env_non_empty(key)
889        .ok()
890        .and_then(std::convert::identity)
891        .is_some()
892}
893
894#[must_use]
895pub fn read_base_url(config: OpenAiCompatConfig) -> String {
896    std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
897}
898
899fn chat_completions_endpoint(base_url: &str) -> String {
900    let trimmed = base_url.trim_end_matches('/');
901    if trimmed.ends_with("/chat/completions") {
902        trimmed.to_string()
903    } else {
904        format!("{trimmed}/chat/completions")
905    }
906}
907
908fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
909    headers
910        .get(REQUEST_ID_HEADER)
911        .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
912        .and_then(|value| value.to_str().ok())
913        .map(ToOwned::to_owned)
914}
915
916async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
917    let status = response.status();
918    if status.is_success() {
919        return Ok(response);
920    }
921
922    let body = response.text().await.unwrap_or_default();
923    let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
924    let retryable = is_retryable_status(status);
925
926    Err(ApiError::Api {
927        status,
928        error_type: parsed_error
929            .as_ref()
930            .and_then(|error| error.error.error_type.clone()),
931        message: parsed_error
932            .as_ref()
933            .and_then(|error| error.error.message.clone()),
934        body,
935        retryable,
936    })
937}
938
939const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
940    matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
941}
942
943fn normalize_finish_reason(value: &str) -> String {
944    match value {
945        "stop" => "end_turn",
946        "tool_calls" => "tool_use",
947        other => other,
948    }
949    .to_string()
950}
951
952trait StringExt {
953    fn if_empty_then(self, fallback: String) -> String;
954}
955
956impl StringExt for String {
957    fn if_empty_then(self, fallback: String) -> String {
958        if self.is_empty() {
959            fallback
960        } else {
961            self
962        }
963    }
964}
965
966#[cfg(test)]
967mod tests {
968    use super::{
969        build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason,
970        openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
971    };
972    use crate::error::ApiError;
973    use crate::types::{
974        InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition,
975        ToolResultContentBlock,
976    };
977    use serde_json::json;
978    use std::sync::{Mutex, OnceLock};
979
980    #[test]
981    fn request_translation_uses_openai_compatible_shape() {
982        let payload = build_chat_completion_request(&MessageRequest {
983            model: "grok-3".to_string(),
984            max_tokens: 64,
985            messages: vec![InputMessage {
986                role: "user".to_string(),
987                content: vec![
988                    InputContentBlock::Text {
989                        text: "hello".to_string(),
990                    },
991                    InputContentBlock::ToolResult {
992                        tool_use_id: "tool_1".to_string(),
993                        content: vec![ToolResultContentBlock::Json {
994                            value: json!({"ok": true}),
995                        }],
996                        is_error: false,
997                    },
998                ],
999            }],
1000            system: Some("be helpful".to_string()),
1001            tools: Some(vec![ToolDefinition {
1002                name: "weather".to_string(),
1003                description: Some("Get weather".to_string()),
1004                input_schema: json!({"type": "object"}),
1005            }]),
1006            tool_choice: Some(ToolChoice::Auto),
1007            stream: false,
1008        });
1009
1010        assert_eq!(payload["messages"][0]["role"], json!("system"));
1011        assert_eq!(payload["messages"][1]["role"], json!("user"));
1012        assert_eq!(payload["messages"][2]["role"], json!("tool"));
1013        assert_eq!(payload["tools"][0]["type"], json!("function"));
1014        assert_eq!(payload["tool_choice"], json!("auto"));
1015    }
1016
1017    #[test]
1018    fn tool_choice_translation_supports_required_function() {
1019        assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));
1020        assert_eq!(
1021            openai_tool_choice(&ToolChoice::Tool {
1022                name: "weather".to_string(),
1023            }),
1024            json!({"type": "function", "function": {"name": "weather"}})
1025        );
1026    }
1027
1028    #[test]
1029    fn parses_tool_arguments_fallback() {
1030        assert_eq!(
1031            parse_tool_arguments("{\"city\":\"Paris\"}"),
1032            json!({"city": "Paris"})
1033        );
1034        assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"}));
1035    }
1036
1037    #[test]
1038    fn missing_xai_api_key_is_provider_specific() {
1039        let _lock = env_lock();
1040        std::env::remove_var("XAI_API_KEY");
1041        let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai())
1042            .expect_err("missing key should error");
1043        assert!(matches!(
1044            error,
1045            ApiError::MissingCredentials {
1046                provider: "xAI",
1047                ..
1048            }
1049        ));
1050    }
1051
1052    #[test]
1053    fn endpoint_builder_accepts_base_urls_and_full_endpoints() {
1054        assert_eq!(
1055            chat_completions_endpoint("https://api.x.ai/v1"),
1056            "https://api.x.ai/v1/chat/completions"
1057        );
1058        assert_eq!(
1059            chat_completions_endpoint("https://api.x.ai/v1/"),
1060            "https://api.x.ai/v1/chat/completions"
1061        );
1062        assert_eq!(
1063            chat_completions_endpoint("https://api.x.ai/v1/chat/completions"),
1064            "https://api.x.ai/v1/chat/completions"
1065        );
1066    }
1067
1068    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
1069        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
1070        LOCK.get_or_init(|| Mutex::new(()))
1071            .lock()
1072            .expect("env lock")
1073    }
1074
1075    #[test]
1076    fn normalizes_stop_reasons() {
1077        assert_eq!(normalize_finish_reason("stop"), "end_turn");
1078        assert_eq!(normalize_finish_reason("tool_calls"), "tool_use");
1079    }
1080}