1use crate::models::llm::{
12 GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13 LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27 pub api_endpoint: Option<String>,
28 pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32 pub fn with_api_key(api_key: impl Into<String>) -> Self {
34 Self {
35 api_key: Some(api_key.into()),
36 api_endpoint: None,
37 }
38 }
39
40 pub fn from_provider_auth(auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
42 match auth {
43 crate::models::auth::ProviderAuth::Api { key } => Some(Self::with_api_key(key)),
44 crate::models::auth::ProviderAuth::OAuth { .. } => None, }
46 }
47
48 pub fn with_provider_auth(mut self, auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
50 match auth {
51 crate::models::auth::ProviderAuth::Api { key } => {
52 self.api_key = Some(key.clone());
53 Some(self)
54 }
55 crate::models::auth::ProviderAuth::OAuth { .. } => None, }
57 }
58}
59
60#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
66pub enum OpenAIModel {
67 #[serde(rename = "o3-2025-04-16")]
69 O3,
70 #[serde(rename = "o4-mini-2025-04-16")]
71 O4Mini,
72
73 #[default]
74 #[serde(rename = "gpt-5-2025-08-07")]
75 GPT5,
76 #[serde(rename = "gpt-5.1-2025-11-13")]
77 GPT51,
78 #[serde(rename = "gpt-5-mini-2025-08-07")]
79 GPT5Mini,
80 #[serde(rename = "gpt-5-nano-2025-08-07")]
81 GPT5Nano,
82
83 Custom(String),
84}
85
86impl OpenAIModel {
87 pub fn from_string(s: &str) -> Result<Self, String> {
88 serde_json::from_value(serde_json::Value::String(s.to_string()))
89 .map_err(|_| "Failed to deserialize OpenAI model".to_string())
90 }
91
92 pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
94
95 pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
97
98 pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
100
101 pub fn default_smart_model() -> String {
103 Self::DEFAULT_SMART_MODEL.to_string()
104 }
105
106 pub fn default_eco_model() -> String {
108 Self::DEFAULT_ECO_MODEL.to_string()
109 }
110
111 pub fn default_recovery_model() -> String {
113 Self::DEFAULT_RECOVERY_MODEL.to_string()
114 }
115}
116
117impl ContextAware for OpenAIModel {
118 fn context_info(&self) -> ModelContextInfo {
119 let model_name = self.to_string();
120
121 if model_name.starts_with("o3") {
122 return ModelContextInfo {
123 max_tokens: 200_000,
124 pricing_tiers: vec![ContextPricingTier {
125 label: "Standard".to_string(),
126 input_cost_per_million: 2.0,
127 output_cost_per_million: 8.0,
128 upper_bound: None,
129 }],
130 approach_warning_threshold: 0.8,
131 };
132 }
133
134 if model_name.starts_with("o4-mini") {
135 return ModelContextInfo {
136 max_tokens: 200_000,
137 pricing_tiers: vec![ContextPricingTier {
138 label: "Standard".to_string(),
139 input_cost_per_million: 1.10,
140 output_cost_per_million: 4.40,
141 upper_bound: None,
142 }],
143 approach_warning_threshold: 0.8,
144 };
145 }
146
147 if model_name.starts_with("gpt-5-mini") {
148 return ModelContextInfo {
149 max_tokens: 400_000,
150 pricing_tiers: vec![ContextPricingTier {
151 label: "Standard".to_string(),
152 input_cost_per_million: 0.25,
153 output_cost_per_million: 2.0,
154 upper_bound: None,
155 }],
156 approach_warning_threshold: 0.8,
157 };
158 }
159
160 if model_name.starts_with("gpt-5-nano") {
161 return ModelContextInfo {
162 max_tokens: 400_000,
163 pricing_tiers: vec![ContextPricingTier {
164 label: "Standard".to_string(),
165 input_cost_per_million: 0.05,
166 output_cost_per_million: 0.40,
167 upper_bound: None,
168 }],
169 approach_warning_threshold: 0.8,
170 };
171 }
172
173 if model_name.starts_with("gpt-5") {
174 return ModelContextInfo {
175 max_tokens: 400_000,
176 pricing_tiers: vec![ContextPricingTier {
177 label: "Standard".to_string(),
178 input_cost_per_million: 1.25,
179 output_cost_per_million: 10.0,
180 upper_bound: None,
181 }],
182 approach_warning_threshold: 0.8,
183 };
184 }
185
186 ModelContextInfo::default()
187 }
188
189 fn model_name(&self) -> String {
190 match self {
191 OpenAIModel::O3 => "O3".to_string(),
192 OpenAIModel::O4Mini => "O4-mini".to_string(),
193 OpenAIModel::GPT5 => "GPT-5".to_string(),
194 OpenAIModel::GPT51 => "GPT-5.1".to_string(),
195 OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
196 OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
197 OpenAIModel::Custom(name) => format!("Custom ({})", name),
198 }
199 }
200}
201
202impl std::fmt::Display for OpenAIModel {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 match self {
205 OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
206 OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
207 OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
208 OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
209 OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
210 OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
211 OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
212 }
213 }
214}
215
216#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
218pub enum AgentModel {
219 #[serde(rename = "smart")]
220 #[default]
221 Smart,
222 #[serde(rename = "eco")]
223 Eco,
224 #[serde(rename = "recovery")]
225 Recovery,
226}
227
228impl std::fmt::Display for AgentModel {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 match self {
231 AgentModel::Smart => write!(f, "smart"),
232 AgentModel::Eco => write!(f, "eco"),
233 AgentModel::Recovery => write!(f, "recovery"),
234 }
235 }
236}
237
238impl From<String> for AgentModel {
239 fn from(value: String) -> Self {
240 match value.as_str() {
241 "eco" => AgentModel::Eco,
242 "recovery" => AgentModel::Recovery,
243 _ => AgentModel::Smart,
244 }
245 }
246}
247
248#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
254#[serde(rename_all = "lowercase")]
255pub enum Role {
256 System,
257 Developer,
258 User,
259 #[default]
260 Assistant,
261 Tool,
262}
263
264impl std::fmt::Display for Role {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 match self {
267 Role::System => write!(f, "system"),
268 Role::Developer => write!(f, "developer"),
269 Role::User => write!(f, "user"),
270 Role::Assistant => write!(f, "assistant"),
271 Role::Tool => write!(f, "tool"),
272 }
273 }
274}
275
276#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
278pub struct ChatMessage {
279 pub role: Role,
280 pub content: Option<MessageContent>,
281 #[serde(skip_serializing_if = "Option::is_none")]
282 pub name: Option<String>,
283 #[serde(skip_serializing_if = "Option::is_none")]
284 pub tool_calls: Option<Vec<ToolCall>>,
285 #[serde(skip_serializing_if = "Option::is_none")]
286 pub tool_call_id: Option<String>,
287 #[serde(skip_serializing_if = "Option::is_none")]
288 pub usage: Option<LLMTokenUsage>,
289}
290
291impl ChatMessage {
292 pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
293 messages
294 .iter()
295 .rev()
296 .find(|message| message.role != Role::User && message.role != Role::Tool)
297 }
298
299 pub fn to_xml(&self) -> String {
300 match &self.content {
301 Some(MessageContent::String(s)) => {
302 format!("<message role=\"{}\">{}</message>", self.role, s)
303 }
304 Some(MessageContent::Array(parts)) => parts
305 .iter()
306 .map(|part| {
307 format!(
308 "<message role=\"{}\" type=\"{}\">{}</message>",
309 self.role,
310 part.r#type,
311 part.text.clone().unwrap_or_default()
312 )
313 })
314 .collect::<Vec<String>>()
315 .join("\n"),
316 None => String::new(),
317 }
318 }
319}
320
321#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
323#[serde(untagged)]
324pub enum MessageContent {
325 String(String),
326 Array(Vec<ContentPart>),
327}
328
329impl MessageContent {
330 pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
331 match self {
332 MessageContent::String(s) => MessageContent::String(format!(
333 "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
334 )),
335 MessageContent::Array(parts) => MessageContent::Array(
336 std::iter::once(ContentPart {
337 r#type: "text".to_string(),
338 text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
339 image_url: None,
340 })
341 .chain(parts.iter().cloned())
342 .collect(),
343 ),
344 }
345 }
346
347 pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
348 match self {
349 MessageContent::String(s) => s
350 .rfind("<checkpoint_id>")
351 .and_then(|start| {
352 s[start..]
353 .find("</checkpoint_id>")
354 .map(|end| (start + "<checkpoint_id>".len(), start + end))
355 })
356 .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
357 MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
358 part.text.as_deref().and_then(|text| {
359 text.rfind("<checkpoint_id>")
360 .and_then(|start| {
361 text[start..]
362 .find("</checkpoint_id>")
363 .map(|end| (start + "<checkpoint_id>".len(), start + end))
364 })
365 .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
366 })
367 }),
368 }
369 }
370}
371
372impl std::fmt::Display for MessageContent {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 match self {
375 MessageContent::String(s) => write!(f, "{s}"),
376 MessageContent::Array(parts) => {
377 let text_parts: Vec<String> =
378 parts.iter().filter_map(|part| part.text.clone()).collect();
379 write!(f, "{}", text_parts.join("\n"))
380 }
381 }
382 }
383}
384
385impl Default for MessageContent {
386 fn default() -> Self {
387 MessageContent::String(String::new())
388 }
389}
390
391#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
393pub struct ContentPart {
394 pub r#type: String,
395 #[serde(skip_serializing_if = "Option::is_none")]
396 pub text: Option<String>,
397 #[serde(skip_serializing_if = "Option::is_none")]
398 pub image_url: Option<ImageUrl>,
399}
400
401#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
403pub struct ImageUrl {
404 pub url: String,
405 #[serde(skip_serializing_if = "Option::is_none")]
406 pub detail: Option<String>,
407}
408
409#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
415pub struct Tool {
416 pub r#type: String,
417 pub function: FunctionDefinition,
418}
419
420#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
422pub struct FunctionDefinition {
423 pub name: String,
424 pub description: Option<String>,
425 pub parameters: serde_json::Value,
426}
427
428impl From<Tool> for LLMTool {
429 fn from(tool: Tool) -> Self {
430 LLMTool {
431 name: tool.function.name,
432 description: tool.function.description.unwrap_or_default(),
433 input_schema: tool.function.parameters,
434 }
435 }
436}
437
438#[derive(Debug, Clone, PartialEq)]
440pub enum ToolChoice {
441 Auto,
442 Required,
443 Object(ToolChoiceObject),
444}
445
446impl Serialize for ToolChoice {
447 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
448 where
449 S: serde::Serializer,
450 {
451 match self {
452 ToolChoice::Auto => serializer.serialize_str("auto"),
453 ToolChoice::Required => serializer.serialize_str("required"),
454 ToolChoice::Object(obj) => obj.serialize(serializer),
455 }
456 }
457}
458
459impl<'de> Deserialize<'de> for ToolChoice {
460 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
461 where
462 D: serde::Deserializer<'de>,
463 {
464 struct ToolChoiceVisitor;
465
466 impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
467 type Value = ToolChoice;
468
469 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
470 formatter.write_str("string or object")
471 }
472
473 fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
474 where
475 E: serde::de::Error,
476 {
477 match value {
478 "auto" => Ok(ToolChoice::Auto),
479 "required" => Ok(ToolChoice::Required),
480 _ => Err(serde::de::Error::unknown_variant(
481 value,
482 &["auto", "required"],
483 )),
484 }
485 }
486
487 fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
488 where
489 M: serde::de::MapAccess<'de>,
490 {
491 let obj = ToolChoiceObject::deserialize(
492 serde::de::value::MapAccessDeserializer::new(map),
493 )?;
494 Ok(ToolChoice::Object(obj))
495 }
496 }
497
498 deserializer.deserialize_any(ToolChoiceVisitor)
499 }
500}
501
502#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
503pub struct ToolChoiceObject {
504 pub r#type: String,
505 pub function: FunctionChoice,
506}
507
508#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
509pub struct FunctionChoice {
510 pub name: String,
511}
512
513#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
515pub struct ToolCall {
516 pub id: String,
517 pub r#type: String,
518 pub function: FunctionCall,
519}
520
521#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
523pub struct FunctionCall {
524 pub name: String,
525 pub arguments: String,
526}
527
528#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
530pub enum ToolCallResultStatus {
531 Success,
532 Error,
533 Cancelled,
534}
535
536#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
538pub struct ToolCallResult {
539 pub call: ToolCall,
540 pub result: String,
541 pub status: ToolCallResultStatus,
542}
543
544#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
546pub struct ToolCallResultProgress {
547 pub id: Uuid,
548 pub message: String,
549}
550
551#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
557pub struct ChatCompletionRequest {
558 pub model: String,
559 pub messages: Vec<ChatMessage>,
560 #[serde(skip_serializing_if = "Option::is_none")]
561 pub frequency_penalty: Option<f32>,
562 #[serde(skip_serializing_if = "Option::is_none")]
563 pub logit_bias: Option<serde_json::Value>,
564 #[serde(skip_serializing_if = "Option::is_none")]
565 pub logprobs: Option<bool>,
566 #[serde(skip_serializing_if = "Option::is_none")]
567 pub max_tokens: Option<u32>,
568 #[serde(skip_serializing_if = "Option::is_none")]
569 pub n: Option<u32>,
570 #[serde(skip_serializing_if = "Option::is_none")]
571 pub presence_penalty: Option<f32>,
572 #[serde(skip_serializing_if = "Option::is_none")]
573 pub response_format: Option<ResponseFormat>,
574 #[serde(skip_serializing_if = "Option::is_none")]
575 pub seed: Option<i64>,
576 #[serde(skip_serializing_if = "Option::is_none")]
577 pub stop: Option<StopSequence>,
578 #[serde(skip_serializing_if = "Option::is_none")]
579 pub stream: Option<bool>,
580 #[serde(skip_serializing_if = "Option::is_none")]
581 pub temperature: Option<f32>,
582 #[serde(skip_serializing_if = "Option::is_none")]
583 pub top_p: Option<f32>,
584 #[serde(skip_serializing_if = "Option::is_none")]
585 pub tools: Option<Vec<Tool>>,
586 #[serde(skip_serializing_if = "Option::is_none")]
587 pub tool_choice: Option<ToolChoice>,
588 #[serde(skip_serializing_if = "Option::is_none")]
589 pub user: Option<String>,
590 #[serde(skip_serializing_if = "Option::is_none")]
591 pub context: Option<ChatCompletionContext>,
592}
593
594impl ChatCompletionRequest {
595 pub fn new(
596 model: String,
597 messages: Vec<ChatMessage>,
598 tools: Option<Vec<Tool>>,
599 stream: Option<bool>,
600 ) -> Self {
601 Self {
602 model,
603 messages,
604 frequency_penalty: None,
605 logit_bias: None,
606 logprobs: None,
607 max_tokens: None,
608 n: None,
609 presence_penalty: None,
610 response_format: None,
611 seed: None,
612 stop: None,
613 stream,
614 temperature: None,
615 top_p: None,
616 tools,
617 tool_choice: None,
618 user: None,
619 context: None,
620 }
621 }
622}
623
624#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
625pub struct ChatCompletionContext {
626 pub scratchpad: Option<Value>,
627}
628
629#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
630pub struct ResponseFormat {
631 pub r#type: String,
632}
633
634#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
635#[serde(untagged)]
636pub enum StopSequence {
637 String(String),
638 Array(Vec<String>),
639}
640
641#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
643pub struct ChatCompletionResponse {
644 pub id: String,
645 pub object: String,
646 pub created: u64,
647 pub model: String,
648 pub choices: Vec<ChatCompletionChoice>,
649 pub usage: LLMTokenUsage,
650 #[serde(skip_serializing_if = "Option::is_none")]
651 pub system_fingerprint: Option<String>,
652 pub metadata: Option<serde_json::Value>,
653}
654
655#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
656pub struct ChatCompletionChoice {
657 pub index: usize,
658 pub message: ChatMessage,
659 pub logprobs: Option<LogProbs>,
660 pub finish_reason: FinishReason,
661}
662
663#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
664#[serde(rename_all = "snake_case")]
665pub enum FinishReason {
666 Stop,
667 Length,
668 ContentFilter,
669 ToolCalls,
670}
671
672#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
673pub struct LogProbs {
674 pub content: Option<Vec<LogProbContent>>,
675}
676
677#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
678pub struct LogProbContent {
679 pub token: String,
680 pub logprob: f32,
681 pub bytes: Option<Vec<u8>>,
682 pub top_logprobs: Option<Vec<TokenLogprob>>,
683}
684
685#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
686pub struct TokenLogprob {
687 pub token: String,
688 pub logprob: f32,
689 pub bytes: Option<Vec<u8>>,
690}
691
692#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
697pub struct ChatCompletionStreamResponse {
698 pub id: String,
699 pub object: String,
700 pub created: u64,
701 pub model: String,
702 pub choices: Vec<ChatCompletionStreamChoice>,
703 pub usage: Option<LLMTokenUsage>,
704 pub metadata: Option<serde_json::Value>,
705}
706
707#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
708pub struct ChatCompletionStreamChoice {
709 pub index: usize,
710 pub delta: ChatMessageDelta,
711 pub finish_reason: Option<FinishReason>,
712}
713
714#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
715pub struct ChatMessageDelta {
716 #[serde(skip_serializing_if = "Option::is_none")]
717 pub role: Option<Role>,
718 #[serde(skip_serializing_if = "Option::is_none")]
719 pub content: Option<String>,
720 #[serde(skip_serializing_if = "Option::is_none")]
721 pub tool_calls: Option<Vec<ToolCallDelta>>,
722}
723
724#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
725pub struct ToolCallDelta {
726 pub index: usize,
727 pub id: Option<String>,
728 pub r#type: Option<String>,
729 pub function: Option<FunctionCallDelta>,
730}
731
732#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
733pub struct FunctionCallDelta {
734 pub name: Option<String>,
735 pub arguments: Option<String>,
736}
737
738impl From<LLMMessage> for ChatMessage {
743 fn from(llm_message: LLMMessage) -> Self {
744 let role = match llm_message.role.as_str() {
745 "system" => Role::System,
746 "user" => Role::User,
747 "assistant" => Role::Assistant,
748 "tool" => Role::Tool,
749 "developer" => Role::Developer,
750 _ => Role::User,
751 };
752
753 let (content, tool_calls) = match llm_message.content {
754 LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None),
755 LLMMessageContent::List(items) => {
756 let mut text_parts = Vec::new();
757 let mut tool_call_parts = Vec::new();
758
759 for item in items {
760 match item {
761 LLMMessageTypedContent::Text { text } => {
762 text_parts.push(ContentPart {
763 r#type: "text".to_string(),
764 text: Some(text),
765 image_url: None,
766 });
767 }
768 LLMMessageTypedContent::ToolCall { id, name, args } => {
769 tool_call_parts.push(ToolCall {
770 id,
771 r#type: "function".to_string(),
772 function: FunctionCall {
773 name,
774 arguments: args.to_string(),
775 },
776 });
777 }
778 LLMMessageTypedContent::ToolResult { content, .. } => {
779 text_parts.push(ContentPart {
780 r#type: "text".to_string(),
781 text: Some(content),
782 image_url: None,
783 });
784 }
785 LLMMessageTypedContent::Image { source } => {
786 text_parts.push(ContentPart {
787 r#type: "image_url".to_string(),
788 text: None,
789 image_url: Some(ImageUrl {
790 url: format!(
791 "data:{};base64,{}",
792 source.media_type, source.data
793 ),
794 detail: None,
795 }),
796 });
797 }
798 }
799 }
800
801 let content = if !text_parts.is_empty() {
802 Some(MessageContent::Array(text_parts))
803 } else {
804 None
805 };
806
807 let tool_calls = if !tool_call_parts.is_empty() {
808 Some(tool_call_parts)
809 } else {
810 None
811 };
812
813 (content, tool_calls)
814 }
815 };
816
817 ChatMessage {
818 role,
819 content,
820 name: None,
821 tool_calls,
822 tool_call_id: None,
823 usage: None,
824 }
825 }
826}
827
828impl From<ChatMessage> for LLMMessage {
829 fn from(chat_message: ChatMessage) -> Self {
830 let mut content_parts = Vec::new();
831
832 match chat_message.content {
833 Some(MessageContent::String(s)) => {
834 if !s.is_empty() {
835 content_parts.push(LLMMessageTypedContent::Text { text: s });
836 }
837 }
838 Some(MessageContent::Array(parts)) => {
839 for part in parts {
840 if let Some(text) = part.text {
841 content_parts.push(LLMMessageTypedContent::Text { text });
842 } else if let Some(image_url) = part.image_url {
843 let (media_type, data) = if image_url.url.starts_with("data:") {
844 let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
845 if parts.len() == 2 {
846 let meta = parts[0];
847 let data = parts[1];
848 let media_type = meta
849 .trim_start_matches("data:")
850 .trim_end_matches(";base64")
851 .to_string();
852 (media_type, data.to_string())
853 } else {
854 ("image/jpeg".to_string(), image_url.url)
855 }
856 } else {
857 ("image/jpeg".to_string(), image_url.url)
858 };
859
860 content_parts.push(LLMMessageTypedContent::Image {
861 source: LLMMessageImageSource {
862 r#type: "base64".to_string(),
863 media_type,
864 data,
865 },
866 });
867 }
868 }
869 }
870 None => {}
871 }
872
873 if let Some(tool_calls) = chat_message.tool_calls {
874 for tool_call in tool_calls {
875 let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
876 content_parts.push(LLMMessageTypedContent::ToolCall {
877 id: tool_call.id,
878 name: tool_call.function.name,
879 args,
880 });
881 }
882 }
883
884 LLMMessage {
885 role: chat_message.role.to_string(),
886 content: if content_parts.is_empty() {
887 LLMMessageContent::String(String::new())
888 } else if content_parts.len() == 1 {
889 match &content_parts[0] {
890 LLMMessageTypedContent::Text { text } => {
891 LLMMessageContent::String(text.clone())
892 }
893 _ => LLMMessageContent::List(content_parts),
894 }
895 } else {
896 LLMMessageContent::List(content_parts)
897 },
898 }
899 }
900}
901
902impl From<GenerationDelta> for ChatMessageDelta {
903 fn from(delta: GenerationDelta) -> Self {
904 match delta {
905 GenerationDelta::Content { content } => ChatMessageDelta {
906 role: Some(Role::Assistant),
907 content: Some(content),
908 tool_calls: None,
909 },
910 GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
911 role: Some(Role::Assistant),
912 content: None,
913 tool_calls: None,
914 },
915 GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
916 role: Some(Role::Assistant),
917 content: None,
918 tool_calls: Some(vec![ToolCallDelta {
919 index: tool_use.index,
920 id: tool_use.id,
921 r#type: Some("function".to_string()),
922 function: Some(FunctionCallDelta {
923 name: tool_use.name,
924 arguments: tool_use.input,
925 }),
926 }]),
927 },
928 _ => ChatMessageDelta {
929 role: Some(Role::Assistant),
930 content: None,
931 tool_calls: None,
932 },
933 }
934 }
935}
936
937#[cfg(test)]
938mod tests {
939 use super::*;
940
941 #[test]
942 fn test_serialize_basic_request() {
943 let request = ChatCompletionRequest {
944 model: AgentModel::Smart.to_string(),
945 messages: vec![
946 ChatMessage {
947 role: Role::System,
948 content: Some(MessageContent::String(
949 "You are a helpful assistant.".to_string(),
950 )),
951 name: None,
952 tool_calls: None,
953 tool_call_id: None,
954 usage: None,
955 },
956 ChatMessage {
957 role: Role::User,
958 content: Some(MessageContent::String("Hello!".to_string())),
959 name: None,
960 tool_calls: None,
961 tool_call_id: None,
962 usage: None,
963 },
964 ],
965 frequency_penalty: None,
966 logit_bias: None,
967 logprobs: None,
968 max_tokens: Some(100),
969 n: None,
970 presence_penalty: None,
971 response_format: None,
972 seed: None,
973 stop: None,
974 stream: None,
975 temperature: Some(0.7),
976 top_p: None,
977 tools: None,
978 tool_choice: None,
979 user: None,
980 context: None,
981 };
982
983 let json = serde_json::to_string(&request).unwrap();
984 assert!(json.contains("\"model\":\"smart\""));
985 assert!(json.contains("\"messages\":["));
986 assert!(json.contains("\"role\":\"system\""));
987 }
988
989 #[test]
990 fn test_llm_message_to_chat_message() {
991 let llm_message = LLMMessage {
992 role: "user".to_string(),
993 content: LLMMessageContent::String("Hello, world!".to_string()),
994 };
995
996 let chat_message = ChatMessage::from(llm_message);
997 assert_eq!(chat_message.role, Role::User);
998 match &chat_message.content {
999 Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1000 _ => panic!("Expected string content"),
1001 }
1002 }
1003}