1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{self, header};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{CompletionResponse, CompletionStream, Provider, StreamChunk};
11use crate::api::sse::parse_sse_stream;
12use crate::api::util::normalize_chat_url;
13use crate::app::SystemContext;
14use crate::app::conversation::{
15 AssistantContent, ImageSource, Message as AppMessage, ToolResult, UserContent,
16};
17use crate::config::model::{ModelId, ModelParameters};
18use steer_tools::ToolSchema;
19
20const DEFAULT_API_URL: &str = "https://api.x.ai/v1/chat/completions";
21
22#[derive(Clone)]
23pub struct XAIClient {
24 http_client: reqwest::Client,
25 base_url: String,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
30#[serde(tag = "role", rename_all = "lowercase")]
31enum XAIMessage {
32 System {
33 content: String,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 name: Option<String>,
36 },
37 User {
38 content: XAIUserContent,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 name: Option<String>,
41 },
42 Assistant {
43 #[serde(skip_serializing_if = "Option::is_none")]
44 content: Option<String>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 tool_calls: Option<Vec<XAIToolCall>>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 name: Option<String>,
49 },
50 Tool {
51 content: String,
52 tool_call_id: String,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 name: Option<String>,
55 },
56}
57
58#[derive(Debug, Serialize, Deserialize)]
59#[serde(untagged)]
60enum XAIUserContent {
61 Text(String),
62 Parts(Vec<XAIUserContentPart>),
63}
64
65#[derive(Debug, Serialize, Deserialize)]
66#[serde(tag = "type")]
67enum XAIUserContentPart {
68 #[serde(rename = "text")]
69 Text { text: String },
70 #[serde(rename = "image_url")]
71 ImageUrl { image_url: XAIImageUrl },
72}
73
74#[derive(Debug, Serialize, Deserialize)]
75struct XAIImageUrl {
76 url: String,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
81struct XAIFunction {
82 name: String,
83 description: String,
84 parameters: serde_json::Value,
85}
86
87#[derive(Debug, Serialize, Deserialize)]
89struct XAITool {
90 #[serde(rename = "type")]
91 tool_type: String, function: XAIFunction,
93}
94
95#[derive(Debug, Serialize, Deserialize)]
97struct XAIToolCall {
98 id: String,
99 #[serde(rename = "type")]
100 tool_type: String,
101 function: XAIFunctionCall,
102}
103
104#[derive(Debug, Serialize, Deserialize)]
105struct XAIFunctionCall {
106 name: String,
107 arguments: String, }
109
110#[derive(Debug, Serialize, Deserialize)]
111#[serde(rename_all = "lowercase")]
112enum ReasoningEffort {
113 Low,
114 High,
115}
116
117#[derive(Debug, Serialize, Deserialize)]
118struct StreamOptions {
119 #[serde(skip_serializing_if = "Option::is_none")]
120 include_usage: Option<bool>,
121}
122
123#[derive(Debug, Serialize, Deserialize)]
124#[serde(untagged)]
125enum ToolChoice {
126 String(String), Specific {
128 #[serde(rename = "type")]
129 tool_type: String,
130 function: ToolChoiceFunction,
131 },
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135struct ToolChoiceFunction {
136 name: String,
137}
138
139#[derive(Debug, Serialize, Deserialize)]
140struct ResponseFormat {
141 #[serde(rename = "type")]
142 format_type: String,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 json_schema: Option<serde_json::Value>,
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148struct SearchParameters {
149 #[serde(skip_serializing_if = "Option::is_none")]
150 from_date: Option<String>,
151 #[serde(skip_serializing_if = "Option::is_none")]
152 to_date: Option<String>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 max_search_results: Option<u32>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 mode: Option<String>,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 return_citations: Option<bool>,
159 #[serde(skip_serializing_if = "Option::is_none")]
160 sources: Option<Vec<String>>,
161}
162
163#[derive(Debug, Serialize, Deserialize)]
164struct WebSearchOptions {
165 #[serde(skip_serializing_if = "Option::is_none")]
166 search_context_size: Option<u32>,
167 #[serde(skip_serializing_if = "Option::is_none")]
168 user_location: Option<String>,
169}
170
171#[derive(Debug, Serialize, Deserialize)]
172struct CompletionRequest {
173 model: String,
174 messages: Vec<XAIMessage>,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 deferred: Option<bool>,
177 #[serde(skip_serializing_if = "Option::is_none")]
178 frequency_penalty: Option<f32>,
179 #[serde(skip_serializing_if = "Option::is_none")]
180 logit_bias: Option<HashMap<String, f32>>,
181 #[serde(skip_serializing_if = "Option::is_none")]
182 logprobs: Option<bool>,
183 #[serde(skip_serializing_if = "Option::is_none")]
184 max_completion_tokens: Option<u32>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 max_tokens: Option<u32>,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 n: Option<u32>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 parallel_tool_calls: Option<bool>,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 presence_penalty: Option<f32>,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 reasoning_effort: Option<ReasoningEffort>,
195 #[serde(skip_serializing_if = "Option::is_none")]
196 response_format: Option<ResponseFormat>,
197 #[serde(skip_serializing_if = "Option::is_none")]
198 search_parameters: Option<SearchParameters>,
199 #[serde(skip_serializing_if = "Option::is_none")]
200 seed: Option<u64>,
201 #[serde(skip_serializing_if = "Option::is_none")]
202 stop: Option<Vec<String>>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 stream: Option<bool>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 stream_options: Option<StreamOptions>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 temperature: Option<f32>,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 tool_choice: Option<ToolChoice>,
211 #[serde(skip_serializing_if = "Option::is_none")]
212 tools: Option<Vec<XAITool>>,
213 #[serde(skip_serializing_if = "Option::is_none")]
214 top_logprobs: Option<u32>,
215 #[serde(skip_serializing_if = "Option::is_none")]
216 top_p: Option<f32>,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 user: Option<String>,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 web_search_options: Option<WebSearchOptions>,
221}
222
223#[derive(Debug, Serialize, Deserialize)]
224struct XAICompletionResponse {
225 id: String,
226 object: String,
227 created: u64,
228 model: String,
229 choices: Vec<Choice>,
230 #[serde(skip_serializing_if = "Option::is_none")]
231 usage: Option<XAIUsage>,
232 #[serde(skip_serializing_if = "Option::is_none")]
233 system_fingerprint: Option<String>,
234 #[serde(skip_serializing_if = "Option::is_none")]
235 citations: Option<Vec<serde_json::Value>>,
236 #[serde(skip_serializing_if = "Option::is_none")]
237 debug_output: Option<DebugOutput>,
238}
239
240#[derive(Debug, Serialize, Deserialize)]
241struct Choice {
242 index: i32,
243 message: AssistantMessage,
244 finish_reason: Option<String>,
245}
246
247#[derive(Debug, Serialize, Deserialize)]
248struct AssistantMessage {
249 content: Option<String>,
250 #[serde(skip_serializing_if = "Option::is_none")]
251 tool_calls: Option<Vec<XAIToolCall>>,
252 #[serde(skip_serializing_if = "Option::is_none")]
253 reasoning_content: Option<String>,
254}
255
256#[derive(Debug, Serialize, Deserialize)]
257struct PromptTokensDetails {
258 #[serde(rename = "cached_tokens")]
259 cached: usize,
260 #[serde(rename = "audio_tokens")]
261 audio: usize,
262 #[serde(rename = "image_tokens")]
263 image: usize,
264 #[serde(rename = "text_tokens")]
265 text: usize,
266}
267
268#[derive(Debug, Serialize, Deserialize)]
269struct CompletionTokensDetails {
270 #[serde(rename = "reasoning_tokens")]
271 reasoning: usize,
272 #[serde(rename = "audio_tokens")]
273 audio: usize,
274 #[serde(rename = "accepted_prediction_tokens")]
275 accepted_prediction: usize,
276 #[serde(rename = "rejected_prediction_tokens")]
277 rejected_prediction: usize,
278}
279
280#[derive(Debug, Serialize, Deserialize)]
281struct XAIUsage {
282 prompt_tokens: usize,
283 completion_tokens: usize,
284 total_tokens: usize,
285 #[serde(skip_serializing_if = "Option::is_none")]
286 num_sources_used: Option<usize>,
287 #[serde(skip_serializing_if = "Option::is_none")]
288 prompt_tokens_details: Option<PromptTokensDetails>,
289 #[serde(skip_serializing_if = "Option::is_none")]
290 completion_tokens_details: Option<CompletionTokensDetails>,
291}
292
293#[derive(Debug, Serialize, Deserialize)]
294struct DebugOutput {
295 attempts: usize,
296 cache_read_count: usize,
297 cache_read_input_bytes: usize,
298 cache_write_count: usize,
299 cache_write_input_bytes: usize,
300 prompt: String,
301 request: String,
302 responses: Vec<String>,
303}
304
305#[derive(Debug, Deserialize)]
306struct XAIStreamChunk {
307 #[expect(dead_code)]
308 id: String,
309 choices: Vec<XAIStreamChoice>,
310}
311
312#[derive(Debug, Deserialize)]
313struct XAIStreamChoice {
314 #[expect(dead_code)]
315 index: u32,
316 delta: XAIStreamDelta,
317 #[expect(dead_code)]
318 finish_reason: Option<String>,
319}
320
321#[derive(Debug, Deserialize)]
322struct XAIStreamDelta {
323 #[serde(skip_serializing_if = "Option::is_none")]
324 content: Option<String>,
325 #[serde(skip_serializing_if = "Option::is_none")]
326 tool_calls: Option<Vec<XAIStreamToolCall>>,
327 #[serde(skip_serializing_if = "Option::is_none")]
328 reasoning_content: Option<String>,
329}
330
331#[derive(Debug, Deserialize)]
332struct XAIStreamToolCall {
333 index: usize,
334 #[serde(skip_serializing_if = "Option::is_none")]
335 id: Option<String>,
336 #[serde(skip_serializing_if = "Option::is_none")]
337 function: Option<XAIStreamFunction>,
338}
339
340#[derive(Debug, Deserialize)]
341struct XAIStreamFunction {
342 #[serde(skip_serializing_if = "Option::is_none")]
343 name: Option<String>,
344 #[serde(skip_serializing_if = "Option::is_none")]
345 arguments: Option<String>,
346}
347
348impl XAIClient {
349 pub fn new(api_key: String) -> Result<Self, ApiError> {
350 Self::with_base_url(api_key, None)
351 }
352
353 pub fn with_base_url(api_key: String, base_url: Option<String>) -> Result<Self, ApiError> {
354 let mut headers = header::HeaderMap::new();
355 headers.insert(
356 header::AUTHORIZATION,
357 header::HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|e| {
358 ApiError::AuthenticationFailed {
359 provider: "xai".to_string(),
360 details: format!("Invalid API key: {e}"),
361 }
362 })?,
363 );
364
365 let client = reqwest::Client::builder()
366 .default_headers(headers)
367 .timeout(std::time::Duration::from_secs(300)) .build()
369 .map_err(ApiError::Network)?;
370
371 let base_url = normalize_chat_url(base_url.as_deref(), DEFAULT_API_URL);
372
373 Ok(Self {
374 http_client: client,
375 base_url,
376 })
377 }
378
379 fn user_image_part(
380 image: &crate::app::conversation::ImageContent,
381 ) -> Result<XAIUserContentPart, ApiError> {
382 let image_url = match &image.source {
383 ImageSource::DataUrl { data_url } => data_url.clone(),
384 ImageSource::Url { url } => url.clone(),
385 ImageSource::SessionFile { relative_path } => {
386 return Err(ApiError::UnsupportedFeature {
387 provider: "xai".to_string(),
388 feature: "image input source".to_string(),
389 details: format!(
390 "xAI chat API cannot access session file '{}' directly; use data URLs or public URLs",
391 relative_path
392 ),
393 });
394 }
395 };
396
397 Ok(XAIUserContentPart::ImageUrl {
398 image_url: XAIImageUrl { url: image_url },
399 })
400 }
401
402 fn convert_messages(
403 messages: Vec<AppMessage>,
404 system: Option<SystemContext>,
405 ) -> Result<Vec<XAIMessage>, ApiError> {
406 let mut xai_messages = Vec::new();
407
408 if let Some(system_content) = system.and_then(|context| context.render()) {
410 xai_messages.push(XAIMessage::System {
411 content: system_content,
412 name: None,
413 });
414 }
415
416 for message in messages {
418 match &message.data {
419 crate::app::conversation::MessageData::User { content, .. } => {
420 let mut content_parts = Vec::new();
421
422 for user_content in content {
423 match user_content {
424 UserContent::Text { text } => {
425 content_parts.push(XAIUserContentPart::Text { text: text.clone() });
426 }
427 UserContent::Image { image } => {
428 content_parts.push(Self::user_image_part(image)?);
429 }
430 UserContent::CommandExecution {
431 command,
432 stdout,
433 stderr,
434 exit_code,
435 } => {
436 content_parts.push(XAIUserContentPart::Text {
437 text: UserContent::format_command_execution_as_xml(
438 command, stdout, stderr, *exit_code,
439 ),
440 });
441 }
442 }
443 }
444
445 if !content_parts.is_empty() {
447 let content = match content_parts.as_slice() {
448 [XAIUserContentPart::Text { text }] => {
449 XAIUserContent::Text(text.clone())
450 }
451 _ => XAIUserContent::Parts(content_parts),
452 };
453
454 xai_messages.push(XAIMessage::User {
455 content,
456 name: None,
457 });
458 }
459 }
460 crate::app::conversation::MessageData::Assistant { content, .. } => {
461 let mut text_parts = Vec::new();
463 let mut tool_calls = Vec::new();
464
465 for content_block in content {
466 match content_block {
467 AssistantContent::Text { text } => {
468 text_parts.push(text.clone());
469 }
470 AssistantContent::Image { image } => {
471 match Self::user_image_part(image) {
472 Ok(XAIUserContentPart::ImageUrl { image_url }) => {
473 text_parts.push(format!("[Image URL: {}]", image_url.url));
474 }
475 Ok(XAIUserContentPart::Text { text }) => {
476 text_parts.push(text);
477 }
478 Err(err) => {
479 debug!(
480 target: "xai::convert_messages",
481 "Skipping unsupported assistant image block: {}",
482 err
483 );
484 }
485 }
486 }
487 AssistantContent::ToolCall { tool_call, .. } => {
488 tool_calls.push(XAIToolCall {
489 id: tool_call.id.clone(),
490 tool_type: "function".to_string(),
491 function: XAIFunctionCall {
492 name: tool_call.name.clone(),
493 arguments: tool_call.parameters.to_string(),
494 },
495 });
496 }
497 AssistantContent::Thought { .. } => {
498
499 }
501 }
502 }
503
504 let content = if text_parts.is_empty() {
506 None
507 } else {
508 Some(text_parts.join("\n"))
509 };
510
511 let tool_calls_opt = if tool_calls.is_empty() {
512 None
513 } else {
514 Some(tool_calls)
515 };
516
517 xai_messages.push(XAIMessage::Assistant {
518 content,
519 tool_calls: tool_calls_opt,
520 name: None,
521 });
522 }
523 crate::app::conversation::MessageData::Tool {
524 tool_use_id,
525 result,
526 ..
527 } => {
528 let content_text = if let ToolResult::Error(e) = result {
530 format!("Error: {e}")
531 } else {
532 let text = result.llm_format();
533 if text.trim().is_empty() {
534 "(No output)".to_string()
535 } else {
536 text
537 }
538 };
539
540 xai_messages.push(XAIMessage::Tool {
541 content: content_text,
542 tool_call_id: tool_use_id.clone(),
543 name: None,
544 });
545 }
546 }
547 }
548
549 Ok(xai_messages)
550 }
551
552 fn convert_tools(tools: Vec<ToolSchema>) -> Vec<XAITool> {
553 tools
554 .into_iter()
555 .map(|tool| XAITool {
556 tool_type: "function".to_string(),
557 function: XAIFunction {
558 name: tool.name,
559 description: tool.description,
560 parameters: tool.input_schema.as_value().clone(),
561 },
562 })
563 .collect()
564 }
565}
566
567#[async_trait]
568impl Provider for XAIClient {
569 fn name(&self) -> &'static str {
570 "xai"
571 }
572
573 async fn complete(
574 &self,
575 model_id: &ModelId,
576 messages: Vec<AppMessage>,
577 system: Option<SystemContext>,
578 tools: Option<Vec<ToolSchema>>,
579 call_options: Option<ModelParameters>,
580 token: CancellationToken,
581 ) -> Result<CompletionResponse, ApiError> {
582 let xai_messages = Self::convert_messages(messages, system)?;
583 let xai_tools = tools.map(Self::convert_tools);
584
585 let (supports_thinking, reasoning_effort) = call_options
587 .as_ref()
588 .and_then(|opts| opts.thinking_config)
589 .map_or((false, None), |tc| {
590 let effort = tc.effort.map(|e| match e {
591 crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
592 crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High, crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
594 crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, });
596 (tc.enabled, effort)
597 });
598
599 let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
601 reasoning_effort.or(Some(ReasoningEffort::High))
602 } else {
603 None
604 };
605
606 let request = CompletionRequest {
607 model: model_id.id.clone(), messages: xai_messages,
609 deferred: None,
610 frequency_penalty: None,
611 logit_bias: None,
612 logprobs: None,
613 max_completion_tokens: Some(32768),
614 max_tokens: None,
615 n: None,
616 parallel_tool_calls: None,
617 presence_penalty: None,
618 reasoning_effort,
619 response_format: None,
620 search_parameters: None,
621 seed: None,
622 stop: None,
623 stream: None,
624 stream_options: None,
625 temperature: call_options
626 .as_ref()
627 .and_then(|o| o.temperature)
628 .or(Some(1.0)),
629 tool_choice: None,
630 tools: xai_tools,
631 top_logprobs: None,
632 top_p: call_options.as_ref().and_then(|o| o.top_p),
633 user: None,
634 web_search_options: None,
635 };
636
637 let response = self
638 .http_client
639 .post(&self.base_url)
640 .json(&request)
641 .send()
642 .await
643 .map_err(ApiError::Network)?;
644
645 if !response.status().is_success() {
646 let status = response.status();
647 let error_text = response.text().await.unwrap_or_else(|_| String::new());
648
649 debug!(
650 target: "grok::complete",
651 "Grok API error - Status: {}, Body: {}",
652 status,
653 error_text
654 );
655
656 return match status.as_u16() {
657 429 => Err(ApiError::RateLimited {
658 provider: self.name().to_string(),
659 details: error_text,
660 }),
661 400 => Err(ApiError::InvalidRequest {
662 provider: self.name().to_string(),
663 details: error_text,
664 }),
665 401 => Err(ApiError::AuthenticationFailed {
666 provider: self.name().to_string(),
667 details: error_text,
668 }),
669 _ => Err(ApiError::ServerError {
670 provider: self.name().to_string(),
671 status_code: status.as_u16(),
672 details: error_text,
673 }),
674 };
675 }
676
677 let response_text = tokio::select! {
678 () = token.cancelled() => {
679 debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
680 return Err(ApiError::Cancelled { provider: self.name().to_string() });
681 }
682 text_res = response.text() => {
683 text_res?
684 }
685 };
686
687 let xai_response: XAICompletionResponse =
688 serde_json::from_str(&response_text).map_err(|e| {
689 error!(
690 target: "xai::complete",
691 "Failed to parse response: {}, Body: {}",
692 e,
693 response_text
694 );
695 ApiError::ResponseParsingError {
696 provider: self.name().to_string(),
697 details: format!("Error: {e}, Body: {response_text}"),
698 }
699 })?;
700
701 if let Some(choice) = xai_response.choices.first() {
703 let mut content_blocks = Vec::new();
704
705 if let Some(reasoning) = &choice.message.reasoning_content
707 && !reasoning.trim().is_empty()
708 {
709 content_blocks.push(AssistantContent::Thought {
710 thought: crate::app::conversation::ThoughtContent::Simple {
711 text: reasoning.clone(),
712 },
713 });
714 }
715
716 if let Some(content) = &choice.message.content
718 && !content.trim().is_empty()
719 {
720 content_blocks.push(AssistantContent::Text {
721 text: content.clone(),
722 });
723 }
724
725 if let Some(tool_calls) = &choice.message.tool_calls {
727 for tool_call in tool_calls {
728 let parameters = serde_json::from_str(&tool_call.function.arguments)
730 .unwrap_or(serde_json::Value::Null);
731
732 content_blocks.push(AssistantContent::ToolCall {
733 tool_call: steer_tools::ToolCall {
734 id: tool_call.id.clone(),
735 name: tool_call.function.name.clone(),
736 parameters,
737 },
738 thought_signature: None,
739 });
740 }
741 }
742
743 Ok(crate::api::provider::CompletionResponse {
744 content: content_blocks,
745 })
746 } else {
747 Err(ApiError::NoChoices {
748 provider: self.name().to_string(),
749 })
750 }
751 }
752
753 async fn stream_complete(
754 &self,
755 model_id: &ModelId,
756 messages: Vec<AppMessage>,
757 system: Option<SystemContext>,
758 tools: Option<Vec<ToolSchema>>,
759 call_options: Option<ModelParameters>,
760 token: CancellationToken,
761 ) -> Result<CompletionStream, ApiError> {
762 let xai_messages = Self::convert_messages(messages, system)?;
763 let xai_tools = tools.map(Self::convert_tools);
764
765 let (supports_thinking, reasoning_effort) = call_options
766 .as_ref()
767 .and_then(|opts| opts.thinking_config)
768 .map_or((false, None), |tc| {
769 let effort = tc.effort.map(|e| match e {
770 crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
771 crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High,
772 crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
773 crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, });
775 (tc.enabled, effort)
776 });
777
778 let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
779 reasoning_effort.or(Some(ReasoningEffort::High))
780 } else {
781 None
782 };
783
784 let request = CompletionRequest {
785 model: model_id.id.clone(),
786 messages: xai_messages,
787 deferred: None,
788 frequency_penalty: None,
789 logit_bias: None,
790 logprobs: None,
791 max_completion_tokens: Some(32768),
792 max_tokens: None,
793 n: None,
794 parallel_tool_calls: None,
795 presence_penalty: None,
796 reasoning_effort,
797 response_format: None,
798 search_parameters: None,
799 seed: None,
800 stop: None,
801 stream: Some(true),
802 stream_options: None,
803 temperature: call_options
804 .as_ref()
805 .and_then(|o| o.temperature)
806 .or(Some(1.0)),
807 tool_choice: None,
808 tools: xai_tools,
809 top_logprobs: None,
810 top_p: call_options.as_ref().and_then(|o| o.top_p),
811 user: None,
812 web_search_options: None,
813 };
814
815 let response = self
816 .http_client
817 .post(&self.base_url)
818 .json(&request)
819 .send()
820 .await
821 .map_err(ApiError::Network)?;
822
823 if !response.status().is_success() {
824 let status = response.status();
825 let error_text = response.text().await.unwrap_or_else(|_| String::new());
826
827 debug!(
828 target: "xai::stream",
829 "xAI API error - Status: {}, Body: {}",
830 status,
831 error_text
832 );
833
834 return match status.as_u16() {
835 429 => Err(ApiError::RateLimited {
836 provider: self.name().to_string(),
837 details: error_text,
838 }),
839 400 => Err(ApiError::InvalidRequest {
840 provider: self.name().to_string(),
841 details: error_text,
842 }),
843 401 => Err(ApiError::AuthenticationFailed {
844 provider: self.name().to_string(),
845 details: error_text,
846 }),
847 _ => Err(ApiError::ServerError {
848 provider: self.name().to_string(),
849 status_code: status.as_u16(),
850 details: error_text,
851 }),
852 };
853 }
854
855 let byte_stream = response.bytes_stream();
856 let sse_stream = parse_sse_stream(byte_stream);
857
858 Ok(Box::pin(XAIClient::convert_xai_stream(sse_stream, token)))
859 }
860}
861
862impl XAIClient {
863 fn convert_xai_stream(
864 mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
865 + Unpin
866 + Send
867 + 'static,
868 token: CancellationToken,
869 ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
870 struct ToolCallAccumulator {
871 id: String,
872 name: String,
873 args: String,
874 }
875
876 async_stream::stream! {
877 let mut content: Vec<AssistantContent> = Vec::new();
878 let mut tool_call_indices: Vec<Option<usize>> = Vec::new();
879 let mut tool_calls: HashMap<usize, ToolCallAccumulator> = HashMap::new();
880 let mut tool_calls_started: std::collections::HashSet<usize> =
881 std::collections::HashSet::new();
882 let mut tool_call_positions: HashMap<usize, usize> = HashMap::new();
883 loop {
884 if token.is_cancelled() {
885 yield StreamChunk::Error(StreamError::Cancelled);
886 break;
887 }
888
889 let event_result = tokio::select! {
890 biased;
891 () = token.cancelled() => {
892 yield StreamChunk::Error(StreamError::Cancelled);
893 break;
894 }
895 event = sse_stream.next() => event
896 };
897
898 let Some(event_result) = event_result else {
899 break;
900 };
901
902 let event = match event_result {
903 Ok(e) => e,
904 Err(e) => {
905 yield StreamChunk::Error(StreamError::SseParse(e));
906 break;
907 }
908 };
909
910 if event.data == "[DONE]" {
911 let tool_calls = std::mem::take(&mut tool_calls);
912 let mut final_content = Vec::new();
913
914 for (block, tool_index) in content.into_iter().zip(tool_call_indices.into_iter())
915 {
916 if let Some(index) = tool_index {
917 let Some(tool_call) = tool_calls.get(&index) else {
918 continue;
919 };
920 if tool_call.id.is_empty() || tool_call.name.is_empty() {
921 debug!(
922 target: "xai::stream",
923 "Skipping tool call with missing id/name: id='{}' name='{}'",
924 tool_call.id,
925 tool_call.name
926 );
927 continue;
928 }
929 let parameters = serde_json::from_str(&tool_call.args)
930 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
931 final_content.push(AssistantContent::ToolCall {
932 tool_call: steer_tools::ToolCall {
933 id: tool_call.id.clone(),
934 name: tool_call.name.clone(),
935 parameters,
936 },
937 thought_signature: None,
938 });
939 } else {
940 final_content.push(block);
941 }
942 }
943
944 yield StreamChunk::MessageComplete(CompletionResponse { content: final_content });
945 break;
946 }
947
948 let chunk: XAIStreamChunk = match serde_json::from_str(&event.data) {
949 Ok(c) => c,
950 Err(e) => {
951 debug!(target: "xai::stream", "Failed to parse chunk: {} data: {}", e, event.data);
952 continue;
953 }
954 };
955
956 if let Some(choice) = chunk.choices.first() {
957 if let Some(text_delta) = &choice.delta.content {
958 if let Some(AssistantContent::Text { text }) = content.last_mut() { text.push_str(text_delta) } else {
959 content.push(AssistantContent::Text {
960 text: text_delta.clone(),
961 });
962 tool_call_indices.push(None);
963 }
964 yield StreamChunk::TextDelta(text_delta.clone());
965 }
966
967 if let Some(thinking_delta) = &choice.delta.reasoning_content {
968 if let Some(AssistantContent::Thought {
969 thought: crate::app::conversation::ThoughtContent::Simple { text },
970 }) = content.last_mut() { text.push_str(thinking_delta) } else {
971 content.push(AssistantContent::Thought {
972 thought: crate::app::conversation::ThoughtContent::Simple {
973 text: thinking_delta.clone(),
974 },
975 });
976 tool_call_indices.push(None);
977 }
978 yield StreamChunk::ThinkingDelta(thinking_delta.clone());
979 }
980
981 if let Some(tcs) = &choice.delta.tool_calls {
982 for tc in tcs {
983 let entry = tool_calls.entry(tc.index).or_insert_with(|| {
984 ToolCallAccumulator {
985 id: String::new(),
986 name: String::new(),
987 args: String::new(),
988 }
989 });
990 let mut started_now = false;
991 let mut flushed_now = false;
992
993 if let Some(id) = &tc.id
994 && !id.is_empty() {
995 entry.id.clone_from(id);
996 }
997 if let Some(func) = &tc.function
998 && let Some(name) = &func.name
999 && !name.is_empty() {
1000 entry.name.clone_from(name);
1001 }
1002
1003 if let std::collections::hash_map::Entry::Vacant(e) = tool_call_positions.entry(tc.index) {
1004 let pos = content.len();
1005 content.push(AssistantContent::ToolCall {
1006 tool_call: steer_tools::ToolCall {
1007 id: entry.id.clone(),
1008 name: entry.name.clone(),
1009 parameters: serde_json::Value::String(entry.args.clone()),
1010 },
1011 thought_signature: None,
1012 });
1013 tool_call_indices.push(Some(tc.index));
1014 e.insert(pos);
1015 }
1016
1017 if !entry.id.is_empty()
1018 && !entry.name.is_empty()
1019 && !tool_calls_started.contains(&tc.index)
1020 {
1021 tool_calls_started.insert(tc.index);
1022 started_now = true;
1023 yield StreamChunk::ToolUseStart {
1024 id: entry.id.clone(),
1025 name: entry.name.clone(),
1026 };
1027 }
1028
1029 if let Some(func) = &tc.function
1030 && let Some(args) = &func.arguments {
1031 entry.args.push_str(args);
1032 if tool_calls_started.contains(&tc.index) {
1033 if started_now {
1034 if !entry.args.is_empty() {
1035 yield StreamChunk::ToolUseInputDelta {
1036 id: entry.id.clone(),
1037 delta: entry.args.clone(),
1038 };
1039 flushed_now = true;
1040 }
1041 } else if !args.is_empty() {
1042 yield StreamChunk::ToolUseInputDelta {
1043 id: entry.id.clone(),
1044 delta: args.clone(),
1045 };
1046 }
1047 }
1048 }
1049
1050 if started_now && !flushed_now && !entry.args.is_empty() {
1051 yield StreamChunk::ToolUseInputDelta {
1052 id: entry.id.clone(),
1053 delta: entry.args.clone(),
1054 };
1055 }
1056 }
1057 }
1058 }
1059 }
1060 }
1061 }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066 use super::*;
1067 use crate::app::conversation::{ImageContent, ImageSource, Message, MessageData, UserContent};
1068
1069 #[test]
1070 fn test_convert_messages_includes_data_url_image_part() {
1071 let messages = vec![Message {
1072 data: MessageData::User {
1073 content: vec![
1074 UserContent::Text {
1075 text: "describe".to_string(),
1076 },
1077 UserContent::Image {
1078 image: ImageContent {
1079 mime_type: "image/png".to_string(),
1080 source: ImageSource::DataUrl {
1081 data_url: "".to_string(),
1082 },
1083 width: None,
1084 height: None,
1085 bytes: None,
1086 sha256: None,
1087 },
1088 },
1089 ],
1090 },
1091 timestamp: 1,
1092 id: "msg-1".to_string(),
1093 parent_message_id: None,
1094 }];
1095
1096 let converted = XAIClient::convert_messages(messages, None).expect("convert messages");
1097 assert_eq!(converted.len(), 1);
1098
1099 match &converted[0] {
1100 XAIMessage::User { content, .. } => match content {
1101 XAIUserContent::Parts(parts) => {
1102 assert_eq!(parts.len(), 2);
1103 assert!(matches!(
1104 &parts[0],
1105 XAIUserContentPart::Text { text } if text == "describe"
1106 ));
1107 assert!(matches!(
1108 &parts[1],
1109 XAIUserContentPart::ImageUrl { image_url }
1110 if image_url.url == ""
1111 ));
1112 }
1113 other => panic!("Expected parts content, got {other:?}"),
1114 },
1115 other => panic!("Expected user message, got {other:?}"),
1116 }
1117 }
1118
1119 #[test]
1120 fn test_convert_messages_rejects_session_file_image_source() {
1121 let messages = vec![Message {
1122 data: MessageData::User {
1123 content: vec![UserContent::Image {
1124 image: ImageContent {
1125 mime_type: "image/png".to_string(),
1126 source: ImageSource::SessionFile {
1127 relative_path: "session-1/image.png".to_string(),
1128 },
1129 width: None,
1130 height: None,
1131 bytes: None,
1132 sha256: None,
1133 },
1134 }],
1135 },
1136 timestamp: 1,
1137 id: "msg-1".to_string(),
1138 parent_message_id: None,
1139 }];
1140
1141 let err =
1142 XAIClient::convert_messages(messages, None).expect_err("expected unsupported feature");
1143 match err {
1144 ApiError::UnsupportedFeature {
1145 provider,
1146 feature,
1147 details,
1148 } => {
1149 assert_eq!(provider, "xai");
1150 assert_eq!(feature, "image input source");
1151 assert!(details.contains("session file"));
1152 }
1153 other => panic!("Expected UnsupportedFeature, got {other:?}"),
1154 }
1155 }
1156}