Skip to main content

steer_core/api/xai/
client.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{self, header};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{
11    CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage,
12};
13use crate::api::sse::parse_sse_stream;
14use crate::api::util::{map_http_status_to_api_error, normalize_chat_url};
15use crate::app::SystemContext;
16use crate::app::conversation::{
17    AssistantContent, ImageSource, Message as AppMessage, ToolResult, UserContent,
18};
19use crate::config::model::{ModelId, ModelParameters};
20use steer_tools::ToolSchema;
21
22const DEFAULT_API_URL: &str = "https://api.x.ai/v1/chat/completions";
23
24#[derive(Clone)]
25pub struct XAIClient {
26    http_client: reqwest::Client,
27    base_url: String,
28}
29
30// xAI-specific message format (similar to OpenAI but with some differences)
31#[derive(Debug, Serialize, Deserialize)]
32#[serde(tag = "role", rename_all = "lowercase")]
33enum XAIMessage {
34    System {
35        content: String,
36        #[serde(skip_serializing_if = "Option::is_none")]
37        name: Option<String>,
38    },
39    User {
40        content: XAIUserContent,
41        #[serde(skip_serializing_if = "Option::is_none")]
42        name: Option<String>,
43    },
44    Assistant {
45        #[serde(skip_serializing_if = "Option::is_none")]
46        content: Option<String>,
47        #[serde(skip_serializing_if = "Option::is_none")]
48        tool_calls: Option<Vec<XAIToolCall>>,
49        #[serde(skip_serializing_if = "Option::is_none")]
50        name: Option<String>,
51    },
52    Tool {
53        content: String,
54        tool_call_id: String,
55        #[serde(skip_serializing_if = "Option::is_none")]
56        name: Option<String>,
57    },
58}
59
60#[derive(Debug, Serialize, Deserialize)]
61#[serde(untagged)]
62enum XAIUserContent {
63    Text(String),
64    Parts(Vec<XAIUserContentPart>),
65}
66
67#[derive(Debug, Serialize, Deserialize)]
68#[serde(tag = "type")]
69enum XAIUserContentPart {
70    #[serde(rename = "text")]
71    Text { text: String },
72    #[serde(rename = "image_url")]
73    ImageUrl { image_url: XAIImageUrl },
74}
75
76#[derive(Debug, Serialize, Deserialize)]
77struct XAIImageUrl {
78    url: String,
79}
80
81// xAI function calling format
82#[derive(Debug, Serialize, Deserialize)]
83struct XAIFunction {
84    name: String,
85    description: String,
86    parameters: serde_json::Value,
87}
88
89// xAI tool format
90#[derive(Debug, Serialize, Deserialize)]
91struct XAITool {
92    #[serde(rename = "type")]
93    tool_type: String, // "function"
94    function: XAIFunction,
95}
96
97// xAI tool call
98#[derive(Debug, Serialize, Deserialize)]
99struct XAIToolCall {
100    id: String,
101    #[serde(rename = "type")]
102    tool_type: String,
103    function: XAIFunctionCall,
104}
105
106#[derive(Debug, Serialize, Deserialize)]
107struct XAIFunctionCall {
108    name: String,
109    arguments: String, // JSON string
110}
111
112#[derive(Debug, Serialize, Deserialize)]
113#[serde(rename_all = "lowercase")]
114enum ReasoningEffort {
115    Low,
116    High,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct StreamOptions {
121    #[serde(skip_serializing_if = "Option::is_none")]
122    include_usage: Option<bool>,
123}
124
125#[derive(Debug, Serialize, Deserialize)]
126#[serde(untagged)]
127enum ToolChoice {
128    String(String), // "auto", "required", "none"
129    Specific {
130        #[serde(rename = "type")]
131        tool_type: String,
132        function: ToolChoiceFunction,
133    },
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137struct ToolChoiceFunction {
138    name: String,
139}
140
141#[derive(Debug, Serialize, Deserialize)]
142struct ResponseFormat {
143    #[serde(rename = "type")]
144    format_type: String,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    json_schema: Option<serde_json::Value>,
147}
148
149#[derive(Debug, Serialize, Deserialize)]
150struct SearchParameters {
151    #[serde(skip_serializing_if = "Option::is_none")]
152    from_date: Option<String>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    to_date: Option<String>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    max_search_results: Option<u32>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    mode: Option<String>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    return_citations: Option<bool>,
161    #[serde(skip_serializing_if = "Option::is_none")]
162    sources: Option<Vec<String>>,
163}
164
165#[derive(Debug, Serialize, Deserialize)]
166struct WebSearchOptions {
167    #[serde(skip_serializing_if = "Option::is_none")]
168    search_context_size: Option<u32>,
169    #[serde(skip_serializing_if = "Option::is_none")]
170    user_location: Option<String>,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174struct CompletionRequest {
175    model: String,
176    messages: Vec<XAIMessage>,
177    #[serde(skip_serializing_if = "Option::is_none")]
178    deferred: Option<bool>,
179    #[serde(skip_serializing_if = "Option::is_none")]
180    frequency_penalty: Option<f32>,
181    #[serde(skip_serializing_if = "Option::is_none")]
182    logit_bias: Option<HashMap<String, f32>>,
183    #[serde(skip_serializing_if = "Option::is_none")]
184    logprobs: Option<bool>,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    max_completion_tokens: Option<u32>,
187    #[serde(skip_serializing_if = "Option::is_none")]
188    max_tokens: Option<u32>,
189    #[serde(skip_serializing_if = "Option::is_none")]
190    n: Option<u32>,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    parallel_tool_calls: Option<bool>,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    presence_penalty: Option<f32>,
195    #[serde(skip_serializing_if = "Option::is_none")]
196    reasoning_effort: Option<ReasoningEffort>,
197    #[serde(skip_serializing_if = "Option::is_none")]
198    response_format: Option<ResponseFormat>,
199    #[serde(skip_serializing_if = "Option::is_none")]
200    search_parameters: Option<SearchParameters>,
201    #[serde(skip_serializing_if = "Option::is_none")]
202    seed: Option<u64>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    stop: Option<Vec<String>>,
205    #[serde(skip_serializing_if = "Option::is_none")]
206    stream: Option<bool>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    stream_options: Option<StreamOptions>,
209    #[serde(skip_serializing_if = "Option::is_none")]
210    temperature: Option<f32>,
211    #[serde(skip_serializing_if = "Option::is_none")]
212    tool_choice: Option<ToolChoice>,
213    #[serde(skip_serializing_if = "Option::is_none")]
214    tools: Option<Vec<XAITool>>,
215    #[serde(skip_serializing_if = "Option::is_none")]
216    top_logprobs: Option<u32>,
217    #[serde(skip_serializing_if = "Option::is_none")]
218    top_p: Option<f32>,
219    #[serde(skip_serializing_if = "Option::is_none")]
220    user: Option<String>,
221    #[serde(skip_serializing_if = "Option::is_none")]
222    web_search_options: Option<WebSearchOptions>,
223}
224
225#[derive(Debug, Serialize, Deserialize)]
226struct XAICompletionResponse {
227    id: String,
228    object: String,
229    created: u64,
230    model: String,
231    choices: Vec<Choice>,
232    #[serde(skip_serializing_if = "Option::is_none")]
233    usage: Option<XAIUsage>,
234    #[serde(skip_serializing_if = "Option::is_none")]
235    system_fingerprint: Option<String>,
236    #[serde(skip_serializing_if = "Option::is_none")]
237    citations: Option<Vec<serde_json::Value>>,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    debug_output: Option<DebugOutput>,
240}
241
242#[derive(Debug, Serialize, Deserialize)]
243struct Choice {
244    index: i32,
245    message: AssistantMessage,
246    finish_reason: Option<String>,
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250struct AssistantMessage {
251    content: Option<String>,
252    #[serde(skip_serializing_if = "Option::is_none")]
253    tool_calls: Option<Vec<XAIToolCall>>,
254    #[serde(skip_serializing_if = "Option::is_none")]
255    reasoning_content: Option<String>,
256}
257
258#[derive(Debug, Serialize, Deserialize)]
259struct PromptTokensDetails {
260    #[serde(rename = "cached_tokens")]
261    cached: usize,
262    #[serde(rename = "audio_tokens")]
263    audio: usize,
264    #[serde(rename = "image_tokens")]
265    image: usize,
266    #[serde(rename = "text_tokens")]
267    text: usize,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271struct CompletionTokensDetails {
272    #[serde(rename = "reasoning_tokens")]
273    reasoning: usize,
274    #[serde(rename = "audio_tokens")]
275    audio: usize,
276    #[serde(rename = "accepted_prediction_tokens")]
277    accepted_prediction: usize,
278    #[serde(rename = "rejected_prediction_tokens")]
279    rejected_prediction: usize,
280}
281
282#[derive(Debug, Serialize, Deserialize)]
283struct XAIUsage {
284    prompt_tokens: usize,
285    completion_tokens: usize,
286    total_tokens: usize,
287    #[serde(skip_serializing_if = "Option::is_none")]
288    num_sources_used: Option<usize>,
289    #[serde(skip_serializing_if = "Option::is_none")]
290    prompt_tokens_details: Option<PromptTokensDetails>,
291    #[serde(skip_serializing_if = "Option::is_none")]
292    completion_tokens_details: Option<CompletionTokensDetails>,
293}
294
295#[derive(Debug, Serialize, Deserialize)]
296struct DebugOutput {
297    attempts: usize,
298    cache_read_count: usize,
299    cache_read_input_bytes: usize,
300    cache_write_count: usize,
301    cache_write_input_bytes: usize,
302    prompt: String,
303    request: String,
304    responses: Vec<String>,
305}
306
307#[derive(Debug, Deserialize)]
308struct XAIStreamChunk {
309    #[expect(dead_code)]
310    id: String,
311    choices: Vec<XAIStreamChoice>,
312    #[serde(skip_serializing_if = "Option::is_none")]
313    usage: Option<XAIUsage>,
314}
315
316#[derive(Debug, Deserialize)]
317struct XAIStreamChoice {
318    #[expect(dead_code)]
319    index: u32,
320    delta: XAIStreamDelta,
321    #[expect(dead_code)]
322    finish_reason: Option<String>,
323}
324
325#[derive(Debug, Deserialize)]
326struct XAIStreamDelta {
327    #[serde(skip_serializing_if = "Option::is_none")]
328    content: Option<String>,
329    #[serde(skip_serializing_if = "Option::is_none")]
330    tool_calls: Option<Vec<XAIStreamToolCall>>,
331    #[serde(skip_serializing_if = "Option::is_none")]
332    reasoning_content: Option<String>,
333}
334
335#[derive(Debug, Deserialize)]
336struct XAIStreamToolCall {
337    index: usize,
338    #[serde(skip_serializing_if = "Option::is_none")]
339    id: Option<String>,
340    #[serde(skip_serializing_if = "Option::is_none")]
341    function: Option<XAIStreamFunction>,
342}
343
344#[derive(Debug, Deserialize)]
345struct XAIStreamFunction {
346    #[serde(skip_serializing_if = "Option::is_none")]
347    name: Option<String>,
348    #[serde(skip_serializing_if = "Option::is_none")]
349    arguments: Option<String>,
350}
351
352impl XAIClient {
353    pub fn new(api_key: String) -> Result<Self, ApiError> {
354        Self::with_base_url(api_key, None)
355    }
356
357    pub fn with_base_url(api_key: String, base_url: Option<String>) -> Result<Self, ApiError> {
358        let mut headers = header::HeaderMap::new();
359        headers.insert(
360            header::AUTHORIZATION,
361            header::HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|e| {
362                ApiError::AuthenticationFailed {
363                    provider: "xai".to_string(),
364                    details: format!("Invalid API key: {e}"),
365                }
366            })?,
367        );
368
369        let client = reqwest::Client::builder()
370            .default_headers(headers)
371            .timeout(std::time::Duration::from_secs(300)) // 5 minute timeout
372            .build()
373            .map_err(ApiError::Network)?;
374
375        let base_url = normalize_chat_url(base_url.as_deref(), DEFAULT_API_URL);
376
377        Ok(Self {
378            http_client: client,
379            base_url,
380        })
381    }
382
383    fn user_image_part(
384        image: &crate::app::conversation::ImageContent,
385    ) -> Result<XAIUserContentPart, ApiError> {
386        let image_url = match &image.source {
387            ImageSource::DataUrl { data_url } => data_url.clone(),
388            ImageSource::Url { url } => url.clone(),
389            ImageSource::SessionFile { relative_path } => {
390                return Err(ApiError::UnsupportedFeature {
391                    provider: "xai".to_string(),
392                    feature: "image input source".to_string(),
393                    details: format!(
394                        "xAI chat API cannot access session file '{}' directly; use data URLs or public URLs",
395                        relative_path
396                    ),
397                });
398            }
399        };
400
401        Ok(XAIUserContentPart::ImageUrl {
402            image_url: XAIImageUrl { url: image_url },
403        })
404    }
405
406    fn convert_messages(
407        messages: Vec<AppMessage>,
408        system: Option<SystemContext>,
409    ) -> Result<Vec<XAIMessage>, ApiError> {
410        let mut xai_messages = Vec::new();
411
412        // Add system message if provided
413        if let Some(system_content) = system.and_then(|context| context.render()) {
414            xai_messages.push(XAIMessage::System {
415                content: system_content,
416                name: None,
417            });
418        }
419
420        // Convert our messages to xAI format
421        for message in messages {
422            match &message.data {
423                crate::app::conversation::MessageData::User { content, .. } => {
424                    let mut content_parts = Vec::new();
425
426                    for user_content in content {
427                        match user_content {
428                            UserContent::Text { text } => {
429                                content_parts.push(XAIUserContentPart::Text { text: text.clone() });
430                            }
431                            UserContent::Image { image } => {
432                                content_parts.push(Self::user_image_part(image)?);
433                            }
434                            UserContent::CommandExecution {
435                                command,
436                                stdout,
437                                stderr,
438                                exit_code,
439                            } => {
440                                content_parts.push(XAIUserContentPart::Text {
441                                    text: UserContent::format_command_execution_as_xml(
442                                        command, stdout, stderr, *exit_code,
443                                    ),
444                                });
445                            }
446                        }
447                    }
448
449                    // Only add the message if it has content after filtering
450                    if !content_parts.is_empty() {
451                        let content = match content_parts.as_slice() {
452                            [XAIUserContentPart::Text { text }] => {
453                                XAIUserContent::Text(text.clone())
454                            }
455                            _ => XAIUserContent::Parts(content_parts),
456                        };
457
458                        xai_messages.push(XAIMessage::User {
459                            content,
460                            name: None,
461                        });
462                    }
463                }
464                crate::app::conversation::MessageData::Assistant { content, .. } => {
465                    // Convert AssistantContent to xAI format
466                    let mut text_parts = Vec::new();
467                    let mut tool_calls = Vec::new();
468
469                    for content_block in content {
470                        match content_block {
471                            AssistantContent::Text { text } => {
472                                text_parts.push(text.clone());
473                            }
474                            AssistantContent::Image { image } => {
475                                match Self::user_image_part(image) {
476                                    Ok(XAIUserContentPart::ImageUrl { image_url }) => {
477                                        text_parts.push(format!("[Image URL: {}]", image_url.url));
478                                    }
479                                    Ok(XAIUserContentPart::Text { text }) => {
480                                        text_parts.push(text);
481                                    }
482                                    Err(err) => {
483                                        debug!(
484                                            target: "xai::convert_messages",
485                                            "Skipping unsupported assistant image block: {}",
486                                            err
487                                        );
488                                    }
489                                }
490                            }
491                            AssistantContent::ToolCall { tool_call, .. } => {
492                                tool_calls.push(XAIToolCall {
493                                    id: tool_call.id.clone(),
494                                    tool_type: "function".to_string(),
495                                    function: XAIFunctionCall {
496                                        name: tool_call.name.clone(),
497                                        arguments: tool_call.parameters.to_string(),
498                                    },
499                                });
500                            }
501                            AssistantContent::Thought { .. } => {
502
503                                // xAI doesn't support thinking blocks in requests, only in responses
504                            }
505                        }
506                    }
507
508                    // Build the assistant message
509                    let content = if text_parts.is_empty() {
510                        None
511                    } else {
512                        Some(text_parts.join("\n"))
513                    };
514
515                    let tool_calls_opt = if tool_calls.is_empty() {
516                        None
517                    } else {
518                        Some(tool_calls)
519                    };
520
521                    xai_messages.push(XAIMessage::Assistant {
522                        content,
523                        tool_calls: tool_calls_opt,
524                        name: None,
525                    });
526                }
527                crate::app::conversation::MessageData::Tool {
528                    tool_use_id,
529                    result,
530                    ..
531                } => {
532                    // Convert ToolResult to xAI format
533                    let content_text = if let ToolResult::Error(e) = result {
534                        format!("Error: {e}")
535                    } else {
536                        let text = result.llm_format();
537                        if text.trim().is_empty() {
538                            "(No output)".to_string()
539                        } else {
540                            text
541                        }
542                    };
543
544                    xai_messages.push(XAIMessage::Tool {
545                        content: content_text,
546                        tool_call_id: tool_use_id.clone(),
547                        name: None,
548                    });
549                }
550            }
551        }
552
553        Ok(xai_messages)
554    }
555
556    fn convert_tools(tools: Vec<ToolSchema>) -> Vec<XAITool> {
557        tools
558            .into_iter()
559            .map(|tool| XAITool {
560                tool_type: "function".to_string(),
561                function: XAIFunction {
562                    name: tool.name,
563                    description: tool.description,
564                    parameters: tool.input_schema.as_value().clone(),
565                },
566            })
567            .collect()
568    }
569}
570
571fn saturating_u32(value: usize) -> u32 {
572    u32::try_from(value).unwrap_or(u32::MAX)
573}
574
575fn map_xai_usage(usage: &XAIUsage) -> TokenUsage {
576    TokenUsage::new(
577        saturating_u32(usage.prompt_tokens),
578        saturating_u32(usage.completion_tokens),
579        saturating_u32(usage.total_tokens),
580    )
581}
582
583fn convert_xai_completion_response(
584    xai_response: XAICompletionResponse,
585) -> Result<CompletionResponse, ApiError> {
586    let usage = xai_response.usage.as_ref().map(map_xai_usage);
587
588    let Some(choice) = xai_response.choices.first() else {
589        return Err(ApiError::NoChoices {
590            provider: "xai".to_string(),
591        });
592    };
593
594    let mut content_blocks = Vec::new();
595
596    // Add reasoning content (thinking) first if present
597    if let Some(reasoning) = &choice.message.reasoning_content
598        && !reasoning.trim().is_empty()
599    {
600        content_blocks.push(AssistantContent::Thought {
601            thought: crate::app::conversation::ThoughtContent::Simple {
602                text: reasoning.clone(),
603            },
604        });
605    }
606
607    // Add regular content
608    if let Some(content) = &choice.message.content
609        && !content.trim().is_empty()
610    {
611        content_blocks.push(AssistantContent::Text {
612            text: content.clone(),
613        });
614    }
615
616    // Add tool calls
617    if let Some(tool_calls) = &choice.message.tool_calls {
618        for tool_call in tool_calls {
619            // Parse the arguments JSON string
620            let parameters = serde_json::from_str(&tool_call.function.arguments)
621                .unwrap_or(serde_json::Value::Null);
622
623            content_blocks.push(AssistantContent::ToolCall {
624                tool_call: steer_tools::ToolCall {
625                    id: tool_call.id.clone(),
626                    name: tool_call.function.name.clone(),
627                    parameters,
628                },
629                thought_signature: None,
630            });
631        }
632    }
633
634    Ok(CompletionResponse {
635        content: content_blocks,
636        usage,
637    })
638}
639
640#[async_trait]
641impl Provider for XAIClient {
642    fn name(&self) -> &'static str {
643        "xai"
644    }
645
646    async fn complete(
647        &self,
648        model_id: &ModelId,
649        messages: Vec<AppMessage>,
650        system: Option<SystemContext>,
651        tools: Option<Vec<ToolSchema>>,
652        call_options: Option<ModelParameters>,
653        token: CancellationToken,
654    ) -> Result<CompletionResponse, ApiError> {
655        let xai_messages = Self::convert_messages(messages, system)?;
656        let xai_tools = tools.map(Self::convert_tools);
657
658        // Extract thinking support and map optional effort
659        let (supports_thinking, reasoning_effort) = call_options
660            .as_ref()
661            .and_then(|opts| opts.thinking_config)
662            .map_or((false, None), |tc| {
663                let effort = tc.effort.map(|e| match e {
664                    crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
665                    crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High, // xAI has Low/High only
666                    crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
667                    crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, // xAI has Low/High only
668                });
669                (tc.enabled, effort)
670            });
671
672        // grok-4 supports thinking by default but not the reasoning_effort parameter
673        let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
674            reasoning_effort.or(Some(ReasoningEffort::High))
675        } else {
676            None
677        };
678
679        let max_output_tokens = call_options
680            .as_ref()
681            .and_then(|o| o.max_output_tokens)
682            .or(Some(32_768));
683
684        let request = CompletionRequest {
685            model: model_id.id.clone(), // Use the model ID string
686            messages: xai_messages,
687            deferred: None,
688            frequency_penalty: None,
689            logit_bias: None,
690            logprobs: None,
691            max_completion_tokens: max_output_tokens,
692            max_tokens: max_output_tokens,
693            n: None,
694            parallel_tool_calls: None,
695            presence_penalty: None,
696            reasoning_effort,
697            response_format: None,
698            search_parameters: None,
699            seed: None,
700            stop: None,
701            stream: None,
702            stream_options: None,
703            temperature: call_options
704                .as_ref()
705                .and_then(|o| o.temperature)
706                .or(Some(1.0)),
707            tool_choice: None,
708            tools: xai_tools,
709            top_logprobs: None,
710            top_p: call_options.as_ref().and_then(|o| o.top_p),
711            user: None,
712            web_search_options: None,
713        };
714
715        let response = self
716            .http_client
717            .post(&self.base_url)
718            .json(&request)
719            .send()
720            .await
721            .map_err(ApiError::Network)?;
722
723        if !response.status().is_success() {
724            let status = response.status();
725            let error_text = response.text().await.unwrap_or_else(|_| String::new());
726
727            debug!(
728                target: "grok::complete",
729                "Grok API error - Status: {}, Body: {}",
730                status,
731                error_text
732            );
733
734            return Err(map_http_status_to_api_error(
735                self.name(),
736                status.as_u16(),
737                error_text,
738            ));
739        }
740
741        let response_text = tokio::select! {
742            () = token.cancelled() => {
743                debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
744                return Err(ApiError::Cancelled { provider: self.name().to_string() });
745            }
746            text_res = response.text() => {
747                text_res?
748            }
749        };
750
751        let xai_response: XAICompletionResponse =
752            serde_json::from_str(&response_text).map_err(|e| {
753                error!(
754                    target: "xai::complete",
755                    "Failed to parse response: {}, Body: {}",
756                    e,
757                    response_text
758                );
759                ApiError::ResponseParsingError {
760                    provider: self.name().to_string(),
761                    details: format!("Error: {e}, Body: {response_text}"),
762                }
763            })?;
764
765        convert_xai_completion_response(xai_response).map_err(|err| match err {
766            ApiError::NoChoices { .. } => ApiError::NoChoices {
767                provider: self.name().to_string(),
768            },
769            other => other,
770        })
771    }
772
773    async fn stream_complete(
774        &self,
775        model_id: &ModelId,
776        messages: Vec<AppMessage>,
777        system: Option<SystemContext>,
778        tools: Option<Vec<ToolSchema>>,
779        call_options: Option<ModelParameters>,
780        token: CancellationToken,
781    ) -> Result<CompletionStream, ApiError> {
782        let xai_messages = Self::convert_messages(messages, system)?;
783        let xai_tools = tools.map(Self::convert_tools);
784
785        let (supports_thinking, reasoning_effort) = call_options
786            .as_ref()
787            .and_then(|opts| opts.thinking_config)
788            .map_or((false, None), |tc| {
789                let effort = tc.effort.map(|e| match e {
790                    crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
791                    crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High,
792                    crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
793                    crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, // xAI has Low/High only
794                });
795                (tc.enabled, effort)
796            });
797
798        let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
799            reasoning_effort.or(Some(ReasoningEffort::High))
800        } else {
801            None
802        };
803
804        let max_output_tokens = call_options
805            .as_ref()
806            .and_then(|o| o.max_output_tokens)
807            .or(Some(32_768));
808
809        let request = CompletionRequest {
810            model: model_id.id.clone(),
811            messages: xai_messages,
812            deferred: None,
813            frequency_penalty: None,
814            logit_bias: None,
815            logprobs: None,
816            max_completion_tokens: max_output_tokens,
817            max_tokens: max_output_tokens,
818            n: None,
819            parallel_tool_calls: None,
820            presence_penalty: None,
821            reasoning_effort,
822            response_format: None,
823            search_parameters: None,
824            seed: None,
825            stop: None,
826            stream: Some(true),
827            stream_options: Some(StreamOptions {
828                include_usage: Some(true),
829            }),
830            temperature: call_options
831                .as_ref()
832                .and_then(|o| o.temperature)
833                .or(Some(1.0)),
834            tool_choice: None,
835            tools: xai_tools,
836            top_logprobs: None,
837            top_p: call_options.as_ref().and_then(|o| o.top_p),
838            user: None,
839            web_search_options: None,
840        };
841
842        let response = self
843            .http_client
844            .post(&self.base_url)
845            .json(&request)
846            .send()
847            .await
848            .map_err(ApiError::Network)?;
849
850        if !response.status().is_success() {
851            let status = response.status();
852            let error_text = response.text().await.unwrap_or_else(|_| String::new());
853
854            debug!(
855                target: "xai::stream",
856                "xAI API error - Status: {}, Body: {}",
857                status,
858                error_text
859            );
860
861            return Err(map_http_status_to_api_error(
862                self.name(),
863                status.as_u16(),
864                error_text,
865            ));
866        }
867
868        let byte_stream = response.bytes_stream();
869        let sse_stream = parse_sse_stream(byte_stream);
870
871        Ok(Box::pin(XAIClient::convert_xai_stream(sse_stream, token)))
872    }
873}
874
875impl XAIClient {
876    fn convert_xai_stream(
877        mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
878        + Unpin
879        + Send
880        + 'static,
881        token: CancellationToken,
882    ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
883        struct ToolCallAccumulator {
884            id: String,
885            name: String,
886            args: String,
887        }
888
889        async_stream::stream! {
890            let mut content: Vec<AssistantContent> = Vec::new();
891            let mut tool_call_indices: Vec<Option<usize>> = Vec::new();
892            let mut tool_calls: HashMap<usize, ToolCallAccumulator> = HashMap::new();
893            let mut tool_calls_started: std::collections::HashSet<usize> =
894                std::collections::HashSet::new();
895            let mut tool_call_positions: HashMap<usize, usize> = HashMap::new();
896            let mut latest_usage: Option<TokenUsage> = None;
897            loop {
898                if token.is_cancelled() {
899                    yield StreamChunk::Error(StreamError::Cancelled);
900                    break;
901                }
902
903                let event_result = tokio::select! {
904                    biased;
905                    () = token.cancelled() => {
906                        yield StreamChunk::Error(StreamError::Cancelled);
907                        break;
908                    }
909                    event = sse_stream.next() => event
910                };
911
912                let Some(event_result) = event_result else {
913                    break;
914                };
915
916                let event = match event_result {
917                    Ok(e) => e,
918                    Err(e) => {
919                        yield StreamChunk::Error(StreamError::SseParse(e));
920                        break;
921                    }
922                };
923
924                if event.data == "[DONE]" {
925                    let tool_calls = std::mem::take(&mut tool_calls);
926                    let mut final_content = Vec::new();
927
928                    for (block, tool_index) in content.into_iter().zip(tool_call_indices.into_iter())
929                    {
930                        if let Some(index) = tool_index {
931                            let Some(tool_call) = tool_calls.get(&index) else {
932                                continue;
933                            };
934                            if tool_call.id.is_empty() || tool_call.name.is_empty() {
935                                debug!(
936                                    target: "xai::stream",
937                                    "Skipping tool call with missing id/name: id='{}' name='{}'",
938                                    tool_call.id,
939                                    tool_call.name
940                                );
941                                continue;
942                            }
943                            let parameters = serde_json::from_str(&tool_call.args)
944                                .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
945                            final_content.push(AssistantContent::ToolCall {
946                                tool_call: steer_tools::ToolCall {
947                                    id: tool_call.id.clone(),
948                                    name: tool_call.name.clone(),
949                                    parameters,
950                                },
951                                thought_signature: None,
952                            });
953                        } else {
954                            final_content.push(block);
955                        }
956                    }
957
958                    yield StreamChunk::MessageComplete(CompletionResponse {
959                        content: final_content,
960                        usage: latest_usage,
961                    });
962                    break;
963                }
964
965                let chunk: XAIStreamChunk = match serde_json::from_str(&event.data) {
966                    Ok(c) => c,
967                    Err(e) => {
968                        debug!(target: "xai::stream", "Failed to parse chunk: {} data: {}", e, event.data);
969                        continue;
970                    }
971                };
972
973                if let Some(usage) = chunk.usage.as_ref() {
974                    latest_usage = Some(map_xai_usage(usage));
975                }
976
977                if let Some(choice) = chunk.choices.first() {
978                    if let Some(text_delta) = &choice.delta.content {
979                        if let Some(AssistantContent::Text { text }) = content.last_mut() { text.push_str(text_delta) } else {
980                            content.push(AssistantContent::Text {
981                                text: text_delta.clone(),
982                            });
983                            tool_call_indices.push(None);
984                        }
985                        yield StreamChunk::TextDelta(text_delta.clone());
986                    }
987
988                    if let Some(thinking_delta) = &choice.delta.reasoning_content {
989                        if let Some(AssistantContent::Thought {
990                                thought: crate::app::conversation::ThoughtContent::Simple { text },
991                            }) = content.last_mut() { text.push_str(thinking_delta) } else {
992                            content.push(AssistantContent::Thought {
993                                thought: crate::app::conversation::ThoughtContent::Simple {
994                                    text: thinking_delta.clone(),
995                                },
996                            });
997                            tool_call_indices.push(None);
998                        }
999                        yield StreamChunk::ThinkingDelta(thinking_delta.clone());
1000                    }
1001
1002                    if let Some(tcs) = &choice.delta.tool_calls {
1003                        for tc in tcs {
1004                            let entry = tool_calls.entry(tc.index).or_insert_with(|| {
1005                                ToolCallAccumulator {
1006                                    id: String::new(),
1007                                    name: String::new(),
1008                                    args: String::new(),
1009                                }
1010                            });
1011                            let mut started_now = false;
1012                            let mut flushed_now = false;
1013
1014                            if let Some(id) = &tc.id
1015                                && !id.is_empty() {
1016                                    entry.id.clone_from(id);
1017                                }
1018                            if let Some(func) = &tc.function
1019                                && let Some(name) = &func.name
1020                                    && !name.is_empty() {
1021                                        entry.name.clone_from(name);
1022                                    }
1023
1024                            if let std::collections::hash_map::Entry::Vacant(e) = tool_call_positions.entry(tc.index) {
1025                                let pos = content.len();
1026                                content.push(AssistantContent::ToolCall {
1027                                    tool_call: steer_tools::ToolCall {
1028                                        id: entry.id.clone(),
1029                                        name: entry.name.clone(),
1030                                        parameters: serde_json::Value::String(entry.args.clone()),
1031                                    },
1032                                    thought_signature: None,
1033                                });
1034                                tool_call_indices.push(Some(tc.index));
1035                                e.insert(pos);
1036                            }
1037
1038                            if !entry.id.is_empty()
1039                                && !entry.name.is_empty()
1040                                && !tool_calls_started.contains(&tc.index)
1041                            {
1042                                tool_calls_started.insert(tc.index);
1043                                started_now = true;
1044                                yield StreamChunk::ToolUseStart {
1045                                    id: entry.id.clone(),
1046                                    name: entry.name.clone(),
1047                                };
1048                            }
1049
1050                            if let Some(func) = &tc.function
1051                                && let Some(args) = &func.arguments {
1052                                    entry.args.push_str(args);
1053                                    if tool_calls_started.contains(&tc.index) {
1054                                        if started_now {
1055                                            if !entry.args.is_empty() {
1056                                                yield StreamChunk::ToolUseInputDelta {
1057                                                    id: entry.id.clone(),
1058                                                    delta: entry.args.clone(),
1059                                                };
1060                                                flushed_now = true;
1061                                            }
1062                                        } else if !args.is_empty() {
1063                                            yield StreamChunk::ToolUseInputDelta {
1064                                                id: entry.id.clone(),
1065                                                delta: args.clone(),
1066                                            };
1067                                        }
1068                                    }
1069                                }
1070
1071                            if started_now && !flushed_now && !entry.args.is_empty() {
1072                                yield StreamChunk::ToolUseInputDelta {
1073                                    id: entry.id.clone(),
1074                                    delta: entry.args.clone(),
1075                                };
1076                            }
1077                        }
1078                    }
1079                }
1080            }
1081        }
1082    }
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087    use super::*;
1088    use crate::api::provider::StreamChunk;
1089    use crate::api::sse::SseEvent;
1090    use crate::app::conversation::{
1091        AssistantContent, ImageContent, ImageSource, Message, MessageData, UserContent,
1092    };
1093    use futures::StreamExt;
1094    use futures::stream;
1095    use std::pin::pin;
1096    use tokio_util::sync::CancellationToken;
1097
1098    #[test]
1099    fn test_convert_messages_includes_data_url_image_part() {
1100        let messages = vec![Message {
1101            data: MessageData::User {
1102                content: vec![
1103                    UserContent::Text {
1104                        text: "describe".to_string(),
1105                    },
1106                    UserContent::Image {
1107                        image: ImageContent {
1108                            mime_type: "image/png".to_string(),
1109                            source: ImageSource::DataUrl {
1110                                data_url: "data:image/png;base64,Zm9v".to_string(),
1111                            },
1112                            width: None,
1113                            height: None,
1114                            bytes: None,
1115                            sha256: None,
1116                        },
1117                    },
1118                ],
1119            },
1120            timestamp: 1,
1121            id: "msg-1".to_string(),
1122            parent_message_id: None,
1123        }];
1124
1125        let converted = XAIClient::convert_messages(messages, None).expect("convert messages");
1126        assert_eq!(converted.len(), 1);
1127
1128        match &converted[0] {
1129            XAIMessage::User { content, .. } => match content {
1130                XAIUserContent::Parts(parts) => {
1131                    assert_eq!(parts.len(), 2);
1132                    assert!(matches!(
1133                        &parts[0],
1134                        XAIUserContentPart::Text { text } if text == "describe"
1135                    ));
1136                    assert!(matches!(
1137                        &parts[1],
1138                        XAIUserContentPart::ImageUrl { image_url }
1139                            if image_url.url == "data:image/png;base64,Zm9v"
1140                    ));
1141                }
1142                other => panic!("Expected parts content, got {other:?}"),
1143            },
1144            other => panic!("Expected user message, got {other:?}"),
1145        }
1146    }
1147
1148    #[test]
1149    fn test_convert_messages_rejects_session_file_image_source() {
1150        let messages = vec![Message {
1151            data: MessageData::User {
1152                content: vec![UserContent::Image {
1153                    image: ImageContent {
1154                        mime_type: "image/png".to_string(),
1155                        source: ImageSource::SessionFile {
1156                            relative_path: "session-1/image.png".to_string(),
1157                        },
1158                        width: None,
1159                        height: None,
1160                        bytes: None,
1161                        sha256: None,
1162                    },
1163                }],
1164            },
1165            timestamp: 1,
1166            id: "msg-1".to_string(),
1167            parent_message_id: None,
1168        }];
1169
1170        let err =
1171            XAIClient::convert_messages(messages, None).expect_err("expected unsupported feature");
1172        match err {
1173            ApiError::UnsupportedFeature {
1174                provider,
1175                feature,
1176                details,
1177            } => {
1178                assert_eq!(provider, "xai");
1179                assert_eq!(feature, "image input source");
1180                assert!(details.contains("session file"));
1181            }
1182            other => panic!("Expected UnsupportedFeature, got {other:?}"),
1183        }
1184    }
1185
1186    #[test]
1187    fn test_map_xai_usage() {
1188        let usage = XAIUsage {
1189            prompt_tokens: 15,
1190            completion_tokens: 9,
1191            total_tokens: 24,
1192            num_sources_used: None,
1193            prompt_tokens_details: None,
1194            completion_tokens_details: None,
1195        };
1196
1197        assert_eq!(map_xai_usage(&usage), TokenUsage::new(15, 9, 24));
1198    }
1199
1200    #[test]
1201    fn test_non_stream_completion_maps_usage() {
1202        let usage = XAIUsage {
1203            prompt_tokens: 6,
1204            completion_tokens: 4,
1205            total_tokens: 10,
1206            num_sources_used: None,
1207            prompt_tokens_details: None,
1208            completion_tokens_details: None,
1209        };
1210        let choice = Choice {
1211            index: 0,
1212            message: AssistantMessage {
1213                content: Some("hello".to_string()),
1214                tool_calls: None,
1215                reasoning_content: None,
1216            },
1217            finish_reason: Some("stop".to_string()),
1218        };
1219        let response = XAICompletionResponse {
1220            id: "resp_1".to_string(),
1221            object: "chat.completion".to_string(),
1222            created: 1,
1223            model: "grok-test".to_string(),
1224            choices: vec![choice],
1225            usage: Some(usage),
1226            system_fingerprint: None,
1227            citations: None,
1228            debug_output: None,
1229        };
1230
1231        let converted = convert_xai_completion_response(response).expect("response should map");
1232
1233        assert_eq!(converted.usage, Some(TokenUsage::new(6, 4, 10)));
1234        assert!(matches!(
1235            converted.content.first(),
1236            Some(AssistantContent::Text { text }) if text == "hello"
1237        ));
1238    }
1239
1240    #[tokio::test]
1241    async fn test_convert_xai_stream_captures_final_usage() {
1242        let events = vec![
1243            Ok(SseEvent {
1244                event_type: None,
1245                data: r#"{"id":"chatcmpl-1","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#.to_string(),
1246                id: None,
1247            }),
1248            Ok(SseEvent {
1249                event_type: None,
1250                data: r#"{"id":"chatcmpl-1","choices":[],"usage":{"prompt_tokens":12,"completion_tokens":5,"total_tokens":17}}"#.to_string(),
1251                id: None,
1252            }),
1253            Ok(SseEvent {
1254                event_type: None,
1255                data: "[DONE]".to_string(),
1256                id: None,
1257            }),
1258        ];
1259
1260        let sse_stream = stream::iter(events);
1261        let token = CancellationToken::new();
1262        let mut stream = pin!(XAIClient::convert_xai_stream(sse_stream, token));
1263
1264        let first_delta = stream.next().await.unwrap();
1265        assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1266
1267        let complete = stream.next().await.unwrap();
1268        if let StreamChunk::MessageComplete(response) = complete {
1269            assert_eq!(response.usage, Some(TokenUsage::new(12, 5, 17)));
1270            assert!(matches!(
1271                response.content.first(),
1272                Some(AssistantContent::Text { text }) if text == "Hello"
1273            ));
1274        } else {
1275            panic!("Expected MessageComplete");
1276        }
1277    }
1278}