Skip to main content

steer_core/api/xai/
client.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{self, header};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{CompletionResponse, CompletionStream, Provider, StreamChunk};
11use crate::api::sse::parse_sse_stream;
12use crate::api::util::normalize_chat_url;
13use crate::app::SystemContext;
14use crate::app::conversation::{AssistantContent, Message as AppMessage, ToolResult, UserContent};
15use crate::config::model::{ModelId, ModelParameters};
16use steer_tools::ToolSchema;
17
18const DEFAULT_API_URL: &str = "https://api.x.ai/v1/chat/completions";
19
20#[derive(Clone)]
21pub struct XAIClient {
22    http_client: reqwest::Client,
23    base_url: String,
24}
25
26// xAI-specific message format (similar to OpenAI but with some differences)
27#[derive(Debug, Serialize, Deserialize)]
28#[serde(tag = "role", rename_all = "lowercase")]
29enum XAIMessage {
30    System {
31        content: String,
32        #[serde(skip_serializing_if = "Option::is_none")]
33        name: Option<String>,
34    },
35    User {
36        content: String,
37        #[serde(skip_serializing_if = "Option::is_none")]
38        name: Option<String>,
39    },
40    Assistant {
41        #[serde(skip_serializing_if = "Option::is_none")]
42        content: Option<String>,
43        #[serde(skip_serializing_if = "Option::is_none")]
44        tool_calls: Option<Vec<XAIToolCall>>,
45        #[serde(skip_serializing_if = "Option::is_none")]
46        name: Option<String>,
47    },
48    Tool {
49        content: String,
50        tool_call_id: String,
51        #[serde(skip_serializing_if = "Option::is_none")]
52        name: Option<String>,
53    },
54}
55
56// xAI function calling format
57#[derive(Debug, Serialize, Deserialize)]
58struct XAIFunction {
59    name: String,
60    description: String,
61    parameters: serde_json::Value,
62}
63
64// xAI tool format
65#[derive(Debug, Serialize, Deserialize)]
66struct XAITool {
67    #[serde(rename = "type")]
68    tool_type: String, // "function"
69    function: XAIFunction,
70}
71
72// xAI tool call
73#[derive(Debug, Serialize, Deserialize)]
74struct XAIToolCall {
75    id: String,
76    #[serde(rename = "type")]
77    tool_type: String,
78    function: XAIFunctionCall,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82struct XAIFunctionCall {
83    name: String,
84    arguments: String, // JSON string
85}
86
87#[derive(Debug, Serialize, Deserialize)]
88#[serde(rename_all = "lowercase")]
89enum ReasoningEffort {
90    Low,
91    High,
92}
93
94#[derive(Debug, Serialize, Deserialize)]
95struct StreamOptions {
96    #[serde(skip_serializing_if = "Option::is_none")]
97    include_usage: Option<bool>,
98}
99
100#[derive(Debug, Serialize, Deserialize)]
101#[serde(untagged)]
102enum ToolChoice {
103    String(String), // "auto", "required", "none"
104    Specific {
105        #[serde(rename = "type")]
106        tool_type: String,
107        function: ToolChoiceFunction,
108    },
109}
110
111#[derive(Debug, Serialize, Deserialize)]
112struct ToolChoiceFunction {
113    name: String,
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117struct ResponseFormat {
118    #[serde(rename = "type")]
119    format_type: String,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    json_schema: Option<serde_json::Value>,
122}
123
124#[derive(Debug, Serialize, Deserialize)]
125struct SearchParameters {
126    #[serde(skip_serializing_if = "Option::is_none")]
127    from_date: Option<String>,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    to_date: Option<String>,
130    #[serde(skip_serializing_if = "Option::is_none")]
131    max_search_results: Option<u32>,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    mode: Option<String>,
134    #[serde(skip_serializing_if = "Option::is_none")]
135    return_citations: Option<bool>,
136    #[serde(skip_serializing_if = "Option::is_none")]
137    sources: Option<Vec<String>>,
138}
139
140#[derive(Debug, Serialize, Deserialize)]
141struct WebSearchOptions {
142    #[serde(skip_serializing_if = "Option::is_none")]
143    search_context_size: Option<u32>,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    user_location: Option<String>,
146}
147
148#[derive(Debug, Serialize, Deserialize)]
149struct CompletionRequest {
150    model: String,
151    messages: Vec<XAIMessage>,
152    #[serde(skip_serializing_if = "Option::is_none")]
153    deferred: Option<bool>,
154    #[serde(skip_serializing_if = "Option::is_none")]
155    frequency_penalty: Option<f32>,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    logit_bias: Option<HashMap<String, f32>>,
158    #[serde(skip_serializing_if = "Option::is_none")]
159    logprobs: Option<bool>,
160    #[serde(skip_serializing_if = "Option::is_none")]
161    max_completion_tokens: Option<u32>,
162    #[serde(skip_serializing_if = "Option::is_none")]
163    max_tokens: Option<u32>,
164    #[serde(skip_serializing_if = "Option::is_none")]
165    n: Option<u32>,
166    #[serde(skip_serializing_if = "Option::is_none")]
167    parallel_tool_calls: Option<bool>,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    presence_penalty: Option<f32>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    reasoning_effort: Option<ReasoningEffort>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    response_format: Option<ResponseFormat>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    search_parameters: Option<SearchParameters>,
176    #[serde(skip_serializing_if = "Option::is_none")]
177    seed: Option<u64>,
178    #[serde(skip_serializing_if = "Option::is_none")]
179    stop: Option<Vec<String>>,
180    #[serde(skip_serializing_if = "Option::is_none")]
181    stream: Option<bool>,
182    #[serde(skip_serializing_if = "Option::is_none")]
183    stream_options: Option<StreamOptions>,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    temperature: Option<f32>,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    tool_choice: Option<ToolChoice>,
188    #[serde(skip_serializing_if = "Option::is_none")]
189    tools: Option<Vec<XAITool>>,
190    #[serde(skip_serializing_if = "Option::is_none")]
191    top_logprobs: Option<u32>,
192    #[serde(skip_serializing_if = "Option::is_none")]
193    top_p: Option<f32>,
194    #[serde(skip_serializing_if = "Option::is_none")]
195    user: Option<String>,
196    #[serde(skip_serializing_if = "Option::is_none")]
197    web_search_options: Option<WebSearchOptions>,
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201struct XAICompletionResponse {
202    id: String,
203    object: String,
204    created: u64,
205    model: String,
206    choices: Vec<Choice>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    usage: Option<XAIUsage>,
209    #[serde(skip_serializing_if = "Option::is_none")]
210    system_fingerprint: Option<String>,
211    #[serde(skip_serializing_if = "Option::is_none")]
212    citations: Option<Vec<serde_json::Value>>,
213    #[serde(skip_serializing_if = "Option::is_none")]
214    debug_output: Option<DebugOutput>,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218struct Choice {
219    index: i32,
220    message: AssistantMessage,
221    finish_reason: Option<String>,
222}
223
224#[derive(Debug, Serialize, Deserialize)]
225struct AssistantMessage {
226    content: Option<String>,
227    #[serde(skip_serializing_if = "Option::is_none")]
228    tool_calls: Option<Vec<XAIToolCall>>,
229    #[serde(skip_serializing_if = "Option::is_none")]
230    reasoning_content: Option<String>,
231}
232
233#[derive(Debug, Serialize, Deserialize)]
234struct PromptTokensDetails {
235    #[serde(rename = "cached_tokens")]
236    cached: usize,
237    #[serde(rename = "audio_tokens")]
238    audio: usize,
239    #[serde(rename = "image_tokens")]
240    image: usize,
241    #[serde(rename = "text_tokens")]
242    text: usize,
243}
244
245#[derive(Debug, Serialize, Deserialize)]
246struct CompletionTokensDetails {
247    #[serde(rename = "reasoning_tokens")]
248    reasoning: usize,
249    #[serde(rename = "audio_tokens")]
250    audio: usize,
251    #[serde(rename = "accepted_prediction_tokens")]
252    accepted_prediction: usize,
253    #[serde(rename = "rejected_prediction_tokens")]
254    rejected_prediction: usize,
255}
256
257#[derive(Debug, Serialize, Deserialize)]
258struct XAIUsage {
259    prompt_tokens: usize,
260    completion_tokens: usize,
261    total_tokens: usize,
262    #[serde(skip_serializing_if = "Option::is_none")]
263    num_sources_used: Option<usize>,
264    #[serde(skip_serializing_if = "Option::is_none")]
265    prompt_tokens_details: Option<PromptTokensDetails>,
266    #[serde(skip_serializing_if = "Option::is_none")]
267    completion_tokens_details: Option<CompletionTokensDetails>,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271struct DebugOutput {
272    attempts: usize,
273    cache_read_count: usize,
274    cache_read_input_bytes: usize,
275    cache_write_count: usize,
276    cache_write_input_bytes: usize,
277    prompt: String,
278    request: String,
279    responses: Vec<String>,
280}
281
282#[derive(Debug, Deserialize)]
283struct XAIStreamChunk {
284    #[expect(dead_code)]
285    id: String,
286    choices: Vec<XAIStreamChoice>,
287}
288
289#[derive(Debug, Deserialize)]
290struct XAIStreamChoice {
291    #[expect(dead_code)]
292    index: u32,
293    delta: XAIStreamDelta,
294    #[expect(dead_code)]
295    finish_reason: Option<String>,
296}
297
298#[derive(Debug, Deserialize)]
299struct XAIStreamDelta {
300    #[serde(skip_serializing_if = "Option::is_none")]
301    content: Option<String>,
302    #[serde(skip_serializing_if = "Option::is_none")]
303    tool_calls: Option<Vec<XAIStreamToolCall>>,
304    #[serde(skip_serializing_if = "Option::is_none")]
305    reasoning_content: Option<String>,
306}
307
308#[derive(Debug, Deserialize)]
309struct XAIStreamToolCall {
310    index: usize,
311    #[serde(skip_serializing_if = "Option::is_none")]
312    id: Option<String>,
313    #[serde(skip_serializing_if = "Option::is_none")]
314    function: Option<XAIStreamFunction>,
315}
316
317#[derive(Debug, Deserialize)]
318struct XAIStreamFunction {
319    #[serde(skip_serializing_if = "Option::is_none")]
320    name: Option<String>,
321    #[serde(skip_serializing_if = "Option::is_none")]
322    arguments: Option<String>,
323}
324
325impl XAIClient {
326    pub fn new(api_key: String) -> Result<Self, ApiError> {
327        Self::with_base_url(api_key, None)
328    }
329
330    pub fn with_base_url(api_key: String, base_url: Option<String>) -> Result<Self, ApiError> {
331        let mut headers = header::HeaderMap::new();
332        headers.insert(
333            header::AUTHORIZATION,
334            header::HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|e| {
335                ApiError::AuthenticationFailed {
336                    provider: "xai".to_string(),
337                    details: format!("Invalid API key: {e}"),
338                }
339            })?,
340        );
341
342        let client = reqwest::Client::builder()
343            .default_headers(headers)
344            .timeout(std::time::Duration::from_secs(300)) // 5 minute timeout
345            .build()
346            .map_err(ApiError::Network)?;
347
348        let base_url = normalize_chat_url(base_url.as_deref(), DEFAULT_API_URL);
349
350        Ok(Self {
351            http_client: client,
352            base_url,
353        })
354    }
355
356    fn convert_messages(
357        messages: Vec<AppMessage>,
358        system: Option<SystemContext>,
359    ) -> Vec<XAIMessage> {
360        let mut xai_messages = Vec::new();
361
362        // Add system message if provided
363        if let Some(system_content) = system.and_then(|context| context.render()) {
364            xai_messages.push(XAIMessage::System {
365                content: system_content,
366                name: None,
367            });
368        }
369
370        // Convert our messages to xAI format
371        for message in messages {
372            match &message.data {
373                crate::app::conversation::MessageData::User { content, .. } => {
374                    // Convert UserContent to text
375                    let combined_text = content
376                        .iter()
377                        .map(|user_content| match user_content {
378                            UserContent::Text { text } => text.clone(),
379                            UserContent::CommandExecution {
380                                command,
381                                stdout,
382                                stderr,
383                                exit_code,
384                            } => UserContent::format_command_execution_as_xml(
385                                command, stdout, stderr, *exit_code,
386                            ),
387                        })
388                        .collect::<Vec<_>>()
389                        .join("\n");
390
391                    // Only add the message if it has content after filtering
392                    if !combined_text.trim().is_empty() {
393                        xai_messages.push(XAIMessage::User {
394                            content: combined_text,
395                            name: None,
396                        });
397                    }
398                }
399                crate::app::conversation::MessageData::Assistant { content, .. } => {
400                    // Convert AssistantContent to xAI format
401                    let mut text_parts = Vec::new();
402                    let mut tool_calls = Vec::new();
403
404                    for content_block in content {
405                        match content_block {
406                            AssistantContent::Text { text } => {
407                                text_parts.push(text.clone());
408                            }
409                            AssistantContent::ToolCall { tool_call, .. } => {
410                                tool_calls.push(XAIToolCall {
411                                    id: tool_call.id.clone(),
412                                    tool_type: "function".to_string(),
413                                    function: XAIFunctionCall {
414                                        name: tool_call.name.clone(),
415                                        arguments: tool_call.parameters.to_string(),
416                                    },
417                                });
418                            }
419                            AssistantContent::Thought { .. } => {
420                                // xAI doesn't support thinking blocks in requests, only in responses
421                            }
422                        }
423                    }
424
425                    // Build the assistant message
426                    let content = if text_parts.is_empty() {
427                        None
428                    } else {
429                        Some(text_parts.join("\n"))
430                    };
431
432                    let tool_calls_opt = if tool_calls.is_empty() {
433                        None
434                    } else {
435                        Some(tool_calls)
436                    };
437
438                    xai_messages.push(XAIMessage::Assistant {
439                        content,
440                        tool_calls: tool_calls_opt,
441                        name: None,
442                    });
443                }
444                crate::app::conversation::MessageData::Tool {
445                    tool_use_id,
446                    result,
447                    ..
448                } => {
449                    // Convert ToolResult to xAI format
450                    let content_text = if let ToolResult::Error(e) = result {
451                        format!("Error: {e}")
452                    } else {
453                        let text = result.llm_format();
454                        if text.trim().is_empty() {
455                            "(No output)".to_string()
456                        } else {
457                            text
458                        }
459                    };
460
461                    xai_messages.push(XAIMessage::Tool {
462                        content: content_text,
463                        tool_call_id: tool_use_id.clone(),
464                        name: None,
465                    });
466                }
467            }
468        }
469
470        xai_messages
471    }
472
473    fn convert_tools(tools: Vec<ToolSchema>) -> Vec<XAITool> {
474        tools
475            .into_iter()
476            .map(|tool| XAITool {
477                tool_type: "function".to_string(),
478                function: XAIFunction {
479                    name: tool.name,
480                    description: tool.description,
481                    parameters: tool.input_schema.as_value().clone(),
482                },
483            })
484            .collect()
485    }
486}
487
488#[async_trait]
489impl Provider for XAIClient {
490    fn name(&self) -> &'static str {
491        "xai"
492    }
493
494    async fn complete(
495        &self,
496        model_id: &ModelId,
497        messages: Vec<AppMessage>,
498        system: Option<SystemContext>,
499        tools: Option<Vec<ToolSchema>>,
500        call_options: Option<ModelParameters>,
501        token: CancellationToken,
502    ) -> Result<CompletionResponse, ApiError> {
503        let xai_messages = Self::convert_messages(messages, system);
504        let xai_tools = tools.map(Self::convert_tools);
505
506        // Extract thinking support and map optional effort
507        let (supports_thinking, reasoning_effort) = call_options
508            .as_ref()
509            .and_then(|opts| opts.thinking_config)
510            .map_or((false, None), |tc| {
511                let effort = tc.effort.map(|e| match e {
512                    crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
513                    crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High, // xAI has Low/High only
514                    crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
515                    crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, // xAI has Low/High only
516                });
517                (tc.enabled, effort)
518            });
519
520        // grok-4 supports thinking by default but not the reasoning_effort parameter
521        let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
522            reasoning_effort.or(Some(ReasoningEffort::High))
523        } else {
524            None
525        };
526
527        let request = CompletionRequest {
528            model: model_id.id.clone(), // Use the model ID string
529            messages: xai_messages,
530            deferred: None,
531            frequency_penalty: None,
532            logit_bias: None,
533            logprobs: None,
534            max_completion_tokens: Some(32768),
535            max_tokens: None,
536            n: None,
537            parallel_tool_calls: None,
538            presence_penalty: None,
539            reasoning_effort,
540            response_format: None,
541            search_parameters: None,
542            seed: None,
543            stop: None,
544            stream: None,
545            stream_options: None,
546            temperature: call_options
547                .as_ref()
548                .and_then(|o| o.temperature)
549                .or(Some(1.0)),
550            tool_choice: None,
551            tools: xai_tools,
552            top_logprobs: None,
553            top_p: call_options.as_ref().and_then(|o| o.top_p),
554            user: None,
555            web_search_options: None,
556        };
557
558        let response = self
559            .http_client
560            .post(&self.base_url)
561            .json(&request)
562            .send()
563            .await
564            .map_err(ApiError::Network)?;
565
566        if !response.status().is_success() {
567            let status = response.status();
568            let error_text = response.text().await.unwrap_or_else(|_| String::new());
569
570            debug!(
571                target: "grok::complete",
572                "Grok API error - Status: {}, Body: {}",
573                status,
574                error_text
575            );
576
577            return match status.as_u16() {
578                429 => Err(ApiError::RateLimited {
579                    provider: self.name().to_string(),
580                    details: error_text,
581                }),
582                400 => Err(ApiError::InvalidRequest {
583                    provider: self.name().to_string(),
584                    details: error_text,
585                }),
586                401 => Err(ApiError::AuthenticationFailed {
587                    provider: self.name().to_string(),
588                    details: error_text,
589                }),
590                _ => Err(ApiError::ServerError {
591                    provider: self.name().to_string(),
592                    status_code: status.as_u16(),
593                    details: error_text,
594                }),
595            };
596        }
597
598        let response_text = tokio::select! {
599            () = token.cancelled() => {
600                debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
601                return Err(ApiError::Cancelled { provider: self.name().to_string() });
602            }
603            text_res = response.text() => {
604                text_res?
605            }
606        };
607
608        let xai_response: XAICompletionResponse =
609            serde_json::from_str(&response_text).map_err(|e| {
610                error!(
611                    target: "xai::complete",
612                    "Failed to parse response: {}, Body: {}",
613                    e,
614                    response_text
615                );
616                ApiError::ResponseParsingError {
617                    provider: self.name().to_string(),
618                    details: format!("Error: {e}, Body: {response_text}"),
619                }
620            })?;
621
622        // Convert xAI response to our CompletionResponse
623        if let Some(choice) = xai_response.choices.first() {
624            let mut content_blocks = Vec::new();
625
626            // Add reasoning content (thinking) first if present
627            if let Some(reasoning) = &choice.message.reasoning_content
628                && !reasoning.trim().is_empty()
629            {
630                content_blocks.push(AssistantContent::Thought {
631                    thought: crate::app::conversation::ThoughtContent::Simple {
632                        text: reasoning.clone(),
633                    },
634                });
635            }
636
637            // Add regular content
638            if let Some(content) = &choice.message.content
639                && !content.trim().is_empty()
640            {
641                content_blocks.push(AssistantContent::Text {
642                    text: content.clone(),
643                });
644            }
645
646            // Add tool calls
647            if let Some(tool_calls) = &choice.message.tool_calls {
648                for tool_call in tool_calls {
649                    // Parse the arguments JSON string
650                    let parameters = serde_json::from_str(&tool_call.function.arguments)
651                        .unwrap_or(serde_json::Value::Null);
652
653                    content_blocks.push(AssistantContent::ToolCall {
654                        tool_call: steer_tools::ToolCall {
655                            id: tool_call.id.clone(),
656                            name: tool_call.function.name.clone(),
657                            parameters,
658                        },
659                        thought_signature: None,
660                    });
661                }
662            }
663
664            Ok(crate::api::provider::CompletionResponse {
665                content: content_blocks,
666            })
667        } else {
668            Err(ApiError::NoChoices {
669                provider: self.name().to_string(),
670            })
671        }
672    }
673
674    async fn stream_complete(
675        &self,
676        model_id: &ModelId,
677        messages: Vec<AppMessage>,
678        system: Option<SystemContext>,
679        tools: Option<Vec<ToolSchema>>,
680        call_options: Option<ModelParameters>,
681        token: CancellationToken,
682    ) -> Result<CompletionStream, ApiError> {
683        let xai_messages = Self::convert_messages(messages, system);
684        let xai_tools = tools.map(Self::convert_tools);
685
686        let (supports_thinking, reasoning_effort) = call_options
687            .as_ref()
688            .and_then(|opts| opts.thinking_config)
689            .map_or((false, None), |tc| {
690                let effort = tc.effort.map(|e| match e {
691                    crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
692                    crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High,
693                    crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
694                    crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, // xAI has Low/High only
695                });
696                (tc.enabled, effort)
697            });
698
699        let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
700            reasoning_effort.or(Some(ReasoningEffort::High))
701        } else {
702            None
703        };
704
705        let request = CompletionRequest {
706            model: model_id.id.clone(),
707            messages: xai_messages,
708            deferred: None,
709            frequency_penalty: None,
710            logit_bias: None,
711            logprobs: None,
712            max_completion_tokens: Some(32768),
713            max_tokens: None,
714            n: None,
715            parallel_tool_calls: None,
716            presence_penalty: None,
717            reasoning_effort,
718            response_format: None,
719            search_parameters: None,
720            seed: None,
721            stop: None,
722            stream: Some(true),
723            stream_options: None,
724            temperature: call_options
725                .as_ref()
726                .and_then(|o| o.temperature)
727                .or(Some(1.0)),
728            tool_choice: None,
729            tools: xai_tools,
730            top_logprobs: None,
731            top_p: call_options.as_ref().and_then(|o| o.top_p),
732            user: None,
733            web_search_options: None,
734        };
735
736        let response = self
737            .http_client
738            .post(&self.base_url)
739            .json(&request)
740            .send()
741            .await
742            .map_err(ApiError::Network)?;
743
744        if !response.status().is_success() {
745            let status = response.status();
746            let error_text = response.text().await.unwrap_or_else(|_| String::new());
747
748            debug!(
749                target: "xai::stream",
750                "xAI API error - Status: {}, Body: {}",
751                status,
752                error_text
753            );
754
755            return match status.as_u16() {
756                429 => Err(ApiError::RateLimited {
757                    provider: self.name().to_string(),
758                    details: error_text,
759                }),
760                400 => Err(ApiError::InvalidRequest {
761                    provider: self.name().to_string(),
762                    details: error_text,
763                }),
764                401 => Err(ApiError::AuthenticationFailed {
765                    provider: self.name().to_string(),
766                    details: error_text,
767                }),
768                _ => Err(ApiError::ServerError {
769                    provider: self.name().to_string(),
770                    status_code: status.as_u16(),
771                    details: error_text,
772                }),
773            };
774        }
775
776        let byte_stream = response.bytes_stream();
777        let sse_stream = parse_sse_stream(byte_stream);
778
779        Ok(Box::pin(XAIClient::convert_xai_stream(sse_stream, token)))
780    }
781}
782
783impl XAIClient {
784    fn convert_xai_stream(
785        mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
786        + Unpin
787        + Send
788        + 'static,
789        token: CancellationToken,
790    ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
791        struct ToolCallAccumulator {
792            id: String,
793            name: String,
794            args: String,
795        }
796
797        async_stream::stream! {
798            let mut content: Vec<AssistantContent> = Vec::new();
799            let mut tool_call_indices: Vec<Option<usize>> = Vec::new();
800            let mut tool_calls: HashMap<usize, ToolCallAccumulator> = HashMap::new();
801            let mut tool_calls_started: std::collections::HashSet<usize> =
802                std::collections::HashSet::new();
803            let mut tool_call_positions: HashMap<usize, usize> = HashMap::new();
804            loop {
805                if token.is_cancelled() {
806                    yield StreamChunk::Error(StreamError::Cancelled);
807                    break;
808                }
809
810                let event_result = tokio::select! {
811                    biased;
812                    () = token.cancelled() => {
813                        yield StreamChunk::Error(StreamError::Cancelled);
814                        break;
815                    }
816                    event = sse_stream.next() => event
817                };
818
819                let Some(event_result) = event_result else {
820                    break;
821                };
822
823                let event = match event_result {
824                    Ok(e) => e,
825                    Err(e) => {
826                        yield StreamChunk::Error(StreamError::SseParse(e));
827                        break;
828                    }
829                };
830
831                if event.data == "[DONE]" {
832                    let tool_calls = std::mem::take(&mut tool_calls);
833                    let mut final_content = Vec::new();
834
835                    for (block, tool_index) in content.into_iter().zip(tool_call_indices.into_iter())
836                    {
837                        if let Some(index) = tool_index {
838                            let Some(tool_call) = tool_calls.get(&index) else {
839                                continue;
840                            };
841                            if tool_call.id.is_empty() || tool_call.name.is_empty() {
842                                debug!(
843                                    target: "xai::stream",
844                                    "Skipping tool call with missing id/name: id='{}' name='{}'",
845                                    tool_call.id,
846                                    tool_call.name
847                                );
848                                continue;
849                            }
850                            let parameters = serde_json::from_str(&tool_call.args)
851                                .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
852                            final_content.push(AssistantContent::ToolCall {
853                                tool_call: steer_tools::ToolCall {
854                                    id: tool_call.id.clone(),
855                                    name: tool_call.name.clone(),
856                                    parameters,
857                                },
858                                thought_signature: None,
859                            });
860                        } else {
861                            final_content.push(block);
862                        }
863                    }
864
865                    yield StreamChunk::MessageComplete(CompletionResponse { content: final_content });
866                    break;
867                }
868
869                let chunk: XAIStreamChunk = match serde_json::from_str(&event.data) {
870                    Ok(c) => c,
871                    Err(e) => {
872                        debug!(target: "xai::stream", "Failed to parse chunk: {} data: {}", e, event.data);
873                        continue;
874                    }
875                };
876
877                if let Some(choice) = chunk.choices.first() {
878                    if let Some(text_delta) = &choice.delta.content {
879                        if let Some(AssistantContent::Text { text }) = content.last_mut() { text.push_str(text_delta) } else {
880                            content.push(AssistantContent::Text {
881                                text: text_delta.clone(),
882                            });
883                            tool_call_indices.push(None);
884                        }
885                        yield StreamChunk::TextDelta(text_delta.clone());
886                    }
887
888                    if let Some(thinking_delta) = &choice.delta.reasoning_content {
889                        if let Some(AssistantContent::Thought {
890                                thought: crate::app::conversation::ThoughtContent::Simple { text },
891                            }) = content.last_mut() { text.push_str(thinking_delta) } else {
892                            content.push(AssistantContent::Thought {
893                                thought: crate::app::conversation::ThoughtContent::Simple {
894                                    text: thinking_delta.clone(),
895                                },
896                            });
897                            tool_call_indices.push(None);
898                        }
899                        yield StreamChunk::ThinkingDelta(thinking_delta.clone());
900                    }
901
902                    if let Some(tcs) = &choice.delta.tool_calls {
903                        for tc in tcs {
904                            let entry = tool_calls.entry(tc.index).or_insert_with(|| {
905                                ToolCallAccumulator {
906                                    id: String::new(),
907                                    name: String::new(),
908                                    args: String::new(),
909                                }
910                            });
911                            let mut started_now = false;
912                            let mut flushed_now = false;
913
914                            if let Some(id) = &tc.id
915                                && !id.is_empty() {
916                                    entry.id.clone_from(id);
917                                }
918                            if let Some(func) = &tc.function
919                                && let Some(name) = &func.name
920                                    && !name.is_empty() {
921                                        entry.name.clone_from(name);
922                                    }
923
924                            if let std::collections::hash_map::Entry::Vacant(e) = tool_call_positions.entry(tc.index) {
925                                let pos = content.len();
926                                content.push(AssistantContent::ToolCall {
927                                    tool_call: steer_tools::ToolCall {
928                                        id: entry.id.clone(),
929                                        name: entry.name.clone(),
930                                        parameters: serde_json::Value::String(entry.args.clone()),
931                                    },
932                                    thought_signature: None,
933                                });
934                                tool_call_indices.push(Some(tc.index));
935                                e.insert(pos);
936                            }
937
938                            if !entry.id.is_empty()
939                                && !entry.name.is_empty()
940                                && !tool_calls_started.contains(&tc.index)
941                            {
942                                tool_calls_started.insert(tc.index);
943                                started_now = true;
944                                yield StreamChunk::ToolUseStart {
945                                    id: entry.id.clone(),
946                                    name: entry.name.clone(),
947                                };
948                            }
949
950                            if let Some(func) = &tc.function
951                                && let Some(args) = &func.arguments {
952                                    entry.args.push_str(args);
953                                    if tool_calls_started.contains(&tc.index) {
954                                        if started_now {
955                                            if !entry.args.is_empty() {
956                                                yield StreamChunk::ToolUseInputDelta {
957                                                    id: entry.id.clone(),
958                                                    delta: entry.args.clone(),
959                                                };
960                                                flushed_now = true;
961                                            }
962                                        } else if !args.is_empty() {
963                                            yield StreamChunk::ToolUseInputDelta {
964                                                id: entry.id.clone(),
965                                                delta: args.clone(),
966                                            };
967                                        }
968                                    }
969                                }
970
971                            if started_now && !flushed_now && !entry.args.is_empty() {
972                                yield StreamChunk::ToolUseInputDelta {
973                                    id: entry.id.clone(),
974                                    delta: entry.args.clone(),
975                                };
976                            }
977                        }
978                    }
979                }
980            }
981        }
982    }
983}