steer_core/api/xai/
client.rs

1use async_trait::async_trait;
2use reqwest::{self, header};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error};
7
8use crate::api::error::ApiError;
9use crate::api::provider::{CompletionResponse, Provider};
10use crate::api::util::normalize_chat_url;
11use crate::app::conversation::{AssistantContent, Message as AppMessage, ToolResult, UserContent};
12use crate::config::model::{ModelId, ModelParameters};
13use steer_tools::ToolSchema;
14
15const DEFAULT_API_URL: &str = "https://api.x.ai/v1/chat/completions";
16
17#[derive(Clone)]
18pub struct XAIClient {
19    http_client: reqwest::Client,
20    base_url: String,
21}
22
23// xAI-specific message format (similar to OpenAI but with some differences)
24#[derive(Debug, Serialize, Deserialize)]
25#[serde(tag = "role", rename_all = "lowercase")]
26enum XAIMessage {
27    System {
28        content: String,
29        #[serde(skip_serializing_if = "Option::is_none")]
30        name: Option<String>,
31    },
32    User {
33        content: String,
34        #[serde(skip_serializing_if = "Option::is_none")]
35        name: Option<String>,
36    },
37    Assistant {
38        #[serde(skip_serializing_if = "Option::is_none")]
39        content: Option<String>,
40        #[serde(skip_serializing_if = "Option::is_none")]
41        tool_calls: Option<Vec<XAIToolCall>>,
42        #[serde(skip_serializing_if = "Option::is_none")]
43        name: Option<String>,
44    },
45    Tool {
46        content: String,
47        tool_call_id: String,
48        #[serde(skip_serializing_if = "Option::is_none")]
49        name: Option<String>,
50    },
51}
52
53// xAI function calling format
54#[derive(Debug, Serialize, Deserialize)]
55struct XAIFunction {
56    name: String,
57    description: String,
58    parameters: serde_json::Value,
59}
60
61// xAI tool format
62#[derive(Debug, Serialize, Deserialize)]
63struct XAITool {
64    #[serde(rename = "type")]
65    tool_type: String, // "function"
66    function: XAIFunction,
67}
68
69// xAI tool call
70#[derive(Debug, Serialize, Deserialize)]
71struct XAIToolCall {
72    id: String,
73    #[serde(rename = "type")]
74    tool_type: String,
75    function: XAIFunctionCall,
76}
77
78#[derive(Debug, Serialize, Deserialize)]
79struct XAIFunctionCall {
80    name: String,
81    arguments: String, // JSON string
82}
83
84#[derive(Debug, Serialize, Deserialize)]
85#[serde(rename_all = "lowercase")]
86enum ReasoningEffort {
87    Low,
88    High,
89}
90
91#[derive(Debug, Serialize, Deserialize)]
92struct StreamOptions {
93    #[serde(skip_serializing_if = "Option::is_none")]
94    include_usage: Option<bool>,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98#[serde(untagged)]
99enum ToolChoice {
100    String(String), // "auto", "required", "none"
101    Specific {
102        #[serde(rename = "type")]
103        tool_type: String,
104        function: ToolChoiceFunction,
105    },
106}
107
108#[derive(Debug, Serialize, Deserialize)]
109struct ToolChoiceFunction {
110    name: String,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114struct ResponseFormat {
115    #[serde(rename = "type")]
116    format_type: String,
117    #[serde(skip_serializing_if = "Option::is_none")]
118    json_schema: Option<serde_json::Value>,
119}
120
121#[derive(Debug, Serialize, Deserialize)]
122struct SearchParameters {
123    #[serde(skip_serializing_if = "Option::is_none")]
124    from_date: Option<String>,
125    #[serde(skip_serializing_if = "Option::is_none")]
126    to_date: Option<String>,
127    #[serde(skip_serializing_if = "Option::is_none")]
128    max_search_results: Option<u32>,
129    #[serde(skip_serializing_if = "Option::is_none")]
130    mode: Option<String>,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    return_citations: Option<bool>,
133    #[serde(skip_serializing_if = "Option::is_none")]
134    sources: Option<Vec<String>>,
135}
136
137#[derive(Debug, Serialize, Deserialize)]
138struct WebSearchOptions {
139    #[serde(skip_serializing_if = "Option::is_none")]
140    search_context_size: Option<u32>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    user_location: Option<String>,
143}
144
145#[derive(Debug, Serialize, Deserialize)]
146struct CompletionRequest {
147    model: String,
148    messages: Vec<XAIMessage>,
149    #[serde(skip_serializing_if = "Option::is_none")]
150    deferred: Option<bool>,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    frequency_penalty: Option<f32>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    logit_bias: Option<HashMap<String, f32>>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    logprobs: Option<bool>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    max_completion_tokens: Option<u32>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    max_tokens: Option<u32>,
161    #[serde(skip_serializing_if = "Option::is_none")]
162    n: Option<u32>,
163    #[serde(skip_serializing_if = "Option::is_none")]
164    parallel_tool_calls: Option<bool>,
165    #[serde(skip_serializing_if = "Option::is_none")]
166    presence_penalty: Option<f32>,
167    #[serde(skip_serializing_if = "Option::is_none")]
168    reasoning_effort: Option<ReasoningEffort>,
169    #[serde(skip_serializing_if = "Option::is_none")]
170    response_format: Option<ResponseFormat>,
171    #[serde(skip_serializing_if = "Option::is_none")]
172    search_parameters: Option<SearchParameters>,
173    #[serde(skip_serializing_if = "Option::is_none")]
174    seed: Option<u64>,
175    #[serde(skip_serializing_if = "Option::is_none")]
176    stop: Option<Vec<String>>,
177    #[serde(skip_serializing_if = "Option::is_none")]
178    stream: Option<bool>,
179    #[serde(skip_serializing_if = "Option::is_none")]
180    stream_options: Option<StreamOptions>,
181    #[serde(skip_serializing_if = "Option::is_none")]
182    temperature: Option<f32>,
183    #[serde(skip_serializing_if = "Option::is_none")]
184    tool_choice: Option<ToolChoice>,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    tools: Option<Vec<XAITool>>,
187    #[serde(skip_serializing_if = "Option::is_none")]
188    top_logprobs: Option<u32>,
189    #[serde(skip_serializing_if = "Option::is_none")]
190    top_p: Option<f32>,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    user: Option<String>,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    web_search_options: Option<WebSearchOptions>,
195}
196
197#[derive(Debug, Serialize, Deserialize)]
198struct XAICompletionResponse {
199    id: String,
200    object: String,
201    created: u64,
202    model: String,
203    choices: Vec<Choice>,
204    #[serde(skip_serializing_if = "Option::is_none")]
205    usage: Option<XAIUsage>,
206    #[serde(skip_serializing_if = "Option::is_none")]
207    system_fingerprint: Option<String>,
208    #[serde(skip_serializing_if = "Option::is_none")]
209    citations: Option<Vec<serde_json::Value>>,
210    #[serde(skip_serializing_if = "Option::is_none")]
211    debug_output: Option<DebugOutput>,
212}
213
214#[derive(Debug, Serialize, Deserialize)]
215struct Choice {
216    index: i32,
217    message: AssistantMessage,
218    finish_reason: Option<String>,
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222struct AssistantMessage {
223    content: Option<String>,
224    #[serde(skip_serializing_if = "Option::is_none")]
225    tool_calls: Option<Vec<XAIToolCall>>,
226    #[serde(skip_serializing_if = "Option::is_none")]
227    reasoning_content: Option<String>,
228}
229
230#[derive(Debug, Serialize, Deserialize)]
231struct PromptTokensDetails {
232    cached_tokens: usize,
233    audio_tokens: usize,
234    image_tokens: usize,
235    text_tokens: usize,
236}
237
238#[derive(Debug, Serialize, Deserialize)]
239struct CompletionTokensDetails {
240    reasoning_tokens: usize,
241    audio_tokens: usize,
242    accepted_prediction_tokens: usize,
243    rejected_prediction_tokens: usize,
244}
245
246#[derive(Debug, Serialize, Deserialize)]
247struct XAIUsage {
248    prompt_tokens: usize,
249    completion_tokens: usize,
250    total_tokens: usize,
251    #[serde(skip_serializing_if = "Option::is_none")]
252    num_sources_used: Option<usize>,
253    #[serde(skip_serializing_if = "Option::is_none")]
254    prompt_tokens_details: Option<PromptTokensDetails>,
255    #[serde(skip_serializing_if = "Option::is_none")]
256    completion_tokens_details: Option<CompletionTokensDetails>,
257}
258
259#[derive(Debug, Serialize, Deserialize)]
260struct DebugOutput {
261    attempts: usize,
262    cache_read_count: usize,
263    cache_read_input_bytes: usize,
264    cache_write_count: usize,
265    cache_write_input_bytes: usize,
266    prompt: String,
267    request: String,
268    responses: Vec<String>,
269}
270
271impl XAIClient {
272    pub fn new(api_key: String) -> Self {
273        Self::with_base_url(api_key, None)
274    }
275
276    pub fn with_base_url(api_key: String, base_url: Option<String>) -> Self {
277        let mut headers = header::HeaderMap::new();
278        headers.insert(
279            header::AUTHORIZATION,
280            header::HeaderValue::from_str(&format!("Bearer {api_key}"))
281                .expect("Invalid API key format"),
282        );
283
284        let client = reqwest::Client::builder()
285            .default_headers(headers)
286            .timeout(std::time::Duration::from_secs(300)) // 5 minute timeout
287            .build()
288            .expect("Failed to build HTTP client");
289
290        let base_url = normalize_chat_url(base_url.as_deref(), DEFAULT_API_URL);
291
292        Self {
293            http_client: client,
294            base_url,
295        }
296    }
297
298    fn convert_messages(
299        &self,
300        messages: Vec<AppMessage>,
301        system: Option<String>,
302    ) -> Vec<XAIMessage> {
303        let mut xai_messages = Vec::new();
304
305        // Add system message if provided
306        if let Some(system_content) = system {
307            xai_messages.push(XAIMessage::System {
308                content: system_content,
309                name: None,
310            });
311        }
312
313        // Convert our messages to xAI format
314        for message in messages {
315            match &message.data {
316                crate::app::conversation::MessageData::User { content, .. } => {
317                    // Convert UserContent to text
318                    let combined_text = content
319                        .iter()
320                        .filter_map(|user_content| match user_content {
321                            UserContent::Text { text } => Some(text.clone()),
322                            UserContent::CommandExecution {
323                                command,
324                                stdout,
325                                stderr,
326                                exit_code,
327                            } => Some(UserContent::format_command_execution_as_xml(
328                                command, stdout, stderr, *exit_code,
329                            )),
330                            UserContent::AppCommand { .. } => {
331                                // Don't send app commands to the model - they're for local execution only
332                                None
333                            }
334                        })
335                        .collect::<Vec<_>>()
336                        .join("\n");
337
338                    // Only add the message if it has content after filtering
339                    if !combined_text.trim().is_empty() {
340                        xai_messages.push(XAIMessage::User {
341                            content: combined_text,
342                            name: None,
343                        });
344                    }
345                }
346                crate::app::conversation::MessageData::Assistant { content, .. } => {
347                    // Convert AssistantContent to xAI format
348                    let mut text_parts = Vec::new();
349                    let mut tool_calls = Vec::new();
350
351                    for content_block in content {
352                        match content_block {
353                            AssistantContent::Text { text } => {
354                                text_parts.push(text.clone());
355                            }
356                            AssistantContent::ToolCall { tool_call } => {
357                                tool_calls.push(XAIToolCall {
358                                    id: tool_call.id.clone(),
359                                    tool_type: "function".to_string(),
360                                    function: XAIFunctionCall {
361                                        name: tool_call.name.clone(),
362                                        arguments: tool_call.parameters.to_string(),
363                                    },
364                                });
365                            }
366                            AssistantContent::Thought { .. } => {
367                                // xAI doesn't support thinking blocks in requests, only in responses
368                                continue;
369                            }
370                        }
371                    }
372
373                    // Build the assistant message
374                    let content = if text_parts.is_empty() {
375                        None
376                    } else {
377                        Some(text_parts.join("\n"))
378                    };
379
380                    let tool_calls_opt = if tool_calls.is_empty() {
381                        None
382                    } else {
383                        Some(tool_calls)
384                    };
385
386                    xai_messages.push(XAIMessage::Assistant {
387                        content,
388                        tool_calls: tool_calls_opt,
389                        name: None,
390                    });
391                }
392                crate::app::conversation::MessageData::Tool {
393                    tool_use_id,
394                    result,
395                    ..
396                } => {
397                    // Convert ToolResult to xAI format
398                    let content_text = match result {
399                        ToolResult::Error(e) => format!("Error: {e}"),
400                        _ => {
401                            let text = result.llm_format();
402                            if text.trim().is_empty() {
403                                "(No output)".to_string()
404                            } else {
405                                text
406                            }
407                        }
408                    };
409
410                    xai_messages.push(XAIMessage::Tool {
411                        content: content_text,
412                        tool_call_id: tool_use_id.clone(),
413                        name: None,
414                    });
415                }
416            }
417        }
418
419        xai_messages
420    }
421
422    fn convert_tools(&self, tools: Vec<ToolSchema>) -> Vec<XAITool> {
423        tools
424            .into_iter()
425            .map(|tool| XAITool {
426                tool_type: "function".to_string(),
427                function: XAIFunction {
428                    name: tool.name,
429                    description: tool.description,
430                    parameters: serde_json::json!({
431                        "type": tool.input_schema.schema_type,
432                        "properties": tool.input_schema.properties,
433                        "required": tool.input_schema.required,
434                    }),
435                },
436            })
437            .collect()
438    }
439}
440
441#[async_trait]
442impl Provider for XAIClient {
443    fn name(&self) -> &'static str {
444        "xai"
445    }
446
447    async fn complete(
448        &self,
449        model_id: &ModelId,
450        messages: Vec<AppMessage>,
451        system: Option<String>,
452        tools: Option<Vec<ToolSchema>>,
453        call_options: Option<ModelParameters>,
454        token: CancellationToken,
455    ) -> Result<CompletionResponse, ApiError> {
456        let xai_messages = self.convert_messages(messages, system);
457        let xai_tools = tools.map(|t| self.convert_tools(t));
458
459        // Extract thinking support and map optional effort
460        let (supports_thinking, reasoning_effort) = call_options
461            .as_ref()
462            .and_then(|opts| opts.thinking_config)
463            .map(|tc| {
464                let effort = tc.effort.map(|e| match e {
465                    crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
466                    crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High, // xAI has Low/High only
467                    crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
468                });
469                (tc.enabled, effort)
470            })
471            .unwrap_or((false, None));
472
473        // grok-4 supports thinking by default but not the reasoning_effort parameter
474        let reasoning_effort = if supports_thinking && model_id.1 != "grok-4-0709" {
475            reasoning_effort.or(Some(ReasoningEffort::High))
476        } else {
477            None
478        };
479
480        let request = CompletionRequest {
481            model: model_id.1.clone(), // Use the model ID string
482            messages: xai_messages,
483            deferred: None,
484            frequency_penalty: None,
485            logit_bias: None,
486            logprobs: None,
487            max_completion_tokens: Some(32768),
488            max_tokens: None,
489            n: None,
490            parallel_tool_calls: None,
491            presence_penalty: None,
492            reasoning_effort,
493            response_format: None,
494            search_parameters: None,
495            seed: None,
496            stop: None,
497            stream: None,
498            stream_options: None,
499            temperature: call_options
500                .as_ref()
501                .and_then(|o| o.temperature)
502                .or(Some(1.0)),
503            tool_choice: None,
504            tools: xai_tools,
505            top_logprobs: None,
506            top_p: call_options.as_ref().and_then(|o| o.top_p),
507            user: None,
508            web_search_options: None,
509        };
510
511        let response = self
512            .http_client
513            .post(&self.base_url)
514            .json(&request)
515            .send()
516            .await
517            .map_err(ApiError::Network)?;
518
519        if !response.status().is_success() {
520            let status = response.status();
521            let error_text = response.text().await.unwrap_or_else(|_| String::new());
522
523            debug!(
524                target: "grok::complete",
525                "Grok API error - Status: {}, Body: {}",
526                status,
527                error_text
528            );
529
530            return match status.as_u16() {
531                429 => Err(ApiError::RateLimited {
532                    provider: self.name().to_string(),
533                    details: error_text,
534                }),
535                400 => Err(ApiError::InvalidRequest {
536                    provider: self.name().to_string(),
537                    details: error_text,
538                }),
539                401 => Err(ApiError::AuthenticationFailed {
540                    provider: self.name().to_string(),
541                    details: error_text,
542                }),
543                _ => Err(ApiError::ServerError {
544                    provider: self.name().to_string(),
545                    status_code: status.as_u16(),
546                    details: error_text,
547                }),
548            };
549        }
550
551        let response_text = tokio::select! {
552            _ = token.cancelled() => {
553                debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
554                return Err(ApiError::Cancelled { provider: self.name().to_string() });
555            }
556            text_res = response.text() => {
557                text_res?
558            }
559        };
560
561        let xai_response: XAICompletionResponse =
562            serde_json::from_str(&response_text).map_err(|e| {
563                error!(
564                    target: "xai::complete",
565                    "Failed to parse response: {}, Body: {}",
566                    e,
567                    response_text
568                );
569                ApiError::ResponseParsingError {
570                    provider: self.name().to_string(),
571                    details: format!("Error: {e}, Body: {response_text}"),
572                }
573            })?;
574
575        // Convert xAI response to our CompletionResponse
576        if let Some(choice) = xai_response.choices.first() {
577            let mut content_blocks = Vec::new();
578
579            // Add reasoning content (thinking) first if present
580            if let Some(reasoning) = &choice.message.reasoning_content {
581                if !reasoning.trim().is_empty() {
582                    content_blocks.push(AssistantContent::Thought {
583                        thought: crate::app::conversation::ThoughtContent::Simple {
584                            text: reasoning.clone(),
585                        },
586                    });
587                }
588            }
589
590            // Add regular content
591            if let Some(content) = &choice.message.content {
592                if !content.trim().is_empty() {
593                    content_blocks.push(AssistantContent::Text {
594                        text: content.clone(),
595                    });
596                }
597            }
598
599            // Add tool calls
600            if let Some(tool_calls) = &choice.message.tool_calls {
601                for tool_call in tool_calls {
602                    // Parse the arguments JSON string
603                    let parameters = serde_json::from_str(&tool_call.function.arguments)
604                        .unwrap_or(serde_json::Value::Null);
605
606                    content_blocks.push(AssistantContent::ToolCall {
607                        tool_call: steer_tools::ToolCall {
608                            id: tool_call.id.clone(),
609                            name: tool_call.function.name.clone(),
610                            parameters,
611                        },
612                    });
613                }
614            }
615
616            Ok(crate::api::provider::CompletionResponse {
617                content: content_blocks,
618            })
619        } else {
620            Err(ApiError::NoChoices {
621                provider: self.name().to_string(),
622            })
623        }
624    }
625}