steer_core/api/xai/
client.rs

1use async_trait::async_trait;
2use reqwest::{self, header};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error};
7
8use crate::api::provider::{CompletionResponse, Provider};
9use crate::api::{Model, error::ApiError};
10use crate::app::conversation::{AssistantContent, Message as AppMessage, ToolResult, UserContent};
11use steer_tools::ToolSchema;
12
13const API_URL: &str = "https://api.x.ai/v1/chat/completions";
14
15#[derive(Clone)]
16pub struct XAIClient {
17    http_client: reqwest::Client,
18}
19
20// xAI-specific message format (similar to OpenAI but with some differences)
21#[derive(Debug, Serialize, Deserialize)]
22#[serde(tag = "role", rename_all = "lowercase")]
23enum XAIMessage {
24    System {
25        content: String,
26        #[serde(skip_serializing_if = "Option::is_none")]
27        name: Option<String>,
28    },
29    User {
30        content: String,
31        #[serde(skip_serializing_if = "Option::is_none")]
32        name: Option<String>,
33    },
34    Assistant {
35        #[serde(skip_serializing_if = "Option::is_none")]
36        content: Option<String>,
37        #[serde(skip_serializing_if = "Option::is_none")]
38        tool_calls: Option<Vec<XAIToolCall>>,
39        #[serde(skip_serializing_if = "Option::is_none")]
40        name: Option<String>,
41    },
42    Tool {
43        content: String,
44        tool_call_id: String,
45        #[serde(skip_serializing_if = "Option::is_none")]
46        name: Option<String>,
47    },
48}
49
50// xAI function calling format
51#[derive(Debug, Serialize, Deserialize)]
52struct XAIFunction {
53    name: String,
54    description: String,
55    parameters: serde_json::Value,
56}
57
58// xAI tool format
59#[derive(Debug, Serialize, Deserialize)]
60struct XAITool {
61    #[serde(rename = "type")]
62    tool_type: String, // "function"
63    function: XAIFunction,
64}
65
66// xAI tool call
67#[derive(Debug, Serialize, Deserialize)]
68struct XAIToolCall {
69    id: String,
70    #[serde(rename = "type")]
71    tool_type: String,
72    function: XAIFunctionCall,
73}
74
75#[derive(Debug, Serialize, Deserialize)]
76struct XAIFunctionCall {
77    name: String,
78    arguments: String, // JSON string
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82#[serde(rename_all = "lowercase")]
83enum ReasoningEffort {
84    Low,
85    High,
86}
87
88#[derive(Debug, Serialize, Deserialize)]
89struct StreamOptions {
90    #[serde(skip_serializing_if = "Option::is_none")]
91    include_usage: Option<bool>,
92}
93
94#[derive(Debug, Serialize, Deserialize)]
95#[serde(untagged)]
96enum ToolChoice {
97    String(String), // "auto", "required", "none"
98    Specific {
99        #[serde(rename = "type")]
100        tool_type: String,
101        function: ToolChoiceFunction,
102    },
103}
104
105#[derive(Debug, Serialize, Deserialize)]
106struct ToolChoiceFunction {
107    name: String,
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111struct ResponseFormat {
112    #[serde(rename = "type")]
113    format_type: String,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    json_schema: Option<serde_json::Value>,
116}
117
118#[derive(Debug, Serialize, Deserialize)]
119struct SearchParameters {
120    #[serde(skip_serializing_if = "Option::is_none")]
121    from_date: Option<String>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    to_date: Option<String>,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    max_search_results: Option<u32>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    mode: Option<String>,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    return_citations: Option<bool>,
130    #[serde(skip_serializing_if = "Option::is_none")]
131    sources: Option<Vec<String>>,
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135struct WebSearchOptions {
136    #[serde(skip_serializing_if = "Option::is_none")]
137    search_context_size: Option<u32>,
138    #[serde(skip_serializing_if = "Option::is_none")]
139    user_location: Option<String>,
140}
141
142#[derive(Debug, Serialize, Deserialize)]
143struct CompletionRequest {
144    model: String,
145    messages: Vec<XAIMessage>,
146    #[serde(skip_serializing_if = "Option::is_none")]
147    deferred: Option<bool>,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    frequency_penalty: Option<f32>,
150    #[serde(skip_serializing_if = "Option::is_none")]
151    logit_bias: Option<HashMap<String, f32>>,
152    #[serde(skip_serializing_if = "Option::is_none")]
153    logprobs: Option<bool>,
154    #[serde(skip_serializing_if = "Option::is_none")]
155    max_completion_tokens: Option<u32>,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    max_tokens: Option<u32>,
158    #[serde(skip_serializing_if = "Option::is_none")]
159    n: Option<u32>,
160    #[serde(skip_serializing_if = "Option::is_none")]
161    parallel_tool_calls: Option<bool>,
162    #[serde(skip_serializing_if = "Option::is_none")]
163    presence_penalty: Option<f32>,
164    #[serde(skip_serializing_if = "Option::is_none")]
165    reasoning_effort: Option<ReasoningEffort>,
166    #[serde(skip_serializing_if = "Option::is_none")]
167    response_format: Option<ResponseFormat>,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    search_parameters: Option<SearchParameters>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    seed: Option<u64>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    stop: Option<Vec<String>>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    stream: Option<bool>,
176    #[serde(skip_serializing_if = "Option::is_none")]
177    stream_options: Option<StreamOptions>,
178    #[serde(skip_serializing_if = "Option::is_none")]
179    temperature: Option<f32>,
180    #[serde(skip_serializing_if = "Option::is_none")]
181    tool_choice: Option<ToolChoice>,
182    #[serde(skip_serializing_if = "Option::is_none")]
183    tools: Option<Vec<XAITool>>,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    top_logprobs: Option<u32>,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    top_p: Option<f32>,
188    #[serde(skip_serializing_if = "Option::is_none")]
189    user: Option<String>,
190    #[serde(skip_serializing_if = "Option::is_none")]
191    web_search_options: Option<WebSearchOptions>,
192}
193
194#[derive(Debug, Serialize, Deserialize)]
195struct XAICompletionResponse {
196    id: String,
197    object: String,
198    created: u64,
199    model: String,
200    choices: Vec<Choice>,
201    #[serde(skip_serializing_if = "Option::is_none")]
202    usage: Option<XAIUsage>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    system_fingerprint: Option<String>,
205    #[serde(skip_serializing_if = "Option::is_none")]
206    citations: Option<Vec<serde_json::Value>>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    debug_output: Option<DebugOutput>,
209}
210
211#[derive(Debug, Serialize, Deserialize)]
212struct Choice {
213    index: i32,
214    message: AssistantMessage,
215    finish_reason: Option<String>,
216}
217
218#[derive(Debug, Serialize, Deserialize)]
219struct AssistantMessage {
220    content: Option<String>,
221    #[serde(skip_serializing_if = "Option::is_none")]
222    tool_calls: Option<Vec<XAIToolCall>>,
223    #[serde(skip_serializing_if = "Option::is_none")]
224    reasoning_content: Option<String>,
225}
226
227#[derive(Debug, Serialize, Deserialize)]
228struct PromptTokensDetails {
229    cached_tokens: usize,
230    audio_tokens: usize,
231    image_tokens: usize,
232    text_tokens: usize,
233}
234
235#[derive(Debug, Serialize, Deserialize)]
236struct CompletionTokensDetails {
237    reasoning_tokens: usize,
238    audio_tokens: usize,
239    accepted_prediction_tokens: usize,
240    rejected_prediction_tokens: usize,
241}
242
243#[derive(Debug, Serialize, Deserialize)]
244struct XAIUsage {
245    prompt_tokens: usize,
246    completion_tokens: usize,
247    total_tokens: usize,
248    #[serde(skip_serializing_if = "Option::is_none")]
249    num_sources_used: Option<usize>,
250    #[serde(skip_serializing_if = "Option::is_none")]
251    prompt_tokens_details: Option<PromptTokensDetails>,
252    #[serde(skip_serializing_if = "Option::is_none")]
253    completion_tokens_details: Option<CompletionTokensDetails>,
254}
255
256#[derive(Debug, Serialize, Deserialize)]
257struct DebugOutput {
258    attempts: usize,
259    cache_read_count: usize,
260    cache_read_input_bytes: usize,
261    cache_write_count: usize,
262    cache_write_input_bytes: usize,
263    prompt: String,
264    request: String,
265    responses: Vec<String>,
266}
267
268impl XAIClient {
269    pub fn new(api_key: String) -> Self {
270        let mut headers = header::HeaderMap::new();
271        headers.insert(
272            header::AUTHORIZATION,
273            header::HeaderValue::from_str(&format!("Bearer {api_key}"))
274                .expect("Invalid API key format"),
275        );
276
277        let client = reqwest::Client::builder()
278            .default_headers(headers)
279            .timeout(std::time::Duration::from_secs(300)) // 5 minute timeout
280            .build()
281            .expect("Failed to build HTTP client");
282
283        Self {
284            http_client: client,
285        }
286    }
287
288    fn convert_messages(
289        &self,
290        messages: Vec<AppMessage>,
291        system: Option<String>,
292    ) -> Vec<XAIMessage> {
293        let mut xai_messages = Vec::new();
294
295        // Add system message if provided
296        if let Some(system_content) = system {
297            xai_messages.push(XAIMessage::System {
298                content: system_content,
299                name: None,
300            });
301        }
302
303        // Convert our messages to xAI format
304        for message in messages {
305            match &message.data {
306                crate::app::conversation::MessageData::User { content, .. } => {
307                    // Convert UserContent to text
308                    let combined_text = content
309                        .iter()
310                        .filter_map(|user_content| match user_content {
311                            UserContent::Text { text } => Some(text.clone()),
312                            UserContent::CommandExecution {
313                                command,
314                                stdout,
315                                stderr,
316                                exit_code,
317                            } => Some(UserContent::format_command_execution_as_xml(
318                                command, stdout, stderr, *exit_code,
319                            )),
320                            UserContent::AppCommand { .. } => {
321                                // Don't send app commands to the model - they're for local execution only
322                                None
323                            }
324                        })
325                        .collect::<Vec<_>>()
326                        .join("\n");
327
328                    // Only add the message if it has content after filtering
329                    if !combined_text.trim().is_empty() {
330                        xai_messages.push(XAIMessage::User {
331                            content: combined_text,
332                            name: None,
333                        });
334                    }
335                }
336                crate::app::conversation::MessageData::Assistant { content, .. } => {
337                    // Convert AssistantContent to xAI format
338                    let mut text_parts = Vec::new();
339                    let mut tool_calls = Vec::new();
340
341                    for content_block in content {
342                        match content_block {
343                            AssistantContent::Text { text } => {
344                                text_parts.push(text.clone());
345                            }
346                            AssistantContent::ToolCall { tool_call } => {
347                                tool_calls.push(XAIToolCall {
348                                    id: tool_call.id.clone(),
349                                    tool_type: "function".to_string(),
350                                    function: XAIFunctionCall {
351                                        name: tool_call.name.clone(),
352                                        arguments: tool_call.parameters.to_string(),
353                                    },
354                                });
355                            }
356                            AssistantContent::Thought { .. } => {
357                                // xAI doesn't support thinking blocks in requests, only in responses
358                                continue;
359                            }
360                        }
361                    }
362
363                    // Build the assistant message
364                    let content = if text_parts.is_empty() {
365                        None
366                    } else {
367                        Some(text_parts.join("\n"))
368                    };
369
370                    let tool_calls_opt = if tool_calls.is_empty() {
371                        None
372                    } else {
373                        Some(tool_calls)
374                    };
375
376                    xai_messages.push(XAIMessage::Assistant {
377                        content,
378                        tool_calls: tool_calls_opt,
379                        name: None,
380                    });
381                }
382                crate::app::conversation::MessageData::Tool {
383                    tool_use_id,
384                    result,
385                    ..
386                } => {
387                    // Convert ToolResult to xAI format
388                    let content_text = match result {
389                        ToolResult::Error(e) => format!("Error: {e}"),
390                        _ => {
391                            let text = result.llm_format();
392                            if text.trim().is_empty() {
393                                "(No output)".to_string()
394                            } else {
395                                text
396                            }
397                        }
398                    };
399
400                    xai_messages.push(XAIMessage::Tool {
401                        content: content_text,
402                        tool_call_id: tool_use_id.clone(),
403                        name: None,
404                    });
405                }
406            }
407        }
408
409        xai_messages
410    }
411
412    fn convert_tools(&self, tools: Vec<ToolSchema>) -> Vec<XAITool> {
413        tools
414            .into_iter()
415            .map(|tool| XAITool {
416                tool_type: "function".to_string(),
417                function: XAIFunction {
418                    name: tool.name,
419                    description: tool.description,
420                    parameters: serde_json::json!({
421                        "type": tool.input_schema.schema_type,
422                        "properties": tool.input_schema.properties,
423                        "required": tool.input_schema.required,
424                    }),
425                },
426            })
427            .collect()
428    }
429}
430
431#[async_trait]
432impl Provider for XAIClient {
433    fn name(&self) -> &'static str {
434        "xai"
435    }
436
437    async fn complete(
438        &self,
439        model: Model,
440        messages: Vec<AppMessage>,
441        system: Option<String>,
442        tools: Option<Vec<ToolSchema>>,
443        token: CancellationToken,
444    ) -> Result<CompletionResponse, ApiError> {
445        let xai_messages = self.convert_messages(messages, system);
446        let xai_tools = tools.map(|t| self.convert_tools(t));
447
448        // grok-4 supports thinking by default but not the reasoning_effort parameter
449        let reasoning_effort = if model.supports_thinking() && !matches!(model, Model::Grok4_0709) {
450            Some(ReasoningEffort::High)
451        } else {
452            None
453        };
454
455        let request = CompletionRequest {
456            model: model.as_ref().to_string(),
457            messages: xai_messages,
458            deferred: None,
459            frequency_penalty: None,
460            logit_bias: None,
461            logprobs: None,
462            max_completion_tokens: Some(32768),
463            max_tokens: None,
464            n: None,
465            parallel_tool_calls: None,
466            presence_penalty: None,
467            reasoning_effort,
468            response_format: None,
469            search_parameters: None,
470            seed: None,
471            stop: None,
472            stream: None,
473            stream_options: None,
474            temperature: Some(1.0),
475            tool_choice: None,
476            tools: xai_tools,
477            top_logprobs: None,
478            top_p: None,
479            user: None,
480            web_search_options: None,
481        };
482
483        let response = self
484            .http_client
485            .post(API_URL)
486            .json(&request)
487            .send()
488            .await
489            .map_err(ApiError::Network)?;
490
491        if !response.status().is_success() {
492            let status = response.status();
493            let error_text = response.text().await.unwrap_or_else(|_| String::new());
494
495            debug!(
496                target: "grok::complete",
497                "Grok API error - Status: {}, Body: {}",
498                status,
499                error_text
500            );
501
502            return match status.as_u16() {
503                429 => Err(ApiError::RateLimited {
504                    provider: self.name().to_string(),
505                    details: error_text,
506                }),
507                400 => Err(ApiError::InvalidRequest {
508                    provider: self.name().to_string(),
509                    details: error_text,
510                }),
511                401 => Err(ApiError::AuthenticationFailed {
512                    provider: self.name().to_string(),
513                    details: error_text,
514                }),
515                _ => Err(ApiError::ServerError {
516                    provider: self.name().to_string(),
517                    status_code: status.as_u16(),
518                    details: error_text,
519                }),
520            };
521        }
522
523        let response_text = tokio::select! {
524            _ = token.cancelled() => {
525                debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
526                return Err(ApiError::Cancelled { provider: self.name().to_string() });
527            }
528            text_res = response.text() => {
529                text_res?
530            }
531        };
532
533        let xai_response: XAICompletionResponse =
534            serde_json::from_str(&response_text).map_err(|e| {
535                error!(
536                    target: "xai::complete",
537                    "Failed to parse response: {}, Body: {}",
538                    e,
539                    response_text
540                );
541                ApiError::ResponseParsingError {
542                    provider: self.name().to_string(),
543                    details: format!("Error: {e}, Body: {response_text}"),
544                }
545            })?;
546
547        // Convert xAI response to our CompletionResponse
548        if let Some(choice) = xai_response.choices.first() {
549            let mut content_blocks = Vec::new();
550
551            // Add reasoning content (thinking) first if present
552            if let Some(reasoning) = &choice.message.reasoning_content {
553                if !reasoning.trim().is_empty() {
554                    content_blocks.push(AssistantContent::Thought {
555                        thought: crate::app::conversation::ThoughtContent::Simple {
556                            text: reasoning.clone(),
557                        },
558                    });
559                }
560            }
561
562            // Add regular content
563            if let Some(content) = &choice.message.content {
564                if !content.trim().is_empty() {
565                    content_blocks.push(AssistantContent::Text {
566                        text: content.clone(),
567                    });
568                }
569            }
570
571            // Add tool calls
572            if let Some(tool_calls) = &choice.message.tool_calls {
573                for tool_call in tool_calls {
574                    // Parse the arguments JSON string
575                    let parameters = serde_json::from_str(&tool_call.function.arguments)
576                        .unwrap_or(serde_json::Value::Null);
577
578                    content_blocks.push(AssistantContent::ToolCall {
579                        tool_call: steer_tools::ToolCall {
580                            id: tool_call.id.clone(),
581                            name: tool_call.function.name.clone(),
582                            parameters,
583                        },
584                    });
585                }
586            }
587
588            Ok(crate::api::provider::CompletionResponse {
589                content: content_blocks,
590            })
591        } else {
592            Err(ApiError::NoChoices {
593                provider: self.name().to_string(),
594            })
595        }
596    }
597}