1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{self, header};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{CompletionResponse, CompletionStream, Provider, StreamChunk};
11use crate::api::sse::parse_sse_stream;
12use crate::api::util::normalize_chat_url;
13use crate::app::SystemContext;
14use crate::app::conversation::{AssistantContent, Message as AppMessage, ToolResult, UserContent};
15use crate::config::model::{ModelId, ModelParameters};
16use steer_tools::ToolSchema;
17
18const DEFAULT_API_URL: &str = "https://api.x.ai/v1/chat/completions";
19
20#[derive(Clone)]
21pub struct XAIClient {
22 http_client: reqwest::Client,
23 base_url: String,
24}
25
26#[derive(Debug, Serialize, Deserialize)]
28#[serde(tag = "role", rename_all = "lowercase")]
29enum XAIMessage {
30 System {
31 content: String,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 name: Option<String>,
34 },
35 User {
36 content: String,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 name: Option<String>,
39 },
40 Assistant {
41 #[serde(skip_serializing_if = "Option::is_none")]
42 content: Option<String>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 tool_calls: Option<Vec<XAIToolCall>>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 name: Option<String>,
47 },
48 Tool {
49 content: String,
50 tool_call_id: String,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 name: Option<String>,
53 },
54}
55
56#[derive(Debug, Serialize, Deserialize)]
58struct XAIFunction {
59 name: String,
60 description: String,
61 parameters: serde_json::Value,
62}
63
64#[derive(Debug, Serialize, Deserialize)]
66struct XAITool {
67 #[serde(rename = "type")]
68 tool_type: String, function: XAIFunction,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
74struct XAIToolCall {
75 id: String,
76 #[serde(rename = "type")]
77 tool_type: String,
78 function: XAIFunctionCall,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82struct XAIFunctionCall {
83 name: String,
84 arguments: String, }
86
87#[derive(Debug, Serialize, Deserialize)]
88#[serde(rename_all = "lowercase")]
89enum ReasoningEffort {
90 Low,
91 High,
92}
93
94#[derive(Debug, Serialize, Deserialize)]
95struct StreamOptions {
96 #[serde(skip_serializing_if = "Option::is_none")]
97 include_usage: Option<bool>,
98}
99
100#[derive(Debug, Serialize, Deserialize)]
101#[serde(untagged)]
102enum ToolChoice {
103 String(String), Specific {
105 #[serde(rename = "type")]
106 tool_type: String,
107 function: ToolChoiceFunction,
108 },
109}
110
111#[derive(Debug, Serialize, Deserialize)]
112struct ToolChoiceFunction {
113 name: String,
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117struct ResponseFormat {
118 #[serde(rename = "type")]
119 format_type: String,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 json_schema: Option<serde_json::Value>,
122}
123
124#[derive(Debug, Serialize, Deserialize)]
125struct SearchParameters {
126 #[serde(skip_serializing_if = "Option::is_none")]
127 from_date: Option<String>,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 to_date: Option<String>,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 max_search_results: Option<u32>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 mode: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 return_citations: Option<bool>,
136 #[serde(skip_serializing_if = "Option::is_none")]
137 sources: Option<Vec<String>>,
138}
139
140#[derive(Debug, Serialize, Deserialize)]
141struct WebSearchOptions {
142 #[serde(skip_serializing_if = "Option::is_none")]
143 search_context_size: Option<u32>,
144 #[serde(skip_serializing_if = "Option::is_none")]
145 user_location: Option<String>,
146}
147
148#[derive(Debug, Serialize, Deserialize)]
149struct CompletionRequest {
150 model: String,
151 messages: Vec<XAIMessage>,
152 #[serde(skip_serializing_if = "Option::is_none")]
153 deferred: Option<bool>,
154 #[serde(skip_serializing_if = "Option::is_none")]
155 frequency_penalty: Option<f32>,
156 #[serde(skip_serializing_if = "Option::is_none")]
157 logit_bias: Option<HashMap<String, f32>>,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 logprobs: Option<bool>,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 max_completion_tokens: Option<u32>,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 max_tokens: Option<u32>,
164 #[serde(skip_serializing_if = "Option::is_none")]
165 n: Option<u32>,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 parallel_tool_calls: Option<bool>,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 presence_penalty: Option<f32>,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 reasoning_effort: Option<ReasoningEffort>,
172 #[serde(skip_serializing_if = "Option::is_none")]
173 response_format: Option<ResponseFormat>,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 search_parameters: Option<SearchParameters>,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 seed: Option<u64>,
178 #[serde(skip_serializing_if = "Option::is_none")]
179 stop: Option<Vec<String>>,
180 #[serde(skip_serializing_if = "Option::is_none")]
181 stream: Option<bool>,
182 #[serde(skip_serializing_if = "Option::is_none")]
183 stream_options: Option<StreamOptions>,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 temperature: Option<f32>,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 tool_choice: Option<ToolChoice>,
188 #[serde(skip_serializing_if = "Option::is_none")]
189 tools: Option<Vec<XAITool>>,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 top_logprobs: Option<u32>,
192 #[serde(skip_serializing_if = "Option::is_none")]
193 top_p: Option<f32>,
194 #[serde(skip_serializing_if = "Option::is_none")]
195 user: Option<String>,
196 #[serde(skip_serializing_if = "Option::is_none")]
197 web_search_options: Option<WebSearchOptions>,
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201struct XAICompletionResponse {
202 id: String,
203 object: String,
204 created: u64,
205 model: String,
206 choices: Vec<Choice>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 usage: Option<XAIUsage>,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 system_fingerprint: Option<String>,
211 #[serde(skip_serializing_if = "Option::is_none")]
212 citations: Option<Vec<serde_json::Value>>,
213 #[serde(skip_serializing_if = "Option::is_none")]
214 debug_output: Option<DebugOutput>,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218struct Choice {
219 index: i32,
220 message: AssistantMessage,
221 finish_reason: Option<String>,
222}
223
224#[derive(Debug, Serialize, Deserialize)]
225struct AssistantMessage {
226 content: Option<String>,
227 #[serde(skip_serializing_if = "Option::is_none")]
228 tool_calls: Option<Vec<XAIToolCall>>,
229 #[serde(skip_serializing_if = "Option::is_none")]
230 reasoning_content: Option<String>,
231}
232
233#[derive(Debug, Serialize, Deserialize)]
234struct PromptTokensDetails {
235 #[serde(rename = "cached_tokens")]
236 cached: usize,
237 #[serde(rename = "audio_tokens")]
238 audio: usize,
239 #[serde(rename = "image_tokens")]
240 image: usize,
241 #[serde(rename = "text_tokens")]
242 text: usize,
243}
244
245#[derive(Debug, Serialize, Deserialize)]
246struct CompletionTokensDetails {
247 #[serde(rename = "reasoning_tokens")]
248 reasoning: usize,
249 #[serde(rename = "audio_tokens")]
250 audio: usize,
251 #[serde(rename = "accepted_prediction_tokens")]
252 accepted_prediction: usize,
253 #[serde(rename = "rejected_prediction_tokens")]
254 rejected_prediction: usize,
255}
256
257#[derive(Debug, Serialize, Deserialize)]
258struct XAIUsage {
259 prompt_tokens: usize,
260 completion_tokens: usize,
261 total_tokens: usize,
262 #[serde(skip_serializing_if = "Option::is_none")]
263 num_sources_used: Option<usize>,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 prompt_tokens_details: Option<PromptTokensDetails>,
266 #[serde(skip_serializing_if = "Option::is_none")]
267 completion_tokens_details: Option<CompletionTokensDetails>,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271struct DebugOutput {
272 attempts: usize,
273 cache_read_count: usize,
274 cache_read_input_bytes: usize,
275 cache_write_count: usize,
276 cache_write_input_bytes: usize,
277 prompt: String,
278 request: String,
279 responses: Vec<String>,
280}
281
282#[derive(Debug, Deserialize)]
283struct XAIStreamChunk {
284 #[expect(dead_code)]
285 id: String,
286 choices: Vec<XAIStreamChoice>,
287}
288
289#[derive(Debug, Deserialize)]
290struct XAIStreamChoice {
291 #[expect(dead_code)]
292 index: u32,
293 delta: XAIStreamDelta,
294 #[expect(dead_code)]
295 finish_reason: Option<String>,
296}
297
298#[derive(Debug, Deserialize)]
299struct XAIStreamDelta {
300 #[serde(skip_serializing_if = "Option::is_none")]
301 content: Option<String>,
302 #[serde(skip_serializing_if = "Option::is_none")]
303 tool_calls: Option<Vec<XAIStreamToolCall>>,
304 #[serde(skip_serializing_if = "Option::is_none")]
305 reasoning_content: Option<String>,
306}
307
308#[derive(Debug, Deserialize)]
309struct XAIStreamToolCall {
310 index: usize,
311 #[serde(skip_serializing_if = "Option::is_none")]
312 id: Option<String>,
313 #[serde(skip_serializing_if = "Option::is_none")]
314 function: Option<XAIStreamFunction>,
315}
316
317#[derive(Debug, Deserialize)]
318struct XAIStreamFunction {
319 #[serde(skip_serializing_if = "Option::is_none")]
320 name: Option<String>,
321 #[serde(skip_serializing_if = "Option::is_none")]
322 arguments: Option<String>,
323}
324
325impl XAIClient {
326 pub fn new(api_key: String) -> Result<Self, ApiError> {
327 Self::with_base_url(api_key, None)
328 }
329
330 pub fn with_base_url(api_key: String, base_url: Option<String>) -> Result<Self, ApiError> {
331 let mut headers = header::HeaderMap::new();
332 headers.insert(
333 header::AUTHORIZATION,
334 header::HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|e| {
335 ApiError::AuthenticationFailed {
336 provider: "xai".to_string(),
337 details: format!("Invalid API key: {e}"),
338 }
339 })?,
340 );
341
342 let client = reqwest::Client::builder()
343 .default_headers(headers)
344 .timeout(std::time::Duration::from_secs(300)) .build()
346 .map_err(ApiError::Network)?;
347
348 let base_url = normalize_chat_url(base_url.as_deref(), DEFAULT_API_URL);
349
350 Ok(Self {
351 http_client: client,
352 base_url,
353 })
354 }
355
356 fn convert_messages(
357 messages: Vec<AppMessage>,
358 system: Option<SystemContext>,
359 ) -> Vec<XAIMessage> {
360 let mut xai_messages = Vec::new();
361
362 if let Some(system_content) = system.and_then(|context| context.render()) {
364 xai_messages.push(XAIMessage::System {
365 content: system_content,
366 name: None,
367 });
368 }
369
370 for message in messages {
372 match &message.data {
373 crate::app::conversation::MessageData::User { content, .. } => {
374 let combined_text = content
376 .iter()
377 .map(|user_content| match user_content {
378 UserContent::Text { text } => text.clone(),
379 UserContent::CommandExecution {
380 command,
381 stdout,
382 stderr,
383 exit_code,
384 } => UserContent::format_command_execution_as_xml(
385 command, stdout, stderr, *exit_code,
386 ),
387 })
388 .collect::<Vec<_>>()
389 .join("\n");
390
391 if !combined_text.trim().is_empty() {
393 xai_messages.push(XAIMessage::User {
394 content: combined_text,
395 name: None,
396 });
397 }
398 }
399 crate::app::conversation::MessageData::Assistant { content, .. } => {
400 let mut text_parts = Vec::new();
402 let mut tool_calls = Vec::new();
403
404 for content_block in content {
405 match content_block {
406 AssistantContent::Text { text } => {
407 text_parts.push(text.clone());
408 }
409 AssistantContent::ToolCall { tool_call, .. } => {
410 tool_calls.push(XAIToolCall {
411 id: tool_call.id.clone(),
412 tool_type: "function".to_string(),
413 function: XAIFunctionCall {
414 name: tool_call.name.clone(),
415 arguments: tool_call.parameters.to_string(),
416 },
417 });
418 }
419 AssistantContent::Thought { .. } => {
420 }
422 }
423 }
424
425 let content = if text_parts.is_empty() {
427 None
428 } else {
429 Some(text_parts.join("\n"))
430 };
431
432 let tool_calls_opt = if tool_calls.is_empty() {
433 None
434 } else {
435 Some(tool_calls)
436 };
437
438 xai_messages.push(XAIMessage::Assistant {
439 content,
440 tool_calls: tool_calls_opt,
441 name: None,
442 });
443 }
444 crate::app::conversation::MessageData::Tool {
445 tool_use_id,
446 result,
447 ..
448 } => {
449 let content_text = if let ToolResult::Error(e) = result {
451 format!("Error: {e}")
452 } else {
453 let text = result.llm_format();
454 if text.trim().is_empty() {
455 "(No output)".to_string()
456 } else {
457 text
458 }
459 };
460
461 xai_messages.push(XAIMessage::Tool {
462 content: content_text,
463 tool_call_id: tool_use_id.clone(),
464 name: None,
465 });
466 }
467 }
468 }
469
470 xai_messages
471 }
472
473 fn convert_tools(tools: Vec<ToolSchema>) -> Vec<XAITool> {
474 tools
475 .into_iter()
476 .map(|tool| XAITool {
477 tool_type: "function".to_string(),
478 function: XAIFunction {
479 name: tool.name,
480 description: tool.description,
481 parameters: tool.input_schema.as_value().clone(),
482 },
483 })
484 .collect()
485 }
486}
487
488#[async_trait]
489impl Provider for XAIClient {
490 fn name(&self) -> &'static str {
491 "xai"
492 }
493
494 async fn complete(
495 &self,
496 model_id: &ModelId,
497 messages: Vec<AppMessage>,
498 system: Option<SystemContext>,
499 tools: Option<Vec<ToolSchema>>,
500 call_options: Option<ModelParameters>,
501 token: CancellationToken,
502 ) -> Result<CompletionResponse, ApiError> {
503 let xai_messages = Self::convert_messages(messages, system);
504 let xai_tools = tools.map(Self::convert_tools);
505
506 let (supports_thinking, reasoning_effort) = call_options
508 .as_ref()
509 .and_then(|opts| opts.thinking_config)
510 .map_or((false, None), |tc| {
511 let effort = tc.effort.map(|e| match e {
512 crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
513 crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High, crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
515 crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, });
517 (tc.enabled, effort)
518 });
519
520 let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
522 reasoning_effort.or(Some(ReasoningEffort::High))
523 } else {
524 None
525 };
526
527 let request = CompletionRequest {
528 model: model_id.id.clone(), messages: xai_messages,
530 deferred: None,
531 frequency_penalty: None,
532 logit_bias: None,
533 logprobs: None,
534 max_completion_tokens: Some(32768),
535 max_tokens: None,
536 n: None,
537 parallel_tool_calls: None,
538 presence_penalty: None,
539 reasoning_effort,
540 response_format: None,
541 search_parameters: None,
542 seed: None,
543 stop: None,
544 stream: None,
545 stream_options: None,
546 temperature: call_options
547 .as_ref()
548 .and_then(|o| o.temperature)
549 .or(Some(1.0)),
550 tool_choice: None,
551 tools: xai_tools,
552 top_logprobs: None,
553 top_p: call_options.as_ref().and_then(|o| o.top_p),
554 user: None,
555 web_search_options: None,
556 };
557
558 let response = self
559 .http_client
560 .post(&self.base_url)
561 .json(&request)
562 .send()
563 .await
564 .map_err(ApiError::Network)?;
565
566 if !response.status().is_success() {
567 let status = response.status();
568 let error_text = response.text().await.unwrap_or_else(|_| String::new());
569
570 debug!(
571 target: "grok::complete",
572 "Grok API error - Status: {}, Body: {}",
573 status,
574 error_text
575 );
576
577 return match status.as_u16() {
578 429 => Err(ApiError::RateLimited {
579 provider: self.name().to_string(),
580 details: error_text,
581 }),
582 400 => Err(ApiError::InvalidRequest {
583 provider: self.name().to_string(),
584 details: error_text,
585 }),
586 401 => Err(ApiError::AuthenticationFailed {
587 provider: self.name().to_string(),
588 details: error_text,
589 }),
590 _ => Err(ApiError::ServerError {
591 provider: self.name().to_string(),
592 status_code: status.as_u16(),
593 details: error_text,
594 }),
595 };
596 }
597
598 let response_text = tokio::select! {
599 () = token.cancelled() => {
600 debug!(target: "grok::complete", "Cancellation token triggered while reading successful response body.");
601 return Err(ApiError::Cancelled { provider: self.name().to_string() });
602 }
603 text_res = response.text() => {
604 text_res?
605 }
606 };
607
608 let xai_response: XAICompletionResponse =
609 serde_json::from_str(&response_text).map_err(|e| {
610 error!(
611 target: "xai::complete",
612 "Failed to parse response: {}, Body: {}",
613 e,
614 response_text
615 );
616 ApiError::ResponseParsingError {
617 provider: self.name().to_string(),
618 details: format!("Error: {e}, Body: {response_text}"),
619 }
620 })?;
621
622 if let Some(choice) = xai_response.choices.first() {
624 let mut content_blocks = Vec::new();
625
626 if let Some(reasoning) = &choice.message.reasoning_content
628 && !reasoning.trim().is_empty()
629 {
630 content_blocks.push(AssistantContent::Thought {
631 thought: crate::app::conversation::ThoughtContent::Simple {
632 text: reasoning.clone(),
633 },
634 });
635 }
636
637 if let Some(content) = &choice.message.content
639 && !content.trim().is_empty()
640 {
641 content_blocks.push(AssistantContent::Text {
642 text: content.clone(),
643 });
644 }
645
646 if let Some(tool_calls) = &choice.message.tool_calls {
648 for tool_call in tool_calls {
649 let parameters = serde_json::from_str(&tool_call.function.arguments)
651 .unwrap_or(serde_json::Value::Null);
652
653 content_blocks.push(AssistantContent::ToolCall {
654 tool_call: steer_tools::ToolCall {
655 id: tool_call.id.clone(),
656 name: tool_call.function.name.clone(),
657 parameters,
658 },
659 thought_signature: None,
660 });
661 }
662 }
663
664 Ok(crate::api::provider::CompletionResponse {
665 content: content_blocks,
666 })
667 } else {
668 Err(ApiError::NoChoices {
669 provider: self.name().to_string(),
670 })
671 }
672 }
673
674 async fn stream_complete(
675 &self,
676 model_id: &ModelId,
677 messages: Vec<AppMessage>,
678 system: Option<SystemContext>,
679 tools: Option<Vec<ToolSchema>>,
680 call_options: Option<ModelParameters>,
681 token: CancellationToken,
682 ) -> Result<CompletionStream, ApiError> {
683 let xai_messages = Self::convert_messages(messages, system);
684 let xai_tools = tools.map(Self::convert_tools);
685
686 let (supports_thinking, reasoning_effort) = call_options
687 .as_ref()
688 .and_then(|opts| opts.thinking_config)
689 .map_or((false, None), |tc| {
690 let effort = tc.effort.map(|e| match e {
691 crate::config::toml_types::ThinkingEffort::Low => ReasoningEffort::Low,
692 crate::config::toml_types::ThinkingEffort::Medium => ReasoningEffort::High,
693 crate::config::toml_types::ThinkingEffort::High => ReasoningEffort::High,
694 crate::config::toml_types::ThinkingEffort::XHigh => ReasoningEffort::High, });
696 (tc.enabled, effort)
697 });
698
699 let reasoning_effort = if supports_thinking && model_id.id != "grok-4-0709" {
700 reasoning_effort.or(Some(ReasoningEffort::High))
701 } else {
702 None
703 };
704
705 let request = CompletionRequest {
706 model: model_id.id.clone(),
707 messages: xai_messages,
708 deferred: None,
709 frequency_penalty: None,
710 logit_bias: None,
711 logprobs: None,
712 max_completion_tokens: Some(32768),
713 max_tokens: None,
714 n: None,
715 parallel_tool_calls: None,
716 presence_penalty: None,
717 reasoning_effort,
718 response_format: None,
719 search_parameters: None,
720 seed: None,
721 stop: None,
722 stream: Some(true),
723 stream_options: None,
724 temperature: call_options
725 .as_ref()
726 .and_then(|o| o.temperature)
727 .or(Some(1.0)),
728 tool_choice: None,
729 tools: xai_tools,
730 top_logprobs: None,
731 top_p: call_options.as_ref().and_then(|o| o.top_p),
732 user: None,
733 web_search_options: None,
734 };
735
736 let response = self
737 .http_client
738 .post(&self.base_url)
739 .json(&request)
740 .send()
741 .await
742 .map_err(ApiError::Network)?;
743
744 if !response.status().is_success() {
745 let status = response.status();
746 let error_text = response.text().await.unwrap_or_else(|_| String::new());
747
748 debug!(
749 target: "xai::stream",
750 "xAI API error - Status: {}, Body: {}",
751 status,
752 error_text
753 );
754
755 return match status.as_u16() {
756 429 => Err(ApiError::RateLimited {
757 provider: self.name().to_string(),
758 details: error_text,
759 }),
760 400 => Err(ApiError::InvalidRequest {
761 provider: self.name().to_string(),
762 details: error_text,
763 }),
764 401 => Err(ApiError::AuthenticationFailed {
765 provider: self.name().to_string(),
766 details: error_text,
767 }),
768 _ => Err(ApiError::ServerError {
769 provider: self.name().to_string(),
770 status_code: status.as_u16(),
771 details: error_text,
772 }),
773 };
774 }
775
776 let byte_stream = response.bytes_stream();
777 let sse_stream = parse_sse_stream(byte_stream);
778
779 Ok(Box::pin(XAIClient::convert_xai_stream(sse_stream, token)))
780 }
781}
782
783impl XAIClient {
784 fn convert_xai_stream(
785 mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
786 + Unpin
787 + Send
788 + 'static,
789 token: CancellationToken,
790 ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
791 struct ToolCallAccumulator {
792 id: String,
793 name: String,
794 args: String,
795 }
796
797 async_stream::stream! {
798 let mut content: Vec<AssistantContent> = Vec::new();
799 let mut tool_call_indices: Vec<Option<usize>> = Vec::new();
800 let mut tool_calls: HashMap<usize, ToolCallAccumulator> = HashMap::new();
801 let mut tool_calls_started: std::collections::HashSet<usize> =
802 std::collections::HashSet::new();
803 let mut tool_call_positions: HashMap<usize, usize> = HashMap::new();
804 loop {
805 if token.is_cancelled() {
806 yield StreamChunk::Error(StreamError::Cancelled);
807 break;
808 }
809
810 let event_result = tokio::select! {
811 biased;
812 () = token.cancelled() => {
813 yield StreamChunk::Error(StreamError::Cancelled);
814 break;
815 }
816 event = sse_stream.next() => event
817 };
818
819 let Some(event_result) = event_result else {
820 break;
821 };
822
823 let event = match event_result {
824 Ok(e) => e,
825 Err(e) => {
826 yield StreamChunk::Error(StreamError::SseParse(e));
827 break;
828 }
829 };
830
831 if event.data == "[DONE]" {
832 let tool_calls = std::mem::take(&mut tool_calls);
833 let mut final_content = Vec::new();
834
835 for (block, tool_index) in content.into_iter().zip(tool_call_indices.into_iter())
836 {
837 if let Some(index) = tool_index {
838 let Some(tool_call) = tool_calls.get(&index) else {
839 continue;
840 };
841 if tool_call.id.is_empty() || tool_call.name.is_empty() {
842 debug!(
843 target: "xai::stream",
844 "Skipping tool call with missing id/name: id='{}' name='{}'",
845 tool_call.id,
846 tool_call.name
847 );
848 continue;
849 }
850 let parameters = serde_json::from_str(&tool_call.args)
851 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
852 final_content.push(AssistantContent::ToolCall {
853 tool_call: steer_tools::ToolCall {
854 id: tool_call.id.clone(),
855 name: tool_call.name.clone(),
856 parameters,
857 },
858 thought_signature: None,
859 });
860 } else {
861 final_content.push(block);
862 }
863 }
864
865 yield StreamChunk::MessageComplete(CompletionResponse { content: final_content });
866 break;
867 }
868
869 let chunk: XAIStreamChunk = match serde_json::from_str(&event.data) {
870 Ok(c) => c,
871 Err(e) => {
872 debug!(target: "xai::stream", "Failed to parse chunk: {} data: {}", e, event.data);
873 continue;
874 }
875 };
876
877 if let Some(choice) = chunk.choices.first() {
878 if let Some(text_delta) = &choice.delta.content {
879 if let Some(AssistantContent::Text { text }) = content.last_mut() { text.push_str(text_delta) } else {
880 content.push(AssistantContent::Text {
881 text: text_delta.clone(),
882 });
883 tool_call_indices.push(None);
884 }
885 yield StreamChunk::TextDelta(text_delta.clone());
886 }
887
888 if let Some(thinking_delta) = &choice.delta.reasoning_content {
889 if let Some(AssistantContent::Thought {
890 thought: crate::app::conversation::ThoughtContent::Simple { text },
891 }) = content.last_mut() { text.push_str(thinking_delta) } else {
892 content.push(AssistantContent::Thought {
893 thought: crate::app::conversation::ThoughtContent::Simple {
894 text: thinking_delta.clone(),
895 },
896 });
897 tool_call_indices.push(None);
898 }
899 yield StreamChunk::ThinkingDelta(thinking_delta.clone());
900 }
901
902 if let Some(tcs) = &choice.delta.tool_calls {
903 for tc in tcs {
904 let entry = tool_calls.entry(tc.index).or_insert_with(|| {
905 ToolCallAccumulator {
906 id: String::new(),
907 name: String::new(),
908 args: String::new(),
909 }
910 });
911 let mut started_now = false;
912 let mut flushed_now = false;
913
914 if let Some(id) = &tc.id
915 && !id.is_empty() {
916 entry.id.clone_from(id);
917 }
918 if let Some(func) = &tc.function
919 && let Some(name) = &func.name
920 && !name.is_empty() {
921 entry.name.clone_from(name);
922 }
923
924 if let std::collections::hash_map::Entry::Vacant(e) = tool_call_positions.entry(tc.index) {
925 let pos = content.len();
926 content.push(AssistantContent::ToolCall {
927 tool_call: steer_tools::ToolCall {
928 id: entry.id.clone(),
929 name: entry.name.clone(),
930 parameters: serde_json::Value::String(entry.args.clone()),
931 },
932 thought_signature: None,
933 });
934 tool_call_indices.push(Some(tc.index));
935 e.insert(pos);
936 }
937
938 if !entry.id.is_empty()
939 && !entry.name.is_empty()
940 && !tool_calls_started.contains(&tc.index)
941 {
942 tool_calls_started.insert(tc.index);
943 started_now = true;
944 yield StreamChunk::ToolUseStart {
945 id: entry.id.clone(),
946 name: entry.name.clone(),
947 };
948 }
949
950 if let Some(func) = &tc.function
951 && let Some(args) = &func.arguments {
952 entry.args.push_str(args);
953 if tool_calls_started.contains(&tc.index) {
954 if started_now {
955 if !entry.args.is_empty() {
956 yield StreamChunk::ToolUseInputDelta {
957 id: entry.id.clone(),
958 delta: entry.args.clone(),
959 };
960 flushed_now = true;
961 }
962 } else if !args.is_empty() {
963 yield StreamChunk::ToolUseInputDelta {
964 id: entry.id.clone(),
965 delta: args.clone(),
966 };
967 }
968 }
969 }
970
971 if started_now && !flushed_now && !entry.args.is_empty() {
972 yield StreamChunk::ToolUseInputDelta {
973 id: entry.id.clone(),
974 delta: entry.args.clone(),
975 };
976 }
977 }
978 }
979 }
980 }
981 }
982 }
983}