Skip to main content

steer_core/api/xai/
client.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{self, header};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{CompletionResponse, CompletionStream, Provider, StreamChunk};
11use crate::api::sse::parse_sse_stream;
12use crate::api::util::normalize_chat_url;
13use crate::app::SystemContext;
14use crate::app::conversation::{
15    AssistantContent, ImageSource, Message as AppMessage, ToolResult, UserContent,
16};
17use crate::config::model::{ModelId, ModelParameters};
18use steer_tools::ToolSchema;
19
20const DEFAULT_API_URL: &str = "https://api.x.ai/v1/chat/completions";
21
22#[derive(Clone)]
23pub struct XAIClient {
24    http_client: reqwest::Client,
25    base_url: String,
26}
27
28// xAI-specific message format (similar to OpenAI but with some differences)
29#[derive(Debug, Serialize, Deserialize)]
30#[serde(tag = "role", rename_all = "lowercase")]
31enum XAIMessage {
32    System {
33        content: String,
34        #[serde(skip_serializing_if = "Option::is_none")]
35        name: Option<String>,
36    },
37    User {
38        content: XAIUserContent,
39        #[serde(skip_serializing_if = "Option::is_none")]
40        name: Option<String>,
41    },
42    Assistant {
43        #[serde(skip_serializing_if = "Option::is_none")]
44        content: Option<String>,
45        #[serde(skip_serializing_if = "Option::is_none")]
46        tool_calls: Option<Vec<XAIToolCall>>,
47        #[serde(skip_serializing_if = "Option::is_none")]
48        name: Option<String>,
49    },
50    Tool {
51        content: String,
52        tool_call_id: String,
53        #[serde(skip_serializing_if = "Option::is_none")]
54        name: Option<String>,
55    },
56}
57
58#[derive(Debug, Serialize, Deserialize)]
59#[serde(untagged)]
60enum XAIUserContent {
61    Text(String),
62    Parts(Vec<XAIUserContentPart>),
63}
64
65#[derive(Debug, Serialize, Deserialize)]
66#[serde(tag = "type")]
67enum XAIUserContentPart {
68    #[serde(rename = "text")]
69    Text { text: String },
70    #[serde(rename = "image_url")]
71    ImageUrl { image_url: XAIImageUrl },
72}
73
74#[derive(Debug, Serialize, Deserialize)]
75struct XAIImageUrl {
76    url: String,
77}
78
79// xAI function calling format
80#[derive(Debug, Serialize, Deserialize)]
81struct XAIFunction {
82    name: String,
83    description: String,
84    parameters: serde_json::Value,
85}
86
87// xAI tool format
88#[derive(Debug, Serialize, Deserialize)]
89struct XAITool {
90    #[serde(rename = "type")]
91    tool_type: String, // "function"
92    function: XAIFunction,
93}
94
95// xAI tool call
96#[derive(Debug, Serialize, Deserialize)]
97struct XAIToolCall {
98    id: String,
99    #[serde(rename = "type")]
100    tool_type: String,
101    function: XAIFunctionCall,
102}
103
104#[derive(Debug, Serialize, Deserialize)]
105struct XAIFunctionCall {
106    name: String,
107    arguments: String, // JSON string
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111#[serde(rename_all = "lowercase")]
112enum ReasoningEffort {
113    Low,
114    High,
115}
116
117#[derive(Debug, Serialize, Deserialize)]
118struct StreamOptions {
119    #[serde(skip_serializing_if = "Option::is_none")]
120    include_usage: Option<bool>,
121}
122
123#[derive(Debug, Serialize, Deserialize)]
124#[serde(untagged)]
125enum ToolChoice {
126    String(String), // "auto", "required", "none"
127    Specific {
128        #[serde(rename = "type")]
129        tool_type: String,
130        function: ToolChoiceFunction,
131    },
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135struct ToolChoiceFunction {
136    name: String,
137}
138
139#[derive(Debug, Serialize, Deserialize)]
140struct ResponseFormat {
141    #[serde(rename = "type")]
142    format_type: String,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    json_schema: Option<serde_json::Value>,
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148struct SearchParameters {
149    #[serde(skip_serializing_if = "Option::is_none")]
150    from_date: Option<String>,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    to_date: Option<String>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    max_search_results: Option<u32>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    mode: Option<String>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    return_citations: Option<bool>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    sources: Option<Vec<String>>,
161}
162
163#[derive(Debug, Serialize, Deserialize)]
164struct WebSearchOptions {
165    #[serde(skip_serializing_if = "Option::is_none")]
166    search_context_size: Option<u32>,
167    #[serde(skip_serializing_if = "Option::is_none")]
168    user_location: Option<String>,
169}
170
171#[derive(Debug, Serialize, Deserialize)]
172struct CompletionRequest {
173    model: String,
174    messages: Vec<XAIMessage>,
175    #[serde(skip_serializing_if = "Option::is_none")]
176    deferred: Option<bool>,
177    #[serde(skip_serializing_if = "Option::is_none")]
178    frequency_penalty: Option<f32>,
179    #[serde(skip_serializing_if = "Option::is_none")]
180    logit_bias: Option<HashMap<String, f32>>,
181    #[serde(skip_serializing_if = "Option::is_none")]
182    logprobs: Option<bool>,
183    #[serde(skip_serializing_if = "Option::is_none")]
184    max_completion_tokens: Option<u32>,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    max_tokens: Option<u32>,
187    #[serde(skip_serializing_if = "Option::is_none")]
188    n: Option<u32>,
189    #[serde(skip_serializing_if = "Option::is_none")]
190    parallel_tool_calls: Option<bool>,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    presence_penalty: Option<f32>,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    reasoning_effort: Option<ReasoningEffort>,
195    #[serde(skip_serializing_if = "Option::is_none")]
196    response_format: Option<ResponseFormat>,
197    #[serde(skip_serializing_if = "Option::is_none")]
198    search_parameters: Option<SearchParameters>,
199    #[serde(skip_serializing_if = "Option::is_none")]
200    seed: Option<u64>,
201    #[serde(skip_serializing_if = "Option::is_none")]
202    stop: Option<Vec<String>>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    stream: Option<bool>,
205    #[serde(skip_serializing_if = "Option::is_none")]
206    stream_options: Option<StreamOptions>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    temperature: Option<f32>,
209    #[serde(skip_serializing_if = "Option::is_none")]
210    tool_choice: Option<ToolChoice>,
211    #[serde(skip_serializing_if = "Option::is_none")]
212    tools: Option<Vec<XAITool>>,
213    #[serde(skip_serializing_if = "Option::is_none")]
214    top_logprobs: Option<u32>,
215    #[serde(skip_serializing_if = "Option::is_none")]
216    top_p: Option<f32>,
217    #[serde(skip_serializing_if = "Option::is_none")]
218    user: Option<String>,
219    #[serde(skip_serializing_if = "Option::is_none")]
220    web_search_options: Option<WebSearchOptions>,
221}
222
223#[derive(Debug, Serialize, Deserialize)]
224struct XAICompletionResponse {
225    id: String,
226    object: String,
227    created: u64,
228    model: String,
229    choices: Vec<Choice>,
230    #[serde(skip_serializing_if = "Option::is_none")]
231    usage: Option<XAIUsage>,
232    #[serde(skip_serializing_if = "Option::is_none")]
233    system_fingerprint: Option<String>,
234    #[serde(skip_serializing_if = "Option::is_none")]
235    citations: Option<Vec<serde_json::Value>>,
236    #[serde(skip_serializing_if = "Option::is_none")]
237    debug_output: Option<DebugOutput>,
238}
239
240#[derive(Debug, Serialize, Deserialize)]
241struct Choice {
242    index: i32,
243    message: AssistantMessage,
244    finish_reason: Option<String>,
245}
246
247#[derive(Debug, Serialize, Deserialize)]
248struct AssistantMessage {
249    content: Option<String>,
250    #[serde(skip_serializing_if = "Option::is_none")]
251    tool_calls: Option<Vec<XAIToolCall>>,
252    #[serde(skip_serializing_if = "Option::is_none")]
253    reasoning_content: Option<String>,
254}
255
256#[derive(Debug, Serialize, Deserialize)]
257struct PromptTokensDetails {
258    #[serde(rename = "cached_tokens")]
259    cached: usize,
260    #[serde(rename = "audio_tokens")]
261    audio: usize,
262    #[serde(rename = "image_tokens")]
263    image: usize,
264    #[serde(rename = "text_tokens")]
265    text: usize,
266}
267
268#[derive(Debug, Serialize, Deserialize)]
269struct CompletionTokensDetails {
270    #[serde(rename = "reasoning_tokens")]
271    reasoning: usize,
272    #[serde(rename = "audio_tokens")]
273    audio: usize,
274    #[serde(rename = "accepted_prediction_tokens")]
275    accepted_prediction: usize,
276    #[serde(rename = "rejected_prediction_tokens")]
277    rejected_prediction: usize,
278}
279
280#[derive(Debug, Serialize, Deserialize)]
281struct XAIUsage {
282    prompt_tokens: usize,
283    completion_tokens: usize,
284    total_tokens: usize,
285    #[serde(skip_serializing_if = "Option::is_none")]
286    num_sources_used: Option<usize>,
287    #[serde(skip_serializing_if = "Option::is_none")]
288    prompt_tokens_details: Option<PromptTokensDetails>,
289    #[serde(skip_serializing_if = "Option::is_none")]
290    completion_tokens_details: Option<CompletionTokensDetails>,
291}
292
293#[derive(Debug, Serialize, Deserialize)]
294struct DebugOutput {
295    attempts: usize,
296    cache_read_count: usize,
297    cache_read_input_bytes: usize,
298    cache_write_count: usize,
299    cache_write_input_bytes: usize,
300    prompt: String,
301    request: String,
302    responses: Vec<String>,
303}
304
305#[derive(Debug, Deserialize)]
306struct XAIStreamChunk {
307    #[expect(dead_code)]
308    id: String,
309    choices: Vec<XAIStreamChoice>,
310}
311
312#[derive(Debug, Deserialize)]
313struct XAIStreamChoice {
314    #[expect(dead_code)]
315    index: u32,
316    delta: XAIStreamDelta,
317    #[expect(dead_code)]
318    finish_reason: Option<String>,
319}
320
321#[derive(Debug, Deserialize)]
322struct XAIStreamDelta {
323    #[serde(skip_serializing_if = "Option::is_none")]
324    content: Option<String>,
325    #[serde(skip_serializing_if = "Option::is_none")]
326    tool_calls: Option<Vec<XAIStreamToolCall>>,
327    #[serde(skip_serializing_if = "Option::is_none")]
328    reasoning_content: Option<String>,
329}
330
331#[derive(Debug, Deserialize)]
332struct XAIStreamToolCall {
333    index: usize,
334    #[serde(skip_serializing_if = "Option::is_none")]
335    id: Option<String>,
336    #[serde(skip_serializing_if = "Option::is_none")]
337    function: Option<XAIStreamFunction>,
338}
339
340#[derive(Debug, Deserialize)]
341struct XAIStreamFunction {
342    #[serde(skip_serializing_if = "Option::is_none")]
343    name: Option<String>,
344    #[serde(skip_serializing_if = "Option::is_none")]
345    arguments: Option<String>,
346}
347
348impl XAIClient {
349    pub fn new(api_key: String) -> Result<Self, ApiError> {
350        Self::with_base_url(api_key, None)
351    }
352
353    pub fn with_base_url(api_key: String, base_url: Option<String>) -> Result<Self, ApiError> {
354        let mut headers = header::HeaderMap::new();
355        headers.insert(
356            header::AUTHORIZATION,
357            header::HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|e| {
358                ApiError::AuthenticationFailed {
359                    provider: "xai".to_string(),
360                    details: format!("Invalid API key: {e}"),
361                }
362            })?,
363        );
364
365        let client = reqwest::Client::builder()
366            .default_headers(headers)
367            .timeout(std::time::Duration::from_secs(300)) // 5 minute timeout
368            .build()
369            .map_err(ApiError::Network)?;
370
371        let base_url = normalize_chat_url(base_url.as_deref(), DEFAULT_API_URL);
372
373        Ok(Self {
374            http_client: client,
375            base_url,
376        })
377    }
378
379    fn user_image_part(
380        image: &crate::app::conversation::ImageContent,
381    ) -> Result<XAIUserContentPart, ApiError> {
382        let image_url = match &image.source {
383            ImageSource::DataUrl { data_url } => data_url.clone(),
384            ImageSource::Url { url } => url.clone(),
385            ImageSource::SessionFile { relative_path } => {
386                return Err(ApiError::UnsupportedFeature {
387                    provider: "xai".to_string(),
388                    feature: "image input source".to_string(),
389                    details: format!(
390                        "xAI chat API cannot access session file '{}' directly; use data URLs or public URLs",
391                        relative_path
392                    ),
393                });
394            }
395        };
396
397        Ok(XAIUserContentPart::ImageUrl {
398            image_url: XAIImageUrl { url: image_url },
399        })
400    }
401
402    fn convert_messages(
403        messages: Vec<AppMessage>,
404        system: Option<SystemContext>,
405    ) -> Result<Vec<XAIMessage>, ApiError> {
406        let mut xai_messages = Vec::new();
407
408        // Add system message if provided
409        if let Some(system_content) = system.and_then(|context| context.render()) {
410            xai_messages.push(XAIMessage::System {
411                content: system_content,
412                name: None,
413            });
414        }
415
416        // Convert our messages to xAI format
417        for message in messages {
418            match &message.data {
419                crate::app::conversation::MessageData::User { content, .. } => {
420                    let mut content_parts = Vec::new();
421
422                    for user_content in content {
423                        match user_content {
424                            UserContent::Text { text } => {
425                                content_parts.push(XAIUserContentPart::Text { text: text.clone() });
426                            }
427                            UserContent::Image { image } => {
428                                content_parts.push(Self::user_image_part(image)?);
429                            }
430                            UserContent::CommandExecution {
431                                command,
432                                stdout,
433                                stderr,
434                                exit_code,
435                            } => {
436                                content_parts.push(XAIUserContentPart::Text {
437                                    text: UserContent::format_command_execution_as_xml(
438                                        command, stdout, stderr, *exit_code,
439                                    ),
440                                });
441                            }
442                        }
443                    }
444
445                    // Only add the message if it has content after filtering
446                    if !content_parts.is_empty() {
447                        let content = match content_parts.as_slice() {
448                            [XAIUserContentPart::Text { text }] => {
449                                XAIUserContent::Text(text.clone())
450                            }
451                            _ => XAIUserContent::Parts(content_parts),
452                        };
453
454                        xai_messages.push(XAIMessage::User {
455                            content,
456                            name: None,
457                        });
458                    }
459                }
460                crate::app::conversation::MessageData::Assistant { content, .. } => {
461                    // Convert AssistantContent to xAI format
462                    let mut text_parts = Vec::new();
463                    let mut tool_calls = Vec::new();
464
465                    for content_block in content {
466                        match content_block {
467                            AssistantContent::Text { text } => {
468                                text_parts.push(text.clone());
469                            }
470                            AssistantContent::Image { image } => {
471                                match Self::user_image_part(image) {
472                                    Ok(XAIUserContentPart::ImageUrl { image_url }) => {
473                                        text_parts.push(format!("[Image URL: {}]", image_url.url));
474                                    }
475                                    Ok(XAIUserContentPart::Text { text }) => {
476                                        text_parts.push(text);
477                                    }
478                                    Err(err) => {
479                                        debug!(
480                                            target: "xai::convert_messages",
481                                            "Skipping unsupported assistant image block: {}",
482                                            err
483                                        );
484                                    }
485                                }
486                            }
487                            AssistantContent::ToolCall { tool_call, .. } => {
488                                tool_calls.push(XAIToolCall {
489                                    id: tool_call.id.clone(),
490                                    tool_type: "function".to_string(),
491                                    function: XAIFunctionCall {
492                                        name: tool_call.name.clone(),
493                                        arguments: tool_call.parameters.to_string(),
494                                    },
495                                });
496                            }
497                            AssistantContent::Thought { .. } => {
498
499                                // xAI doesn't support thinking blocks in requests, only in responses
500                            }
501                        }
502                    }
503
504                    // Build the assistant message
505                    let content = if text_parts.is_empty() {
506                        None
507                    } else {
508                        Some(text_parts.join("\n"))
509                    };
510
511                    let tool_calls_opt = if tool_calls.is_empty() {
512                        None
513                    } else {
514                        Some(tool_calls)
515                    };
516
517                    xai_messages.push(XAIMessage::Assistant {
518                        content,
519                        tool_calls: tool_calls_opt,
520                        name: None,
521                    });
522                }
523                crate::app::conversation::MessageData::Tool {
524                    tool_use_id,
525                    result,
526                    ..
527                } => {
528                    // Convert ToolResult to xAI format
529                    let content_text = if let ToolResult::Error(e) = result {
530                        format!("Error: {e}")
531                    } else {
532                        let text = result.llm_format();
533                        if text.trim().is_empty() {
534                            "(No output)".to_string()
535                        } else {
536                            text
537                        }
538                    };
539
540                    xai_messages.push(XAIMessage::Tool {
541                        content: content_text,
542                        tool_call_id: tool_use_id.clone(),
543                        name: None,
544                    });
545                }
546            }
547        }
548
549        Ok(xai_messages)
550    }
551
552    fn convert_tools(tools: Vec<ToolSchema>) -> Vec<XAITool> {
553        tools
554            .into_iter()
555            .map(|tool| XAITool {
556                tool_type: "function".to_string(),
557                function: XAIFunction {
558                    name: tool.name,
559                    description: tool.description,
560                    parameters: tool.input_schema.as_value().clone(),
561                },
562            })
563            .collect()
564    }
565}
566
567#[async_trait]
568impl Provider for XAIClient {
569    fn name(&self) -> &'static str {
570        "xai"
571    }
572
573    async fn complete(
574        &self,
575        model_id: &ModelId,
576        messages: Vec<AppMessage>,
577        system: Option<SystemContext>,
578        tools: Option<Vec<ToolSchema>>,
579        call_options: Option<ModelParameters>,
580        token: CancellationToken,
581    ) -> Result<CompletionResponse, ApiError> {
582        let xai_messages = Self::convert_messages(messages, system)?;
583        let xai_tools = tools.map(Self::convert_tools);
584
585        // Extract thinking support and map optional effort
586        let (supports_thinking, reasoning_effort) = call_options
587            .as_ref()
588            .and_then(|opts| opts.thinking_config)
589            .map_or((false, None), |tc| {
590                let effort = tc.effort.map(|e| match e {
591                    crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
592                    crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High, // xAI has Low/High only
593                    crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
594                    crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, // xAI has Low/High only
595                });
596                (tc.enabled, effort)
597            });
598
599        // grok-4 supports thinking by default but not the reasoning_effort parameter
600        let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
601            reasoning_effort.or(Some(ReasoningEffort::High))
602        } else {
603            None
604        };
605
606        let request = CompletionRequest {
607            model: model_id.id.clone(), // Use the model ID string
608            messages: xai_messages,
609            deferred: None,
610            frequency_penalty: None,
611            logit_bias: None,
612            logprobs: None,
613            max_completion_tokens: Some(32768),
614            max_tokens: None,
615            n: None,
616            parallel_tool_calls: None,
617            presence_penalty: None,
618            reasoning_effort,
619            response_format: None,
620            search_parameters: None,
621            seed: None,
622            stop: None,
623            stream: None,
624            stream_options: None,
625            temperature: call_options
626                .as_ref()
627                .and_then(|o| o.temperature)
628                .or(Some(1.0)),
629            tool_choice: None,
630            tools: xai_tools,
631            top_logprobs: None,
632            top_p: call_options.as_ref().and_then(|o| o.top_p),
633            user: None,
634            web_search_options: None,
635        };
636
637        let response = self
638            .http_client
639            .post(&self.base_url)
640            .json(&request)
641            .send()
642            .await
643            .map_err(ApiError::Network)?;
644
645        if !response.status().is_success() {
646            let status = response.status();
647            let error_text = response.text().await.unwrap_or_else(|_| String::new());
648
649            debug!(
650                target: "grok::complete",
651                "Grok API error - Status: {}, Body: {}",
652                status,
653                error_text
654            );
655
656            return match status.as_u16() {
657                429 => Err(ApiError::RateLimited {
658                    provider: self.name().to_string(),
659                    details: error_text,
660                }),
661                400 => Err(ApiError::InvalidRequest {
662                    provider: self.name().to_string(),
663                    details: error_text,
664                }),
665                401 => Err(ApiError::AuthenticationFailed {
666                    provider: self.name().to_string(),
667                    details: error_text,
668                }),
669                _ => Err(ApiError::ServerError {
670                    provider: self.name().to_string(),
671                    status_code: status.as_u16(),
672                    details: error_text,
673                }),
674            };
675        }
676
677        let response_text = tokio::select! {
678            () = token.cancelled() => {
679                debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
680                return Err(ApiError::Cancelled { provider: self.name().to_string() });
681            }
682            text_res = response.text() => {
683                text_res?
684            }
685        };
686
687        let xai_response: XAICompletionResponse =
688            serde_json::from_str(&response_text).map_err(|e| {
689                error!(
690                    target: "xai::complete",
691                    "Failed to parse response: {}, Body: {}",
692                    e,
693                    response_text
694                );
695                ApiError::ResponseParsingError {
696                    provider: self.name().to_string(),
697                    details: format!("Error: {e}, Body: {response_text}"),
698                }
699            })?;
700
701        // Convert xAI response to our CompletionResponse
702        if let Some(choice) = xai_response.choices.first() {
703            let mut content_blocks = Vec::new();
704
705            // Add reasoning content (thinking) first if present
706            if let Some(reasoning) = &choice.message.reasoning_content
707                && !reasoning.trim().is_empty()
708            {
709                content_blocks.push(AssistantContent::Thought {
710                    thought: crate::app::conversation::ThoughtContent::Simple {
711                        text: reasoning.clone(),
712                    },
713                });
714            }
715
716            // Add regular content
717            if let Some(content) = &choice.message.content
718                && !content.trim().is_empty()
719            {
720                content_blocks.push(AssistantContent::Text {
721                    text: content.clone(),
722                });
723            }
724
725            // Add tool calls
726            if let Some(tool_calls) = &choice.message.tool_calls {
727                for tool_call in tool_calls {
728                    // Parse the arguments JSON string
729                    let parameters = serde_json::from_str(&tool_call.function.arguments)
730                        .unwrap_or(serde_json::Value::Null);
731
732                    content_blocks.push(AssistantContent::ToolCall {
733                        tool_call: steer_tools::ToolCall {
734                            id: tool_call.id.clone(),
735                            name: tool_call.function.name.clone(),
736                            parameters,
737                        },
738                        thought_signature: None,
739                    });
740                }
741            }
742
743            Ok(crate::api::provider::CompletionResponse {
744                content: content_blocks,
745            })
746        } else {
747            Err(ApiError::NoChoices {
748                provider: self.name().to_string(),
749            })
750        }
751    }
752
753    async fn stream_complete(
754        &self,
755        model_id: &ModelId,
756        messages: Vec<AppMessage>,
757        system: Option<SystemContext>,
758        tools: Option<Vec<ToolSchema>>,
759        call_options: Option<ModelParameters>,
760        token: CancellationToken,
761    ) -> Result<CompletionStream, ApiError> {
762        let xai_messages = Self::convert_messages(messages, system)?;
763        let xai_tools = tools.map(Self::convert_tools);
764
765        let (supports_thinking, reasoning_effort) = call_options
766            .as_ref()
767            .and_then(|opts| opts.thinking_config)
768            .map_or((false, None), |tc| {
769                let effort = tc.effort.map(|e| match e {
770                    crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
771                    crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High,
772                    crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
773                    crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, // xAI has Low/High only
774                });
775                (tc.enabled, effort)
776            });
777
778        let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
779            reasoning_effort.or(Some(ReasoningEffort::High))
780        } else {
781            None
782        };
783
784        let request = CompletionRequest {
785            model: model_id.id.clone(),
786            messages: xai_messages,
787            deferred: None,
788            frequency_penalty: None,
789            logit_bias: None,
790            logprobs: None,
791            max_completion_tokens: Some(32768),
792            max_tokens: None,
793            n: None,
794            parallel_tool_calls: None,
795            presence_penalty: None,
796            reasoning_effort,
797            response_format: None,
798            search_parameters: None,
799            seed: None,
800            stop: None,
801            stream: Some(true),
802            stream_options: None,
803            temperature: call_options
804                .as_ref()
805                .and_then(|o| o.temperature)
806                .or(Some(1.0)),
807            tool_choice: None,
808            tools: xai_tools,
809            top_logprobs: None,
810            top_p: call_options.as_ref().and_then(|o| o.top_p),
811            user: None,
812            web_search_options: None,
813        };
814
815        let response = self
816            .http_client
817            .post(&self.base_url)
818            .json(&request)
819            .send()
820            .await
821            .map_err(ApiError::Network)?;
822
823        if !response.status().is_success() {
824            let status = response.status();
825            let error_text = response.text().await.unwrap_or_else(|_| String::new());
826
827            debug!(
828                target: "xai::stream",
829                "xAI API error - Status: {}, Body: {}",
830                status,
831                error_text
832            );
833
834            return match status.as_u16() {
835                429 => Err(ApiError::RateLimited {
836                    provider: self.name().to_string(),
837                    details: error_text,
838                }),
839                400 => Err(ApiError::InvalidRequest {
840                    provider: self.name().to_string(),
841                    details: error_text,
842                }),
843                401 => Err(ApiError::AuthenticationFailed {
844                    provider: self.name().to_string(),
845                    details: error_text,
846                }),
847                _ => Err(ApiError::ServerError {
848                    provider: self.name().to_string(),
849                    status_code: status.as_u16(),
850                    details: error_text,
851                }),
852            };
853        }
854
855        let byte_stream = response.bytes_stream();
856        let sse_stream = parse_sse_stream(byte_stream);
857
858        Ok(Box::pin(XAIClient::convert_xai_stream(sse_stream, token)))
859    }
860}
861
862impl XAIClient {
863    fn convert_xai_stream(
864        mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
865        + Unpin
866        + Send
867        + 'static,
868        token: CancellationToken,
869    ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
870        struct ToolCallAccumulator {
871            id: String,
872            name: String,
873            args: String,
874        }
875
876        async_stream::stream! {
877            let mut content: Vec<AssistantContent> = Vec::new();
878            let mut tool_call_indices: Vec<Option<usize>> = Vec::new();
879            let mut tool_calls: HashMap<usize, ToolCallAccumulator> = HashMap::new();
880            let mut tool_calls_started: std::collections::HashSet<usize> =
881                std::collections::HashSet::new();
882            let mut tool_call_positions: HashMap<usize, usize> = HashMap::new();
883            loop {
884                if token.is_cancelled() {
885                    yield StreamChunk::Error(StreamError::Cancelled);
886                    break;
887                }
888
889                let event_result = tokio::select! {
890                    biased;
891                    () = token.cancelled() => {
892                        yield StreamChunk::Error(StreamError::Cancelled);
893                        break;
894                    }
895                    event = sse_stream.next() => event
896                };
897
898                let Some(event_result) = event_result else {
899                    break;
900                };
901
902                let event = match event_result {
903                    Ok(e) => e,
904                    Err(e) => {
905                        yield StreamChunk::Error(StreamError::SseParse(e));
906                        break;
907                    }
908                };
909
910                if event.data == "[DONE]" {
911                    let tool_calls = std::mem::take(&mut tool_calls);
912                    let mut final_content = Vec::new();
913
914                    for (block, tool_index) in content.into_iter().zip(tool_call_indices.into_iter())
915                    {
916                        if let Some(index) = tool_index {
917                            let Some(tool_call) = tool_calls.get(&index) else {
918                                continue;
919                            };
920                            if tool_call.id.is_empty() || tool_call.name.is_empty() {
921                                debug!(
922                                    target: "xai::stream",
923                                    "Skipping tool call with missing id/name: id='{}' name='{}'",
924                                    tool_call.id,
925                                    tool_call.name
926                                );
927                                continue;
928                            }
929                            let parameters = serde_json::from_str(&tool_call.args)
930                                .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
931                            final_content.push(AssistantContent::ToolCall {
932                                tool_call: steer_tools::ToolCall {
933                                    id: tool_call.id.clone(),
934                                    name: tool_call.name.clone(),
935                                    parameters,
936                                },
937                                thought_signature: None,
938                            });
939                        } else {
940                            final_content.push(block);
941                        }
942                    }
943
944                    yield StreamChunk::MessageComplete(CompletionResponse { content: final_content });
945                    break;
946                }
947
948                let chunk: XAIStreamChunk = match serde_json::from_str(&event.data) {
949                    Ok(c) => c,
950                    Err(e) => {
951                        debug!(target: "xai::stream", "Failed to parse chunk: {} data: {}", e, event.data);
952                        continue;
953                    }
954                };
955
956                if let Some(choice) = chunk.choices.first() {
957                    if let Some(text_delta) = &choice.delta.content {
958                        if let Some(AssistantContent::Text { text }) = content.last_mut() { text.push_str(text_delta) } else {
959                            content.push(AssistantContent::Text {
960                                text: text_delta.clone(),
961                            });
962                            tool_call_indices.push(None);
963                        }
964                        yield StreamChunk::TextDelta(text_delta.clone());
965                    }
966
967                    if let Some(thinking_delta) = &choice.delta.reasoning_content {
968                        if let Some(AssistantContent::Thought {
969                                thought: crate::app::conversation::ThoughtContent::Simple { text },
970                            }) = content.last_mut() { text.push_str(thinking_delta) } else {
971                            content.push(AssistantContent::Thought {
972                                thought: crate::app::conversation::ThoughtContent::Simple {
973                                    text: thinking_delta.clone(),
974                                },
975                            });
976                            tool_call_indices.push(None);
977                        }
978                        yield StreamChunk::ThinkingDelta(thinking_delta.clone());
979                    }
980
981                    if let Some(tcs) = &choice.delta.tool_calls {
982                        for tc in tcs {
983                            let entry = tool_calls.entry(tc.index).or_insert_with(|| {
984                                ToolCallAccumulator {
985                                    id: String::new(),
986                                    name: String::new(),
987                                    args: String::new(),
988                                }
989                            });
990                            let mut started_now = false;
991                            let mut flushed_now = false;
992
993                            if let Some(id) = &tc.id
994                                && !id.is_empty() {
995                                    entry.id.clone_from(id);
996                                }
997                            if let Some(func) = &tc.function
998                                && let Some(name) = &func.name
999                                    && !name.is_empty() {
1000                                        entry.name.clone_from(name);
1001                                    }
1002
1003                            if let std::collections::hash_map::Entry::Vacant(e) = tool_call_positions.entry(tc.index) {
1004                                let pos = content.len();
1005                                content.push(AssistantContent::ToolCall {
1006                                    tool_call: steer_tools::ToolCall {
1007                                        id: entry.id.clone(),
1008                                        name: entry.name.clone(),
1009                                        parameters: serde_json::Value::String(entry.args.clone()),
1010                                    },
1011                                    thought_signature: None,
1012                                });
1013                                tool_call_indices.push(Some(tc.index));
1014                                e.insert(pos);
1015                            }
1016
1017                            if !entry.id.is_empty()
1018                                && !entry.name.is_empty()
1019                                && !tool_calls_started.contains(&tc.index)
1020                            {
1021                                tool_calls_started.insert(tc.index);
1022                                started_now = true;
1023                                yield StreamChunk::ToolUseStart {
1024                                    id: entry.id.clone(),
1025                                    name: entry.name.clone(),
1026                                };
1027                            }
1028
1029                            if let Some(func) = &tc.function
1030                                && let Some(args) = &func.arguments {
1031                                    entry.args.push_str(args);
1032                                    if tool_calls_started.contains(&tc.index) {
1033                                        if started_now {
1034                                            if !entry.args.is_empty() {
1035                                                yield StreamChunk::ToolUseInputDelta {
1036                                                    id: entry.id.clone(),
1037                                                    delta: entry.args.clone(),
1038                                                };
1039                                                flushed_now = true;
1040                                            }
1041                                        } else if !args.is_empty() {
1042                                            yield StreamChunk::ToolUseInputDelta {
1043                                                id: entry.id.clone(),
1044                                                delta: args.clone(),
1045                                            };
1046                                        }
1047                                    }
1048                                }
1049
1050                            if started_now && !flushed_now && !entry.args.is_empty() {
1051                                yield StreamChunk::ToolUseInputDelta {
1052                                    id: entry.id.clone(),
1053                                    delta: entry.args.clone(),
1054                                };
1055                            }
1056                        }
1057                    }
1058                }
1059            }
1060        }
1061    }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066    use super::*;
1067    use crate::app::conversation::{ImageContent, ImageSource, Message, MessageData, UserContent};
1068
1069    #[test]
1070    fn test_convert_messages_includes_data_url_image_part() {
1071        let messages = vec![Message {
1072            data: MessageData::User {
1073                content: vec![
1074                    UserContent::Text {
1075                        text: "describe".to_string(),
1076                    },
1077                    UserContent::Image {
1078                        image: ImageContent {
1079                            mime_type: "image/png".to_string(),
1080                            source: ImageSource::DataUrl {
1081                                data_url: "".to_string(),
1082                            },
1083                            width: None,
1084                            height: None,
1085                            bytes: None,
1086                            sha256: None,
1087                        },
1088                    },
1089                ],
1090            },
1091            timestamp: 1,
1092            id: "msg-1".to_string(),
1093            parent_message_id: None,
1094        }];
1095
1096        let converted = XAIClient::convert_messages(messages, None).expect("convert messages");
1097        assert_eq!(converted.len(), 1);
1098
1099        match &converted[0] {
1100            XAIMessage::User { content, .. } => match content {
1101                XAIUserContent::Parts(parts) => {
1102                    assert_eq!(parts.len(), 2);
1103                    assert!(matches!(
1104                        &parts[0],
1105                        XAIUserContentPart::Text { text } if text == "describe"
1106                    ));
1107                    assert!(matches!(
1108                        &parts[1],
1109                        XAIUserContentPart::ImageUrl { image_url }
1110                            if image_url.url == ""
1111                    ));
1112                }
1113                other => panic!("Expected parts content, got {other:?}"),
1114            },
1115            other => panic!("Expected user message, got {other:?}"),
1116        }
1117    }
1118
1119    #[test]
1120    fn test_convert_messages_rejects_session_file_image_source() {
1121        let messages = vec![Message {
1122            data: MessageData::User {
1123                content: vec![UserContent::Image {
1124                    image: ImageContent {
1125                        mime_type: "image/png".to_string(),
1126                        source: ImageSource::SessionFile {
1127                            relative_path: "session-1/image.png".to_string(),
1128                        },
1129                        width: None,
1130                        height: None,
1131                        bytes: None,
1132                        sha256: None,
1133                    },
1134                }],
1135            },
1136            timestamp: 1,
1137            id: "msg-1".to_string(),
1138            parent_message_id: None,
1139        }];
1140
1141        let err =
1142            XAIClient::convert_messages(messages, None).expect_err("expected unsupported feature");
1143        match err {
1144            ApiError::UnsupportedFeature {
1145                provider,
1146                feature,
1147                details,
1148            } => {
1149                assert_eq!(provider, "xai");
1150                assert_eq!(feature, "image input source");
1151                assert!(details.contains("session file"));
1152            }
1153            other => panic!("Expected UnsupportedFeature, got {other:?}"),
1154        }
1155    }
1156}