Skip to main content

steer_core/api/xai/
client.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{self, header};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{
11    CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage,
12};
13use crate::api::sse::parse_sse_stream;
14use crate::api::util::normalize_chat_url;
15use crate::app::SystemContext;
16use crate::app::conversation::{
17    AssistantContent, ImageSource, Message as AppMessage, ToolResult, UserContent,
18};
19use crate::config::model::{ModelId, ModelParameters};
20use steer_tools::ToolSchema;
21
22const DEFAULT_API_URL: &str = "https://api.x.ai/v1/chat/completions";
23
24#[derive(Clone)]
25pub struct XAIClient {
26    http_client: reqwest::Client,
27    base_url: String,
28}
29
30// xAI-specific message format (similar to OpenAI but with some differences)
31#[derive(Debug, Serialize, Deserialize)]
32#[serde(tag = "role", rename_all = "lowercase")]
33enum XAIMessage {
34    System {
35        content: String,
36        #[serde(skip_serializing_if = "Option::is_none")]
37        name: Option<String>,
38    },
39    User {
40        content: XAIUserContent,
41        #[serde(skip_serializing_if = "Option::is_none")]
42        name: Option<String>,
43    },
44    Assistant {
45        #[serde(skip_serializing_if = "Option::is_none")]
46        content: Option<String>,
47        #[serde(skip_serializing_if = "Option::is_none")]
48        tool_calls: Option<Vec<XAIToolCall>>,
49        #[serde(skip_serializing_if = "Option::is_none")]
50        name: Option<String>,
51    },
52    Tool {
53        content: String,
54        tool_call_id: String,
55        #[serde(skip_serializing_if = "Option::is_none")]
56        name: Option<String>,
57    },
58}
59
60#[derive(Debug, Serialize, Deserialize)]
61#[serde(untagged)]
62enum XAIUserContent {
63    Text(String),
64    Parts(Vec<XAIUserContentPart>),
65}
66
67#[derive(Debug, Serialize, Deserialize)]
68#[serde(tag = "type")]
69enum XAIUserContentPart {
70    #[serde(rename = "text")]
71    Text { text: String },
72    #[serde(rename = "image_url")]
73    ImageUrl { image_url: XAIImageUrl },
74}
75
76#[derive(Debug, Serialize, Deserialize)]
77struct XAIImageUrl {
78    url: String,
79}
80
81// xAI function calling format
82#[derive(Debug, Serialize, Deserialize)]
83struct XAIFunction {
84    name: String,
85    description: String,
86    parameters: serde_json::Value,
87}
88
89// xAI tool format
90#[derive(Debug, Serialize, Deserialize)]
91struct XAITool {
92    #[serde(rename = "type")]
93    tool_type: String, // "function"
94    function: XAIFunction,
95}
96
97// xAI tool call
98#[derive(Debug, Serialize, Deserialize)]
99struct XAIToolCall {
100    id: String,
101    #[serde(rename = "type")]
102    tool_type: String,
103    function: XAIFunctionCall,
104}
105
106#[derive(Debug, Serialize, Deserialize)]
107struct XAIFunctionCall {
108    name: String,
109    arguments: String, // JSON string
110}
111
112#[derive(Debug, Serialize, Deserialize)]
113#[serde(rename_all = "lowercase")]
114enum ReasoningEffort {
115    Low,
116    High,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct StreamOptions {
121    #[serde(skip_serializing_if = "Option::is_none")]
122    include_usage: Option<bool>,
123}
124
125#[derive(Debug, Serialize, Deserialize)]
126#[serde(untagged)]
127enum ToolChoice {
128    String(String), // "auto", "required", "none"
129    Specific {
130        #[serde(rename = "type")]
131        tool_type: String,
132        function: ToolChoiceFunction,
133    },
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137struct ToolChoiceFunction {
138    name: String,
139}
140
141#[derive(Debug, Serialize, Deserialize)]
142struct ResponseFormat {
143    #[serde(rename = "type")]
144    format_type: String,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    json_schema: Option<serde_json::Value>,
147}
148
149#[derive(Debug, Serialize, Deserialize)]
150struct SearchParameters {
151    #[serde(skip_serializing_if = "Option::is_none")]
152    from_date: Option<String>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    to_date: Option<String>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    max_search_results: Option<u32>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    mode: Option<String>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    return_citations: Option<bool>,
161    #[serde(skip_serializing_if = "Option::is_none")]
162    sources: Option<Vec<String>>,
163}
164
165#[derive(Debug, Serialize, Deserialize)]
166struct WebSearchOptions {
167    #[serde(skip_serializing_if = "Option::is_none")]
168    search_context_size: Option<u32>,
169    #[serde(skip_serializing_if = "Option::is_none")]
170    user_location: Option<String>,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174struct CompletionRequest {
175    model: String,
176    messages: Vec<XAIMessage>,
177    #[serde(skip_serializing_if = "Option::is_none")]
178    deferred: Option<bool>,
179    #[serde(skip_serializing_if = "Option::is_none")]
180    frequency_penalty: Option<f32>,
181    #[serde(skip_serializing_if = "Option::is_none")]
182    logit_bias: Option<HashMap<String, f32>>,
183    #[serde(skip_serializing_if = "Option::is_none")]
184    logprobs: Option<bool>,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    max_completion_tokens: Option<u32>,
187    #[serde(skip_serializing_if = "Option::is_none")]
188    max_tokens: Option<u32>,
189    #[serde(skip_serializing_if = "Option::is_none")]
190    n: Option<u32>,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    parallel_tool_calls: Option<bool>,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    presence_penalty: Option<f32>,
195    #[serde(skip_serializing_if = "Option::is_none")]
196    reasoning_effort: Option<ReasoningEffort>,
197    #[serde(skip_serializing_if = "Option::is_none")]
198    response_format: Option<ResponseFormat>,
199    #[serde(skip_serializing_if = "Option::is_none")]
200    search_parameters: Option<SearchParameters>,
201    #[serde(skip_serializing_if = "Option::is_none")]
202    seed: Option<u64>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    stop: Option<Vec<String>>,
205    #[serde(skip_serializing_if = "Option::is_none")]
206    stream: Option<bool>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    stream_options: Option<StreamOptions>,
209    #[serde(skip_serializing_if = "Option::is_none")]
210    temperature: Option<f32>,
211    #[serde(skip_serializing_if = "Option::is_none")]
212    tool_choice: Option<ToolChoice>,
213    #[serde(skip_serializing_if = "Option::is_none")]
214    tools: Option<Vec<XAITool>>,
215    #[serde(skip_serializing_if = "Option::is_none")]
216    top_logprobs: Option<u32>,
217    #[serde(skip_serializing_if = "Option::is_none")]
218    top_p: Option<f32>,
219    #[serde(skip_serializing_if = "Option::is_none")]
220    user: Option<String>,
221    #[serde(skip_serializing_if = "Option::is_none")]
222    web_search_options: Option<WebSearchOptions>,
223}
224
225#[derive(Debug, Serialize, Deserialize)]
226struct XAICompletionResponse {
227    id: String,
228    object: String,
229    created: u64,
230    model: String,
231    choices: Vec<Choice>,
232    #[serde(skip_serializing_if = "Option::is_none")]
233    usage: Option<XAIUsage>,
234    #[serde(skip_serializing_if = "Option::is_none")]
235    system_fingerprint: Option<String>,
236    #[serde(skip_serializing_if = "Option::is_none")]
237    citations: Option<Vec<serde_json::Value>>,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    debug_output: Option<DebugOutput>,
240}
241
242#[derive(Debug, Serialize, Deserialize)]
243struct Choice {
244    index: i32,
245    message: AssistantMessage,
246    finish_reason: Option<String>,
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250struct AssistantMessage {
251    content: Option<String>,
252    #[serde(skip_serializing_if = "Option::is_none")]
253    tool_calls: Option<Vec<XAIToolCall>>,
254    #[serde(skip_serializing_if = "Option::is_none")]
255    reasoning_content: Option<String>,
256}
257
258#[derive(Debug, Serialize, Deserialize)]
259struct PromptTokensDetails {
260    #[serde(rename = "cached_tokens")]
261    cached: usize,
262    #[serde(rename = "audio_tokens")]
263    audio: usize,
264    #[serde(rename = "image_tokens")]
265    image: usize,
266    #[serde(rename = "text_tokens")]
267    text: usize,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271struct CompletionTokensDetails {
272    #[serde(rename = "reasoning_tokens")]
273    reasoning: usize,
274    #[serde(rename = "audio_tokens")]
275    audio: usize,
276    #[serde(rename = "accepted_prediction_tokens")]
277    accepted_prediction: usize,
278    #[serde(rename = "rejected_prediction_tokens")]
279    rejected_prediction: usize,
280}
281
282#[derive(Debug, Serialize, Deserialize)]
283struct XAIUsage {
284    prompt_tokens: usize,
285    completion_tokens: usize,
286    total_tokens: usize,
287    #[serde(skip_serializing_if = "Option::is_none")]
288    num_sources_used: Option<usize>,
289    #[serde(skip_serializing_if = "Option::is_none")]
290    prompt_tokens_details: Option<PromptTokensDetails>,
291    #[serde(skip_serializing_if = "Option::is_none")]
292    completion_tokens_details: Option<CompletionTokensDetails>,
293}
294
295#[derive(Debug, Serialize, Deserialize)]
296struct DebugOutput {
297    attempts: usize,
298    cache_read_count: usize,
299    cache_read_input_bytes: usize,
300    cache_write_count: usize,
301    cache_write_input_bytes: usize,
302    prompt: String,
303    request: String,
304    responses: Vec<String>,
305}
306
307#[derive(Debug, Deserialize)]
308struct XAIStreamChunk {
309    #[expect(dead_code)]
310    id: String,
311    choices: Vec<XAIStreamChoice>,
312    #[serde(skip_serializing_if = "Option::is_none")]
313    usage: Option<XAIUsage>,
314}
315
316#[derive(Debug, Deserialize)]
317struct XAIStreamChoice {
318    #[expect(dead_code)]
319    index: u32,
320    delta: XAIStreamDelta,
321    #[expect(dead_code)]
322    finish_reason: Option<String>,
323}
324
325#[derive(Debug, Deserialize)]
326struct XAIStreamDelta {
327    #[serde(skip_serializing_if = "Option::is_none")]
328    content: Option<String>,
329    #[serde(skip_serializing_if = "Option::is_none")]
330    tool_calls: Option<Vec<XAIStreamToolCall>>,
331    #[serde(skip_serializing_if = "Option::is_none")]
332    reasoning_content: Option<String>,
333}
334
335#[derive(Debug, Deserialize)]
336struct XAIStreamToolCall {
337    index: usize,
338    #[serde(skip_serializing_if = "Option::is_none")]
339    id: Option<String>,
340    #[serde(skip_serializing_if = "Option::is_none")]
341    function: Option<XAIStreamFunction>,
342}
343
344#[derive(Debug, Deserialize)]
345struct XAIStreamFunction {
346    #[serde(skip_serializing_if = "Option::is_none")]
347    name: Option<String>,
348    #[serde(skip_serializing_if = "Option::is_none")]
349    arguments: Option<String>,
350}
351
352impl XAIClient {
353    pub fn new(api_key: String) -> Result<Self, ApiError> {
354        Self::with_base_url(api_key, None)
355    }
356
357    pub fn with_base_url(api_key: String, base_url: Option<String>) -> Result<Self, ApiError> {
358        let mut headers = header::HeaderMap::new();
359        headers.insert(
360            header::AUTHORIZATION,
361            header::HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|e| {
362                ApiError::AuthenticationFailed {
363                    provider: "xai".to_string(),
364                    details: format!("Invalid API key: {e}"),
365                }
366            })?,
367        );
368
369        let client = reqwest::Client::builder()
370            .default_headers(headers)
371            .timeout(std::time::Duration::from_secs(300)) // 5 minute timeout
372            .build()
373            .map_err(ApiError::Network)?;
374
375        let base_url = normalize_chat_url(base_url.as_deref(), DEFAULT_API_URL);
376
377        Ok(Self {
378            http_client: client,
379            base_url,
380        })
381    }
382
383    fn user_image_part(
384        image: &crate::app::conversation::ImageContent,
385    ) -> Result<XAIUserContentPart, ApiError> {
386        let image_url = match &image.source {
387            ImageSource::DataUrl { data_url } => data_url.clone(),
388            ImageSource::Url { url } => url.clone(),
389            ImageSource::SessionFile { relative_path } => {
390                return Err(ApiError::UnsupportedFeature {
391                    provider: "xai".to_string(),
392                    feature: "image input source".to_string(),
393                    details: format!(
394                        "xAI chat API cannot access session file '{}' directly; use data URLs or public URLs",
395                        relative_path
396                    ),
397                });
398            }
399        };
400
401        Ok(XAIUserContentPart::ImageUrl {
402            image_url: XAIImageUrl { url: image_url },
403        })
404    }
405
406    fn convert_messages(
407        messages: Vec<AppMessage>,
408        system: Option<SystemContext>,
409    ) -> Result<Vec<XAIMessage>, ApiError> {
410        let mut xai_messages = Vec::new();
411
412        // Add system message if provided
413        if let Some(system_content) = system.and_then(|context| context.render()) {
414            xai_messages.push(XAIMessage::System {
415                content: system_content,
416                name: None,
417            });
418        }
419
420        // Convert our messages to xAI format
421        for message in messages {
422            match &message.data {
423                crate::app::conversation::MessageData::User { content, .. } => {
424                    let mut content_parts = Vec::new();
425
426                    for user_content in content {
427                        match user_content {
428                            UserContent::Text { text } => {
429                                content_parts.push(XAIUserContentPart::Text { text: text.clone() });
430                            }
431                            UserContent::Image { image } => {
432                                content_parts.push(Self::user_image_part(image)?);
433                            }
434                            UserContent::CommandExecution {
435                                command,
436                                stdout,
437                                stderr,
438                                exit_code,
439                            } => {
440                                content_parts.push(XAIUserContentPart::Text {
441                                    text: UserContent::format_command_execution_as_xml(
442                                        command, stdout, stderr, *exit_code,
443                                    ),
444                                });
445                            }
446                        }
447                    }
448
449                    // Only add the message if it has content after filtering
450                    if !content_parts.is_empty() {
451                        let content = match content_parts.as_slice() {
452                            [XAIUserContentPart::Text { text }] => {
453                                XAIUserContent::Text(text.clone())
454                            }
455                            _ => XAIUserContent::Parts(content_parts),
456                        };
457
458                        xai_messages.push(XAIMessage::User {
459                            content,
460                            name: None,
461                        });
462                    }
463                }
464                crate::app::conversation::MessageData::Assistant { content, .. } => {
465                    // Convert AssistantContent to xAI format
466                    let mut text_parts = Vec::new();
467                    let mut tool_calls = Vec::new();
468
469                    for content_block in content {
470                        match content_block {
471                            AssistantContent::Text { text } => {
472                                text_parts.push(text.clone());
473                            }
474                            AssistantContent::Image { image } => {
475                                match Self::user_image_part(image) {
476                                    Ok(XAIUserContentPart::ImageUrl { image_url }) => {
477                                        text_parts.push(format!("[Image URL: {}]", image_url.url));
478                                    }
479                                    Ok(XAIUserContentPart::Text { text }) => {
480                                        text_parts.push(text);
481                                    }
482                                    Err(err) => {
483                                        debug!(
484                                            target: "xai::convert_messages",
485                                            "Skipping unsupported assistant image block: {}",
486                                            err
487                                        );
488                                    }
489                                }
490                            }
491                            AssistantContent::ToolCall { tool_call, .. } => {
492                                tool_calls.push(XAIToolCall {
493                                    id: tool_call.id.clone(),
494                                    tool_type: "function".to_string(),
495                                    function: XAIFunctionCall {
496                                        name: tool_call.name.clone(),
497                                        arguments: tool_call.parameters.to_string(),
498                                    },
499                                });
500                            }
501                            AssistantContent::Thought { .. } => {
502
503                                // xAI doesn't support thinking blocks in requests, only in responses
504                            }
505                        }
506                    }
507
508                    // Build the assistant message
509                    let content = if text_parts.is_empty() {
510                        None
511                    } else {
512                        Some(text_parts.join("\n"))
513                    };
514
515                    let tool_calls_opt = if tool_calls.is_empty() {
516                        None
517                    } else {
518                        Some(tool_calls)
519                    };
520
521                    xai_messages.push(XAIMessage::Assistant {
522                        content,
523                        tool_calls: tool_calls_opt,
524                        name: None,
525                    });
526                }
527                crate::app::conversation::MessageData::Tool {
528                    tool_use_id,
529                    result,
530                    ..
531                } => {
532                    // Convert ToolResult to xAI format
533                    let content_text = if let ToolResult::Error(e) = result {
534                        format!("Error: {e}")
535                    } else {
536                        let text = result.llm_format();
537                        if text.trim().is_empty() {
538                            "(No output)".to_string()
539                        } else {
540                            text
541                        }
542                    };
543
544                    xai_messages.push(XAIMessage::Tool {
545                        content: content_text,
546                        tool_call_id: tool_use_id.clone(),
547                        name: None,
548                    });
549                }
550            }
551        }
552
553        Ok(xai_messages)
554    }
555
556    fn convert_tools(tools: Vec<ToolSchema>) -> Vec<XAITool> {
557        tools
558            .into_iter()
559            .map(|tool| XAITool {
560                tool_type: "function".to_string(),
561                function: XAIFunction {
562                    name: tool.name,
563                    description: tool.description,
564                    parameters: tool.input_schema.as_value().clone(),
565                },
566            })
567            .collect()
568    }
569}
570
571fn saturating_u32(value: usize) -> u32 {
572    u32::try_from(value).unwrap_or(u32::MAX)
573}
574
575fn map_xai_usage(usage: &XAIUsage) -> TokenUsage {
576    TokenUsage::new(
577        saturating_u32(usage.prompt_tokens),
578        saturating_u32(usage.completion_tokens),
579        saturating_u32(usage.total_tokens),
580    )
581}
582
583fn convert_xai_completion_response(
584    xai_response: XAICompletionResponse,
585) -> Result<CompletionResponse, ApiError> {
586    let usage = xai_response.usage.as_ref().map(map_xai_usage);
587
588    let Some(choice) = xai_response.choices.first() else {
589        return Err(ApiError::NoChoices {
590            provider: "xai".to_string(),
591        });
592    };
593
594    let mut content_blocks = Vec::new();
595
596    // Add reasoning content (thinking) first if present
597    if let Some(reasoning) = &choice.message.reasoning_content
598        && !reasoning.trim().is_empty()
599    {
600        content_blocks.push(AssistantContent::Thought {
601            thought: crate::app::conversation::ThoughtContent::Simple {
602                text: reasoning.clone(),
603            },
604        });
605    }
606
607    // Add regular content
608    if let Some(content) = &choice.message.content
609        && !content.trim().is_empty()
610    {
611        content_blocks.push(AssistantContent::Text {
612            text: content.clone(),
613        });
614    }
615
616    // Add tool calls
617    if let Some(tool_calls) = &choice.message.tool_calls {
618        for tool_call in tool_calls {
619            // Parse the arguments JSON string
620            let parameters = serde_json::from_str(&tool_call.function.arguments)
621                .unwrap_or(serde_json::Value::Null);
622
623            content_blocks.push(AssistantContent::ToolCall {
624                tool_call: steer_tools::ToolCall {
625                    id: tool_call.id.clone(),
626                    name: tool_call.function.name.clone(),
627                    parameters,
628                },
629                thought_signature: None,
630            });
631        }
632    }
633
634    Ok(CompletionResponse {
635        content: content_blocks,
636        usage,
637    })
638}
639
640#[async_trait]
641impl Provider for XAIClient {
642    fn name(&self) -> &'static str {
643        "xai"
644    }
645
646    async fn complete(
647        &self,
648        model_id: &ModelId,
649        messages: Vec<AppMessage>,
650        system: Option<SystemContext>,
651        tools: Option<Vec<ToolSchema>>,
652        call_options: Option<ModelParameters>,
653        token: CancellationToken,
654    ) -> Result<CompletionResponse, ApiError> {
655        let xai_messages = Self::convert_messages(messages, system)?;
656        let xai_tools = tools.map(Self::convert_tools);
657
658        // Extract thinking support and map optional effort
659        let (supports_thinking, reasoning_effort) = call_options
660            .as_ref()
661            .and_then(|opts| opts.thinking_config)
662            .map_or((false, None), |tc| {
663                let effort = tc.effort.map(|e| match e {
664                    crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
665                    crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High, // xAI has Low/High only
666                    crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
667                    crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, // xAI has Low/High only
668                });
669                (tc.enabled, effort)
670            });
671
672        // grok-4 supports thinking by default but not the reasoning_effort parameter
673        let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
674            reasoning_effort.or(Some(ReasoningEffort::High))
675        } else {
676            None
677        };
678
679        let request = CompletionRequest {
680            model: model_id.id.clone(), // Use the model ID string
681            messages: xai_messages,
682            deferred: None,
683            frequency_penalty: None,
684            logit_bias: None,
685            logprobs: None,
686            max_completion_tokens: Some(32768),
687            max_tokens: None,
688            n: None,
689            parallel_tool_calls: None,
690            presence_penalty: None,
691            reasoning_effort,
692            response_format: None,
693            search_parameters: None,
694            seed: None,
695            stop: None,
696            stream: None,
697            stream_options: None,
698            temperature: call_options
699                .as_ref()
700                .and_then(|o| o.temperature)
701                .or(Some(1.0)),
702            tool_choice: None,
703            tools: xai_tools,
704            top_logprobs: None,
705            top_p: call_options.as_ref().and_then(|o| o.top_p),
706            user: None,
707            web_search_options: None,
708        };
709
710        let response = self
711            .http_client
712            .post(&self.base_url)
713            .json(&request)
714            .send()
715            .await
716            .map_err(ApiError::Network)?;
717
718        if !response.status().is_success() {
719            let status = response.status();
720            let error_text = response.text().await.unwrap_or_else(|_| String::new());
721
722            debug!(
723                target: "grok::complete",
724                "Grok API error - Status: {}, Body: {}",
725                status,
726                error_text
727            );
728
729            return match status.as_u16() {
730                429 => Err(ApiError::RateLimited {
731                    provider: self.name().to_string(),
732                    details: error_text,
733                }),
734                400 => Err(ApiError::InvalidRequest {
735                    provider: self.name().to_string(),
736                    details: error_text,
737                }),
738                401 => Err(ApiError::AuthenticationFailed {
739                    provider: self.name().to_string(),
740                    details: error_text,
741                }),
742                _ => Err(ApiError::ServerError {
743                    provider: self.name().to_string(),
744                    status_code: status.as_u16(),
745                    details: error_text,
746                }),
747            };
748        }
749
750        let response_text = tokio::select! {
751            () = token.cancelled() => {
752                debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
753                return Err(ApiError::Cancelled { provider: self.name().to_string() });
754            }
755            text_res = response.text() => {
756                text_res?
757            }
758        };
759
760        let xai_response: XAICompletionResponse =
761            serde_json::from_str(&response_text).map_err(|e| {
762                error!(
763                    target: "xai::complete",
764                    "Failed to parse response: {}, Body: {}",
765                    e,
766                    response_text
767                );
768                ApiError::ResponseParsingError {
769                    provider: self.name().to_string(),
770                    details: format!("Error: {e}, Body: {response_text}"),
771                }
772            })?;
773
774        convert_xai_completion_response(xai_response).map_err(|err| match err {
775            ApiError::NoChoices { .. } => ApiError::NoChoices {
776                provider: self.name().to_string(),
777            },
778            other => other,
779        })
780    }
781
782    async fn stream_complete(
783        &self,
784        model_id: &ModelId,
785        messages: Vec<AppMessage>,
786        system: Option<SystemContext>,
787        tools: Option<Vec<ToolSchema>>,
788        call_options: Option<ModelParameters>,
789        token: CancellationToken,
790    ) -> Result<CompletionStream, ApiError> {
791        let xai_messages = Self::convert_messages(messages, system)?;
792        let xai_tools = tools.map(Self::convert_tools);
793
794        let (supports_thinking, reasoning_effort) = call_options
795            .as_ref()
796            .and_then(|opts| opts.thinking_config)
797            .map_or((false, None), |tc| {
798                let effort = tc.effort.map(|e| match e {
799                    crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
800                    crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High,
801                    crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
802                    crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, // xAI has Low/High only
803                });
804                (tc.enabled, effort)
805            });
806
807        let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
808            reasoning_effort.or(Some(ReasoningEffort::High))
809        } else {
810            None
811        };
812
813        let request = CompletionRequest {
814            model: model_id.id.clone(),
815            messages: xai_messages,
816            deferred: None,
817            frequency_penalty: None,
818            logit_bias: None,
819            logprobs: None,
820            max_completion_tokens: Some(32768),
821            max_tokens: None,
822            n: None,
823            parallel_tool_calls: None,
824            presence_penalty: None,
825            reasoning_effort,
826            response_format: None,
827            search_parameters: None,
828            seed: None,
829            stop: None,
830            stream: Some(true),
831            stream_options: Some(StreamOptions {
832                include_usage: Some(true),
833            }),
834            temperature: call_options
835                .as_ref()
836                .and_then(|o| o.temperature)
837                .or(Some(1.0)),
838            tool_choice: None,
839            tools: xai_tools,
840            top_logprobs: None,
841            top_p: call_options.as_ref().and_then(|o| o.top_p),
842            user: None,
843            web_search_options: None,
844        };
845
846        let response = self
847            .http_client
848            .post(&self.base_url)
849            .json(&request)
850            .send()
851            .await
852            .map_err(ApiError::Network)?;
853
854        if !response.status().is_success() {
855            let status = response.status();
856            let error_text = response.text().await.unwrap_or_else(|_| String::new());
857
858            debug!(
859                target: "xai::stream",
860                "xAI API error - Status: {}, Body: {}",
861                status,
862                error_text
863            );
864
865            return match status.as_u16() {
866                429 => Err(ApiError::RateLimited {
867                    provider: self.name().to_string(),
868                    details: error_text,
869                }),
870                400 => Err(ApiError::InvalidRequest {
871                    provider: self.name().to_string(),
872                    details: error_text,
873                }),
874                401 => Err(ApiError::AuthenticationFailed {
875                    provider: self.name().to_string(),
876                    details: error_text,
877                }),
878                _ => Err(ApiError::ServerError {
879                    provider: self.name().to_string(),
880                    status_code: status.as_u16(),
881                    details: error_text,
882                }),
883            };
884        }
885
886        let byte_stream = response.bytes_stream();
887        let sse_stream = parse_sse_stream(byte_stream);
888
889        Ok(Box::pin(XAIClient::convert_xai_stream(sse_stream, token)))
890    }
891}
892
893impl XAIClient {
894    fn convert_xai_stream(
895        mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
896        + Unpin
897        + Send
898        + 'static,
899        token: CancellationToken,
900    ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
901        struct ToolCallAccumulator {
902            id: String,
903            name: String,
904            args: String,
905        }
906
907        async_stream::stream! {
908            let mut content: Vec<AssistantContent> = Vec::new();
909            let mut tool_call_indices: Vec<Option<usize>> = Vec::new();
910            let mut tool_calls: HashMap<usize, ToolCallAccumulator> = HashMap::new();
911            let mut tool_calls_started: std::collections::HashSet<usize> =
912                std::collections::HashSet::new();
913            let mut tool_call_positions: HashMap<usize, usize> = HashMap::new();
914            let mut latest_usage: Option<TokenUsage> = None;
915            loop {
916                if token.is_cancelled() {
917                    yield StreamChunk::Error(StreamError::Cancelled);
918                    break;
919                }
920
921                let event_result = tokio::select! {
922                    biased;
923                    () = token.cancelled() => {
924                        yield StreamChunk::Error(StreamError::Cancelled);
925                        break;
926                    }
927                    event = sse_stream.next() => event
928                };
929
930                let Some(event_result) = event_result else {
931                    break;
932                };
933
934                let event = match event_result {
935                    Ok(e) => e,
936                    Err(e) => {
937                        yield StreamChunk::Error(StreamError::SseParse(e));
938                        break;
939                    }
940                };
941
942                if event.data == "[DONE]" {
943                    let tool_calls = std::mem::take(&mut tool_calls);
944                    let mut final_content = Vec::new();
945
946                    for (block, tool_index) in content.into_iter().zip(tool_call_indices.into_iter())
947                    {
948                        if let Some(index) = tool_index {
949                            let Some(tool_call) = tool_calls.get(&index) else {
950                                continue;
951                            };
952                            if tool_call.id.is_empty() || tool_call.name.is_empty() {
953                                debug!(
954                                    target: "xai::stream",
955                                    "Skipping tool call with missing id/name: id='{}' name='{}'",
956                                    tool_call.id,
957                                    tool_call.name
958                                );
959                                continue;
960                            }
961                            let parameters = serde_json::from_str(&tool_call.args)
962                                .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
963                            final_content.push(AssistantContent::ToolCall {
964                                tool_call: steer_tools::ToolCall {
965                                    id: tool_call.id.clone(),
966                                    name: tool_call.name.clone(),
967                                    parameters,
968                                },
969                                thought_signature: None,
970                            });
971                        } else {
972                            final_content.push(block);
973                        }
974                    }
975
976                    yield StreamChunk::MessageComplete(CompletionResponse {
977                        content: final_content,
978                        usage: latest_usage,
979                    });
980                    break;
981                }
982
983                let chunk: XAIStreamChunk = match serde_json::from_str(&event.data) {
984                    Ok(c) => c,
985                    Err(e) => {
986                        debug!(target: "xai::stream", "Failed to parse chunk: {} data: {}", e, event.data);
987                        continue;
988                    }
989                };
990
991                if let Some(usage) = chunk.usage.as_ref() {
992                    latest_usage = Some(map_xai_usage(usage));
993                }
994
995                if let Some(choice) = chunk.choices.first() {
996                    if let Some(text_delta) = &choice.delta.content {
997                        if let Some(AssistantContent::Text { text }) = content.last_mut() { text.push_str(text_delta) } else {
998                            content.push(AssistantContent::Text {
999                                text: text_delta.clone(),
1000                            });
1001                            tool_call_indices.push(None);
1002                        }
1003                        yield StreamChunk::TextDelta(text_delta.clone());
1004                    }
1005
1006                    if let Some(thinking_delta) = &choice.delta.reasoning_content {
1007                        if let Some(AssistantContent::Thought {
1008                                thought: crate::app::conversation::ThoughtContent::Simple { text },
1009                            }) = content.last_mut() { text.push_str(thinking_delta) } else {
1010                            content.push(AssistantContent::Thought {
1011                                thought: crate::app::conversation::ThoughtContent::Simple {
1012                                    text: thinking_delta.clone(),
1013                                },
1014                            });
1015                            tool_call_indices.push(None);
1016                        }
1017                        yield StreamChunk::ThinkingDelta(thinking_delta.clone());
1018                    }
1019
1020                    if let Some(tcs) = &choice.delta.tool_calls {
1021                        for tc in tcs {
1022                            let entry = tool_calls.entry(tc.index).or_insert_with(|| {
1023                                ToolCallAccumulator {
1024                                    id: String::new(),
1025                                    name: String::new(),
1026                                    args: String::new(),
1027                                }
1028                            });
1029                            let mut started_now = false;
1030                            let mut flushed_now = false;
1031
1032                            if let Some(id) = &tc.id
1033                                && !id.is_empty() {
1034                                    entry.id.clone_from(id);
1035                                }
1036                            if let Some(func) = &tc.function
1037                                && let Some(name) = &func.name
1038                                    && !name.is_empty() {
1039                                        entry.name.clone_from(name);
1040                                    }
1041
1042                            if let std::collections::hash_map::Entry::Vacant(e) = tool_call_positions.entry(tc.index) {
1043                                let pos = content.len();
1044                                content.push(AssistantContent::ToolCall {
1045                                    tool_call: steer_tools::ToolCall {
1046                                        id: entry.id.clone(),
1047                                        name: entry.name.clone(),
1048                                        parameters: serde_json::Value::String(entry.args.clone()),
1049                                    },
1050                                    thought_signature: None,
1051                                });
1052                                tool_call_indices.push(Some(tc.index));
1053                                e.insert(pos);
1054                            }
1055
1056                            if !entry.id.is_empty()
1057                                && !entry.name.is_empty()
1058                                && !tool_calls_started.contains(&tc.index)
1059                            {
1060                                tool_calls_started.insert(tc.index);
1061                                started_now = true;
1062                                yield StreamChunk::ToolUseStart {
1063                                    id: entry.id.clone(),
1064                                    name: entry.name.clone(),
1065                                };
1066                            }
1067
1068                            if let Some(func) = &tc.function
1069                                && let Some(args) = &func.arguments {
1070                                    entry.args.push_str(args);
1071                                    if tool_calls_started.contains(&tc.index) {
1072                                        if started_now {
1073                                            if !entry.args.is_empty() {
1074                                                yield StreamChunk::ToolUseInputDelta {
1075                                                    id: entry.id.clone(),
1076                                                    delta: entry.args.clone(),
1077                                                };
1078                                                flushed_now = true;
1079                                            }
1080                                        } else if !args.is_empty() {
1081                                            yield StreamChunk::ToolUseInputDelta {
1082                                                id: entry.id.clone(),
1083                                                delta: args.clone(),
1084                                            };
1085                                        }
1086                                    }
1087                                }
1088
1089                            if started_now && !flushed_now && !entry.args.is_empty() {
1090                                yield StreamChunk::ToolUseInputDelta {
1091                                    id: entry.id.clone(),
1092                                    delta: entry.args.clone(),
1093                                };
1094                            }
1095                        }
1096                    }
1097                }
1098            }
1099        }
1100    }
1101}
1102
1103#[cfg(test)]
1104mod tests {
1105    use super::*;
1106    use crate::api::provider::StreamChunk;
1107    use crate::api::sse::SseEvent;
1108    use crate::app::conversation::{
1109        AssistantContent, ImageContent, ImageSource, Message, MessageData, UserContent,
1110    };
1111    use futures::StreamExt;
1112    use futures::stream;
1113    use std::pin::pin;
1114    use tokio_util::sync::CancellationToken;
1115
1116    #[test]
1117    fn test_convert_messages_includes_data_url_image_part() {
1118        let messages = vec![Message {
1119            data: MessageData::User {
1120                content: vec![
1121                    UserContent::Text {
1122                        text: "describe".to_string(),
1123                    },
1124                    UserContent::Image {
1125                        image: ImageContent {
1126                            mime_type: "image/png".to_string(),
1127                            source: ImageSource::DataUrl {
1128                                data_url: "".to_string(),
1129                            },
1130                            width: None,
1131                            height: None,
1132                            bytes: None,
1133                            sha256: None,
1134                        },
1135                    },
1136                ],
1137            },
1138            timestamp: 1,
1139            id: "msg-1".to_string(),
1140            parent_message_id: None,
1141        }];
1142
1143        let converted = XAIClient::convert_messages(messages, None).expect("convert messages");
1144        assert_eq!(converted.len(), 1);
1145
1146        match &converted[0] {
1147            XAIMessage::User { content, .. } => match content {
1148                XAIUserContent::Parts(parts) => {
1149                    assert_eq!(parts.len(), 2);
1150                    assert!(matches!(
1151                        &parts[0],
1152                        XAIUserContentPart::Text { text } if text == "describe"
1153                    ));
1154                    assert!(matches!(
1155                        &parts[1],
1156                        XAIUserContentPart::ImageUrl { image_url }
1157                            if image_url.url == ""
1158                    ));
1159                }
1160                other => panic!("Expected parts content, got {other:?}"),
1161            },
1162            other => panic!("Expected user message, got {other:?}"),
1163        }
1164    }
1165
1166    #[test]
1167    fn test_convert_messages_rejects_session_file_image_source() {
1168        let messages = vec![Message {
1169            data: MessageData::User {
1170                content: vec![UserContent::Image {
1171                    image: ImageContent {
1172                        mime_type: "image/png".to_string(),
1173                        source: ImageSource::SessionFile {
1174                            relative_path: "session-1/image.png".to_string(),
1175                        },
1176                        width: None,
1177                        height: None,
1178                        bytes: None,
1179                        sha256: None,
1180                    },
1181                }],
1182            },
1183            timestamp: 1,
1184            id: "msg-1".to_string(),
1185            parent_message_id: None,
1186        }];
1187
1188        let err =
1189            XAIClient::convert_messages(messages, None).expect_err("expected unsupported feature");
1190        match err {
1191            ApiError::UnsupportedFeature {
1192                provider,
1193                feature,
1194                details,
1195            } => {
1196                assert_eq!(provider, "xai");
1197                assert_eq!(feature, "image input source");
1198                assert!(details.contains("session file"));
1199            }
1200            other => panic!("Expected UnsupportedFeature, got {other:?}"),
1201        }
1202    }
1203
1204    #[test]
1205    fn test_map_xai_usage() {
1206        let usage = XAIUsage {
1207            prompt_tokens: 15,
1208            completion_tokens: 9,
1209            total_tokens: 24,
1210            num_sources_used: None,
1211            prompt_tokens_details: None,
1212            completion_tokens_details: None,
1213        };
1214
1215        assert_eq!(map_xai_usage(&usage), TokenUsage::new(15, 9, 24));
1216    }
1217
1218    #[test]
1219    fn test_non_stream_completion_maps_usage() {
1220        let usage = XAIUsage {
1221            prompt_tokens: 6,
1222            completion_tokens: 4,
1223            total_tokens: 10,
1224            num_sources_used: None,
1225            prompt_tokens_details: None,
1226            completion_tokens_details: None,
1227        };
1228        let choice = Choice {
1229            index: 0,
1230            message: AssistantMessage {
1231                content: Some("hello".to_string()),
1232                tool_calls: None,
1233                reasoning_content: None,
1234            },
1235            finish_reason: Some("stop".to_string()),
1236        };
1237        let response = XAICompletionResponse {
1238            id: "resp_1".to_string(),
1239            object: "chat.completion".to_string(),
1240            created: 1,
1241            model: "grok-test".to_string(),
1242            choices: vec![choice],
1243            usage: Some(usage),
1244            system_fingerprint: None,
1245            citations: None,
1246            debug_output: None,
1247        };
1248
1249        let converted = convert_xai_completion_response(response).expect("response should map");
1250
1251        assert_eq!(converted.usage, Some(TokenUsage::new(6, 4, 10)));
1252        assert!(matches!(
1253            converted.content.first(),
1254            Some(AssistantContent::Text { text }) if text == "hello"
1255        ));
1256    }
1257
1258    #[tokio::test]
1259    async fn test_convert_xai_stream_captures_final_usage() {
1260        let events = vec![
1261            Ok(SseEvent {
1262                event_type: None,
1263                data: r#"{"id":"chatcmpl-1","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#.to_string(),
1264                id: None,
1265            }),
1266            Ok(SseEvent {
1267                event_type: None,
1268                data: r#"{"id":"chatcmpl-1","choices":[],"usage":{"prompt_tokens":12,"completion_tokens":5,"total_tokens":17}}"#.to_string(),
1269                id: None,
1270            }),
1271            Ok(SseEvent {
1272                event_type: None,
1273                data: "[DONE]".to_string(),
1274                id: None,
1275            }),
1276        ];
1277
1278        let sse_stream = stream::iter(events);
1279        let token = CancellationToken::new();
1280        let mut stream = pin!(XAIClient::convert_xai_stream(sse_stream, token));
1281
1282        let first_delta = stream.next().await.unwrap();
1283        assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1284
1285        let complete = stream.next().await.unwrap();
1286        if let StreamChunk::MessageComplete(response) = complete {
1287            assert_eq!(response.usage, Some(TokenUsage::new(12, 5, 17)));
1288            assert!(matches!(
1289                response.content.first(),
1290                Some(AssistantContent::Text { text }) if text == "Hello"
1291            ));
1292        } else {
1293            panic!("Expected MessageComplete");
1294        }
1295    }
1296}