1use async_trait::async_trait;
2use reqwest::{self, header};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error};
7
8use crate::api::error::ApiError;
9use crate::api::provider::{CompletionResponse, Provider};
10use crate::api::util::normalize_chat_url;
11use crate::app::conversation::{AssistantContent, Message as AppMessage, ToolResult, UserContent};
12use crate::config::model::{ModelId, ModelParameters};
13use steer_tools::ToolSchema;
14
15const DEFAULT_API_URL: &str = "https://api.x.ai/v1/chat/completions";
16
17#[derive(Clone)]
18pub struct XAIClient {
19 http_client: reqwest::Client,
20 base_url: String,
21}
22
23#[derive(Debug, Serialize, Deserialize)]
25#[serde(tag = "role", rename_all = "lowercase")]
26enum XAIMessage {
27 System {
28 content: String,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 name: Option<String>,
31 },
32 User {
33 content: String,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 name: Option<String>,
36 },
37 Assistant {
38 #[serde(skip_serializing_if = "Option::is_none")]
39 content: Option<String>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 tool_calls: Option<Vec<XAIToolCall>>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 name: Option<String>,
44 },
45 Tool {
46 content: String,
47 tool_call_id: String,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 name: Option<String>,
50 },
51}
52
53#[derive(Debug, Serialize, Deserialize)]
55struct XAIFunction {
56 name: String,
57 description: String,
58 parameters: serde_json::Value,
59}
60
61#[derive(Debug, Serialize, Deserialize)]
63struct XAITool {
64 #[serde(rename = "type")]
65 tool_type: String, function: XAIFunction,
67}
68
69#[derive(Debug, Serialize, Deserialize)]
71struct XAIToolCall {
72 id: String,
73 #[serde(rename = "type")]
74 tool_type: String,
75 function: XAIFunctionCall,
76}
77
78#[derive(Debug, Serialize, Deserialize)]
79struct XAIFunctionCall {
80 name: String,
81 arguments: String, }
83
84#[derive(Debug, Serialize, Deserialize)]
85#[serde(rename_all = "lowercase")]
86enum ReasoningEffort {
87 Low,
88 High,
89}
90
91#[derive(Debug, Serialize, Deserialize)]
92struct StreamOptions {
93 #[serde(skip_serializing_if = "Option::is_none")]
94 include_usage: Option<bool>,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98#[serde(untagged)]
99enum ToolChoice {
100 String(String), Specific {
102 #[serde(rename = "type")]
103 tool_type: String,
104 function: ToolChoiceFunction,
105 },
106}
107
108#[derive(Debug, Serialize, Deserialize)]
109struct ToolChoiceFunction {
110 name: String,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114struct ResponseFormat {
115 #[serde(rename = "type")]
116 format_type: String,
117 #[serde(skip_serializing_if = "Option::is_none")]
118 json_schema: Option<serde_json::Value>,
119}
120
121#[derive(Debug, Serialize, Deserialize)]
122struct SearchParameters {
123 #[serde(skip_serializing_if = "Option::is_none")]
124 from_date: Option<String>,
125 #[serde(skip_serializing_if = "Option::is_none")]
126 to_date: Option<String>,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 max_search_results: Option<u32>,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 mode: Option<String>,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 return_citations: Option<bool>,
133 #[serde(skip_serializing_if = "Option::is_none")]
134 sources: Option<Vec<String>>,
135}
136
137#[derive(Debug, Serialize, Deserialize)]
138struct WebSearchOptions {
139 #[serde(skip_serializing_if = "Option::is_none")]
140 search_context_size: Option<u32>,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 user_location: Option<String>,
143}
144
145#[derive(Debug, Serialize, Deserialize)]
146struct CompletionRequest {
147 model: String,
148 messages: Vec<XAIMessage>,
149 #[serde(skip_serializing_if = "Option::is_none")]
150 deferred: Option<bool>,
151 #[serde(skip_serializing_if = "Option::is_none")]
152 frequency_penalty: Option<f32>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 logit_bias: Option<HashMap<String, f32>>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 logprobs: Option<bool>,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 max_completion_tokens: Option<u32>,
159 #[serde(skip_serializing_if = "Option::is_none")]
160 max_tokens: Option<u32>,
161 #[serde(skip_serializing_if = "Option::is_none")]
162 n: Option<u32>,
163 #[serde(skip_serializing_if = "Option::is_none")]
164 parallel_tool_calls: Option<bool>,
165 #[serde(skip_serializing_if = "Option::is_none")]
166 presence_penalty: Option<f32>,
167 #[serde(skip_serializing_if = "Option::is_none")]
168 reasoning_effort: Option<ReasoningEffort>,
169 #[serde(skip_serializing_if = "Option::is_none")]
170 response_format: Option<ResponseFormat>,
171 #[serde(skip_serializing_if = "Option::is_none")]
172 search_parameters: Option<SearchParameters>,
173 #[serde(skip_serializing_if = "Option::is_none")]
174 seed: Option<u64>,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 stop: Option<Vec<String>>,
177 #[serde(skip_serializing_if = "Option::is_none")]
178 stream: Option<bool>,
179 #[serde(skip_serializing_if = "Option::is_none")]
180 stream_options: Option<StreamOptions>,
181 #[serde(skip_serializing_if = "Option::is_none")]
182 temperature: Option<f32>,
183 #[serde(skip_serializing_if = "Option::is_none")]
184 tool_choice: Option<ToolChoice>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 tools: Option<Vec<XAITool>>,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 top_logprobs: Option<u32>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 top_p: Option<f32>,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 user: Option<String>,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 web_search_options: Option<WebSearchOptions>,
195}
196
197#[derive(Debug, Serialize, Deserialize)]
198struct XAICompletionResponse {
199 id: String,
200 object: String,
201 created: u64,
202 model: String,
203 choices: Vec<Choice>,
204 #[serde(skip_serializing_if = "Option::is_none")]
205 usage: Option<XAIUsage>,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 system_fingerprint: Option<String>,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 citations: Option<Vec<serde_json::Value>>,
210 #[serde(skip_serializing_if = "Option::is_none")]
211 debug_output: Option<DebugOutput>,
212}
213
214#[derive(Debug, Serialize, Deserialize)]
215struct Choice {
216 index: i32,
217 message: AssistantMessage,
218 finish_reason: Option<String>,
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222struct AssistantMessage {
223 content: Option<String>,
224 #[serde(skip_serializing_if = "Option::is_none")]
225 tool_calls: Option<Vec<XAIToolCall>>,
226 #[serde(skip_serializing_if = "Option::is_none")]
227 reasoning_content: Option<String>,
228}
229
230#[derive(Debug, Serialize, Deserialize)]
231struct PromptTokensDetails {
232 cached_tokens: usize,
233 audio_tokens: usize,
234 image_tokens: usize,
235 text_tokens: usize,
236}
237
238#[derive(Debug, Serialize, Deserialize)]
239struct CompletionTokensDetails {
240 reasoning_tokens: usize,
241 audio_tokens: usize,
242 accepted_prediction_tokens: usize,
243 rejected_prediction_tokens: usize,
244}
245
246#[derive(Debug, Serialize, Deserialize)]
247struct XAIUsage {
248 prompt_tokens: usize,
249 completion_tokens: usize,
250 total_tokens: usize,
251 #[serde(skip_serializing_if = "Option::is_none")]
252 num_sources_used: Option<usize>,
253 #[serde(skip_serializing_if = "Option::is_none")]
254 prompt_tokens_details: Option<PromptTokensDetails>,
255 #[serde(skip_serializing_if = "Option::is_none")]
256 completion_tokens_details: Option<CompletionTokensDetails>,
257}
258
259#[derive(Debug, Serialize, Deserialize)]
260struct DebugOutput {
261 attempts: usize,
262 cache_read_count: usize,
263 cache_read_input_bytes: usize,
264 cache_write_count: usize,
265 cache_write_input_bytes: usize,
266 prompt: String,
267 request: String,
268 responses: Vec<String>,
269}
270
271impl XAIClient {
272 pub fn new(api_key: String) -> Self {
273 Self::with_base_url(api_key, None)
274 }
275
276 pub fn with_base_url(api_key: String, base_url: Option<String>) -> Self {
277 let mut headers = header::HeaderMap::new();
278 headers.insert(
279 header::AUTHORIZATION,
280 header::HeaderValue::from_str(&format!("Bearer {api_key}"))
281 .expect("Invalid API key format"),
282 );
283
284 let client = reqwest::Client::builder()
285 .default_headers(headers)
286 .timeout(std::time::Duration::from_secs(300)) .build()
288 .expect("Failed to build HTTP client");
289
290 let base_url = normalize_chat_url(base_url.as_deref(), DEFAULT_API_URL);
291
292 Self {
293 http_client: client,
294 base_url,
295 }
296 }
297
298 fn convert_messages(
299 &self,
300 messages: Vec<AppMessage>,
301 system: Option<String>,
302 ) -> Vec<XAIMessage> {
303 let mut xai_messages = Vec::new();
304
305 if let Some(system_content) = system {
307 xai_messages.push(XAIMessage::System {
308 content: system_content,
309 name: None,
310 });
311 }
312
313 for message in messages {
315 match &message.data {
316 crate::app::conversation::MessageData::User { content, .. } => {
317 let combined_text = content
319 .iter()
320 .filter_map(|user_content| match user_content {
321 UserContent::Text { text } => Some(text.clone()),
322 UserContent::CommandExecution {
323 command,
324 stdout,
325 stderr,
326 exit_code,
327 } => Some(UserContent::format_command_execution_as_xml(
328 command, stdout, stderr, *exit_code,
329 )),
330 UserContent::AppCommand { .. } => {
331 None
333 }
334 })
335 .collect::<Vec<_>>()
336 .join("\n");
337
338 if !combined_text.trim().is_empty() {
340 xai_messages.push(XAIMessage::User {
341 content: combined_text,
342 name: None,
343 });
344 }
345 }
346 crate::app::conversation::MessageData::Assistant { content, .. } => {
347 let mut text_parts = Vec::new();
349 let mut tool_calls = Vec::new();
350
351 for content_block in content {
352 match content_block {
353 AssistantContent::Text { text } => {
354 text_parts.push(text.clone());
355 }
356 AssistantContent::ToolCall { tool_call } => {
357 tool_calls.push(XAIToolCall {
358 id: tool_call.id.clone(),
359 tool_type: "function".to_string(),
360 function: XAIFunctionCall {
361 name: tool_call.name.clone(),
362 arguments: tool_call.parameters.to_string(),
363 },
364 });
365 }
366 AssistantContent::Thought { .. } => {
367 continue;
369 }
370 }
371 }
372
373 let content = if text_parts.is_empty() {
375 None
376 } else {
377 Some(text_parts.join("\n"))
378 };
379
380 let tool_calls_opt = if tool_calls.is_empty() {
381 None
382 } else {
383 Some(tool_calls)
384 };
385
386 xai_messages.push(XAIMessage::Assistant {
387 content,
388 tool_calls: tool_calls_opt,
389 name: None,
390 });
391 }
392 crate::app::conversation::MessageData::Tool {
393 tool_use_id,
394 result,
395 ..
396 } => {
397 let content_text = match result {
399 ToolResult::Error(e) => format!("Error: {e}"),
400 _ => {
401 let text = result.llm_format();
402 if text.trim().is_empty() {
403 "(No output)".to_string()
404 } else {
405 text
406 }
407 }
408 };
409
410 xai_messages.push(XAIMessage::Tool {
411 content: content_text,
412 tool_call_id: tool_use_id.clone(),
413 name: None,
414 });
415 }
416 }
417 }
418
419 xai_messages
420 }
421
422 fn convert_tools(&self, tools: Vec<ToolSchema>) -> Vec<XAITool> {
423 tools
424 .into_iter()
425 .map(|tool| XAITool {
426 tool_type: "function".to_string(),
427 function: XAIFunction {
428 name: tool.name,
429 description: tool.description,
430 parameters: serde_json::json!({
431 "type": tool.input_schema.schema_type,
432 "properties": tool.input_schema.properties,
433 "required": tool.input_schema.required,
434 }),
435 },
436 })
437 .collect()
438 }
439}
440
441#[async_trait]
442impl Provider for XAIClient {
443 fn name(&self) -> &'static str {
444 "xai"
445 }
446
447 async fn complete(
448 &self,
449 model_id: &ModelId,
450 messages: Vec<AppMessage>,
451 system: Option<String>,
452 tools: Option<Vec<ToolSchema>>,
453 call_options: Option<ModelParameters>,
454 token: CancellationToken,
455 ) -> Result<CompletionResponse, ApiError> {
456 let xai_messages = self.convert_messages(messages, system);
457 let xai_tools = tools.map(|t| self.convert_tools(t));
458
459 let (supports_thinking, reasoning_effort) = call_options
461 .as_ref()
462 .and_then(|opts| opts.thinking_config)
463 .map(|tc| {
464 let effort = tc.effort.map(|e| match e {
465 crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
466 crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High, crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
468 });
469 (tc.enabled, effort)
470 })
471 .unwrap_or((false, None));
472
473 let reasoning_effort = if supports_thinking && model_id.1 != "grok-4-0709" {
475 reasoning_effort.or(Some(ReasoningEffort::High))
476 } else {
477 None
478 };
479
480 let request = CompletionRequest {
481 model: model_id.1.clone(), messages: xai_messages,
483 deferred: None,
484 frequency_penalty: None,
485 logit_bias: None,
486 logprobs: None,
487 max_completion_tokens: Some(32768),
488 max_tokens: None,
489 n: None,
490 parallel_tool_calls: None,
491 presence_penalty: None,
492 reasoning_effort,
493 response_format: None,
494 search_parameters: None,
495 seed: None,
496 stop: None,
497 stream: None,
498 stream_options: None,
499 temperature: call_options
500 .as_ref()
501 .and_then(|o| o.temperature)
502 .or(Some(1.0)),
503 tool_choice: None,
504 tools: xai_tools,
505 top_logprobs: None,
506 top_p: call_options.as_ref().and_then(|o| o.top_p),
507 user: None,
508 web_search_options: None,
509 };
510
511 let response = self
512 .http_client
513 .post(&self.base_url)
514 .json(&request)
515 .send()
516 .await
517 .map_err(ApiError::Network)?;
518
519 if !response.status().is_success() {
520 let status = response.status();
521 let error_text = response.text().await.unwrap_or_else(|_| String::new());
522
523 debug!(
524 target: "grok::complete",
525 "Grok API error - Status: {}, Body: {}",
526 status,
527 error_text
528 );
529
530 return match status.as_u16() {
531 429 => Err(ApiError::RateLimited {
532 provider: self.name().to_string(),
533 details: error_text,
534 }),
535 400 => Err(ApiError::InvalidRequest {
536 provider: self.name().to_string(),
537 details: error_text,
538 }),
539 401 => Err(ApiError::AuthenticationFailed {
540 provider: self.name().to_string(),
541 details: error_text,
542 }),
543 _ => Err(ApiError::ServerError {
544 provider: self.name().to_string(),
545 status_code: status.as_u16(),
546 details: error_text,
547 }),
548 };
549 }
550
551 let response_text = tokio::select! {
552 _ = token.cancelled() => {
553 debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
554 return Err(ApiError::Cancelled { provider: self.name().to_string() });
555 }
556 text_res = response.text() => {
557 text_res?
558 }
559 };
560
561 let xai_response: XAICompletionResponse =
562 serde_json::from_str(&response_text).map_err(|e| {
563 error!(
564 target: "xai::complete",
565 "Failed to parse response: {}, Body: {}",
566 e,
567 response_text
568 );
569 ApiError::ResponseParsingError {
570 provider: self.name().to_string(),
571 details: format!("Error: {e}, Body: {response_text}"),
572 }
573 })?;
574
575 if let Some(choice) = xai_response.choices.first() {
577 let mut content_blocks = Vec::new();
578
579 if let Some(reasoning) = &choice.message.reasoning_content {
581 if !reasoning.trim().is_empty() {
582 content_blocks.push(AssistantContent::Thought {
583 thought: crate::app::conversation::ThoughtContent::Simple {
584 text: reasoning.clone(),
585 },
586 });
587 }
588 }
589
590 if let Some(content) = &choice.message.content {
592 if !content.trim().is_empty() {
593 content_blocks.push(AssistantContent::Text {
594 text: content.clone(),
595 });
596 }
597 }
598
599 if let Some(tool_calls) = &choice.message.tool_calls {
601 for tool_call in tool_calls {
602 let parameters = serde_json::from_str(&tool_call.function.arguments)
604 .unwrap_or(serde_json::Value::Null);
605
606 content_blocks.push(AssistantContent::ToolCall {
607 tool_call: steer_tools::ToolCall {
608 id: tool_call.id.clone(),
609 name: tool_call.function.name.clone(),
610 parameters,
611 },
612 });
613 }
614 }
615
616 Ok(crate::api::provider::CompletionResponse {
617 content: content_blocks,
618 })
619 } else {
620 Err(ApiError::NoChoices {
621 provider: self.name().to_string(),
622 })
623 }
624 }
625}