1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{self, header};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{
11 CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage,
12};
13use crate::api::sse::parse_sse_stream;
14use crate::api::util::normalize_chat_url;
15use crate::app::SystemContext;
16use crate::app::conversation::{
17 AssistantContent, ImageSource, Message as AppMessage, ToolResult, UserContent,
18};
19use crate::config::model::{ModelId, ModelParameters};
20use steer_tools::ToolSchema;
21
22const DEFAULT_API_URL: &str = "https://api.x.ai/v1/chat/completions";
23
24#[derive(Clone)]
25pub struct XAIClient {
26 http_client: reqwest::Client,
27 base_url: String,
28}
29
30#[derive(Debug, Serialize, Deserialize)]
32#[serde(tag = "role", rename_all = "lowercase")]
33enum XAIMessage {
34 System {
35 content: String,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 name: Option<String>,
38 },
39 User {
40 content: XAIUserContent,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 name: Option<String>,
43 },
44 Assistant {
45 #[serde(skip_serializing_if = "Option::is_none")]
46 content: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 tool_calls: Option<Vec<XAIToolCall>>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 name: Option<String>,
51 },
52 Tool {
53 content: String,
54 tool_call_id: String,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 name: Option<String>,
57 },
58}
59
60#[derive(Debug, Serialize, Deserialize)]
61#[serde(untagged)]
62enum XAIUserContent {
63 Text(String),
64 Parts(Vec<XAIUserContentPart>),
65}
66
67#[derive(Debug, Serialize, Deserialize)]
68#[serde(tag = "type")]
69enum XAIUserContentPart {
70 #[serde(rename = "text")]
71 Text { text: String },
72 #[serde(rename = "image_url")]
73 ImageUrl { image_url: XAIImageUrl },
74}
75
76#[derive(Debug, Serialize, Deserialize)]
77struct XAIImageUrl {
78 url: String,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
83struct XAIFunction {
84 name: String,
85 description: String,
86 parameters: serde_json::Value,
87}
88
89#[derive(Debug, Serialize, Deserialize)]
91struct XAITool {
92 #[serde(rename = "type")]
93 tool_type: String, function: XAIFunction,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
99struct XAIToolCall {
100 id: String,
101 #[serde(rename = "type")]
102 tool_type: String,
103 function: XAIFunctionCall,
104}
105
106#[derive(Debug, Serialize, Deserialize)]
107struct XAIFunctionCall {
108 name: String,
109 arguments: String, }
111
112#[derive(Debug, Serialize, Deserialize)]
113#[serde(rename_all = "lowercase")]
114enum ReasoningEffort {
115 Low,
116 High,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct StreamOptions {
121 #[serde(skip_serializing_if = "Option::is_none")]
122 include_usage: Option<bool>,
123}
124
125#[derive(Debug, Serialize, Deserialize)]
126#[serde(untagged)]
127enum ToolChoice {
128 String(String), Specific {
130 #[serde(rename = "type")]
131 tool_type: String,
132 function: ToolChoiceFunction,
133 },
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137struct ToolChoiceFunction {
138 name: String,
139}
140
141#[derive(Debug, Serialize, Deserialize)]
142struct ResponseFormat {
143 #[serde(rename = "type")]
144 format_type: String,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 json_schema: Option<serde_json::Value>,
147}
148
149#[derive(Debug, Serialize, Deserialize)]
150struct SearchParameters {
151 #[serde(skip_serializing_if = "Option::is_none")]
152 from_date: Option<String>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 to_date: Option<String>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 max_search_results: Option<u32>,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 mode: Option<String>,
159 #[serde(skip_serializing_if = "Option::is_none")]
160 return_citations: Option<bool>,
161 #[serde(skip_serializing_if = "Option::is_none")]
162 sources: Option<Vec<String>>,
163}
164
165#[derive(Debug, Serialize, Deserialize)]
166struct WebSearchOptions {
167 #[serde(skip_serializing_if = "Option::is_none")]
168 search_context_size: Option<u32>,
169 #[serde(skip_serializing_if = "Option::is_none")]
170 user_location: Option<String>,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174struct CompletionRequest {
175 model: String,
176 messages: Vec<XAIMessage>,
177 #[serde(skip_serializing_if = "Option::is_none")]
178 deferred: Option<bool>,
179 #[serde(skip_serializing_if = "Option::is_none")]
180 frequency_penalty: Option<f32>,
181 #[serde(skip_serializing_if = "Option::is_none")]
182 logit_bias: Option<HashMap<String, f32>>,
183 #[serde(skip_serializing_if = "Option::is_none")]
184 logprobs: Option<bool>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 max_completion_tokens: Option<u32>,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 max_tokens: Option<u32>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 n: Option<u32>,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 parallel_tool_calls: Option<bool>,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 presence_penalty: Option<f32>,
195 #[serde(skip_serializing_if = "Option::is_none")]
196 reasoning_effort: Option<ReasoningEffort>,
197 #[serde(skip_serializing_if = "Option::is_none")]
198 response_format: Option<ResponseFormat>,
199 #[serde(skip_serializing_if = "Option::is_none")]
200 search_parameters: Option<SearchParameters>,
201 #[serde(skip_serializing_if = "Option::is_none")]
202 seed: Option<u64>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 stop: Option<Vec<String>>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 stream: Option<bool>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 stream_options: Option<StreamOptions>,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 temperature: Option<f32>,
211 #[serde(skip_serializing_if = "Option::is_none")]
212 tool_choice: Option<ToolChoice>,
213 #[serde(skip_serializing_if = "Option::is_none")]
214 tools: Option<Vec<XAITool>>,
215 #[serde(skip_serializing_if = "Option::is_none")]
216 top_logprobs: Option<u32>,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 top_p: Option<f32>,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 user: Option<String>,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 web_search_options: Option<WebSearchOptions>,
223}
224
225#[derive(Debug, Serialize, Deserialize)]
226struct XAICompletionResponse {
227 id: String,
228 object: String,
229 created: u64,
230 model: String,
231 choices: Vec<Choice>,
232 #[serde(skip_serializing_if = "Option::is_none")]
233 usage: Option<XAIUsage>,
234 #[serde(skip_serializing_if = "Option::is_none")]
235 system_fingerprint: Option<String>,
236 #[serde(skip_serializing_if = "Option::is_none")]
237 citations: Option<Vec<serde_json::Value>>,
238 #[serde(skip_serializing_if = "Option::is_none")]
239 debug_output: Option<DebugOutput>,
240}
241
242#[derive(Debug, Serialize, Deserialize)]
243struct Choice {
244 index: i32,
245 message: AssistantMessage,
246 finish_reason: Option<String>,
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250struct AssistantMessage {
251 content: Option<String>,
252 #[serde(skip_serializing_if = "Option::is_none")]
253 tool_calls: Option<Vec<XAIToolCall>>,
254 #[serde(skip_serializing_if = "Option::is_none")]
255 reasoning_content: Option<String>,
256}
257
258#[derive(Debug, Serialize, Deserialize)]
259struct PromptTokensDetails {
260 #[serde(rename = "cached_tokens")]
261 cached: usize,
262 #[serde(rename = "audio_tokens")]
263 audio: usize,
264 #[serde(rename = "image_tokens")]
265 image: usize,
266 #[serde(rename = "text_tokens")]
267 text: usize,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271struct CompletionTokensDetails {
272 #[serde(rename = "reasoning_tokens")]
273 reasoning: usize,
274 #[serde(rename = "audio_tokens")]
275 audio: usize,
276 #[serde(rename = "accepted_prediction_tokens")]
277 accepted_prediction: usize,
278 #[serde(rename = "rejected_prediction_tokens")]
279 rejected_prediction: usize,
280}
281
282#[derive(Debug, Serialize, Deserialize)]
283struct XAIUsage {
284 prompt_tokens: usize,
285 completion_tokens: usize,
286 total_tokens: usize,
287 #[serde(skip_serializing_if = "Option::is_none")]
288 num_sources_used: Option<usize>,
289 #[serde(skip_serializing_if = "Option::is_none")]
290 prompt_tokens_details: Option<PromptTokensDetails>,
291 #[serde(skip_serializing_if = "Option::is_none")]
292 completion_tokens_details: Option<CompletionTokensDetails>,
293}
294
295#[derive(Debug, Serialize, Deserialize)]
296struct DebugOutput {
297 attempts: usize,
298 cache_read_count: usize,
299 cache_read_input_bytes: usize,
300 cache_write_count: usize,
301 cache_write_input_bytes: usize,
302 prompt: String,
303 request: String,
304 responses: Vec<String>,
305}
306
307#[derive(Debug, Deserialize)]
308struct XAIStreamChunk {
309 #[expect(dead_code)]
310 id: String,
311 choices: Vec<XAIStreamChoice>,
312 #[serde(skip_serializing_if = "Option::is_none")]
313 usage: Option<XAIUsage>,
314}
315
316#[derive(Debug, Deserialize)]
317struct XAIStreamChoice {
318 #[expect(dead_code)]
319 index: u32,
320 delta: XAIStreamDelta,
321 #[expect(dead_code)]
322 finish_reason: Option<String>,
323}
324
325#[derive(Debug, Deserialize)]
326struct XAIStreamDelta {
327 #[serde(skip_serializing_if = "Option::is_none")]
328 content: Option<String>,
329 #[serde(skip_serializing_if = "Option::is_none")]
330 tool_calls: Option<Vec<XAIStreamToolCall>>,
331 #[serde(skip_serializing_if = "Option::is_none")]
332 reasoning_content: Option<String>,
333}
334
335#[derive(Debug, Deserialize)]
336struct XAIStreamToolCall {
337 index: usize,
338 #[serde(skip_serializing_if = "Option::is_none")]
339 id: Option<String>,
340 #[serde(skip_serializing_if = "Option::is_none")]
341 function: Option<XAIStreamFunction>,
342}
343
344#[derive(Debug, Deserialize)]
345struct XAIStreamFunction {
346 #[serde(skip_serializing_if = "Option::is_none")]
347 name: Option<String>,
348 #[serde(skip_serializing_if = "Option::is_none")]
349 arguments: Option<String>,
350}
351
352impl XAIClient {
353 pub fn new(api_key: String) -> Result<Self, ApiError> {
354 Self::with_base_url(api_key, None)
355 }
356
357 pub fn with_base_url(api_key: String, base_url: Option<String>) -> Result<Self, ApiError> {
358 let mut headers = header::HeaderMap::new();
359 headers.insert(
360 header::AUTHORIZATION,
361 header::HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|e| {
362 ApiError::AuthenticationFailed {
363 provider: "xai".to_string(),
364 details: format!("Invalid API key: {e}"),
365 }
366 })?,
367 );
368
369 let client = reqwest::Client::builder()
370 .default_headers(headers)
371 .timeout(std::time::Duration::from_secs(300)) .build()
373 .map_err(ApiError::Network)?;
374
375 let base_url = normalize_chat_url(base_url.as_deref(), DEFAULT_API_URL);
376
377 Ok(Self {
378 http_client: client,
379 base_url,
380 })
381 }
382
383 fn user_image_part(
384 image: &crate::app::conversation::ImageContent,
385 ) -> Result<XAIUserContentPart, ApiError> {
386 let image_url = match &image.source {
387 ImageSource::DataUrl { data_url } => data_url.clone(),
388 ImageSource::Url { url } => url.clone(),
389 ImageSource::SessionFile { relative_path } => {
390 return Err(ApiError::UnsupportedFeature {
391 provider: "xai".to_string(),
392 feature: "image input source".to_string(),
393 details: format!(
394 "xAI chat API cannot access session file '{}' directly; use data URLs or public URLs",
395 relative_path
396 ),
397 });
398 }
399 };
400
401 Ok(XAIUserContentPart::ImageUrl {
402 image_url: XAIImageUrl { url: image_url },
403 })
404 }
405
406 fn convert_messages(
407 messages: Vec<AppMessage>,
408 system: Option<SystemContext>,
409 ) -> Result<Vec<XAIMessage>, ApiError> {
410 let mut xai_messages = Vec::new();
411
412 if let Some(system_content) = system.and_then(|context| context.render()) {
414 xai_messages.push(XAIMessage::System {
415 content: system_content,
416 name: None,
417 });
418 }
419
420 for message in messages {
422 match &message.data {
423 crate::app::conversation::MessageData::User { content, .. } => {
424 let mut content_parts = Vec::new();
425
426 for user_content in content {
427 match user_content {
428 UserContent::Text { text } => {
429 content_parts.push(XAIUserContentPart::Text { text: text.clone() });
430 }
431 UserContent::Image { image } => {
432 content_parts.push(Self::user_image_part(image)?);
433 }
434 UserContent::CommandExecution {
435 command,
436 stdout,
437 stderr,
438 exit_code,
439 } => {
440 content_parts.push(XAIUserContentPart::Text {
441 text: UserContent::format_command_execution_as_xml(
442 command, stdout, stderr, *exit_code,
443 ),
444 });
445 }
446 }
447 }
448
449 if !content_parts.is_empty() {
451 let content = match content_parts.as_slice() {
452 [XAIUserContentPart::Text { text }] => {
453 XAIUserContent::Text(text.clone())
454 }
455 _ => XAIUserContent::Parts(content_parts),
456 };
457
458 xai_messages.push(XAIMessage::User {
459 content,
460 name: None,
461 });
462 }
463 }
464 crate::app::conversation::MessageData::Assistant { content, .. } => {
465 let mut text_parts = Vec::new();
467 let mut tool_calls = Vec::new();
468
469 for content_block in content {
470 match content_block {
471 AssistantContent::Text { text } => {
472 text_parts.push(text.clone());
473 }
474 AssistantContent::Image { image } => {
475 match Self::user_image_part(image) {
476 Ok(XAIUserContentPart::ImageUrl { image_url }) => {
477 text_parts.push(format!("[Image URL: {}]", image_url.url));
478 }
479 Ok(XAIUserContentPart::Text { text }) => {
480 text_parts.push(text);
481 }
482 Err(err) => {
483 debug!(
484 target: "xai::convert_messages",
485 "Skipping unsupported assistant image block: {}",
486 err
487 );
488 }
489 }
490 }
491 AssistantContent::ToolCall { tool_call, .. } => {
492 tool_calls.push(XAIToolCall {
493 id: tool_call.id.clone(),
494 tool_type: "function".to_string(),
495 function: XAIFunctionCall {
496 name: tool_call.name.clone(),
497 arguments: tool_call.parameters.to_string(),
498 },
499 });
500 }
501 AssistantContent::Thought { .. } => {
502
503 }
505 }
506 }
507
508 let content = if text_parts.is_empty() {
510 None
511 } else {
512 Some(text_parts.join("\n"))
513 };
514
515 let tool_calls_opt = if tool_calls.is_empty() {
516 None
517 } else {
518 Some(tool_calls)
519 };
520
521 xai_messages.push(XAIMessage::Assistant {
522 content,
523 tool_calls: tool_calls_opt,
524 name: None,
525 });
526 }
527 crate::app::conversation::MessageData::Tool {
528 tool_use_id,
529 result,
530 ..
531 } => {
532 let content_text = if let ToolResult::Error(e) = result {
534 format!("Error: {e}")
535 } else {
536 let text = result.llm_format();
537 if text.trim().is_empty() {
538 "(No output)".to_string()
539 } else {
540 text
541 }
542 };
543
544 xai_messages.push(XAIMessage::Tool {
545 content: content_text,
546 tool_call_id: tool_use_id.clone(),
547 name: None,
548 });
549 }
550 }
551 }
552
553 Ok(xai_messages)
554 }
555
556 fn convert_tools(tools: Vec<ToolSchema>) -> Vec<XAITool> {
557 tools
558 .into_iter()
559 .map(|tool| XAITool {
560 tool_type: "function".to_string(),
561 function: XAIFunction {
562 name: tool.name,
563 description: tool.description,
564 parameters: tool.input_schema.as_value().clone(),
565 },
566 })
567 .collect()
568 }
569}
570
571fn saturating_u32(value: usize) -> u32 {
572 u32::try_from(value).unwrap_or(u32::MAX)
573}
574
575fn map_xai_usage(usage: &XAIUsage) -> TokenUsage {
576 TokenUsage::new(
577 saturating_u32(usage.prompt_tokens),
578 saturating_u32(usage.completion_tokens),
579 saturating_u32(usage.total_tokens),
580 )
581}
582
583fn convert_xai_completion_response(
584 xai_response: XAICompletionResponse,
585) -> Result<CompletionResponse, ApiError> {
586 let usage = xai_response.usage.as_ref().map(map_xai_usage);
587
588 let Some(choice) = xai_response.choices.first() else {
589 return Err(ApiError::NoChoices {
590 provider: "xai".to_string(),
591 });
592 };
593
594 let mut content_blocks = Vec::new();
595
596 if let Some(reasoning) = &choice.message.reasoning_content
598 && !reasoning.trim().is_empty()
599 {
600 content_blocks.push(AssistantContent::Thought {
601 thought: crate::app::conversation::ThoughtContent::Simple {
602 text: reasoning.clone(),
603 },
604 });
605 }
606
607 if let Some(content) = &choice.message.content
609 && !content.trim().is_empty()
610 {
611 content_blocks.push(AssistantContent::Text {
612 text: content.clone(),
613 });
614 }
615
616 if let Some(tool_calls) = &choice.message.tool_calls {
618 for tool_call in tool_calls {
619 let parameters = serde_json::from_str(&tool_call.function.arguments)
621 .unwrap_or(serde_json::Value::Null);
622
623 content_blocks.push(AssistantContent::ToolCall {
624 tool_call: steer_tools::ToolCall {
625 id: tool_call.id.clone(),
626 name: tool_call.function.name.clone(),
627 parameters,
628 },
629 thought_signature: None,
630 });
631 }
632 }
633
634 Ok(CompletionResponse {
635 content: content_blocks,
636 usage,
637 })
638}
639
640#[async_trait]
641impl Provider for XAIClient {
642 fn name(&self) -> &'static str {
643 "xai"
644 }
645
646 async fn complete(
647 &self,
648 model_id: &ModelId,
649 messages: Vec<AppMessage>,
650 system: Option<SystemContext>,
651 tools: Option<Vec<ToolSchema>>,
652 call_options: Option<ModelParameters>,
653 token: CancellationToken,
654 ) -> Result<CompletionResponse, ApiError> {
655 let xai_messages = Self::convert_messages(messages, system)?;
656 let xai_tools = tools.map(Self::convert_tools);
657
658 let (supports_thinking, reasoning_effort) = call_options
660 .as_ref()
661 .and_then(|opts| opts.thinking_config)
662 .map_or((false, None), |tc| {
663 let effort = tc.effort.map(|e| match e {
664 crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
665 crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High, crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
667 crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, });
669 (tc.enabled, effort)
670 });
671
672 let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
674 reasoning_effort.or(Some(ReasoningEffort::High))
675 } else {
676 None
677 };
678
679 let request = CompletionRequest {
680 model: model_id.id.clone(), messages: xai_messages,
682 deferred: None,
683 frequency_penalty: None,
684 logit_bias: None,
685 logprobs: None,
686 max_completion_tokens: Some(32768),
687 max_tokens: None,
688 n: None,
689 parallel_tool_calls: None,
690 presence_penalty: None,
691 reasoning_effort,
692 response_format: None,
693 search_parameters: None,
694 seed: None,
695 stop: None,
696 stream: None,
697 stream_options: None,
698 temperature: call_options
699 .as_ref()
700 .and_then(|o| o.temperature)
701 .or(Some(1.0)),
702 tool_choice: None,
703 tools: xai_tools,
704 top_logprobs: None,
705 top_p: call_options.as_ref().and_then(|o| o.top_p),
706 user: None,
707 web_search_options: None,
708 };
709
710 let response = self
711 .http_client
712 .post(&self.base_url)
713 .json(&request)
714 .send()
715 .await
716 .map_err(ApiError::Network)?;
717
718 if !response.status().is_success() {
719 let status = response.status();
720 let error_text = response.text().await.unwrap_or_else(|_| String::new());
721
722 debug!(
723 target: "grok::complete",
724 "Grok API error - Status: {}, Body: {}",
725 status,
726 error_text
727 );
728
729 return match status.as_u16() {
730 429 => Err(ApiError::RateLimited {
731 provider: self.name().to_string(),
732 details: error_text,
733 }),
734 400 => Err(ApiError::InvalidRequest {
735 provider: self.name().to_string(),
736 details: error_text,
737 }),
738 401 => Err(ApiError::AuthenticationFailed {
739 provider: self.name().to_string(),
740 details: error_text,
741 }),
742 _ => Err(ApiError::ServerError {
743 provider: self.name().to_string(),
744 status_code: status.as_u16(),
745 details: error_text,
746 }),
747 };
748 }
749
750 let response_text = tokio::select! {
751 () = token.cancelled() => {
752 debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
753 return Err(ApiError::Cancelled { provider: self.name().to_string() });
754 }
755 text_res = response.text() => {
756 text_res?
757 }
758 };
759
760 let xai_response: XAICompletionResponse =
761 serde_json::from_str(&response_text).map_err(|e| {
762 error!(
763 target: "xai::complete",
764 "Failed to parse response: {}, Body: {}",
765 e,
766 response_text
767 );
768 ApiError::ResponseParsingError {
769 provider: self.name().to_string(),
770 details: format!("Error: {e}, Body: {response_text}"),
771 }
772 })?;
773
774 convert_xai_completion_response(xai_response).map_err(|err| match err {
775 ApiError::NoChoices { .. } => ApiError::NoChoices {
776 provider: self.name().to_string(),
777 },
778 other => other,
779 })
780 }
781
782 async fn stream_complete(
783 &self,
784 model_id: &ModelId,
785 messages: Vec<AppMessage>,
786 system: Option<SystemContext>,
787 tools: Option<Vec<ToolSchema>>,
788 call_options: Option<ModelParameters>,
789 token: CancellationToken,
790 ) -> Result<CompletionStream, ApiError> {
791 let xai_messages = Self::convert_messages(messages, system)?;
792 let xai_tools = tools.map(Self::convert_tools);
793
794 let (supports_thinking, reasoning_effort) = call_options
795 .as_ref()
796 .and_then(|opts| opts.thinking_config)
797 .map_or((false, None), |tc| {
798 let effort = tc.effort.map(|e| match e {
799 crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
800 crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High,
801 crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
802 crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, });
804 (tc.enabled, effort)
805 });
806
807 let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
808 reasoning_effort.or(Some(ReasoningEffort::High))
809 } else {
810 None
811 };
812
813 let request = CompletionRequest {
814 model: model_id.id.clone(),
815 messages: xai_messages,
816 deferred: None,
817 frequency_penalty: None,
818 logit_bias: None,
819 logprobs: None,
820 max_completion_tokens: Some(32768),
821 max_tokens: None,
822 n: None,
823 parallel_tool_calls: None,
824 presence_penalty: None,
825 reasoning_effort,
826 response_format: None,
827 search_parameters: None,
828 seed: None,
829 stop: None,
830 stream: Some(true),
831 stream_options: Some(StreamOptions {
832 include_usage: Some(true),
833 }),
834 temperature: call_options
835 .as_ref()
836 .and_then(|o| o.temperature)
837 .or(Some(1.0)),
838 tool_choice: None,
839 tools: xai_tools,
840 top_logprobs: None,
841 top_p: call_options.as_ref().and_then(|o| o.top_p),
842 user: None,
843 web_search_options: None,
844 };
845
846 let response = self
847 .http_client
848 .post(&self.base_url)
849 .json(&request)
850 .send()
851 .await
852 .map_err(ApiError::Network)?;
853
854 if !response.status().is_success() {
855 let status = response.status();
856 let error_text = response.text().await.unwrap_or_else(|_| String::new());
857
858 debug!(
859 target: "xai::stream",
860 "xAI API error - Status: {}, Body: {}",
861 status,
862 error_text
863 );
864
865 return match status.as_u16() {
866 429 => Err(ApiError::RateLimited {
867 provider: self.name().to_string(),
868 details: error_text,
869 }),
870 400 => Err(ApiError::InvalidRequest {
871 provider: self.name().to_string(),
872 details: error_text,
873 }),
874 401 => Err(ApiError::AuthenticationFailed {
875 provider: self.name().to_string(),
876 details: error_text,
877 }),
878 _ => Err(ApiError::ServerError {
879 provider: self.name().to_string(),
880 status_code: status.as_u16(),
881 details: error_text,
882 }),
883 };
884 }
885
886 let byte_stream = response.bytes_stream();
887 let sse_stream = parse_sse_stream(byte_stream);
888
889 Ok(Box::pin(XAIClient::convert_xai_stream(sse_stream, token)))
890 }
891}
892
893impl XAIClient {
894 fn convert_xai_stream(
895 mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
896 + Unpin
897 + Send
898 + 'static,
899 token: CancellationToken,
900 ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
901 struct ToolCallAccumulator {
902 id: String,
903 name: String,
904 args: String,
905 }
906
907 async_stream::stream! {
908 let mut content: Vec<AssistantContent> = Vec::new();
909 let mut tool_call_indices: Vec<Option<usize>> = Vec::new();
910 let mut tool_calls: HashMap<usize, ToolCallAccumulator> = HashMap::new();
911 let mut tool_calls_started: std::collections::HashSet<usize> =
912 std::collections::HashSet::new();
913 let mut tool_call_positions: HashMap<usize, usize> = HashMap::new();
914 let mut latest_usage: Option<TokenUsage> = None;
915 loop {
916 if token.is_cancelled() {
917 yield StreamChunk::Error(StreamError::Cancelled);
918 break;
919 }
920
921 let event_result = tokio::select! {
922 biased;
923 () = token.cancelled() => {
924 yield StreamChunk::Error(StreamError::Cancelled);
925 break;
926 }
927 event = sse_stream.next() => event
928 };
929
930 let Some(event_result) = event_result else {
931 break;
932 };
933
934 let event = match event_result {
935 Ok(e) => e,
936 Err(e) => {
937 yield StreamChunk::Error(StreamError::SseParse(e));
938 break;
939 }
940 };
941
942 if event.data == "[DONE]" {
943 let tool_calls = std::mem::take(&mut tool_calls);
944 let mut final_content = Vec::new();
945
946 for (block, tool_index) in content.into_iter().zip(tool_call_indices.into_iter())
947 {
948 if let Some(index) = tool_index {
949 let Some(tool_call) = tool_calls.get(&index) else {
950 continue;
951 };
952 if tool_call.id.is_empty() || tool_call.name.is_empty() {
953 debug!(
954 target: "xai::stream",
955 "Skipping tool call with missing id/name: id='{}' name='{}'",
956 tool_call.id,
957 tool_call.name
958 );
959 continue;
960 }
961 let parameters = serde_json::from_str(&tool_call.args)
962 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
963 final_content.push(AssistantContent::ToolCall {
964 tool_call: steer_tools::ToolCall {
965 id: tool_call.id.clone(),
966 name: tool_call.name.clone(),
967 parameters,
968 },
969 thought_signature: None,
970 });
971 } else {
972 final_content.push(block);
973 }
974 }
975
976 yield StreamChunk::MessageComplete(CompletionResponse {
977 content: final_content,
978 usage: latest_usage,
979 });
980 break;
981 }
982
983 let chunk: XAIStreamChunk = match serde_json::from_str(&event.data) {
984 Ok(c) => c,
985 Err(e) => {
986 debug!(target: "xai::stream", "Failed to parse chunk: {} data: {}", e, event.data);
987 continue;
988 }
989 };
990
991 if let Some(usage) = chunk.usage.as_ref() {
992 latest_usage = Some(map_xai_usage(usage));
993 }
994
995 if let Some(choice) = chunk.choices.first() {
996 if let Some(text_delta) = &choice.delta.content {
997 if let Some(AssistantContent::Text { text }) = content.last_mut() { text.push_str(text_delta) } else {
998 content.push(AssistantContent::Text {
999 text: text_delta.clone(),
1000 });
1001 tool_call_indices.push(None);
1002 }
1003 yield StreamChunk::TextDelta(text_delta.clone());
1004 }
1005
1006 if let Some(thinking_delta) = &choice.delta.reasoning_content {
1007 if let Some(AssistantContent::Thought {
1008 thought: crate::app::conversation::ThoughtContent::Simple { text },
1009 }) = content.last_mut() { text.push_str(thinking_delta) } else {
1010 content.push(AssistantContent::Thought {
1011 thought: crate::app::conversation::ThoughtContent::Simple {
1012 text: thinking_delta.clone(),
1013 },
1014 });
1015 tool_call_indices.push(None);
1016 }
1017 yield StreamChunk::ThinkingDelta(thinking_delta.clone());
1018 }
1019
1020 if let Some(tcs) = &choice.delta.tool_calls {
1021 for tc in tcs {
1022 let entry = tool_calls.entry(tc.index).or_insert_with(|| {
1023 ToolCallAccumulator {
1024 id: String::new(),
1025 name: String::new(),
1026 args: String::new(),
1027 }
1028 });
1029 let mut started_now = false;
1030 let mut flushed_now = false;
1031
1032 if let Some(id) = &tc.id
1033 && !id.is_empty() {
1034 entry.id.clone_from(id);
1035 }
1036 if let Some(func) = &tc.function
1037 && let Some(name) = &func.name
1038 && !name.is_empty() {
1039 entry.name.clone_from(name);
1040 }
1041
1042 if let std::collections::hash_map::Entry::Vacant(e) = tool_call_positions.entry(tc.index) {
1043 let pos = content.len();
1044 content.push(AssistantContent::ToolCall {
1045 tool_call: steer_tools::ToolCall {
1046 id: entry.id.clone(),
1047 name: entry.name.clone(),
1048 parameters: serde_json::Value::String(entry.args.clone()),
1049 },
1050 thought_signature: None,
1051 });
1052 tool_call_indices.push(Some(tc.index));
1053 e.insert(pos);
1054 }
1055
1056 if !entry.id.is_empty()
1057 && !entry.name.is_empty()
1058 && !tool_calls_started.contains(&tc.index)
1059 {
1060 tool_calls_started.insert(tc.index);
1061 started_now = true;
1062 yield StreamChunk::ToolUseStart {
1063 id: entry.id.clone(),
1064 name: entry.name.clone(),
1065 };
1066 }
1067
1068 if let Some(func) = &tc.function
1069 && let Some(args) = &func.arguments {
1070 entry.args.push_str(args);
1071 if tool_calls_started.contains(&tc.index) {
1072 if started_now {
1073 if !entry.args.is_empty() {
1074 yield StreamChunk::ToolUseInputDelta {
1075 id: entry.id.clone(),
1076 delta: entry.args.clone(),
1077 };
1078 flushed_now = true;
1079 }
1080 } else if !args.is_empty() {
1081 yield StreamChunk::ToolUseInputDelta {
1082 id: entry.id.clone(),
1083 delta: args.clone(),
1084 };
1085 }
1086 }
1087 }
1088
1089 if started_now && !flushed_now && !entry.args.is_empty() {
1090 yield StreamChunk::ToolUseInputDelta {
1091 id: entry.id.clone(),
1092 delta: entry.args.clone(),
1093 };
1094 }
1095 }
1096 }
1097 }
1098 }
1099 }
1100 }
1101}
1102
1103#[cfg(test)]
1104mod tests {
1105 use super::*;
1106 use crate::api::provider::StreamChunk;
1107 use crate::api::sse::SseEvent;
1108 use crate::app::conversation::{
1109 AssistantContent, ImageContent, ImageSource, Message, MessageData, UserContent,
1110 };
1111 use futures::StreamExt;
1112 use futures::stream;
1113 use std::pin::pin;
1114 use tokio_util::sync::CancellationToken;
1115
1116 #[test]
1117 fn test_convert_messages_includes_data_url_image_part() {
1118 let messages = vec![Message {
1119 data: MessageData::User {
1120 content: vec![
1121 UserContent::Text {
1122 text: "describe".to_string(),
1123 },
1124 UserContent::Image {
1125 image: ImageContent {
1126 mime_type: "image/png".to_string(),
1127 source: ImageSource::DataUrl {
1128 data_url: "".to_string(),
1129 },
1130 width: None,
1131 height: None,
1132 bytes: None,
1133 sha256: None,
1134 },
1135 },
1136 ],
1137 },
1138 timestamp: 1,
1139 id: "msg-1".to_string(),
1140 parent_message_id: None,
1141 }];
1142
1143 let converted = XAIClient::convert_messages(messages, None).expect("convert messages");
1144 assert_eq!(converted.len(), 1);
1145
1146 match &converted[0] {
1147 XAIMessage::User { content, .. } => match content {
1148 XAIUserContent::Parts(parts) => {
1149 assert_eq!(parts.len(), 2);
1150 assert!(matches!(
1151 &parts[0],
1152 XAIUserContentPart::Text { text } if text == "describe"
1153 ));
1154 assert!(matches!(
1155 &parts[1],
1156 XAIUserContentPart::ImageUrl { image_url }
1157 if image_url.url == ""
1158 ));
1159 }
1160 other => panic!("Expected parts content, got {other:?}"),
1161 },
1162 other => panic!("Expected user message, got {other:?}"),
1163 }
1164 }
1165
1166 #[test]
1167 fn test_convert_messages_rejects_session_file_image_source() {
1168 let messages = vec![Message {
1169 data: MessageData::User {
1170 content: vec![UserContent::Image {
1171 image: ImageContent {
1172 mime_type: "image/png".to_string(),
1173 source: ImageSource::SessionFile {
1174 relative_path: "session-1/image.png".to_string(),
1175 },
1176 width: None,
1177 height: None,
1178 bytes: None,
1179 sha256: None,
1180 },
1181 }],
1182 },
1183 timestamp: 1,
1184 id: "msg-1".to_string(),
1185 parent_message_id: None,
1186 }];
1187
1188 let err =
1189 XAIClient::convert_messages(messages, None).expect_err("expected unsupported feature");
1190 match err {
1191 ApiError::UnsupportedFeature {
1192 provider,
1193 feature,
1194 details,
1195 } => {
1196 assert_eq!(provider, "xai");
1197 assert_eq!(feature, "image input source");
1198 assert!(details.contains("session file"));
1199 }
1200 other => panic!("Expected UnsupportedFeature, got {other:?}"),
1201 }
1202 }
1203
1204 #[test]
1205 fn test_map_xai_usage() {
1206 let usage = XAIUsage {
1207 prompt_tokens: 15,
1208 completion_tokens: 9,
1209 total_tokens: 24,
1210 num_sources_used: None,
1211 prompt_tokens_details: None,
1212 completion_tokens_details: None,
1213 };
1214
1215 assert_eq!(map_xai_usage(&usage), TokenUsage::new(15, 9, 24));
1216 }
1217
1218 #[test]
1219 fn test_non_stream_completion_maps_usage() {
1220 let usage = XAIUsage {
1221 prompt_tokens: 6,
1222 completion_tokens: 4,
1223 total_tokens: 10,
1224 num_sources_used: None,
1225 prompt_tokens_details: None,
1226 completion_tokens_details: None,
1227 };
1228 let choice = Choice {
1229 index: 0,
1230 message: AssistantMessage {
1231 content: Some("hello".to_string()),
1232 tool_calls: None,
1233 reasoning_content: None,
1234 },
1235 finish_reason: Some("stop".to_string()),
1236 };
1237 let response = XAICompletionResponse {
1238 id: "resp_1".to_string(),
1239 object: "chat.completion".to_string(),
1240 created: 1,
1241 model: "grok-test".to_string(),
1242 choices: vec![choice],
1243 usage: Some(usage),
1244 system_fingerprint: None,
1245 citations: None,
1246 debug_output: None,
1247 };
1248
1249 let converted = convert_xai_completion_response(response).expect("response should map");
1250
1251 assert_eq!(converted.usage, Some(TokenUsage::new(6, 4, 10)));
1252 assert!(matches!(
1253 converted.content.first(),
1254 Some(AssistantContent::Text { text }) if text == "hello"
1255 ));
1256 }
1257
1258 #[tokio::test]
1259 async fn test_convert_xai_stream_captures_final_usage() {
1260 let events = vec![
1261 Ok(SseEvent {
1262 event_type: None,
1263 data: r#"{"id":"chatcmpl-1","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#.to_string(),
1264 id: None,
1265 }),
1266 Ok(SseEvent {
1267 event_type: None,
1268 data: r#"{"id":"chatcmpl-1","choices":[],"usage":{"prompt_tokens":12,"completion_tokens":5,"total_tokens":17}}"#.to_string(),
1269 id: None,
1270 }),
1271 Ok(SseEvent {
1272 event_type: None,
1273 data: "[DONE]".to_string(),
1274 id: None,
1275 }),
1276 ];
1277
1278 let sse_stream = stream::iter(events);
1279 let token = CancellationToken::new();
1280 let mut stream = pin!(XAIClient::convert_xai_stream(sse_stream, token));
1281
1282 let first_delta = stream.next().await.unwrap();
1283 assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1284
1285 let complete = stream.next().await.unwrap();
1286 if let StreamChunk::MessageComplete(response) = complete {
1287 assert_eq!(response.usage, Some(TokenUsage::new(12, 5, 17)));
1288 assert!(matches!(
1289 response.content.first(),
1290 Some(AssistantContent::Text { text }) if text == "Hello"
1291 ));
1292 } else {
1293 panic!("Expected MessageComplete");
1294 }
1295 }
1296}