1use async_trait::async_trait;
2use reqwest::{self, header};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error};
7
8use crate::api::provider::{CompletionResponse, Provider};
9use crate::api::{Model, error::ApiError};
10use crate::app::conversation::{AssistantContent, Message as AppMessage, ToolResult, UserContent};
11use steer_tools::ToolSchema;
12
13const API_URL: &str = "https://api.x.ai/v1/chat/completions";
14
15#[derive(Clone)]
16pub struct XAIClient {
17 http_client: reqwest::Client,
18}
19
20#[derive(Debug, Serialize, Deserialize)]
22#[serde(tag = "role", rename_all = "lowercase")]
23enum XAIMessage {
24 System {
25 content: String,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 name: Option<String>,
28 },
29 User {
30 content: String,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 name: Option<String>,
33 },
34 Assistant {
35 #[serde(skip_serializing_if = "Option::is_none")]
36 content: Option<String>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 tool_calls: Option<Vec<XAIToolCall>>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 name: Option<String>,
41 },
42 Tool {
43 content: String,
44 tool_call_id: String,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 name: Option<String>,
47 },
48}
49
50#[derive(Debug, Serialize, Deserialize)]
52struct XAIFunction {
53 name: String,
54 description: String,
55 parameters: serde_json::Value,
56}
57
58#[derive(Debug, Serialize, Deserialize)]
60struct XAITool {
61 #[serde(rename = "type")]
62 tool_type: String, function: XAIFunction,
64}
65
66#[derive(Debug, Serialize, Deserialize)]
68struct XAIToolCall {
69 id: String,
70 #[serde(rename = "type")]
71 tool_type: String,
72 function: XAIFunctionCall,
73}
74
75#[derive(Debug, Serialize, Deserialize)]
76struct XAIFunctionCall {
77 name: String,
78 arguments: String, }
80
81#[derive(Debug, Serialize, Deserialize)]
82#[serde(rename_all = "lowercase")]
83enum ReasoningEffort {
84 Low,
85 High,
86}
87
88#[derive(Debug, Serialize, Deserialize)]
89struct StreamOptions {
90 #[serde(skip_serializing_if = "Option::is_none")]
91 include_usage: Option<bool>,
92}
93
94#[derive(Debug, Serialize, Deserialize)]
95#[serde(untagged)]
96enum ToolChoice {
97 String(String), Specific {
99 #[serde(rename = "type")]
100 tool_type: String,
101 function: ToolChoiceFunction,
102 },
103}
104
105#[derive(Debug, Serialize, Deserialize)]
106struct ToolChoiceFunction {
107 name: String,
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111struct ResponseFormat {
112 #[serde(rename = "type")]
113 format_type: String,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 json_schema: Option<serde_json::Value>,
116}
117
118#[derive(Debug, Serialize, Deserialize)]
119struct SearchParameters {
120 #[serde(skip_serializing_if = "Option::is_none")]
121 from_date: Option<String>,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 to_date: Option<String>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 max_search_results: Option<u32>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 mode: Option<String>,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 return_citations: Option<bool>,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 sources: Option<Vec<String>>,
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135struct WebSearchOptions {
136 #[serde(skip_serializing_if = "Option::is_none")]
137 search_context_size: Option<u32>,
138 #[serde(skip_serializing_if = "Option::is_none")]
139 user_location: Option<String>,
140}
141
142#[derive(Debug, Serialize, Deserialize)]
143struct CompletionRequest {
144 model: String,
145 messages: Vec<XAIMessage>,
146 #[serde(skip_serializing_if = "Option::is_none")]
147 deferred: Option<bool>,
148 #[serde(skip_serializing_if = "Option::is_none")]
149 frequency_penalty: Option<f32>,
150 #[serde(skip_serializing_if = "Option::is_none")]
151 logit_bias: Option<HashMap<String, f32>>,
152 #[serde(skip_serializing_if = "Option::is_none")]
153 logprobs: Option<bool>,
154 #[serde(skip_serializing_if = "Option::is_none")]
155 max_completion_tokens: Option<u32>,
156 #[serde(skip_serializing_if = "Option::is_none")]
157 max_tokens: Option<u32>,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 n: Option<u32>,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 parallel_tool_calls: Option<bool>,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 presence_penalty: Option<f32>,
164 #[serde(skip_serializing_if = "Option::is_none")]
165 reasoning_effort: Option<ReasoningEffort>,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 response_format: Option<ResponseFormat>,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 search_parameters: Option<SearchParameters>,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 seed: Option<u64>,
172 #[serde(skip_serializing_if = "Option::is_none")]
173 stop: Option<Vec<String>>,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 stream: Option<bool>,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 stream_options: Option<StreamOptions>,
178 #[serde(skip_serializing_if = "Option::is_none")]
179 temperature: Option<f32>,
180 #[serde(skip_serializing_if = "Option::is_none")]
181 tool_choice: Option<ToolChoice>,
182 #[serde(skip_serializing_if = "Option::is_none")]
183 tools: Option<Vec<XAITool>>,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 top_logprobs: Option<u32>,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 top_p: Option<f32>,
188 #[serde(skip_serializing_if = "Option::is_none")]
189 user: Option<String>,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 web_search_options: Option<WebSearchOptions>,
192}
193
194#[derive(Debug, Serialize, Deserialize)]
195struct XAICompletionResponse {
196 id: String,
197 object: String,
198 created: u64,
199 model: String,
200 choices: Vec<Choice>,
201 #[serde(skip_serializing_if = "Option::is_none")]
202 usage: Option<XAIUsage>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 system_fingerprint: Option<String>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 citations: Option<Vec<serde_json::Value>>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 debug_output: Option<DebugOutput>,
209}
210
211#[derive(Debug, Serialize, Deserialize)]
212struct Choice {
213 index: i32,
214 message: AssistantMessage,
215 finish_reason: Option<String>,
216}
217
218#[derive(Debug, Serialize, Deserialize)]
219struct AssistantMessage {
220 content: Option<String>,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 tool_calls: Option<Vec<XAIToolCall>>,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 reasoning_content: Option<String>,
225}
226
227#[derive(Debug, Serialize, Deserialize)]
228struct PromptTokensDetails {
229 cached_tokens: usize,
230 audio_tokens: usize,
231 image_tokens: usize,
232 text_tokens: usize,
233}
234
235#[derive(Debug, Serialize, Deserialize)]
236struct CompletionTokensDetails {
237 reasoning_tokens: usize,
238 audio_tokens: usize,
239 accepted_prediction_tokens: usize,
240 rejected_prediction_tokens: usize,
241}
242
243#[derive(Debug, Serialize, Deserialize)]
244struct XAIUsage {
245 prompt_tokens: usize,
246 completion_tokens: usize,
247 total_tokens: usize,
248 #[serde(skip_serializing_if = "Option::is_none")]
249 num_sources_used: Option<usize>,
250 #[serde(skip_serializing_if = "Option::is_none")]
251 prompt_tokens_details: Option<PromptTokensDetails>,
252 #[serde(skip_serializing_if = "Option::is_none")]
253 completion_tokens_details: Option<CompletionTokensDetails>,
254}
255
256#[derive(Debug, Serialize, Deserialize)]
257struct DebugOutput {
258 attempts: usize,
259 cache_read_count: usize,
260 cache_read_input_bytes: usize,
261 cache_write_count: usize,
262 cache_write_input_bytes: usize,
263 prompt: String,
264 request: String,
265 responses: Vec<String>,
266}
267
268impl XAIClient {
269 pub fn new(api_key: String) -> Self {
270 let mut headers = header::HeaderMap::new();
271 headers.insert(
272 header::AUTHORIZATION,
273 header::HeaderValue::from_str(&format!("Bearer {api_key}"))
274 .expect("Invalid API key format"),
275 );
276
277 let client = reqwest::Client::builder()
278 .default_headers(headers)
279 .timeout(std::time::Duration::from_secs(300)) .build()
281 .expect("Failed to build HTTP client");
282
283 Self {
284 http_client: client,
285 }
286 }
287
288 fn convert_messages(
289 &self,
290 messages: Vec<AppMessage>,
291 system: Option<String>,
292 ) -> Vec<XAIMessage> {
293 let mut xai_messages = Vec::new();
294
295 if let Some(system_content) = system {
297 xai_messages.push(XAIMessage::System {
298 content: system_content,
299 name: None,
300 });
301 }
302
303 for message in messages {
305 match &message.data {
306 crate::app::conversation::MessageData::User { content, .. } => {
307 let combined_text = content
309 .iter()
310 .filter_map(|user_content| match user_content {
311 UserContent::Text { text } => Some(text.clone()),
312 UserContent::CommandExecution {
313 command,
314 stdout,
315 stderr,
316 exit_code,
317 } => Some(UserContent::format_command_execution_as_xml(
318 command, stdout, stderr, *exit_code,
319 )),
320 UserContent::AppCommand { .. } => {
321 None
323 }
324 })
325 .collect::<Vec<_>>()
326 .join("\n");
327
328 if !combined_text.trim().is_empty() {
330 xai_messages.push(XAIMessage::User {
331 content: combined_text,
332 name: None,
333 });
334 }
335 }
336 crate::app::conversation::MessageData::Assistant { content, .. } => {
337 let mut text_parts = Vec::new();
339 let mut tool_calls = Vec::new();
340
341 for content_block in content {
342 match content_block {
343 AssistantContent::Text { text } => {
344 text_parts.push(text.clone());
345 }
346 AssistantContent::ToolCall { tool_call } => {
347 tool_calls.push(XAIToolCall {
348 id: tool_call.id.clone(),
349 tool_type: "function".to_string(),
350 function: XAIFunctionCall {
351 name: tool_call.name.clone(),
352 arguments: tool_call.parameters.to_string(),
353 },
354 });
355 }
356 AssistantContent::Thought { .. } => {
357 continue;
359 }
360 }
361 }
362
363 let content = if text_parts.is_empty() {
365 None
366 } else {
367 Some(text_parts.join("\n"))
368 };
369
370 let tool_calls_opt = if tool_calls.is_empty() {
371 None
372 } else {
373 Some(tool_calls)
374 };
375
376 xai_messages.push(XAIMessage::Assistant {
377 content,
378 tool_calls: tool_calls_opt,
379 name: None,
380 });
381 }
382 crate::app::conversation::MessageData::Tool {
383 tool_use_id,
384 result,
385 ..
386 } => {
387 let content_text = match result {
389 ToolResult::Error(e) => format!("Error: {e}"),
390 _ => {
391 let text = result.llm_format();
392 if text.trim().is_empty() {
393 "(No output)".to_string()
394 } else {
395 text
396 }
397 }
398 };
399
400 xai_messages.push(XAIMessage::Tool {
401 content: content_text,
402 tool_call_id: tool_use_id.clone(),
403 name: None,
404 });
405 }
406 }
407 }
408
409 xai_messages
410 }
411
412 fn convert_tools(&self, tools: Vec<ToolSchema>) -> Vec<XAITool> {
413 tools
414 .into_iter()
415 .map(|tool| XAITool {
416 tool_type: "function".to_string(),
417 function: XAIFunction {
418 name: tool.name,
419 description: tool.description,
420 parameters: serde_json::json!({
421 "type": tool.input_schema.schema_type,
422 "properties": tool.input_schema.properties,
423 "required": tool.input_schema.required,
424 }),
425 },
426 })
427 .collect()
428 }
429}
430
431#[async_trait]
432impl Provider for XAIClient {
433 fn name(&self) -> &'static str {
434 "xai"
435 }
436
437 async fn complete(
438 &self,
439 model: Model,
440 messages: Vec<AppMessage>,
441 system: Option<String>,
442 tools: Option<Vec<ToolSchema>>,
443 token: CancellationToken,
444 ) -> Result<CompletionResponse, ApiError> {
445 let xai_messages = self.convert_messages(messages, system);
446 let xai_tools = tools.map(|t| self.convert_tools(t));
447
448 let reasoning_effort = if model.supports_thinking() && !matches!(model, Model::Grok4_0709) {
450 Some(ReasoningEffort::High)
451 } else {
452 None
453 };
454
455 let request = CompletionRequest {
456 model: model.as_ref().to_string(),
457 messages: xai_messages,
458 deferred: None,
459 frequency_penalty: None,
460 logit_bias: None,
461 logprobs: None,
462 max_completion_tokens: Some(32768),
463 max_tokens: None,
464 n: None,
465 parallel_tool_calls: None,
466 presence_penalty: None,
467 reasoning_effort,
468 response_format: None,
469 search_parameters: None,
470 seed: None,
471 stop: None,
472 stream: None,
473 stream_options: None,
474 temperature: Some(1.0),
475 tool_choice: None,
476 tools: xai_tools,
477 top_logprobs: None,
478 top_p: None,
479 user: None,
480 web_search_options: None,
481 };
482
483 let response = self
484 .http_client
485 .post(API_URL)
486 .json(&request)
487 .send()
488 .await
489 .map_err(ApiError::Network)?;
490
491 if !response.status().is_success() {
492 let status = response.status();
493 let error_text = response.text().await.unwrap_or_else(|_| String::new());
494
495 debug!(
496 target: "grok::complete",
497 "Grok API error - Status: {}, Body: {}",
498 status,
499 error_text
500 );
501
502 return match status.as_u16() {
503 429 => Err(ApiError::RateLimited {
504 provider: self.name().to_string(),
505 details: error_text,
506 }),
507 400 => Err(ApiError::InvalidRequest {
508 provider: self.name().to_string(),
509 details: error_text,
510 }),
511 401 => Err(ApiError::AuthenticationFailed {
512 provider: self.name().to_string(),
513 details: error_text,
514 }),
515 _ => Err(ApiError::ServerError {
516 provider: self.name().to_string(),
517 status_code: status.as_u16(),
518 details: error_text,
519 }),
520 };
521 }
522
523 let response_text = tokio::select! {
524 _ = token.cancelled() => {
525 debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
526 return Err(ApiError::Cancelled { provider: self.name().to_string() });
527 }
528 text_res = response.text() => {
529 text_res?
530 }
531 };
532
533 let xai_response: XAICompletionResponse =
534 serde_json::from_str(&response_text).map_err(|e| {
535 error!(
536 target: "xai::complete",
537 "Failed to parse response: {}, Body: {}",
538 e,
539 response_text
540 );
541 ApiError::ResponseParsingError {
542 provider: self.name().to_string(),
543 details: format!("Error: {e}, Body: {response_text}"),
544 }
545 })?;
546
547 if let Some(choice) = xai_response.choices.first() {
549 let mut content_blocks = Vec::new();
550
551 if let Some(reasoning) = &choice.message.reasoning_content {
553 if !reasoning.trim().is_empty() {
554 content_blocks.push(AssistantContent::Thought {
555 thought: crate::app::conversation::ThoughtContent::Simple {
556 text: reasoning.clone(),
557 },
558 });
559 }
560 }
561
562 if let Some(content) = &choice.message.content {
564 if !content.trim().is_empty() {
565 content_blocks.push(AssistantContent::Text {
566 text: content.clone(),
567 });
568 }
569 }
570
571 if let Some(tool_calls) = &choice.message.tool_calls {
573 for tool_call in tool_calls {
574 let parameters = serde_json::from_str(&tool_call.function.arguments)
576 .unwrap_or(serde_json::Value::Null);
577
578 content_blocks.push(AssistantContent::ToolCall {
579 tool_call: steer_tools::ToolCall {
580 id: tool_call.id.clone(),
581 name: tool_call.function.name.clone(),
582 parameters,
583 },
584 });
585 }
586 }
587
588 Ok(crate::api::provider::CompletionResponse {
589 content: content_blocks,
590 })
591 } else {
592 Err(ApiError::NoChoices {
593 provider: self.name().to_string(),
594 })
595 }
596 }
597}