1use async_trait::async_trait;
2use reqwest::{self, header};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error};
7
8use crate::api::error::ApiError;
9use crate::api::provider::{CompletionResponse, Provider};
10use crate::api::util::normalize_chat_url;
11use crate::app::conversation::{AssistantContent, Message as AppMessage, ToolResult, UserContent};
12use crate::config::model::{ModelId, ModelParameters};
13use steer_tools::ToolSchema;
14
15const DEFAULT_API_URL: &str = "https://api.x.ai/v1/chat/completions";
16
17#[derive(Clone)]
18pub struct XAIClient {
19 http_client: reqwest::Client,
20 base_url: String,
21}
22
23#[derive(Debug, Serialize, Deserialize)]
25#[serde(tag = "role", rename_all = "lowercase")]
26enum XAIMessage {
27 System {
28 content: String,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 name: Option<String>,
31 },
32 User {
33 content: String,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 name: Option<String>,
36 },
37 Assistant {
38 #[serde(skip_serializing_if = "Option::is_none")]
39 content: Option<String>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 tool_calls: Option<Vec<XAIToolCall>>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 name: Option<String>,
44 },
45 Tool {
46 content: String,
47 tool_call_id: String,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 name: Option<String>,
50 },
51}
52
53#[derive(Debug, Serialize, Deserialize)]
55struct XAIFunction {
56 name: String,
57 description: String,
58 parameters: serde_json::Value,
59}
60
61#[derive(Debug, Serialize, Deserialize)]
63struct XAITool {
64 #[serde(rename = "type")]
65 tool_type: String, function: XAIFunction,
67}
68
69#[derive(Debug, Serialize, Deserialize)]
71struct XAIToolCall {
72 id: String,
73 #[serde(rename = "type")]
74 tool_type: String,
75 function: XAIFunctionCall,
76}
77
78#[derive(Debug, Serialize, Deserialize)]
79struct XAIFunctionCall {
80 name: String,
81 arguments: String, }
83
84#[derive(Debug, Serialize, Deserialize)]
85#[serde(rename_all = "lowercase")]
86enum ReasoningEffort {
87 Low,
88 High,
89}
90
91#[derive(Debug, Serialize, Deserialize)]
92struct StreamOptions {
93 #[serde(skip_serializing_if = "Option::is_none")]
94 include_usage: Option<bool>,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98#[serde(untagged)]
99enum ToolChoice {
100 String(String), Specific {
102 #[serde(rename = "type")]
103 tool_type: String,
104 function: ToolChoiceFunction,
105 },
106}
107
108#[derive(Debug, Serialize, Deserialize)]
109struct ToolChoiceFunction {
110 name: String,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114struct ResponseFormat {
115 #[serde(rename = "type")]
116 format_type: String,
117 #[serde(skip_serializing_if = "Option::is_none")]
118 json_schema: Option<serde_json::Value>,
119}
120
121#[derive(Debug, Serialize, Deserialize)]
122struct SearchParameters {
123 #[serde(skip_serializing_if = "Option::is_none")]
124 from_date: Option<String>,
125 #[serde(skip_serializing_if = "Option::is_none")]
126 to_date: Option<String>,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 max_search_results: Option<u32>,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 mode: Option<String>,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 return_citations: Option<bool>,
133 #[serde(skip_serializing_if = "Option::is_none")]
134 sources: Option<Vec<String>>,
135}
136
137#[derive(Debug, Serialize, Deserialize)]
138struct WebSearchOptions {
139 #[serde(skip_serializing_if = "Option::is_none")]
140 search_context_size: Option<u32>,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 user_location: Option<String>,
143}
144
145#[derive(Debug, Serialize, Deserialize)]
146struct CompletionRequest {
147 model: String,
148 messages: Vec<XAIMessage>,
149 #[serde(skip_serializing_if = "Option::is_none")]
150 deferred: Option<bool>,
151 #[serde(skip_serializing_if = "Option::is_none")]
152 frequency_penalty: Option<f32>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 logit_bias: Option<HashMap<String, f32>>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 logprobs: Option<bool>,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 max_completion_tokens: Option<u32>,
159 #[serde(skip_serializing_if = "Option::is_none")]
160 max_tokens: Option<u32>,
161 #[serde(skip_serializing_if = "Option::is_none")]
162 n: Option<u32>,
163 #[serde(skip_serializing_if = "Option::is_none")]
164 parallel_tool_calls: Option<bool>,
165 #[serde(skip_serializing_if = "Option::is_none")]
166 presence_penalty: Option<f32>,
167 #[serde(skip_serializing_if = "Option::is_none")]
168 reasoning_effort: Option<ReasoningEffort>,
169 #[serde(skip_serializing_if = "Option::is_none")]
170 response_format: Option<ResponseFormat>,
171 #[serde(skip_serializing_if = "Option::is_none")]
172 search_parameters: Option<SearchParameters>,
173 #[serde(skip_serializing_if = "Option::is_none")]
174 seed: Option<u64>,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 stop: Option<Vec<String>>,
177 #[serde(skip_serializing_if = "Option::is_none")]
178 stream: Option<bool>,
179 #[serde(skip_serializing_if = "Option::is_none")]
180 stream_options: Option<StreamOptions>,
181 #[serde(skip_serializing_if = "Option::is_none")]
182 temperature: Option<f32>,
183 #[serde(skip_serializing_if = "Option::is_none")]
184 tool_choice: Option<ToolChoice>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 tools: Option<Vec<XAITool>>,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 top_logprobs: Option<u32>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 top_p: Option<f32>,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 user: Option<String>,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 web_search_options: Option<WebSearchOptions>,
195}
196
197#[derive(Debug, Serialize, Deserialize)]
198struct XAICompletionResponse {
199 id: String,
200 object: String,
201 created: u64,
202 model: String,
203 choices: Vec<Choice>,
204 #[serde(skip_serializing_if = "Option::is_none")]
205 usage: Option<XAIUsage>,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 system_fingerprint: Option<String>,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 citations: Option<Vec<serde_json::Value>>,
210 #[serde(skip_serializing_if = "Option::is_none")]
211 debug_output: Option<DebugOutput>,
212}
213
214#[derive(Debug, Serialize, Deserialize)]
215struct Choice {
216 index: i32,
217 message: AssistantMessage,
218 finish_reason: Option<String>,
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222struct AssistantMessage {
223 content: Option<String>,
224 #[serde(skip_serializing_if = "Option::is_none")]
225 tool_calls: Option<Vec<XAIToolCall>>,
226 #[serde(skip_serializing_if = "Option::is_none")]
227 reasoning_content: Option<String>,
228}
229
230#[derive(Debug, Serialize, Deserialize)]
231struct PromptTokensDetails {
232 cached_tokens: usize,
233 audio_tokens: usize,
234 image_tokens: usize,
235 text_tokens: usize,
236}
237
238#[derive(Debug, Serialize, Deserialize)]
239struct CompletionTokensDetails {
240 reasoning_tokens: usize,
241 audio_tokens: usize,
242 accepted_prediction_tokens: usize,
243 rejected_prediction_tokens: usize,
244}
245
246#[derive(Debug, Serialize, Deserialize)]
247struct XAIUsage {
248 prompt_tokens: usize,
249 completion_tokens: usize,
250 total_tokens: usize,
251 #[serde(skip_serializing_if = "Option::is_none")]
252 num_sources_used: Option<usize>,
253 #[serde(skip_serializing_if = "Option::is_none")]
254 prompt_tokens_details: Option<PromptTokensDetails>,
255 #[serde(skip_serializing_if = "Option::is_none")]
256 completion_tokens_details: Option<CompletionTokensDetails>,
257}
258
259#[derive(Debug, Serialize, Deserialize)]
260struct DebugOutput {
261 attempts: usize,
262 cache_read_count: usize,
263 cache_read_input_bytes: usize,
264 cache_write_count: usize,
265 cache_write_input_bytes: usize,
266 prompt: String,
267 request: String,
268 responses: Vec<String>,
269}
270
271impl XAIClient {
272 pub fn new(api_key: String) -> Self {
273 Self::with_base_url(api_key, None)
274 }
275
276 pub fn with_base_url(api_key: String, base_url: Option<String>) -> Self {
277 let mut headers = header::HeaderMap::new();
278 headers.insert(
279 header::AUTHORIZATION,
280 header::HeaderValue::from_str(&format!("Bearer {api_key}"))
281 .expect("Invalid API key format"),
282 );
283
284 let client = reqwest::Client::builder()
285 .default_headers(headers)
286 .timeout(std::time::Duration::from_secs(300)) .build()
288 .expect("Failed to build HTTP client");
289
290 let base_url = normalize_chat_url(base_url.as_deref(), DEFAULT_API_URL);
291
292 Self {
293 http_client: client,
294 base_url,
295 }
296 }
297
298 fn convert_messages(
299 &self,
300 messages: Vec<AppMessage>,
301 system: Option<String>,
302 ) -> Vec<XAIMessage> {
303 let mut xai_messages = Vec::new();
304
305 if let Some(system_content) = system {
307 xai_messages.push(XAIMessage::System {
308 content: system_content,
309 name: None,
310 });
311 }
312
313 for message in messages {
315 match &message.data {
316 crate::app::conversation::MessageData::User { content, .. } => {
317 let combined_text = content
319 .iter()
320 .filter_map(|user_content| match user_content {
321 UserContent::Text { text } => Some(text.clone()),
322 UserContent::CommandExecution {
323 command,
324 stdout,
325 stderr,
326 exit_code,
327 } => Some(UserContent::format_command_execution_as_xml(
328 command, stdout, stderr, *exit_code,
329 )),
330 UserContent::AppCommand { .. } => {
331 None
333 }
334 })
335 .collect::<Vec<_>>()
336 .join("\n");
337
338 if !combined_text.trim().is_empty() {
340 xai_messages.push(XAIMessage::User {
341 content: combined_text,
342 name: None,
343 });
344 }
345 }
346 crate::app::conversation::MessageData::Assistant { content, .. } => {
347 let mut text_parts = Vec::new();
349 let mut tool_calls = Vec::new();
350
351 for content_block in content {
352 match content_block {
353 AssistantContent::Text { text } => {
354 text_parts.push(text.clone());
355 }
356 AssistantContent::ToolCall { tool_call } => {
357 tool_calls.push(XAIToolCall {
358 id: tool_call.id.clone(),
359 tool_type: "function".to_string(),
360 function: XAIFunctionCall {
361 name: tool_call.name.clone(),
362 arguments: tool_call.parameters.to_string(),
363 },
364 });
365 }
366 AssistantContent::Thought { .. } => {
367 continue;
369 }
370 }
371 }
372
373 let content = if text_parts.is_empty() {
375 None
376 } else {
377 Some(text_parts.join("\n"))
378 };
379
380 let tool_calls_opt = if tool_calls.is_empty() {
381 None
382 } else {
383 Some(tool_calls)
384 };
385
386 xai_messages.push(XAIMessage::Assistant {
387 content,
388 tool_calls: tool_calls_opt,
389 name: None,
390 });
391 }
392 crate::app::conversation::MessageData::Tool {
393 tool_use_id,
394 result,
395 ..
396 } => {
397 let content_text = match result {
399 ToolResult::Error(e) => format!("Error: {e}"),
400 _ => {
401 let text = result.llm_format();
402 if text.trim().is_empty() {
403 "(No output)".to_string()
404 } else {
405 text
406 }
407 }
408 };
409
410 xai_messages.push(XAIMessage::Tool {
411 content: content_text,
412 tool_call_id: tool_use_id.clone(),
413 name: None,
414 });
415 }
416 }
417 }
418
419 xai_messages
420 }
421
422 fn convert_tools(&self, tools: Vec<ToolSchema>) -> Vec<XAITool> {
423 tools
424 .into_iter()
425 .map(|tool| XAITool {
426 tool_type: "function".to_string(),
427 function: XAIFunction {
428 name: tool.name,
429 description: tool.description,
430 parameters: serde_json::json!({
431 "type": tool.input_schema.schema_type,
432 "properties": tool.input_schema.properties,
433 "required": tool.input_schema.required,
434 }),
435 },
436 })
437 .collect()
438 }
439}
440
441#[async_trait]
442impl Provider for XAIClient {
443 fn name(&self) -> &'static str {
444 "xai"
445 }
446
447 async fn complete(
448 &self,
449 model_id: &ModelId,
450 messages: Vec<AppMessage>,
451 system: Option<String>,
452 tools: Option<Vec<ToolSchema>>,
453 call_options: Option<ModelParameters>,
454 token: CancellationToken,
455 ) -> Result<CompletionResponse, ApiError> {
456 let xai_messages = self.convert_messages(messages, system);
457 let xai_tools = tools.map(|t| self.convert_tools(t));
458
459 let supports_thinking = call_options
461 .as_ref()
462 .and_then(|opts| opts.thinking_config.as_ref())
463 .map(|tc| tc.enabled)
464 .unwrap_or(false);
465
466 let reasoning_effort = if supports_thinking && model_id.1 != "grok-4-0709" {
468 Some(ReasoningEffort::High)
469 } else {
470 None
471 };
472
473 let request = CompletionRequest {
474 model: model_id.1.clone(), messages: xai_messages,
476 deferred: None,
477 frequency_penalty: None,
478 logit_bias: None,
479 logprobs: None,
480 max_completion_tokens: Some(32768),
481 max_tokens: None,
482 n: None,
483 parallel_tool_calls: None,
484 presence_penalty: None,
485 reasoning_effort,
486 response_format: None,
487 search_parameters: None,
488 seed: None,
489 stop: None,
490 stream: None,
491 stream_options: None,
492 temperature: call_options
493 .as_ref()
494 .and_then(|o| o.temperature)
495 .or(Some(1.0)),
496 tool_choice: None,
497 tools: xai_tools,
498 top_logprobs: None,
499 top_p: call_options.as_ref().and_then(|o| o.top_p),
500 user: None,
501 web_search_options: None,
502 };
503
504 let response = self
505 .http_client
506 .post(&self.base_url)
507 .json(&request)
508 .send()
509 .await
510 .map_err(ApiError::Network)?;
511
512 if !response.status().is_success() {
513 let status = response.status();
514 let error_text = response.text().await.unwrap_or_else(|_| String::new());
515
516 debug!(
517 target: "grok::complete",
518 "Grok API error - Status: {}, Body: {}",
519 status,
520 error_text
521 );
522
523 return match status.as_u16() {
524 429 => Err(ApiError::RateLimited {
525 provider: self.name().to_string(),
526 details: error_text,
527 }),
528 400 => Err(ApiError::InvalidRequest {
529 provider: self.name().to_string(),
530 details: error_text,
531 }),
532 401 => Err(ApiError::AuthenticationFailed {
533 provider: self.name().to_string(),
534 details: error_text,
535 }),
536 _ => Err(ApiError::ServerError {
537 provider: self.name().to_string(),
538 status_code: status.as_u16(),
539 details: error_text,
540 }),
541 };
542 }
543
544 let response_text = tokio::select! {
545 _ = token.cancelled() => {
546 debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
547 return Err(ApiError::Cancelled { provider: self.name().to_string() });
548 }
549 text_res = response.text() => {
550 text_res?
551 }
552 };
553
554 let xai_response: XAICompletionResponse =
555 serde_json::from_str(&response_text).map_err(|e| {
556 error!(
557 target: "xai::complete",
558 "Failed to parse response: {}, Body: {}",
559 e,
560 response_text
561 );
562 ApiError::ResponseParsingError {
563 provider: self.name().to_string(),
564 details: format!("Error: {e}, Body: {response_text}"),
565 }
566 })?;
567
568 if let Some(choice) = xai_response.choices.first() {
570 let mut content_blocks = Vec::new();
571
572 if let Some(reasoning) = &choice.message.reasoning_content {
574 if !reasoning.trim().is_empty() {
575 content_blocks.push(AssistantContent::Thought {
576 thought: crate::app::conversation::ThoughtContent::Simple {
577 text: reasoning.clone(),
578 },
579 });
580 }
581 }
582
583 if let Some(content) = &choice.message.content {
585 if !content.trim().is_empty() {
586 content_blocks.push(AssistantContent::Text {
587 text: content.clone(),
588 });
589 }
590 }
591
592 if let Some(tool_calls) = &choice.message.tool_calls {
594 for tool_call in tool_calls {
595 let parameters = serde_json::from_str(&tool_call.function.arguments)
597 .unwrap_or(serde_json::Value::Null);
598
599 content_blocks.push(AssistantContent::ToolCall {
600 tool_call: steer_tools::ToolCall {
601 id: tool_call.id.clone(),
602 name: tool_call.function.name.clone(),
603 parameters,
604 },
605 });
606 }
607 }
608
609 Ok(crate::api::provider::CompletionResponse {
610 content: content_blocks,
611 })
612 } else {
613 Err(ApiError::NoChoices {
614 provider: self.name().to_string(),
615 })
616 }
617 }
618}