1use crate::models::llm::{
12 GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13 LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27 pub api_endpoint: Option<String>,
28 pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32 pub fn with_api_key(api_key: impl Into<String>) -> Self {
34 Self {
35 api_key: Some(api_key.into()),
36 api_endpoint: None,
37 }
38 }
39
40 pub fn from_provider_auth(auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
42 match auth {
43 crate::models::auth::ProviderAuth::Api { key } => Some(Self::with_api_key(key)),
44 crate::models::auth::ProviderAuth::OAuth { .. } => None, }
46 }
47
48 pub fn with_provider_auth(mut self, auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
50 match auth {
51 crate::models::auth::ProviderAuth::Api { key } => {
52 self.api_key = Some(key.clone());
53 Some(self)
54 }
55 crate::models::auth::ProviderAuth::OAuth { .. } => None, }
57 }
58}
59
60#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
66pub enum OpenAIModel {
67 #[serde(rename = "o3-2025-04-16")]
69 O3,
70 #[serde(rename = "o4-mini-2025-04-16")]
71 O4Mini,
72
73 #[default]
74 #[serde(rename = "gpt-5-2025-08-07")]
75 GPT5,
76 #[serde(rename = "gpt-5.1-2025-11-13")]
77 GPT51,
78 #[serde(rename = "gpt-5-mini-2025-08-07")]
79 GPT5Mini,
80 #[serde(rename = "gpt-5-nano-2025-08-07")]
81 GPT5Nano,
82
83 Custom(String),
84}
85
86impl OpenAIModel {
87 pub fn from_string(s: &str) -> Result<Self, String> {
88 serde_json::from_value(serde_json::Value::String(s.to_string()))
89 .map_err(|_| "Failed to deserialize OpenAI model".to_string())
90 }
91
92 pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
94
95 pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
97
98 pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
100
101 pub fn default_smart_model() -> String {
103 Self::DEFAULT_SMART_MODEL.to_string()
104 }
105
106 pub fn default_eco_model() -> String {
108 Self::DEFAULT_ECO_MODEL.to_string()
109 }
110
111 pub fn default_recovery_model() -> String {
113 Self::DEFAULT_RECOVERY_MODEL.to_string()
114 }
115}
116
117impl ContextAware for OpenAIModel {
118 fn context_info(&self) -> ModelContextInfo {
119 let model_name = self.to_string();
120
121 if model_name.starts_with("o3") {
122 return ModelContextInfo {
123 max_tokens: 200_000,
124 pricing_tiers: vec![ContextPricingTier {
125 label: "Standard".to_string(),
126 input_cost_per_million: 2.0,
127 output_cost_per_million: 8.0,
128 upper_bound: None,
129 }],
130 approach_warning_threshold: 0.8,
131 };
132 }
133
134 if model_name.starts_with("o4-mini") {
135 return ModelContextInfo {
136 max_tokens: 200_000,
137 pricing_tiers: vec![ContextPricingTier {
138 label: "Standard".to_string(),
139 input_cost_per_million: 1.10,
140 output_cost_per_million: 4.40,
141 upper_bound: None,
142 }],
143 approach_warning_threshold: 0.8,
144 };
145 }
146
147 if model_name.starts_with("gpt-5-mini") {
148 return ModelContextInfo {
149 max_tokens: 400_000,
150 pricing_tiers: vec![ContextPricingTier {
151 label: "Standard".to_string(),
152 input_cost_per_million: 0.25,
153 output_cost_per_million: 2.0,
154 upper_bound: None,
155 }],
156 approach_warning_threshold: 0.8,
157 };
158 }
159
160 if model_name.starts_with("gpt-5-nano") {
161 return ModelContextInfo {
162 max_tokens: 400_000,
163 pricing_tiers: vec![ContextPricingTier {
164 label: "Standard".to_string(),
165 input_cost_per_million: 0.05,
166 output_cost_per_million: 0.40,
167 upper_bound: None,
168 }],
169 approach_warning_threshold: 0.8,
170 };
171 }
172
173 if model_name.starts_with("gpt-5") {
174 return ModelContextInfo {
175 max_tokens: 400_000,
176 pricing_tiers: vec![ContextPricingTier {
177 label: "Standard".to_string(),
178 input_cost_per_million: 1.25,
179 output_cost_per_million: 10.0,
180 upper_bound: None,
181 }],
182 approach_warning_threshold: 0.8,
183 };
184 }
185
186 ModelContextInfo::default()
187 }
188
189 fn model_name(&self) -> String {
190 match self {
191 OpenAIModel::O3 => "O3".to_string(),
192 OpenAIModel::O4Mini => "O4-mini".to_string(),
193 OpenAIModel::GPT5 => "GPT-5".to_string(),
194 OpenAIModel::GPT51 => "GPT-5.1".to_string(),
195 OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
196 OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
197 OpenAIModel::Custom(name) => format!("Custom ({})", name),
198 }
199 }
200}
201
202impl std::fmt::Display for OpenAIModel {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 match self {
205 OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
206 OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
207 OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
208 OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
209 OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
210 OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
211 OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
212 }
213 }
214}
215
216#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
218pub enum AgentModel {
219 #[serde(rename = "smart")]
220 #[default]
221 Smart,
222 #[serde(rename = "eco")]
223 Eco,
224 #[serde(rename = "recovery")]
225 Recovery,
226}
227
228impl std::fmt::Display for AgentModel {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 match self {
231 AgentModel::Smart => write!(f, "smart"),
232 AgentModel::Eco => write!(f, "eco"),
233 AgentModel::Recovery => write!(f, "recovery"),
234 }
235 }
236}
237
238impl From<String> for AgentModel {
239 fn from(value: String) -> Self {
240 match value.as_str() {
241 "eco" => AgentModel::Eco,
242 "recovery" => AgentModel::Recovery,
243 _ => AgentModel::Smart,
244 }
245 }
246}
247
248#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
254#[serde(rename_all = "lowercase")]
255pub enum Role {
256 System,
257 Developer,
258 User,
259 #[default]
260 Assistant,
261 Tool,
262}
263
264impl std::fmt::Display for Role {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 match self {
267 Role::System => write!(f, "system"),
268 Role::Developer => write!(f, "developer"),
269 Role::User => write!(f, "user"),
270 Role::Assistant => write!(f, "assistant"),
271 Role::Tool => write!(f, "tool"),
272 }
273 }
274}
275
276#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
278pub struct ModelInfo {
279 pub provider: String,
281 pub id: String,
283}
284
285#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
287pub struct ChatMessage {
288 pub role: Role,
289 pub content: Option<MessageContent>,
290 #[serde(skip_serializing_if = "Option::is_none")]
291 pub name: Option<String>,
292 #[serde(skip_serializing_if = "Option::is_none")]
293 pub tool_calls: Option<Vec<ToolCall>>,
294 #[serde(skip_serializing_if = "Option::is_none")]
295 pub tool_call_id: Option<String>,
296 #[serde(skip_serializing_if = "Option::is_none")]
297 pub usage: Option<LLMTokenUsage>,
298
299 #[serde(skip_serializing_if = "Option::is_none")]
302 pub id: Option<String>,
303 #[serde(skip_serializing_if = "Option::is_none")]
305 pub model: Option<ModelInfo>,
306 #[serde(skip_serializing_if = "Option::is_none")]
308 pub cost: Option<f64>,
309 #[serde(skip_serializing_if = "Option::is_none")]
311 pub finish_reason: Option<String>,
312 #[serde(skip_serializing_if = "Option::is_none")]
314 pub created_at: Option<i64>,
315 #[serde(skip_serializing_if = "Option::is_none")]
317 pub completed_at: Option<i64>,
318 #[serde(skip_serializing_if = "Option::is_none")]
320 pub metadata: Option<serde_json::Value>,
321}
322
323impl ChatMessage {
324 pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
325 messages
326 .iter()
327 .rev()
328 .find(|message| message.role != Role::User && message.role != Role::Tool)
329 }
330
331 pub fn to_xml(&self) -> String {
332 match &self.content {
333 Some(MessageContent::String(s)) => {
334 format!("<message role=\"{}\">{}</message>", self.role, s)
335 }
336 Some(MessageContent::Array(parts)) => parts
337 .iter()
338 .map(|part| {
339 format!(
340 "<message role=\"{}\" type=\"{}\">{}</message>",
341 self.role,
342 part.r#type,
343 part.text.clone().unwrap_or_default()
344 )
345 })
346 .collect::<Vec<String>>()
347 .join("\n"),
348 None => String::new(),
349 }
350 }
351}
352
353#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
355#[serde(untagged)]
356pub enum MessageContent {
357 String(String),
358 Array(Vec<ContentPart>),
359}
360
361impl MessageContent {
362 pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
363 match self {
364 MessageContent::String(s) => MessageContent::String(format!(
365 "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
366 )),
367 MessageContent::Array(parts) => MessageContent::Array(
368 std::iter::once(ContentPart {
369 r#type: "text".to_string(),
370 text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
371 image_url: None,
372 })
373 .chain(parts.iter().cloned())
374 .collect(),
375 ),
376 }
377 }
378
379 pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
380 match self {
381 MessageContent::String(s) => s
382 .rfind("<checkpoint_id>")
383 .and_then(|start| {
384 s[start..]
385 .find("</checkpoint_id>")
386 .map(|end| (start + "<checkpoint_id>".len(), start + end))
387 })
388 .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
389 MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
390 part.text.as_deref().and_then(|text| {
391 text.rfind("<checkpoint_id>")
392 .and_then(|start| {
393 text[start..]
394 .find("</checkpoint_id>")
395 .map(|end| (start + "<checkpoint_id>".len(), start + end))
396 })
397 .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
398 })
399 }),
400 }
401 }
402}
403
404impl std::fmt::Display for MessageContent {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 match self {
407 MessageContent::String(s) => write!(f, "{s}"),
408 MessageContent::Array(parts) => {
409 let text_parts: Vec<String> =
410 parts.iter().filter_map(|part| part.text.clone()).collect();
411 write!(f, "{}", text_parts.join("\n"))
412 }
413 }
414 }
415}
416
417impl Default for MessageContent {
418 fn default() -> Self {
419 MessageContent::String(String::new())
420 }
421}
422
423#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
425pub struct ContentPart {
426 pub r#type: String,
427 #[serde(skip_serializing_if = "Option::is_none")]
428 pub text: Option<String>,
429 #[serde(skip_serializing_if = "Option::is_none")]
430 pub image_url: Option<ImageUrl>,
431}
432
433#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
435pub struct ImageUrl {
436 pub url: String,
437 #[serde(skip_serializing_if = "Option::is_none")]
438 pub detail: Option<String>,
439}
440
441#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
447pub struct Tool {
448 pub r#type: String,
449 pub function: FunctionDefinition,
450}
451
452#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
454pub struct FunctionDefinition {
455 pub name: String,
456 pub description: Option<String>,
457 pub parameters: serde_json::Value,
458}
459
460impl From<Tool> for LLMTool {
461 fn from(tool: Tool) -> Self {
462 LLMTool {
463 name: tool.function.name,
464 description: tool.function.description.unwrap_or_default(),
465 input_schema: tool.function.parameters,
466 }
467 }
468}
469
470#[derive(Debug, Clone, PartialEq)]
472pub enum ToolChoice {
473 Auto,
474 Required,
475 Object(ToolChoiceObject),
476}
477
478impl Serialize for ToolChoice {
479 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
480 where
481 S: serde::Serializer,
482 {
483 match self {
484 ToolChoice::Auto => serializer.serialize_str("auto"),
485 ToolChoice::Required => serializer.serialize_str("required"),
486 ToolChoice::Object(obj) => obj.serialize(serializer),
487 }
488 }
489}
490
491impl<'de> Deserialize<'de> for ToolChoice {
492 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
493 where
494 D: serde::Deserializer<'de>,
495 {
496 struct ToolChoiceVisitor;
497
498 impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
499 type Value = ToolChoice;
500
501 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
502 formatter.write_str("string or object")
503 }
504
505 fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
506 where
507 E: serde::de::Error,
508 {
509 match value {
510 "auto" => Ok(ToolChoice::Auto),
511 "required" => Ok(ToolChoice::Required),
512 _ => Err(serde::de::Error::unknown_variant(
513 value,
514 &["auto", "required"],
515 )),
516 }
517 }
518
519 fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
520 where
521 M: serde::de::MapAccess<'de>,
522 {
523 let obj = ToolChoiceObject::deserialize(
524 serde::de::value::MapAccessDeserializer::new(map),
525 )?;
526 Ok(ToolChoice::Object(obj))
527 }
528 }
529
530 deserializer.deserialize_any(ToolChoiceVisitor)
531 }
532}
533
534#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
535pub struct ToolChoiceObject {
536 pub r#type: String,
537 pub function: FunctionChoice,
538}
539
540#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
541pub struct FunctionChoice {
542 pub name: String,
543}
544
545#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
547pub struct ToolCall {
548 pub id: String,
549 pub r#type: String,
550 pub function: FunctionCall,
551}
552
553#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
555pub struct FunctionCall {
556 pub name: String,
557 pub arguments: String,
558}
559
560#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
562pub enum ToolCallResultStatus {
563 Success,
564 Error,
565 Cancelled,
566}
567
568#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
570pub struct ToolCallResult {
571 pub call: ToolCall,
572 pub result: String,
573 pub status: ToolCallResultStatus,
574}
575
576#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
578pub struct ToolCallResultProgress {
579 pub id: Uuid,
580 pub message: String,
581}
582
583#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
589pub struct ChatCompletionRequest {
590 pub model: String,
591 pub messages: Vec<ChatMessage>,
592 #[serde(skip_serializing_if = "Option::is_none")]
593 pub frequency_penalty: Option<f32>,
594 #[serde(skip_serializing_if = "Option::is_none")]
595 pub logit_bias: Option<serde_json::Value>,
596 #[serde(skip_serializing_if = "Option::is_none")]
597 pub logprobs: Option<bool>,
598 #[serde(skip_serializing_if = "Option::is_none")]
599 pub max_tokens: Option<u32>,
600 #[serde(skip_serializing_if = "Option::is_none")]
601 pub n: Option<u32>,
602 #[serde(skip_serializing_if = "Option::is_none")]
603 pub presence_penalty: Option<f32>,
604 #[serde(skip_serializing_if = "Option::is_none")]
605 pub response_format: Option<ResponseFormat>,
606 #[serde(skip_serializing_if = "Option::is_none")]
607 pub seed: Option<i64>,
608 #[serde(skip_serializing_if = "Option::is_none")]
609 pub stop: Option<StopSequence>,
610 #[serde(skip_serializing_if = "Option::is_none")]
611 pub stream: Option<bool>,
612 #[serde(skip_serializing_if = "Option::is_none")]
613 pub temperature: Option<f32>,
614 #[serde(skip_serializing_if = "Option::is_none")]
615 pub top_p: Option<f32>,
616 #[serde(skip_serializing_if = "Option::is_none")]
617 pub tools: Option<Vec<Tool>>,
618 #[serde(skip_serializing_if = "Option::is_none")]
619 pub tool_choice: Option<ToolChoice>,
620 #[serde(skip_serializing_if = "Option::is_none")]
621 pub user: Option<String>,
622 #[serde(skip_serializing_if = "Option::is_none")]
623 pub context: Option<ChatCompletionContext>,
624}
625
626impl ChatCompletionRequest {
627 pub fn new(
628 model: String,
629 messages: Vec<ChatMessage>,
630 tools: Option<Vec<Tool>>,
631 stream: Option<bool>,
632 ) -> Self {
633 Self {
634 model,
635 messages,
636 frequency_penalty: None,
637 logit_bias: None,
638 logprobs: None,
639 max_tokens: None,
640 n: None,
641 presence_penalty: None,
642 response_format: None,
643 seed: None,
644 stop: None,
645 stream,
646 temperature: None,
647 top_p: None,
648 tools,
649 tool_choice: None,
650 user: None,
651 context: None,
652 }
653 }
654}
655
656#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
657pub struct ChatCompletionContext {
658 pub scratchpad: Option<Value>,
659}
660
661#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
662pub struct ResponseFormat {
663 pub r#type: String,
664}
665
666#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
667#[serde(untagged)]
668pub enum StopSequence {
669 String(String),
670 Array(Vec<String>),
671}
672
673#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
675pub struct ChatCompletionResponse {
676 pub id: String,
677 pub object: String,
678 pub created: u64,
679 pub model: String,
680 pub choices: Vec<ChatCompletionChoice>,
681 pub usage: LLMTokenUsage,
682 #[serde(skip_serializing_if = "Option::is_none")]
683 pub system_fingerprint: Option<String>,
684 pub metadata: Option<serde_json::Value>,
685}
686
687#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
688pub struct ChatCompletionChoice {
689 pub index: usize,
690 pub message: ChatMessage,
691 pub logprobs: Option<LogProbs>,
692 pub finish_reason: FinishReason,
693}
694
695#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
696#[serde(rename_all = "snake_case")]
697pub enum FinishReason {
698 Stop,
699 Length,
700 ContentFilter,
701 ToolCalls,
702}
703
704#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
705pub struct LogProbs {
706 pub content: Option<Vec<LogProbContent>>,
707}
708
709#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
710pub struct LogProbContent {
711 pub token: String,
712 pub logprob: f32,
713 pub bytes: Option<Vec<u8>>,
714 pub top_logprobs: Option<Vec<TokenLogprob>>,
715}
716
717#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
718pub struct TokenLogprob {
719 pub token: String,
720 pub logprob: f32,
721 pub bytes: Option<Vec<u8>>,
722}
723
724#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
729pub struct ChatCompletionStreamResponse {
730 pub id: String,
731 pub object: String,
732 pub created: u64,
733 pub model: String,
734 pub choices: Vec<ChatCompletionStreamChoice>,
735 pub usage: Option<LLMTokenUsage>,
736 pub metadata: Option<serde_json::Value>,
737}
738
739#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
740pub struct ChatCompletionStreamChoice {
741 pub index: usize,
742 pub delta: ChatMessageDelta,
743 pub finish_reason: Option<FinishReason>,
744}
745
746#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
747pub struct ChatMessageDelta {
748 #[serde(skip_serializing_if = "Option::is_none")]
749 pub role: Option<Role>,
750 #[serde(skip_serializing_if = "Option::is_none")]
751 pub content: Option<String>,
752 #[serde(skip_serializing_if = "Option::is_none")]
753 pub tool_calls: Option<Vec<ToolCallDelta>>,
754}
755
756#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
757pub struct ToolCallDelta {
758 pub index: usize,
759 pub id: Option<String>,
760 pub r#type: Option<String>,
761 pub function: Option<FunctionCallDelta>,
762}
763
764#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
765pub struct FunctionCallDelta {
766 pub name: Option<String>,
767 pub arguments: Option<String>,
768}
769
770impl From<LLMMessage> for ChatMessage {
775 fn from(llm_message: LLMMessage) -> Self {
776 let role = match llm_message.role.as_str() {
777 "system" => Role::System,
778 "user" => Role::User,
779 "assistant" => Role::Assistant,
780 "tool" => Role::Tool,
781 "developer" => Role::Developer,
782 _ => Role::User,
783 };
784
785 let (content, tool_calls) = match llm_message.content {
786 LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None),
787 LLMMessageContent::List(items) => {
788 let mut text_parts = Vec::new();
789 let mut tool_call_parts = Vec::new();
790
791 for item in items {
792 match item {
793 LLMMessageTypedContent::Text { text } => {
794 text_parts.push(ContentPart {
795 r#type: "text".to_string(),
796 text: Some(text),
797 image_url: None,
798 });
799 }
800 LLMMessageTypedContent::ToolCall { id, name, args } => {
801 tool_call_parts.push(ToolCall {
802 id,
803 r#type: "function".to_string(),
804 function: FunctionCall {
805 name,
806 arguments: args.to_string(),
807 },
808 });
809 }
810 LLMMessageTypedContent::ToolResult { content, .. } => {
811 text_parts.push(ContentPart {
812 r#type: "text".to_string(),
813 text: Some(content),
814 image_url: None,
815 });
816 }
817 LLMMessageTypedContent::Image { source } => {
818 text_parts.push(ContentPart {
819 r#type: "image_url".to_string(),
820 text: None,
821 image_url: Some(ImageUrl {
822 url: format!(
823 "data:{};base64,{}",
824 source.media_type, source.data
825 ),
826 detail: None,
827 }),
828 });
829 }
830 }
831 }
832
833 let content = if !text_parts.is_empty() {
834 Some(MessageContent::Array(text_parts))
835 } else {
836 None
837 };
838
839 let tool_calls = if !tool_call_parts.is_empty() {
840 Some(tool_call_parts)
841 } else {
842 None
843 };
844
845 (content, tool_calls)
846 }
847 };
848
849 ChatMessage {
850 role,
851 content,
852 name: None,
853 tool_calls,
854 tool_call_id: None,
855 usage: None,
856 ..Default::default()
857 }
858 }
859}
860
861impl From<ChatMessage> for LLMMessage {
862 fn from(chat_message: ChatMessage) -> Self {
863 let mut content_parts = Vec::new();
864
865 match chat_message.content {
866 Some(MessageContent::String(s)) => {
867 if !s.is_empty() {
868 content_parts.push(LLMMessageTypedContent::Text { text: s });
869 }
870 }
871 Some(MessageContent::Array(parts)) => {
872 for part in parts {
873 if let Some(text) = part.text {
874 content_parts.push(LLMMessageTypedContent::Text { text });
875 } else if let Some(image_url) = part.image_url {
876 let (media_type, data) = if image_url.url.starts_with("data:") {
877 let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
878 if parts.len() == 2 {
879 let meta = parts[0];
880 let data = parts[1];
881 let media_type = meta
882 .trim_start_matches("data:")
883 .trim_end_matches(";base64")
884 .to_string();
885 (media_type, data.to_string())
886 } else {
887 ("image/jpeg".to_string(), image_url.url)
888 }
889 } else {
890 ("image/jpeg".to_string(), image_url.url)
891 };
892
893 content_parts.push(LLMMessageTypedContent::Image {
894 source: LLMMessageImageSource {
895 r#type: "base64".to_string(),
896 media_type,
897 data,
898 },
899 });
900 }
901 }
902 }
903 None => {}
904 }
905
906 if let Some(tool_calls) = chat_message.tool_calls {
907 for tool_call in tool_calls {
908 let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
909 content_parts.push(LLMMessageTypedContent::ToolCall {
910 id: tool_call.id,
911 name: tool_call.function.name,
912 args,
913 });
914 }
915 }
916
917 if chat_message.role == Role::Tool
922 && let Some(tool_call_id) = chat_message.tool_call_id
923 {
924 let content_str = content_parts
926 .iter()
927 .filter_map(|p| match p {
928 LLMMessageTypedContent::Text { text } => Some(text.clone()),
929 _ => None,
930 })
931 .collect::<Vec<_>>()
932 .join("\n");
933
934 content_parts = vec![LLMMessageTypedContent::ToolResult {
936 tool_use_id: tool_call_id,
937 content: content_str,
938 }];
939 }
940
941 LLMMessage {
942 role: chat_message.role.to_string(),
943 content: if content_parts.is_empty() {
944 LLMMessageContent::String(String::new())
945 } else if content_parts.len() == 1 {
946 match &content_parts[0] {
947 LLMMessageTypedContent::Text { text } => {
948 LLMMessageContent::String(text.clone())
949 }
950 _ => LLMMessageContent::List(content_parts),
951 }
952 } else {
953 LLMMessageContent::List(content_parts)
954 },
955 }
956 }
957}
958
959impl From<GenerationDelta> for ChatMessageDelta {
960 fn from(delta: GenerationDelta) -> Self {
961 match delta {
962 GenerationDelta::Content { content } => ChatMessageDelta {
963 role: Some(Role::Assistant),
964 content: Some(content),
965 tool_calls: None,
966 },
967 GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
968 role: Some(Role::Assistant),
969 content: None,
970 tool_calls: None,
971 },
972 GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
973 role: Some(Role::Assistant),
974 content: None,
975 tool_calls: Some(vec![ToolCallDelta {
976 index: tool_use.index,
977 id: tool_use.id,
978 r#type: Some("function".to_string()),
979 function: Some(FunctionCallDelta {
980 name: tool_use.name,
981 arguments: tool_use.input,
982 }),
983 }]),
984 },
985 _ => ChatMessageDelta {
986 role: Some(Role::Assistant),
987 content: None,
988 tool_calls: None,
989 },
990 }
991 }
992}
993
994#[cfg(test)]
995mod tests {
996 use super::*;
997
998 #[test]
999 fn test_serialize_basic_request() {
1000 let request = ChatCompletionRequest {
1001 model: AgentModel::Smart.to_string(),
1002 messages: vec![
1003 ChatMessage {
1004 role: Role::System,
1005 content: Some(MessageContent::String(
1006 "You are a helpful assistant.".to_string(),
1007 )),
1008 name: None,
1009 tool_calls: None,
1010 tool_call_id: None,
1011 usage: None,
1012 ..Default::default()
1013 },
1014 ChatMessage {
1015 role: Role::User,
1016 content: Some(MessageContent::String("Hello!".to_string())),
1017 name: None,
1018 tool_calls: None,
1019 tool_call_id: None,
1020 usage: None,
1021 ..Default::default()
1022 },
1023 ],
1024 frequency_penalty: None,
1025 logit_bias: None,
1026 logprobs: None,
1027 max_tokens: Some(100),
1028 n: None,
1029 presence_penalty: None,
1030 response_format: None,
1031 seed: None,
1032 stop: None,
1033 stream: None,
1034 temperature: Some(0.7),
1035 top_p: None,
1036 tools: None,
1037 tool_choice: None,
1038 user: None,
1039 context: None,
1040 };
1041
1042 let json = serde_json::to_string(&request).unwrap();
1043 assert!(json.contains("\"model\":\"smart\""));
1044 assert!(json.contains("\"messages\":["));
1045 assert!(json.contains("\"role\":\"system\""));
1046 }
1047
1048 #[test]
1049 fn test_llm_message_to_chat_message() {
1050 let llm_message = LLMMessage {
1051 role: "user".to_string(),
1052 content: LLMMessageContent::String("Hello, world!".to_string()),
1053 };
1054
1055 let chat_message = ChatMessage::from(llm_message);
1056 assert_eq!(chat_message.role, Role::User);
1057 match &chat_message.content {
1058 Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1059 _ => panic!("Expected string content"),
1060 }
1061 }
1062
1063 #[test]
1064 fn test_chat_message_to_llm_message_tool_result() {
1065 let chat_message = ChatMessage {
1069 role: Role::Tool,
1070 content: Some(MessageContent::String("Tool execution result".to_string())),
1071 name: None,
1072 tool_calls: None,
1073 tool_call_id: Some("toolu_01Abc123".to_string()),
1074 usage: None,
1075 ..Default::default()
1076 };
1077
1078 let llm_message: LLMMessage = chat_message.into();
1079
1080 assert_eq!(llm_message.role, "tool");
1082
1083 match &llm_message.content {
1085 LLMMessageContent::List(parts) => {
1086 assert_eq!(parts.len(), 1, "Should have exactly one content part");
1087 match &parts[0] {
1088 LLMMessageTypedContent::ToolResult {
1089 tool_use_id,
1090 content,
1091 } => {
1092 assert_eq!(tool_use_id, "toolu_01Abc123");
1093 assert_eq!(content, "Tool execution result");
1094 }
1095 _ => panic!("Expected ToolResult content part, got {:?}", parts[0]),
1096 }
1097 }
1098 _ => panic!(
1099 "Expected List content with ToolResult, got {:?}",
1100 llm_message.content
1101 ),
1102 }
1103 }
1104
1105 #[test]
1106 fn test_chat_message_to_llm_message_tool_result_empty_content() {
1107 let chat_message = ChatMessage {
1109 role: Role::Tool,
1110 content: None,
1111 name: None,
1112 tool_calls: None,
1113 tool_call_id: Some("toolu_02Xyz789".to_string()),
1114 usage: None,
1115 ..Default::default()
1116 };
1117
1118 let llm_message: LLMMessage = chat_message.into();
1119
1120 assert_eq!(llm_message.role, "tool");
1121 match &llm_message.content {
1122 LLMMessageContent::List(parts) => {
1123 assert_eq!(parts.len(), 1);
1124 match &parts[0] {
1125 LLMMessageTypedContent::ToolResult {
1126 tool_use_id,
1127 content,
1128 } => {
1129 assert_eq!(tool_use_id, "toolu_02Xyz789");
1130 assert_eq!(content, ""); }
1132 _ => panic!("Expected ToolResult content part"),
1133 }
1134 }
1135 _ => panic!("Expected List content with ToolResult"),
1136 }
1137 }
1138
1139 #[test]
1140 fn test_chat_message_to_llm_message_assistant_with_tool_calls() {
1141 let chat_message = ChatMessage {
1143 role: Role::Assistant,
1144 content: Some(MessageContent::String(
1145 "I'll help you with that.".to_string(),
1146 )),
1147 name: None,
1148 tool_calls: Some(vec![ToolCall {
1149 id: "call_abc123".to_string(),
1150 r#type: "function".to_string(),
1151 function: FunctionCall {
1152 name: "get_weather".to_string(),
1153 arguments: r#"{"location": "Paris"}"#.to_string(),
1154 },
1155 }]),
1156 tool_call_id: None,
1157 usage: None,
1158 ..Default::default()
1159 };
1160
1161 let llm_message: LLMMessage = chat_message.into();
1162
1163 assert_eq!(llm_message.role, "assistant");
1164 match &llm_message.content {
1165 LLMMessageContent::List(parts) => {
1166 assert_eq!(parts.len(), 2, "Should have text and tool call");
1167
1168 match &parts[0] {
1170 LLMMessageTypedContent::Text { text } => {
1171 assert_eq!(text, "I'll help you with that.");
1172 }
1173 _ => panic!("Expected Text content part first"),
1174 }
1175
1176 match &parts[1] {
1178 LLMMessageTypedContent::ToolCall { id, name, args } => {
1179 assert_eq!(id, "call_abc123");
1180 assert_eq!(name, "get_weather");
1181 assert_eq!(args["location"], "Paris");
1182 }
1183 _ => panic!("Expected ToolCall content part second"),
1184 }
1185 }
1186 _ => panic!("Expected List content"),
1187 }
1188 }
1189}