1use crate::models::error::{AgentError, BadRequestErrorMessage};
2use crate::models::llm::{
3 GenerationDelta, GenerationDeltaToolUse, LLMChoice, LLMCompletionResponse, LLMMessage,
4 LLMMessageContent, LLMMessageTypedContent, LLMTokenUsage, LLMTool,
5};
6use futures_util::StreamExt;
7use reqwest_middleware::ClientBuilder;
8use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
9use serde::{Deserialize, Serialize};
10use uuid::Uuid;
11
12const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
13
14#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
15pub struct GeminiConfig {
16 pub api_endpoint: Option<String>,
17 pub api_key: Option<String>,
18}
19
20#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
21pub enum GeminiModel {
22 #[default]
23 #[serde(rename = "gemini-3-pro-preview")]
24 Gemini3Pro,
25 #[serde(rename = "gemini-2.5-pro")]
26 Gemini25Pro,
27 #[serde(rename = "gemini-2.5-flash")]
28 Gemini25Flash,
29 #[serde(rename = "gemini-2.5-flash-lite")]
30 Gemini25FlashLite,
31 #[serde(rename = "gemini-2.0-flash")]
32 Gemini20Flash,
33 #[serde(rename = "gemini-2.0-flash-lite")]
34 Gemini20FlashLite,
35}
36
37impl std::fmt::Display for GeminiModel {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 GeminiModel::Gemini3Pro => write!(f, "gemini-3-pro-preview"),
41 GeminiModel::Gemini25Pro => write!(f, "gemini-2.5-pro"),
42 GeminiModel::Gemini25Flash => write!(f, "gemini-2.5-flash"),
43 GeminiModel::Gemini25FlashLite => write!(f, "gemini-2.5-flash-lite"),
44 GeminiModel::Gemini20Flash => write!(f, "gemini-2.0-flash"),
45 GeminiModel::Gemini20FlashLite => write!(f, "gemini-2.0-flash-lite"),
46 }
47 }
48}
49
50impl GeminiModel {
51 pub fn from_string(s: &str) -> Result<Self, String> {
52 serde_json::from_value(serde_json::Value::String(s.to_string()))
53 .map_err(|_| "Failed to deserialize Gemini model".to_string())
54 }
55
56 pub const DEFAULT_SMART_MODEL: GeminiModel = GeminiModel::Gemini3Pro;
58
59 pub const DEFAULT_ECO_MODEL: GeminiModel = GeminiModel::Gemini25Flash;
61
62 pub const DEFAULT_RECOVERY_MODEL: GeminiModel = GeminiModel::Gemini25Flash;
64
65 pub fn default_smart_model() -> String {
67 Self::DEFAULT_SMART_MODEL.to_string()
68 }
69
70 pub fn default_eco_model() -> String {
72 Self::DEFAULT_ECO_MODEL.to_string()
73 }
74
75 pub fn default_recovery_model() -> String {
77 Self::DEFAULT_RECOVERY_MODEL.to_string()
78 }
79}
80
81#[derive(Serialize, Deserialize, Debug)]
82pub struct GeminiInput {
83 pub model: GeminiModel,
84 pub messages: Vec<LLMMessage>,
85 pub max_tokens: u32,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 pub tools: Option<Vec<LLMTool>>,
88}
89
90#[derive(Serialize, Deserialize, Debug)]
91#[serde(rename_all = "camelCase")]
92pub struct GeminiRequest {
93 pub contents: Vec<GeminiContent>,
94
95 #[serde(skip_serializing_if = "Option::is_none")]
96 pub tools: Option<Vec<GeminiTool>>,
97
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub system_instruction: Option<GeminiSystemInstruction>, #[serde(skip_serializing_if = "Option::is_none")]
102 pub generation_config: Option<GeminiGenerationConfig>, }
104
105#[derive(Serialize, Deserialize, Debug, Clone)]
106pub enum GeminiRole {
107 #[serde(rename = "user")]
108 User,
109 #[serde(rename = "model")]
110 Model,
111}
112
113impl std::fmt::Display for GeminiRole {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 match self {
116 GeminiRole::User => write!(f, "user"),
117 GeminiRole::Model => write!(f, "model"),
118 }
119 }
120}
121
122impl GeminiRole {
123 pub fn from_string(s: &str) -> Result<Self, String> {
124 serde_json::from_value(serde_json::Value::String(s.to_string()))
125 .map_err(|_| "Failed to deserialize Gemini role".to_string())
126 }
127}
128
129#[derive(Serialize, Deserialize, Debug, Clone)]
130pub struct GeminiContent {
131 pub role: GeminiRole,
132 #[serde(default)]
133 pub parts: Vec<GeminiPart>,
134}
135
136#[derive(Serialize, Deserialize, Debug, Clone)]
137#[serde(untagged)]
138pub enum GeminiPart {
139 Text {
140 text: String,
141 },
142 FunctionCall {
143 #[serde(rename = "functionCall")]
144 function_call: GeminiFunctionCall,
145 },
146 FunctionResponse {
147 #[serde(rename = "functionResponse")]
148 function_response: GeminiFunctionResponse,
149 },
150 InlineData {
151 #[serde(rename = "inlineData")]
152 inline_data: GeminiInlineData,
153 },
154}
155
156#[derive(Serialize, Deserialize, Debug, Clone)]
157pub struct GeminiFunctionCall {
158 #[serde(default)]
159 pub id: Option<String>,
160 pub name: String,
161 pub args: serde_json::Value,
162}
163
164#[derive(Serialize, Deserialize, Debug, Clone)]
165pub struct GeminiFunctionResponse {
166 pub id: String,
167 pub name: String,
168 pub response: serde_json::Value,
169}
170
171#[derive(Serialize, Deserialize, Debug, Clone)]
172pub struct GeminiInlineData {
173 pub mime_type: String,
174 pub data: String,
175}
176
177#[derive(Serialize, Deserialize, Debug)]
178pub struct GeminiSystemInstruction {
179 pub parts: Vec<GeminiPart>, }
181
182#[derive(Serialize, Deserialize, Debug)]
183pub struct GeminiTool {
184 pub function_declarations: Vec<GeminiFunctionDeclaration>,
185}
186
187#[derive(Serialize, Deserialize, Debug)]
188pub struct GeminiFunctionDeclaration {
189 pub name: String,
190 pub description: String,
191 pub parameters_json_schema: Option<serde_json::Value>,
192}
193
194#[derive(Serialize, Deserialize, Debug)]
195pub struct GeminiGenerationConfig {
196 pub max_output_tokens: Option<u32>,
197 pub temperature: Option<f32>,
198 pub candidate_count: Option<u32>,
199}
200
201#[derive(Serialize, Deserialize, Debug, Clone)]
204pub struct GeminiResponse {
205 pub candidates: Option<Vec<GeminiCandidate>>,
206 pub usage_metadata: Option<GeminiUsageMetadata>,
207 pub model_version: Option<String>,
208 pub response_id: Option<String>,
209}
210
211#[derive(Serialize, Deserialize, Debug, Clone)]
212pub struct GeminiCandidate {
213 pub content: Option<GeminiContent>,
214 pub finish_reason: Option<String>,
215 pub index: Option<u32>,
216}
217
218#[derive(Serialize, Deserialize, Debug, Clone)]
219pub struct GeminiUsageMetadata {
220 pub prompt_token_count: Option<u32>,
221 pub cached_content_token_count: Option<u32>,
222 pub candidates_token_count: Option<u32>,
223 pub tool_use_prompt_token_count: Option<u32>,
224 pub thoughts_token_count: Option<u32>,
225 pub total_token_count: Option<u32>,
226}
227
228impl From<LLMMessage> for GeminiContent {
229 fn from(message: LLMMessage) -> Self {
230 let role = match message.role.as_str() {
231 "assistant" | "model" => GeminiRole::Model,
232 "user" | "tool" => GeminiRole::User,
233 _ => GeminiRole::User,
234 };
235
236 let parts = match message.content {
237 LLMMessageContent::String(text) => vec![GeminiPart::Text { text }],
238 LLMMessageContent::List(items) => items
239 .into_iter()
240 .map(|item| match item {
241 LLMMessageTypedContent::Text { text } => GeminiPart::Text { text },
242
243 LLMMessageTypedContent::ToolCall { id, name, args } => {
244 GeminiPart::FunctionCall {
245 function_call: GeminiFunctionCall {
246 id: Some(id),
247 name,
248 args,
249 },
250 }
251 }
252
253 LLMMessageTypedContent::ToolResult { content, .. } => {
254 GeminiPart::Text { text: content }
255 }
256
257 LLMMessageTypedContent::Image { source } => GeminiPart::InlineData {
258 inline_data: GeminiInlineData {
259 mime_type: source.media_type,
260 data: source.data,
261 },
262 },
263 })
264 .collect(),
265 };
266
267 GeminiContent { role, parts }
268 }
269}
270
271impl From<GeminiContent> for LLMMessage {
273 fn from(content: GeminiContent) -> Self {
274 let role = content.role.to_string();
275 let mut message_content = Vec::new();
276
277 for part in content.parts {
278 match part {
279 GeminiPart::Text { text } => {
280 message_content.push(LLMMessageTypedContent::Text { text });
281 }
282 GeminiPart::FunctionCall { function_call } => {
283 message_content.push(LLMMessageTypedContent::ToolCall {
284 id: function_call.id.unwrap_or_else(|| "".to_string()),
285 name: function_call.name,
286 args: function_call.args,
287 });
288 }
289 GeminiPart::FunctionResponse { function_response } => {
290 message_content.push(LLMMessageTypedContent::ToolResult {
291 tool_use_id: function_response.id,
292 content: function_response.response.to_string(),
293 });
294 }
295 _ => {}
297 }
298 }
299
300 let content = if message_content.is_empty() {
301 LLMMessageContent::String(String::new())
302 } else if message_content.len() == 1 {
303 match &message_content[0] {
304 LLMMessageTypedContent::Text { text } => LLMMessageContent::String(text.clone()),
305 _ => LLMMessageContent::List(message_content),
306 }
307 } else {
308 LLMMessageContent::List(message_content)
309 };
310
311 LLMMessage { role, content }
312 }
313}
314
315impl From<LLMTool> for GeminiFunctionDeclaration {
316 fn from(tool: LLMTool) -> Self {
317 GeminiFunctionDeclaration {
318 name: tool.name,
319 description: tool.description,
320 parameters_json_schema: Some(tool.input_schema),
321 }
322 }
323}
324
325impl From<Vec<LLMTool>> for GeminiTool {
326 fn from(tools: Vec<LLMTool>) -> Self {
327 GeminiTool {
328 function_declarations: tools.into_iter().map(|t| t.into()).collect(),
329 }
330 }
331}
332
333impl From<GeminiResponse> for LLMCompletionResponse {
334 fn from(response: GeminiResponse) -> Self {
335 let usage = response.usage_metadata.map(|u| LLMTokenUsage {
336 prompt_tokens: u.prompt_token_count.unwrap_or(0),
337 completion_tokens: u.candidates_token_count.unwrap_or(0),
338 total_tokens: u.total_token_count.unwrap_or(0),
339 prompt_tokens_details: None,
340 });
341
342 let choices = response
343 .candidates
344 .unwrap_or_default()
345 .into_iter()
346 .enumerate()
347 .map(|(index, candidate)| {
348 let message = candidate
349 .content
350 .map(|c| c.into())
351 .unwrap_or_else(|| LLMMessage {
352 role: "model".to_string(),
353 content: LLMMessageContent::String(String::new()),
354 });
355
356 let has_tool_calls = match &message.content {
357 LLMMessageContent::List(items) => items
358 .iter()
359 .any(|item| matches!(item, LLMMessageTypedContent::ToolCall { .. })),
360 _ => false,
361 };
362
363 let finish_reason = if has_tool_calls {
364 Some("tool_calls".to_string())
365 } else {
366 candidate.finish_reason.map(|s| s.to_lowercase())
367 };
368
369 LLMChoice {
370 finish_reason,
371 index: index as u32,
372 message,
373 }
374 })
375 .collect();
376
377 LLMCompletionResponse {
378 model: response
380 .model_version
381 .unwrap_or_else(|| "gemini".to_string()),
382 object: "chat.completion".to_string(),
383 choices,
384 created: chrono::Utc::now().timestamp_millis() as u64,
385 usage,
386 id: response
387 .response_id
388 .unwrap_or_else(|| "unknown".to_string()),
389 }
390 }
391}
392
393pub struct Gemini {}
394
395impl Gemini {
396 pub async fn chat(
397 config: &GeminiConfig,
398 input: GeminiInput,
399 ) -> Result<LLMCompletionResponse, AgentError> {
400 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
401 let client = ClientBuilder::new(reqwest::Client::new())
402 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
403 .build();
404
405 let (contents, system_instruction) = convert_messages_to_gemini(input.messages)?;
406
407 let tools = input.tools.map(|t| vec![t.into()]);
408
409 let payload = GeminiRequest {
410 contents,
411 tools,
412 system_instruction,
413 generation_config: Some(GeminiGenerationConfig {
414 max_output_tokens: Some(input.max_tokens),
415 temperature: Some(0.0),
416 candidate_count: Some(1),
417 }),
418 };
419
420 let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
421 let api_key = config.api_key.as_ref().map_or("", |v| v);
422
423 let url = format!(
424 "{}/models/{}:generateContent?key={}",
425 api_endpoint, input.model, api_key
426 );
427
428 let response = client
429 .post(&url)
430 .header("Content-Type", "application/json")
431 .json(&payload)
432 .send()
433 .await
434 .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))?;
435
436 if !response.status().is_success() {
437 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
438 format!(
439 "{}: {}",
440 response.status(),
441 response.text().await.unwrap_or_default()
442 ),
443 )));
444 }
445
446 let response_text = response.text().await.map_err(|e| {
448 AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
449 "Failed to read response body: {}",
450 e
451 )))
452 })?;
453
454 let gemini_response: GeminiResponse =
455 serde_json::from_str(&response_text).map_err(|e| {
456 AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
457 "Failed to deserialize Gemini response: {}. Response body: {}",
458 e, response_text
459 )))
460 })?;
461
462 Ok(gemini_response.into())
463 }
464
465 pub async fn chat_stream(
466 config: &GeminiConfig,
467 stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
468 input: GeminiInput,
469 ) -> Result<LLMCompletionResponse, AgentError> {
470 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
471 let client = ClientBuilder::new(reqwest::Client::new())
472 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
473 .build();
474
475 let (contents, system_instruction) = convert_messages_to_gemini(input.messages)?;
476
477 let tools = input.tools.map(|t| vec![t.into()]);
478
479 let payload = GeminiRequest {
480 contents,
481 tools,
482 system_instruction,
483 generation_config: Some(GeminiGenerationConfig {
484 max_output_tokens: Some(input.max_tokens),
485 temperature: Some(0.0),
486 candidate_count: Some(1),
487 }),
488 };
489
490 let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
491 let api_key = config.api_key.as_ref().map_or("", |v| v);
492
493 let url = format!(
494 "{}/models/{}:streamGenerateContent?key={}",
495 api_endpoint, input.model, api_key
496 );
497
498 let response = client
499 .post(&url)
500 .header("Content-Type", "application/json")
501 .json(&payload)
502 .send()
503 .await
504 .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))?;
505
506 if !response.status().is_success() {
507 let status = response.status();
508 let error_body = response.text().await.unwrap_or_default();
509 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
510 format!("{}: {}", status, error_body),
511 )));
512 }
513
514 process_gemini_stream(response, input.model.to_string(), stream_channel_tx).await
515 }
516}
517
518fn convert_messages_to_gemini(
519 messages: Vec<LLMMessage>,
520) -> Result<(Vec<GeminiContent>, Option<GeminiSystemInstruction>), AgentError> {
521 let mut contents = Vec::new();
522 let mut system_parts = Vec::new();
523 let mut tool_id_to_name = std::collections::HashMap::new();
524
525 for message in messages {
526 match message.role.as_str() {
527 "system" => {
528 if let LLMMessageContent::String(text) = message.content {
529 system_parts.push(GeminiPart::Text { text });
530 }
531 }
532 _ => {
533 let role = match message.role.as_str() {
534 "assistant" | "model" => GeminiRole::Model,
535 "user" | "tool" => GeminiRole::User,
536 _ => GeminiRole::User,
537 };
538
539 let mut parts = Vec::new();
540
541 match message.content {
542 LLMMessageContent::String(text) => {
543 parts.push(GeminiPart::Text { text });
544 }
545 LLMMessageContent::List(items) => {
546 for item in items {
547 match item {
548 LLMMessageTypedContent::Text { text } => {
549 parts.push(GeminiPart::Text { text });
550 }
551 LLMMessageTypedContent::ToolCall { id, name, args } => {
552 tool_id_to_name.insert(id.clone(), name.clone());
553 parts.push(GeminiPart::FunctionCall {
554 function_call: GeminiFunctionCall {
555 id: Some(id),
556 name,
557 args,
558 },
559 });
560 }
561 LLMMessageTypedContent::ToolResult {
562 tool_use_id,
563 content,
564 } => {
565 let name = tool_id_to_name
566 .get(&tool_use_id)
567 .cloned()
568 .unwrap_or_else(|| "unknown".to_string());
569
570 let response_json = serde_json::json!({ "result": content });
572
573 parts.push(GeminiPart::FunctionResponse {
574 function_response: GeminiFunctionResponse {
575 id: tool_use_id,
576 name,
577 response: response_json,
578 },
579 });
580 }
581 LLMMessageTypedContent::Image { source } => {
582 parts.push(GeminiPart::InlineData {
583 inline_data: GeminiInlineData {
584 mime_type: source.media_type,
585 data: source.data,
586 },
587 });
588 }
589 }
590 }
591 }
592 }
593
594 contents.push(GeminiContent { role, parts });
595 }
596 }
597 }
598
599 let system_instruction = if system_parts.is_empty() {
600 None
601 } else {
602 Some(GeminiSystemInstruction {
603 parts: system_parts,
604 })
605 };
606
607 Ok((contents, system_instruction))
608}
609
610async fn process_gemini_stream(
611 response: reqwest::Response,
612 model: String,
613 stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
614) -> Result<LLMCompletionResponse, AgentError> {
615 let mut completion_response = LLMCompletionResponse {
616 id: "".to_string(),
617 model: model.clone(),
618 object: "chat.completion".to_string(),
619 choices: vec![],
620 created: chrono::Utc::now().timestamp_millis() as u64,
621 usage: None,
622 };
623
624 let mut stream = response.bytes_stream();
625 let mut line_buffer = String::new();
626 let mut json_accumulator = String::new();
627 let mut brace_depth = 0;
628 let mut in_object = false;
629 let mut finish_reason = None;
630 let mut message_content: Vec<LLMMessageTypedContent> = Vec::new();
631
632 while let Some(chunk) = stream.next().await {
633 let chunk = chunk.map_err(|e| {
634 AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
635 "Failed to read stream chunk: {}",
636 e
637 )))
638 })?;
639
640 let text = std::str::from_utf8(&chunk).map_err(|e| {
641 AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
642 "Failed to parse UTF-8: {}",
643 e
644 )))
645 })?;
646
647 line_buffer.push_str(text);
648
649 while let Some(line_end) = line_buffer.find('\n') {
651 let line = line_buffer[..line_end].trim().to_string();
652 line_buffer = line_buffer[line_end + 1..].to_string();
653
654 if line.is_empty() || line == "[" || line == "]" {
656 continue;
657 }
658
659 for ch in line.chars() {
661 match ch {
662 '{' => {
663 brace_depth += 1;
664 in_object = true;
665 }
666 '}' => {
667 brace_depth -= 1;
668 }
669 _ => {}
670 }
671 }
672
673 if in_object {
675 if !json_accumulator.is_empty() {
676 json_accumulator.push('\n');
677 }
678 json_accumulator.push_str(&line);
679 }
680
681 if in_object && brace_depth == 0 {
683 let mut json_str = json_accumulator.trim();
684 if json_str.starts_with('[') {
685 json_str = json_str[1..].trim();
686 }
687 if json_str.ends_with(']') {
688 json_str = json_str[..json_str.len() - 1].trim();
689 }
690 let json_str = json_str.trim_matches(',').trim();
691
692 match serde_json::from_str::<GeminiResponse>(json_str) {
694 Ok(gemini_response) => {
695 if let Some(candidates) = gemini_response.candidates {
697 for candidate in candidates {
698 if let Some(reason) = candidate.finish_reason {
699 finish_reason = Some(reason);
700 }
701 if let Some(content) = candidate.content {
702 for part in content.parts {
703 match part {
704 GeminiPart::Text { text } => {
705 stream_channel_tx
706 .send(GenerationDelta::Content {
707 content: text.clone(),
708 })
709 .await
710 .map_err(|e| {
711 AgentError::BadRequest(
712 BadRequestErrorMessage::ApiError(
713 e.to_string(),
714 ),
715 )
716 })?;
717 message_content
718 .push(LLMMessageTypedContent::Text { text });
719 }
720 GeminiPart::FunctionCall { function_call } => {
721 let GeminiFunctionCall { id, name, args } =
722 function_call;
723
724 let id = id
725 .unwrap_or_else(|| Uuid::new_v4().to_string());
726 let name_clone = name.clone();
727 let args_clone = args.clone();
728 stream_channel_tx
729 .send(GenerationDelta::ToolUse {
730 tool_use: GenerationDeltaToolUse {
731 id: Some(id.clone()),
732 name: Some(name_clone),
733 input: Some(args_clone.to_string()),
734 index: 0,
735 },
736 })
737 .await
738 .map_err(|e| {
739 AgentError::BadRequest(
740 BadRequestErrorMessage::ApiError(
741 e.to_string(),
742 ),
743 )
744 })?;
745 message_content.push(
746 LLMMessageTypedContent::ToolCall {
747 id,
748 name,
749 args,
750 },
751 );
752 }
753 _ => {}
754 }
755 }
756 }
757 }
758 }
759
760 if let Some(usage) = gemini_response.usage_metadata {
762 let token_usage = LLMTokenUsage {
763 prompt_tokens: usage.prompt_token_count.unwrap_or(0),
764 completion_tokens: usage.candidates_token_count.unwrap_or(0),
765 total_tokens: usage.total_token_count.unwrap_or(0),
766 prompt_tokens_details: None,
767 };
768 stream_channel_tx
769 .send(GenerationDelta::Usage {
770 usage: token_usage.clone(),
771 })
772 .await
773 .map_err(|e| {
774 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
775 e.to_string(),
776 ))
777 })?;
778 completion_response.usage = Some(token_usage);
779 }
780
781 if let Some(response_id) = gemini_response.response_id {
783 completion_response.id = response_id;
784 }
785 }
786 Err(e) => {
787 eprintln!("Failed to parse JSON object: {}. Error: {}", json_str, e);
788 }
789 }
790
791 json_accumulator.clear();
793 in_object = false;
794 }
795 }
796 }
797
798 let has_tool_calls = message_content
799 .iter()
800 .any(|c| matches!(c, LLMMessageTypedContent::ToolCall { .. }));
801
802 let final_finish_reason = if has_tool_calls {
803 Some("tool_calls".to_string())
804 } else {
805 finish_reason.map(|s| s.to_lowercase())
806 };
807
808 completion_response.choices = vec![LLMChoice {
810 finish_reason: final_finish_reason,
811 index: 0,
812 message: LLMMessage {
813 role: "assistant".to_string(),
814 content: if message_content.is_empty() {
815 LLMMessageContent::String(String::new())
816 } else if message_content.len() == 1
817 && matches!(&message_content[0], LLMMessageTypedContent::Text { .. })
818 {
819 if let LLMMessageTypedContent::Text { text } = &message_content[0] {
820 LLMMessageContent::String(text.clone())
821 } else {
822 LLMMessageContent::List(message_content)
823 }
824 } else {
825 LLMMessageContent::List(message_content)
826 },
827 },
828 }];
829
830 eprint!("{:?}", completion_response);
831
832 Ok(completion_response)
833}