1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{self, header};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{
11 CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage,
12};
13use crate::api::sse::parse_sse_stream;
14use crate::api::util::{map_http_status_to_api_error, normalize_chat_url};
15use crate::app::SystemContext;
16use crate::app::conversation::{
17 AssistantContent, ImageSource, Message as AppMessage, ToolResult, UserContent,
18};
19use crate::config::model::{ModelId, ModelParameters};
20use steer_tools::ToolSchema;
21
22const DEFAULT_API_URL: &str = "https://api.x.ai/v1/chat/completions";
23
24#[derive(Clone)]
25pub struct XAIClient {
26 http_client: reqwest::Client,
27 base_url: String,
28}
29
30#[derive(Debug, Serialize, Deserialize)]
32#[serde(tag = "role", rename_all = "lowercase")]
33enum XAIMessage {
34 System {
35 content: String,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 name: Option<String>,
38 },
39 User {
40 content: XAIUserContent,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 name: Option<String>,
43 },
44 Assistant {
45 #[serde(skip_serializing_if = "Option::is_none")]
46 content: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 tool_calls: Option<Vec<XAIToolCall>>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 name: Option<String>,
51 },
52 Tool {
53 content: String,
54 tool_call_id: String,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 name: Option<String>,
57 },
58}
59
60#[derive(Debug, Serialize, Deserialize)]
61#[serde(untagged)]
62enum XAIUserContent {
63 Text(String),
64 Parts(Vec<XAIUserContentPart>),
65}
66
67#[derive(Debug, Serialize, Deserialize)]
68#[serde(tag = "type")]
69enum XAIUserContentPart {
70 #[serde(rename = "text")]
71 Text { text: String },
72 #[serde(rename = "image_url")]
73 ImageUrl { image_url: XAIImageUrl },
74}
75
76#[derive(Debug, Serialize, Deserialize)]
77struct XAIImageUrl {
78 url: String,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
83struct XAIFunction {
84 name: String,
85 description: String,
86 parameters: serde_json::Value,
87}
88
89#[derive(Debug, Serialize, Deserialize)]
91struct XAITool {
92 #[serde(rename = "type")]
93 tool_type: String, function: XAIFunction,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
99struct XAIToolCall {
100 id: String,
101 #[serde(rename = "type")]
102 tool_type: String,
103 function: XAIFunctionCall,
104}
105
106#[derive(Debug, Serialize, Deserialize)]
107struct XAIFunctionCall {
108 name: String,
109 arguments: String, }
111
112#[derive(Debug, Serialize, Deserialize)]
113#[serde(rename_all = "lowercase")]
114enum ReasoningEffort {
115 Low,
116 High,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct StreamOptions {
121 #[serde(skip_serializing_if = "Option::is_none")]
122 include_usage: Option<bool>,
123}
124
125#[derive(Debug, Serialize, Deserialize)]
126#[serde(untagged)]
127enum ToolChoice {
128 String(String), Specific {
130 #[serde(rename = "type")]
131 tool_type: String,
132 function: ToolChoiceFunction,
133 },
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137struct ToolChoiceFunction {
138 name: String,
139}
140
141#[derive(Debug, Serialize, Deserialize)]
142struct ResponseFormat {
143 #[serde(rename = "type")]
144 format_type: String,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 json_schema: Option<serde_json::Value>,
147}
148
149#[derive(Debug, Serialize, Deserialize)]
150struct SearchParameters {
151 #[serde(skip_serializing_if = "Option::is_none")]
152 from_date: Option<String>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 to_date: Option<String>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 max_search_results: Option<u32>,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 mode: Option<String>,
159 #[serde(skip_serializing_if = "Option::is_none")]
160 return_citations: Option<bool>,
161 #[serde(skip_serializing_if = "Option::is_none")]
162 sources: Option<Vec<String>>,
163}
164
165#[derive(Debug, Serialize, Deserialize)]
166struct WebSearchOptions {
167 #[serde(skip_serializing_if = "Option::is_none")]
168 search_context_size: Option<u32>,
169 #[serde(skip_serializing_if = "Option::is_none")]
170 user_location: Option<String>,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174struct CompletionRequest {
175 model: String,
176 messages: Vec<XAIMessage>,
177 #[serde(skip_serializing_if = "Option::is_none")]
178 deferred: Option<bool>,
179 #[serde(skip_serializing_if = "Option::is_none")]
180 frequency_penalty: Option<f32>,
181 #[serde(skip_serializing_if = "Option::is_none")]
182 logit_bias: Option<HashMap<String, f32>>,
183 #[serde(skip_serializing_if = "Option::is_none")]
184 logprobs: Option<bool>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 max_completion_tokens: Option<u32>,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 max_tokens: Option<u32>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 n: Option<u32>,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 parallel_tool_calls: Option<bool>,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 presence_penalty: Option<f32>,
195 #[serde(skip_serializing_if = "Option::is_none")]
196 reasoning_effort: Option<ReasoningEffort>,
197 #[serde(skip_serializing_if = "Option::is_none")]
198 response_format: Option<ResponseFormat>,
199 #[serde(skip_serializing_if = "Option::is_none")]
200 search_parameters: Option<SearchParameters>,
201 #[serde(skip_serializing_if = "Option::is_none")]
202 seed: Option<u64>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 stop: Option<Vec<String>>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 stream: Option<bool>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 stream_options: Option<StreamOptions>,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 temperature: Option<f32>,
211 #[serde(skip_serializing_if = "Option::is_none")]
212 tool_choice: Option<ToolChoice>,
213 #[serde(skip_serializing_if = "Option::is_none")]
214 tools: Option<Vec<XAITool>>,
215 #[serde(skip_serializing_if = "Option::is_none")]
216 top_logprobs: Option<u32>,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 top_p: Option<f32>,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 user: Option<String>,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 web_search_options: Option<WebSearchOptions>,
223}
224
225#[derive(Debug, Serialize, Deserialize)]
226struct XAICompletionResponse {
227 id: String,
228 object: String,
229 created: u64,
230 model: String,
231 choices: Vec<Choice>,
232 #[serde(skip_serializing_if = "Option::is_none")]
233 usage: Option<XAIUsage>,
234 #[serde(skip_serializing_if = "Option::is_none")]
235 system_fingerprint: Option<String>,
236 #[serde(skip_serializing_if = "Option::is_none")]
237 citations: Option<Vec<serde_json::Value>>,
238 #[serde(skip_serializing_if = "Option::is_none")]
239 debug_output: Option<DebugOutput>,
240}
241
242#[derive(Debug, Serialize, Deserialize)]
243struct Choice {
244 index: i32,
245 message: AssistantMessage,
246 finish_reason: Option<String>,
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250struct AssistantMessage {
251 content: Option<String>,
252 #[serde(skip_serializing_if = "Option::is_none")]
253 tool_calls: Option<Vec<XAIToolCall>>,
254 #[serde(skip_serializing_if = "Option::is_none")]
255 reasoning_content: Option<String>,
256}
257
258#[derive(Debug, Serialize, Deserialize)]
259struct PromptTokensDetails {
260 #[serde(rename = "cached_tokens")]
261 cached: usize,
262 #[serde(rename = "audio_tokens")]
263 audio: usize,
264 #[serde(rename = "image_tokens")]
265 image: usize,
266 #[serde(rename = "text_tokens")]
267 text: usize,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271struct CompletionTokensDetails {
272 #[serde(rename = "reasoning_tokens")]
273 reasoning: usize,
274 #[serde(rename = "audio_tokens")]
275 audio: usize,
276 #[serde(rename = "accepted_prediction_tokens")]
277 accepted_prediction: usize,
278 #[serde(rename = "rejected_prediction_tokens")]
279 rejected_prediction: usize,
280}
281
282#[derive(Debug, Serialize, Deserialize)]
283struct XAIUsage {
284 prompt_tokens: usize,
285 completion_tokens: usize,
286 total_tokens: usize,
287 #[serde(skip_serializing_if = "Option::is_none")]
288 num_sources_used: Option<usize>,
289 #[serde(skip_serializing_if = "Option::is_none")]
290 prompt_tokens_details: Option<PromptTokensDetails>,
291 #[serde(skip_serializing_if = "Option::is_none")]
292 completion_tokens_details: Option<CompletionTokensDetails>,
293}
294
295#[derive(Debug, Serialize, Deserialize)]
296struct DebugOutput {
297 attempts: usize,
298 cache_read_count: usize,
299 cache_read_input_bytes: usize,
300 cache_write_count: usize,
301 cache_write_input_bytes: usize,
302 prompt: String,
303 request: String,
304 responses: Vec<String>,
305}
306
307#[derive(Debug, Deserialize)]
308struct XAIStreamChunk {
309 #[expect(dead_code)]
310 id: String,
311 choices: Vec<XAIStreamChoice>,
312 #[serde(skip_serializing_if = "Option::is_none")]
313 usage: Option<XAIUsage>,
314}
315
316#[derive(Debug, Deserialize)]
317struct XAIStreamChoice {
318 #[expect(dead_code)]
319 index: u32,
320 delta: XAIStreamDelta,
321 #[expect(dead_code)]
322 finish_reason: Option<String>,
323}
324
325#[derive(Debug, Deserialize)]
326struct XAIStreamDelta {
327 #[serde(skip_serializing_if = "Option::is_none")]
328 content: Option<String>,
329 #[serde(skip_serializing_if = "Option::is_none")]
330 tool_calls: Option<Vec<XAIStreamToolCall>>,
331 #[serde(skip_serializing_if = "Option::is_none")]
332 reasoning_content: Option<String>,
333}
334
335#[derive(Debug, Deserialize)]
336struct XAIStreamToolCall {
337 index: usize,
338 #[serde(skip_serializing_if = "Option::is_none")]
339 id: Option<String>,
340 #[serde(skip_serializing_if = "Option::is_none")]
341 function: Option<XAIStreamFunction>,
342}
343
344#[derive(Debug, Deserialize)]
345struct XAIStreamFunction {
346 #[serde(skip_serializing_if = "Option::is_none")]
347 name: Option<String>,
348 #[serde(skip_serializing_if = "Option::is_none")]
349 arguments: Option<String>,
350}
351
352impl XAIClient {
353 pub fn new(api_key: String) -> Result<Self, ApiError> {
354 Self::with_base_url(api_key, None)
355 }
356
357 pub fn with_base_url(api_key: String, base_url: Option<String>) -> Result<Self, ApiError> {
358 let mut headers = header::HeaderMap::new();
359 headers.insert(
360 header::AUTHORIZATION,
361 header::HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|e| {
362 ApiError::AuthenticationFailed {
363 provider: "xai".to_string(),
364 details: format!("Invalid API key: {e}"),
365 }
366 })?,
367 );
368
369 let client = reqwest::Client::builder()
370 .default_headers(headers)
371 .timeout(std::time::Duration::from_secs(300)) .build()
373 .map_err(ApiError::Network)?;
374
375 let base_url = normalize_chat_url(base_url.as_deref(), DEFAULT_API_URL);
376
377 Ok(Self {
378 http_client: client,
379 base_url,
380 })
381 }
382
383 fn user_image_part(
384 image: &crate::app::conversation::ImageContent,
385 ) -> Result<XAIUserContentPart, ApiError> {
386 let image_url = match &image.source {
387 ImageSource::DataUrl { data_url } => data_url.clone(),
388 ImageSource::Url { url } => url.clone(),
389 ImageSource::SessionFile { relative_path } => {
390 return Err(ApiError::UnsupportedFeature {
391 provider: "xai".to_string(),
392 feature: "image input source".to_string(),
393 details: format!(
394 "xAI chat API cannot access session file '{}' directly; use data URLs or public URLs",
395 relative_path
396 ),
397 });
398 }
399 };
400
401 Ok(XAIUserContentPart::ImageUrl {
402 image_url: XAIImageUrl { url: image_url },
403 })
404 }
405
406 fn convert_messages(
407 messages: Vec<AppMessage>,
408 system: Option<SystemContext>,
409 ) -> Result<Vec<XAIMessage>, ApiError> {
410 let mut xai_messages = Vec::new();
411
412 if let Some(system_content) = system.and_then(|context| context.render()) {
414 xai_messages.push(XAIMessage::System {
415 content: system_content,
416 name: None,
417 });
418 }
419
420 for message in messages {
422 match &message.data {
423 crate::app::conversation::MessageData::User { content, .. } => {
424 let mut content_parts = Vec::new();
425
426 for user_content in content {
427 match user_content {
428 UserContent::Text { text } => {
429 content_parts.push(XAIUserContentPart::Text { text: text.clone() });
430 }
431 UserContent::Image { image } => {
432 content_parts.push(Self::user_image_part(image)?);
433 }
434 UserContent::CommandExecution {
435 command,
436 stdout,
437 stderr,
438 exit_code,
439 } => {
440 content_parts.push(XAIUserContentPart::Text {
441 text: UserContent::format_command_execution_as_xml(
442 command, stdout, stderr, *exit_code,
443 ),
444 });
445 }
446 }
447 }
448
449 if !content_parts.is_empty() {
451 let content = match content_parts.as_slice() {
452 [XAIUserContentPart::Text { text }] => {
453 XAIUserContent::Text(text.clone())
454 }
455 _ => XAIUserContent::Parts(content_parts),
456 };
457
458 xai_messages.push(XAIMessage::User {
459 content,
460 name: None,
461 });
462 }
463 }
464 crate::app::conversation::MessageData::Assistant { content, .. } => {
465 let mut text_parts = Vec::new();
467 let mut tool_calls = Vec::new();
468
469 for content_block in content {
470 match content_block {
471 AssistantContent::Text { text } => {
472 text_parts.push(text.clone());
473 }
474 AssistantContent::Image { image } => {
475 match Self::user_image_part(image) {
476 Ok(XAIUserContentPart::ImageUrl { image_url }) => {
477 text_parts.push(format!("[Image URL: {}]", image_url.url));
478 }
479 Ok(XAIUserContentPart::Text { text }) => {
480 text_parts.push(text);
481 }
482 Err(err) => {
483 debug!(
484 target: "xai::convert_messages",
485 "Skipping unsupported assistant image block: {}",
486 err
487 );
488 }
489 }
490 }
491 AssistantContent::ToolCall { tool_call, .. } => {
492 tool_calls.push(XAIToolCall {
493 id: tool_call.id.clone(),
494 tool_type: "function".to_string(),
495 function: XAIFunctionCall {
496 name: tool_call.name.clone(),
497 arguments: tool_call.parameters.to_string(),
498 },
499 });
500 }
501 AssistantContent::Thought { .. } => {
502
503 }
505 }
506 }
507
508 let content = if text_parts.is_empty() {
510 None
511 } else {
512 Some(text_parts.join("\n"))
513 };
514
515 let tool_calls_opt = if tool_calls.is_empty() {
516 None
517 } else {
518 Some(tool_calls)
519 };
520
521 xai_messages.push(XAIMessage::Assistant {
522 content,
523 tool_calls: tool_calls_opt,
524 name: None,
525 });
526 }
527 crate::app::conversation::MessageData::Tool {
528 tool_use_id,
529 result,
530 ..
531 } => {
532 let content_text = if let ToolResult::Error(e) = result {
534 format!("Error: {e}")
535 } else {
536 let text = result.llm_format();
537 if text.trim().is_empty() {
538 "(No output)".to_string()
539 } else {
540 text
541 }
542 };
543
544 xai_messages.push(XAIMessage::Tool {
545 content: content_text,
546 tool_call_id: tool_use_id.clone(),
547 name: None,
548 });
549 }
550 }
551 }
552
553 Ok(xai_messages)
554 }
555
556 fn convert_tools(tools: Vec<ToolSchema>) -> Vec<XAITool> {
557 tools
558 .into_iter()
559 .map(|tool| XAITool {
560 tool_type: "function".to_string(),
561 function: XAIFunction {
562 name: tool.name,
563 description: tool.description,
564 parameters: tool.input_schema.as_value().clone(),
565 },
566 })
567 .collect()
568 }
569}
570
571fn saturating_u32(value: usize) -> u32 {
572 u32::try_from(value).unwrap_or(u32::MAX)
573}
574
575fn map_xai_usage(usage: &XAIUsage) -> TokenUsage {
576 TokenUsage::new(
577 saturating_u32(usage.prompt_tokens),
578 saturating_u32(usage.completion_tokens),
579 saturating_u32(usage.total_tokens),
580 )
581}
582
583fn convert_xai_completion_response(
584 xai_response: XAICompletionResponse,
585) -> Result<CompletionResponse, ApiError> {
586 let usage = xai_response.usage.as_ref().map(map_xai_usage);
587
588 let Some(choice) = xai_response.choices.first() else {
589 return Err(ApiError::NoChoices {
590 provider: "xai".to_string(),
591 });
592 };
593
594 let mut content_blocks = Vec::new();
595
596 if let Some(reasoning) = &choice.message.reasoning_content
598 && !reasoning.trim().is_empty()
599 {
600 content_blocks.push(AssistantContent::Thought {
601 thought: crate::app::conversation::ThoughtContent::Simple {
602 text: reasoning.clone(),
603 },
604 });
605 }
606
607 if let Some(content) = &choice.message.content
609 && !content.trim().is_empty()
610 {
611 content_blocks.push(AssistantContent::Text {
612 text: content.clone(),
613 });
614 }
615
616 if let Some(tool_calls) = &choice.message.tool_calls {
618 for tool_call in tool_calls {
619 let parameters = serde_json::from_str(&tool_call.function.arguments)
621 .unwrap_or(serde_json::Value::Null);
622
623 content_blocks.push(AssistantContent::ToolCall {
624 tool_call: steer_tools::ToolCall {
625 id: tool_call.id.clone(),
626 name: tool_call.function.name.clone(),
627 parameters,
628 },
629 thought_signature: None,
630 });
631 }
632 }
633
634 Ok(CompletionResponse {
635 content: content_blocks,
636 usage,
637 })
638}
639
640#[async_trait]
641impl Provider for XAIClient {
642 fn name(&self) -> &'static str {
643 "xai"
644 }
645
646 async fn complete(
647 &self,
648 model_id: &ModelId,
649 messages: Vec<AppMessage>,
650 system: Option<SystemContext>,
651 tools: Option<Vec<ToolSchema>>,
652 call_options: Option<ModelParameters>,
653 token: CancellationToken,
654 ) -> Result<CompletionResponse, ApiError> {
655 let xai_messages = Self::convert_messages(messages, system)?;
656 let xai_tools = tools.map(Self::convert_tools);
657
658 let (supports_thinking, reasoning_effort) = call_options
660 .as_ref()
661 .and_then(|opts| opts.thinking_config)
662 .map_or((false, None), |tc| {
663 let effort = tc.effort.map(|e| match e {
664 crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
665 crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High, crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
667 crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, });
669 (tc.enabled, effort)
670 });
671
672 let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
674 reasoning_effort.or(Some(ReasoningEffort::High))
675 } else {
676 None
677 };
678
679 let max_output_tokens = call_options
680 .as_ref()
681 .and_then(|o| o.max_output_tokens)
682 .or(Some(32_768));
683
684 let request = CompletionRequest {
685 model: model_id.id.clone(), messages: xai_messages,
687 deferred: None,
688 frequency_penalty: None,
689 logit_bias: None,
690 logprobs: None,
691 max_completion_tokens: max_output_tokens,
692 max_tokens: max_output_tokens,
693 n: None,
694 parallel_tool_calls: None,
695 presence_penalty: None,
696 reasoning_effort,
697 response_format: None,
698 search_parameters: None,
699 seed: None,
700 stop: None,
701 stream: None,
702 stream_options: None,
703 temperature: call_options
704 .as_ref()
705 .and_then(|o| o.temperature)
706 .or(Some(1.0)),
707 tool_choice: None,
708 tools: xai_tools,
709 top_logprobs: None,
710 top_p: call_options.as_ref().and_then(|o| o.top_p),
711 user: None,
712 web_search_options: None,
713 };
714
715 let response = self
716 .http_client
717 .post(&self.base_url)
718 .json(&request)
719 .send()
720 .await
721 .map_err(ApiError::Network)?;
722
723 if !response.status().is_success() {
724 let status = response.status();
725 let error_text = response.text().await.unwrap_or_else(|_| String::new());
726
727 debug!(
728 target: "grok::complete",
729 "Grok API error - Status: {}, Body: {}",
730 status,
731 error_text
732 );
733
734 return Err(map_http_status_to_api_error(
735 self.name(),
736 status.as_u16(),
737 error_text,
738 ));
739 }
740
741 let response_text = tokio::select! {
742 () = token.cancelled() => {
743 debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
744 return Err(ApiError::Cancelled { provider: self.name().to_string() });
745 }
746 text_res = response.text() => {
747 text_res?
748 }
749 };
750
751 let xai_response: XAICompletionResponse =
752 serde_json::from_str(&response_text).map_err(|e| {
753 error!(
754 target: "xai::complete",
755 "Failed to parse response: {}, Body: {}",
756 e,
757 response_text
758 );
759 ApiError::ResponseParsingError {
760 provider: self.name().to_string(),
761 details: format!("Error: {e}, Body: {response_text}"),
762 }
763 })?;
764
765 convert_xai_completion_response(xai_response).map_err(|err| match err {
766 ApiError::NoChoices { .. } => ApiError::NoChoices {
767 provider: self.name().to_string(),
768 },
769 other => other,
770 })
771 }
772
773 async fn stream_complete(
774 &self,
775 model_id: &ModelId,
776 messages: Vec<AppMessage>,
777 system: Option<SystemContext>,
778 tools: Option<Vec<ToolSchema>>,
779 call_options: Option<ModelParameters>,
780 token: CancellationToken,
781 ) -> Result<CompletionStream, ApiError> {
782 let xai_messages = Self::convert_messages(messages, system)?;
783 let xai_tools = tools.map(Self::convert_tools);
784
785 let (supports_thinking, reasoning_effort) = call_options
786 .as_ref()
787 .and_then(|opts| opts.thinking_config)
788 .map_or((false, None), |tc| {
789 let effort = tc.effort.map(|e| match e {
790 crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
791 crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High,
792 crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
793 crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, });
795 (tc.enabled, effort)
796 });
797
798 let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
799 reasoning_effort.or(Some(ReasoningEffort::High))
800 } else {
801 None
802 };
803
804 let max_output_tokens = call_options
805 .as_ref()
806 .and_then(|o| o.max_output_tokens)
807 .or(Some(32_768));
808
809 let request = CompletionRequest {
810 model: model_id.id.clone(),
811 messages: xai_messages,
812 deferred: None,
813 frequency_penalty: None,
814 logit_bias: None,
815 logprobs: None,
816 max_completion_tokens: max_output_tokens,
817 max_tokens: max_output_tokens,
818 n: None,
819 parallel_tool_calls: None,
820 presence_penalty: None,
821 reasoning_effort,
822 response_format: None,
823 search_parameters: None,
824 seed: None,
825 stop: None,
826 stream: Some(true),
827 stream_options: Some(StreamOptions {
828 include_usage: Some(true),
829 }),
830 temperature: call_options
831 .as_ref()
832 .and_then(|o| o.temperature)
833 .or(Some(1.0)),
834 tool_choice: None,
835 tools: xai_tools,
836 top_logprobs: None,
837 top_p: call_options.as_ref().and_then(|o| o.top_p),
838 user: None,
839 web_search_options: None,
840 };
841
842 let response = self
843 .http_client
844 .post(&self.base_url)
845 .json(&request)
846 .send()
847 .await
848 .map_err(ApiError::Network)?;
849
850 if !response.status().is_success() {
851 let status = response.status();
852 let error_text = response.text().await.unwrap_or_else(|_| String::new());
853
854 debug!(
855 target: "xai::stream",
856 "xAI API error - Status: {}, Body: {}",
857 status,
858 error_text
859 );
860
861 return Err(map_http_status_to_api_error(
862 self.name(),
863 status.as_u16(),
864 error_text,
865 ));
866 }
867
868 let byte_stream = response.bytes_stream();
869 let sse_stream = parse_sse_stream(byte_stream);
870
871 Ok(Box::pin(XAIClient::convert_xai_stream(sse_stream, token)))
872 }
873}
874
875impl XAIClient {
876 fn convert_xai_stream(
877 mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
878 + Unpin
879 + Send
880 + 'static,
881 token: CancellationToken,
882 ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
883 struct ToolCallAccumulator {
884 id: String,
885 name: String,
886 args: String,
887 }
888
889 async_stream::stream! {
890 let mut content: Vec<AssistantContent> = Vec::new();
891 let mut tool_call_indices: Vec<Option<usize>> = Vec::new();
892 let mut tool_calls: HashMap<usize, ToolCallAccumulator> = HashMap::new();
893 let mut tool_calls_started: std::collections::HashSet<usize> =
894 std::collections::HashSet::new();
895 let mut tool_call_positions: HashMap<usize, usize> = HashMap::new();
896 let mut latest_usage: Option<TokenUsage> = None;
897 loop {
898 if token.is_cancelled() {
899 yield StreamChunk::Error(StreamError::Cancelled);
900 break;
901 }
902
903 let event_result = tokio::select! {
904 biased;
905 () = token.cancelled() => {
906 yield StreamChunk::Error(StreamError::Cancelled);
907 break;
908 }
909 event = sse_stream.next() => event
910 };
911
912 let Some(event_result) = event_result else {
913 break;
914 };
915
916 let event = match event_result {
917 Ok(e) => e,
918 Err(e) => {
919 yield StreamChunk::Error(StreamError::SseParse(e));
920 break;
921 }
922 };
923
924 if event.data == "[DONE]" {
925 let tool_calls = std::mem::take(&mut tool_calls);
926 let mut final_content = Vec::new();
927
928 for (block, tool_index) in content.into_iter().zip(tool_call_indices.into_iter())
929 {
930 if let Some(index) = tool_index {
931 let Some(tool_call) = tool_calls.get(&index) else {
932 continue;
933 };
934 if tool_call.id.is_empty() || tool_call.name.is_empty() {
935 debug!(
936 target: "xai::stream",
937 "Skipping tool call with missing id/name: id='{}' name='{}'",
938 tool_call.id,
939 tool_call.name
940 );
941 continue;
942 }
943 let parameters = serde_json::from_str(&tool_call.args)
944 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
945 final_content.push(AssistantContent::ToolCall {
946 tool_call: steer_tools::ToolCall {
947 id: tool_call.id.clone(),
948 name: tool_call.name.clone(),
949 parameters,
950 },
951 thought_signature: None,
952 });
953 } else {
954 final_content.push(block);
955 }
956 }
957
958 yield StreamChunk::MessageComplete(CompletionResponse {
959 content: final_content,
960 usage: latest_usage,
961 });
962 break;
963 }
964
965 let chunk: XAIStreamChunk = match serde_json::from_str(&event.data) {
966 Ok(c) => c,
967 Err(e) => {
968 debug!(target: "xai::stream", "Failed to parse chunk: {} data: {}", e, event.data);
969 continue;
970 }
971 };
972
973 if let Some(usage) = chunk.usage.as_ref() {
974 latest_usage = Some(map_xai_usage(usage));
975 }
976
977 if let Some(choice) = chunk.choices.first() {
978 if let Some(text_delta) = &choice.delta.content {
979 if let Some(AssistantContent::Text { text }) = content.last_mut() { text.push_str(text_delta) } else {
980 content.push(AssistantContent::Text {
981 text: text_delta.clone(),
982 });
983 tool_call_indices.push(None);
984 }
985 yield StreamChunk::TextDelta(text_delta.clone());
986 }
987
988 if let Some(thinking_delta) = &choice.delta.reasoning_content {
989 if let Some(AssistantContent::Thought {
990 thought: crate::app::conversation::ThoughtContent::Simple { text },
991 }) = content.last_mut() { text.push_str(thinking_delta) } else {
992 content.push(AssistantContent::Thought {
993 thought: crate::app::conversation::ThoughtContent::Simple {
994 text: thinking_delta.clone(),
995 },
996 });
997 tool_call_indices.push(None);
998 }
999 yield StreamChunk::ThinkingDelta(thinking_delta.clone());
1000 }
1001
1002 if let Some(tcs) = &choice.delta.tool_calls {
1003 for tc in tcs {
1004 let entry = tool_calls.entry(tc.index).or_insert_with(|| {
1005 ToolCallAccumulator {
1006 id: String::new(),
1007 name: String::new(),
1008 args: String::new(),
1009 }
1010 });
1011 let mut started_now = false;
1012 let mut flushed_now = false;
1013
1014 if let Some(id) = &tc.id
1015 && !id.is_empty() {
1016 entry.id.clone_from(id);
1017 }
1018 if let Some(func) = &tc.function
1019 && let Some(name) = &func.name
1020 && !name.is_empty() {
1021 entry.name.clone_from(name);
1022 }
1023
1024 if let std::collections::hash_map::Entry::Vacant(e) = tool_call_positions.entry(tc.index) {
1025 let pos = content.len();
1026 content.push(AssistantContent::ToolCall {
1027 tool_call: steer_tools::ToolCall {
1028 id: entry.id.clone(),
1029 name: entry.name.clone(),
1030 parameters: serde_json::Value::String(entry.args.clone()),
1031 },
1032 thought_signature: None,
1033 });
1034 tool_call_indices.push(Some(tc.index));
1035 e.insert(pos);
1036 }
1037
1038 if !entry.id.is_empty()
1039 && !entry.name.is_empty()
1040 && !tool_calls_started.contains(&tc.index)
1041 {
1042 tool_calls_started.insert(tc.index);
1043 started_now = true;
1044 yield StreamChunk::ToolUseStart {
1045 id: entry.id.clone(),
1046 name: entry.name.clone(),
1047 };
1048 }
1049
1050 if let Some(func) = &tc.function
1051 && let Some(args) = &func.arguments {
1052 entry.args.push_str(args);
1053 if tool_calls_started.contains(&tc.index) {
1054 if started_now {
1055 if !entry.args.is_empty() {
1056 yield StreamChunk::ToolUseInputDelta {
1057 id: entry.id.clone(),
1058 delta: entry.args.clone(),
1059 };
1060 flushed_now = true;
1061 }
1062 } else if !args.is_empty() {
1063 yield StreamChunk::ToolUseInputDelta {
1064 id: entry.id.clone(),
1065 delta: args.clone(),
1066 };
1067 }
1068 }
1069 }
1070
1071 if started_now && !flushed_now && !entry.args.is_empty() {
1072 yield StreamChunk::ToolUseInputDelta {
1073 id: entry.id.clone(),
1074 delta: entry.args.clone(),
1075 };
1076 }
1077 }
1078 }
1079 }
1080 }
1081 }
1082 }
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087 use super::*;
1088 use crate::api::provider::StreamChunk;
1089 use crate::api::sse::SseEvent;
1090 use crate::app::conversation::{
1091 AssistantContent, ImageContent, ImageSource, Message, MessageData, UserContent,
1092 };
1093 use futures::StreamExt;
1094 use futures::stream;
1095 use std::pin::pin;
1096 use tokio_util::sync::CancellationToken;
1097
1098 #[test]
1099 fn test_convert_messages_includes_data_url_image_part() {
1100 let messages = vec![Message {
1101 data: MessageData::User {
1102 content: vec![
1103 UserContent::Text {
1104 text: "describe".to_string(),
1105 },
1106 UserContent::Image {
1107 image: ImageContent {
1108 mime_type: "image/png".to_string(),
1109 source: ImageSource::DataUrl {
1110 data_url: "data:image/png;base64,Zm9v".to_string(),
1111 },
1112 width: None,
1113 height: None,
1114 bytes: None,
1115 sha256: None,
1116 },
1117 },
1118 ],
1119 },
1120 timestamp: 1,
1121 id: "msg-1".to_string(),
1122 parent_message_id: None,
1123 }];
1124
1125 let converted = XAIClient::convert_messages(messages, None).expect("convert messages");
1126 assert_eq!(converted.len(), 1);
1127
1128 match &converted[0] {
1129 XAIMessage::User { content, .. } => match content {
1130 XAIUserContent::Parts(parts) => {
1131 assert_eq!(parts.len(), 2);
1132 assert!(matches!(
1133 &parts[0],
1134 XAIUserContentPart::Text { text } if text == "describe"
1135 ));
1136 assert!(matches!(
1137 &parts[1],
1138 XAIUserContentPart::ImageUrl { image_url }
1139 if image_url.url == "data:image/png;base64,Zm9v"
1140 ));
1141 }
1142 other => panic!("Expected parts content, got {other:?}"),
1143 },
1144 other => panic!("Expected user message, got {other:?}"),
1145 }
1146 }
1147
1148 #[test]
1149 fn test_convert_messages_rejects_session_file_image_source() {
1150 let messages = vec![Message {
1151 data: MessageData::User {
1152 content: vec![UserContent::Image {
1153 image: ImageContent {
1154 mime_type: "image/png".to_string(),
1155 source: ImageSource::SessionFile {
1156 relative_path: "session-1/image.png".to_string(),
1157 },
1158 width: None,
1159 height: None,
1160 bytes: None,
1161 sha256: None,
1162 },
1163 }],
1164 },
1165 timestamp: 1,
1166 id: "msg-1".to_string(),
1167 parent_message_id: None,
1168 }];
1169
1170 let err =
1171 XAIClient::convert_messages(messages, None).expect_err("expected unsupported feature");
1172 match err {
1173 ApiError::UnsupportedFeature {
1174 provider,
1175 feature,
1176 details,
1177 } => {
1178 assert_eq!(provider, "xai");
1179 assert_eq!(feature, "image input source");
1180 assert!(details.contains("session file"));
1181 }
1182 other => panic!("Expected UnsupportedFeature, got {other:?}"),
1183 }
1184 }
1185
1186 #[test]
1187 fn test_map_xai_usage() {
1188 let usage = XAIUsage {
1189 prompt_tokens: 15,
1190 completion_tokens: 9,
1191 total_tokens: 24,
1192 num_sources_used: None,
1193 prompt_tokens_details: None,
1194 completion_tokens_details: None,
1195 };
1196
1197 assert_eq!(map_xai_usage(&usage), TokenUsage::new(15, 9, 24));
1198 }
1199
1200 #[test]
1201 fn test_non_stream_completion_maps_usage() {
1202 let usage = XAIUsage {
1203 prompt_tokens: 6,
1204 completion_tokens: 4,
1205 total_tokens: 10,
1206 num_sources_used: None,
1207 prompt_tokens_details: None,
1208 completion_tokens_details: None,
1209 };
1210 let choice = Choice {
1211 index: 0,
1212 message: AssistantMessage {
1213 content: Some("hello".to_string()),
1214 tool_calls: None,
1215 reasoning_content: None,
1216 },
1217 finish_reason: Some("stop".to_string()),
1218 };
1219 let response = XAICompletionResponse {
1220 id: "resp_1".to_string(),
1221 object: "chat.completion".to_string(),
1222 created: 1,
1223 model: "grok-test".to_string(),
1224 choices: vec![choice],
1225 usage: Some(usage),
1226 system_fingerprint: None,
1227 citations: None,
1228 debug_output: None,
1229 };
1230
1231 let converted = convert_xai_completion_response(response).expect("response should map");
1232
1233 assert_eq!(converted.usage, Some(TokenUsage::new(6, 4, 10)));
1234 assert!(matches!(
1235 converted.content.first(),
1236 Some(AssistantContent::Text { text }) if text == "hello"
1237 ));
1238 }
1239
1240 #[tokio::test]
1241 async fn test_convert_xai_stream_captures_final_usage() {
1242 let events = vec![
1243 Ok(SseEvent {
1244 event_type: None,
1245 data: r#"{"id":"chatcmpl-1","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#.to_string(),
1246 id: None,
1247 }),
1248 Ok(SseEvent {
1249 event_type: None,
1250 data: r#"{"id":"chatcmpl-1","choices":[],"usage":{"prompt_tokens":12,"completion_tokens":5,"total_tokens":17}}"#.to_string(),
1251 id: None,
1252 }),
1253 Ok(SseEvent {
1254 event_type: None,
1255 data: "[DONE]".to_string(),
1256 id: None,
1257 }),
1258 ];
1259
1260 let sse_stream = stream::iter(events);
1261 let token = CancellationToken::new();
1262 let mut stream = pin!(XAIClient::convert_xai_stream(sse_stream, token));
1263
1264 let first_delta = stream.next().await.unwrap();
1265 assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1266
1267 let complete = stream.next().await.unwrap();
1268 if let StreamChunk::MessageComplete(response) = complete {
1269 assert_eq!(response.usage, Some(TokenUsage::new(12, 5, 17)));
1270 assert!(matches!(
1271 response.content.first(),
1272 Some(AssistantContent::Text { text }) if text == "Hello"
1273 ));
1274 } else {
1275 panic!("Expected MessageComplete");
1276 }
1277 }
1278}