1use crate::models::llm::{
12 GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13 LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27 pub api_endpoint: Option<String>,
28 pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32 pub fn with_api_key(api_key: impl Into<String>) -> Self {
34 Self {
35 api_key: Some(api_key.into()),
36 api_endpoint: None,
37 }
38 }
39
40 pub fn from_provider_auth(auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
42 match auth {
43 crate::models::auth::ProviderAuth::Api { key } => Some(Self::with_api_key(key)),
44 crate::models::auth::ProviderAuth::OAuth { .. } => None, }
46 }
47
48 pub fn with_provider_auth(mut self, auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
50 match auth {
51 crate::models::auth::ProviderAuth::Api { key } => {
52 self.api_key = Some(key.clone());
53 Some(self)
54 }
55 crate::models::auth::ProviderAuth::OAuth { .. } => None, }
57 }
58}
59
60#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
66pub enum OpenAIModel {
67 #[serde(rename = "o3-2025-04-16")]
69 O3,
70 #[serde(rename = "o4-mini-2025-04-16")]
71 O4Mini,
72
73 #[default]
74 #[serde(rename = "gpt-5-2025-08-07")]
75 GPT5,
76 #[serde(rename = "gpt-5.1-2025-11-13")]
77 GPT51,
78 #[serde(rename = "gpt-5-mini-2025-08-07")]
79 GPT5Mini,
80 #[serde(rename = "gpt-5-nano-2025-08-07")]
81 GPT5Nano,
82
83 Custom(String),
84}
85
86impl OpenAIModel {
87 pub fn from_string(s: &str) -> Result<Self, String> {
88 serde_json::from_value(serde_json::Value::String(s.to_string()))
89 .map_err(|_| "Failed to deserialize OpenAI model".to_string())
90 }
91
92 pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
94
95 pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
97
98 pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
100
101 pub fn default_smart_model() -> String {
103 Self::DEFAULT_SMART_MODEL.to_string()
104 }
105
106 pub fn default_eco_model() -> String {
108 Self::DEFAULT_ECO_MODEL.to_string()
109 }
110
111 pub fn default_recovery_model() -> String {
113 Self::DEFAULT_RECOVERY_MODEL.to_string()
114 }
115}
116
117impl ContextAware for OpenAIModel {
118 fn context_info(&self) -> ModelContextInfo {
119 let model_name = self.to_string();
120
121 if model_name.starts_with("o3") {
122 return ModelContextInfo {
123 max_tokens: 200_000,
124 pricing_tiers: vec![ContextPricingTier {
125 label: "Standard".to_string(),
126 input_cost_per_million: 2.0,
127 output_cost_per_million: 8.0,
128 upper_bound: None,
129 }],
130 approach_warning_threshold: 0.8,
131 };
132 }
133
134 if model_name.starts_with("o4-mini") {
135 return ModelContextInfo {
136 max_tokens: 200_000,
137 pricing_tiers: vec![ContextPricingTier {
138 label: "Standard".to_string(),
139 input_cost_per_million: 1.10,
140 output_cost_per_million: 4.40,
141 upper_bound: None,
142 }],
143 approach_warning_threshold: 0.8,
144 };
145 }
146
147 if model_name.starts_with("gpt-5-mini") {
148 return ModelContextInfo {
149 max_tokens: 400_000,
150 pricing_tiers: vec![ContextPricingTier {
151 label: "Standard".to_string(),
152 input_cost_per_million: 0.25,
153 output_cost_per_million: 2.0,
154 upper_bound: None,
155 }],
156 approach_warning_threshold: 0.8,
157 };
158 }
159
160 if model_name.starts_with("gpt-5-nano") {
161 return ModelContextInfo {
162 max_tokens: 400_000,
163 pricing_tiers: vec![ContextPricingTier {
164 label: "Standard".to_string(),
165 input_cost_per_million: 0.05,
166 output_cost_per_million: 0.40,
167 upper_bound: None,
168 }],
169 approach_warning_threshold: 0.8,
170 };
171 }
172
173 if model_name.starts_with("gpt-5") {
174 return ModelContextInfo {
175 max_tokens: 400_000,
176 pricing_tiers: vec![ContextPricingTier {
177 label: "Standard".to_string(),
178 input_cost_per_million: 1.25,
179 output_cost_per_million: 10.0,
180 upper_bound: None,
181 }],
182 approach_warning_threshold: 0.8,
183 };
184 }
185
186 ModelContextInfo::default()
187 }
188
189 fn model_name(&self) -> String {
190 match self {
191 OpenAIModel::O3 => "O3".to_string(),
192 OpenAIModel::O4Mini => "O4-mini".to_string(),
193 OpenAIModel::GPT5 => "GPT-5".to_string(),
194 OpenAIModel::GPT51 => "GPT-5.1".to_string(),
195 OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
196 OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
197 OpenAIModel::Custom(name) => format!("Custom ({})", name),
198 }
199 }
200}
201
202impl std::fmt::Display for OpenAIModel {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 match self {
205 OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
206 OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
207 OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
208 OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
209 OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
210 OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
211 OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
212 }
213 }
214}
215
216#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
222#[serde(rename_all = "lowercase")]
223pub enum Role {
224 System,
225 Developer,
226 User,
227 #[default]
228 Assistant,
229 Tool,
230}
231
232impl std::fmt::Display for Role {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 match self {
235 Role::System => write!(f, "system"),
236 Role::Developer => write!(f, "developer"),
237 Role::User => write!(f, "user"),
238 Role::Assistant => write!(f, "assistant"),
239 Role::Tool => write!(f, "tool"),
240 }
241 }
242}
243
244#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
246pub struct ModelInfo {
247 pub provider: String,
249 pub id: String,
251}
252
253#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
255pub struct ChatMessage {
256 pub role: Role,
257 pub content: Option<MessageContent>,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 pub name: Option<String>,
260 #[serde(skip_serializing_if = "Option::is_none")]
261 pub tool_calls: Option<Vec<ToolCall>>,
262 #[serde(skip_serializing_if = "Option::is_none")]
263 pub tool_call_id: Option<String>,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 pub usage: Option<LLMTokenUsage>,
266
267 #[serde(skip_serializing_if = "Option::is_none")]
270 pub id: Option<String>,
271 #[serde(skip_serializing_if = "Option::is_none")]
273 pub model: Option<ModelInfo>,
274 #[serde(skip_serializing_if = "Option::is_none")]
276 pub cost: Option<f64>,
277 #[serde(skip_serializing_if = "Option::is_none")]
279 pub finish_reason: Option<String>,
280 #[serde(skip_serializing_if = "Option::is_none")]
282 pub created_at: Option<i64>,
283 #[serde(skip_serializing_if = "Option::is_none")]
285 pub completed_at: Option<i64>,
286 #[serde(skip_serializing_if = "Option::is_none")]
288 pub metadata: Option<serde_json::Value>,
289}
290
291impl ChatMessage {
292 pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
293 messages
294 .iter()
295 .rev()
296 .find(|message| message.role != Role::User && message.role != Role::Tool)
297 }
298
299 pub fn to_xml(&self) -> String {
300 match &self.content {
301 Some(MessageContent::String(s)) => {
302 format!("<message role=\"{}\">{}</message>", self.role, s)
303 }
304 Some(MessageContent::Array(parts)) => parts
305 .iter()
306 .map(|part| {
307 format!(
308 "<message role=\"{}\" type=\"{}\">{}</message>",
309 self.role,
310 part.r#type,
311 part.text.clone().unwrap_or_default()
312 )
313 })
314 .collect::<Vec<String>>()
315 .join("\n"),
316 None => String::new(),
317 }
318 }
319}
320
321#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
323#[serde(untagged)]
324pub enum MessageContent {
325 String(String),
326 Array(Vec<ContentPart>),
327}
328
329impl MessageContent {
330 pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
331 match self {
332 MessageContent::String(s) => MessageContent::String(format!(
333 "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
334 )),
335 MessageContent::Array(parts) => MessageContent::Array(
336 std::iter::once(ContentPart {
337 r#type: "text".to_string(),
338 text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
339 image_url: None,
340 })
341 .chain(parts.iter().cloned())
342 .collect(),
343 ),
344 }
345 }
346
347 #[allow(clippy::string_slice)]
349 pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
350 match self {
351 MessageContent::String(s) => s
352 .rfind("<checkpoint_id>")
353 .and_then(|start| {
354 s[start..]
355 .find("</checkpoint_id>")
356 .map(|end| (start + "<checkpoint_id>".len(), start + end))
357 })
358 .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
359 MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
360 part.text.as_deref().and_then(|text| {
361 text.rfind("<checkpoint_id>")
362 .and_then(|start| {
363 text[start..]
364 .find("</checkpoint_id>")
365 .map(|end| (start + "<checkpoint_id>".len(), start + end))
366 })
367 .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
368 })
369 }),
370 }
371 }
372}
373
374impl std::fmt::Display for MessageContent {
375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 match self {
377 MessageContent::String(s) => write!(f, "{s}"),
378 MessageContent::Array(parts) => {
379 let text_parts: Vec<String> =
380 parts.iter().filter_map(|part| part.text.clone()).collect();
381 write!(f, "{}", text_parts.join("\n"))
382 }
383 }
384 }
385}
386
387impl Default for MessageContent {
388 fn default() -> Self {
389 MessageContent::String(String::new())
390 }
391}
392
393#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
395pub struct ContentPart {
396 pub r#type: String,
397 #[serde(skip_serializing_if = "Option::is_none")]
398 pub text: Option<String>,
399 #[serde(skip_serializing_if = "Option::is_none")]
400 pub image_url: Option<ImageUrl>,
401}
402
403#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
405pub struct ImageUrl {
406 pub url: String,
407 #[serde(skip_serializing_if = "Option::is_none")]
408 pub detail: Option<String>,
409}
410
411#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
417pub struct Tool {
418 pub r#type: String,
419 pub function: FunctionDefinition,
420}
421
422#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
424pub struct FunctionDefinition {
425 pub name: String,
426 pub description: Option<String>,
427 pub parameters: serde_json::Value,
428}
429
430impl From<Tool> for LLMTool {
431 fn from(tool: Tool) -> Self {
432 LLMTool {
433 name: tool.function.name,
434 description: tool.function.description.unwrap_or_default(),
435 input_schema: tool.function.parameters,
436 }
437 }
438}
439
440#[derive(Debug, Clone, PartialEq)]
442pub enum ToolChoice {
443 Auto,
444 Required,
445 Object(ToolChoiceObject),
446}
447
448impl Serialize for ToolChoice {
449 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
450 where
451 S: serde::Serializer,
452 {
453 match self {
454 ToolChoice::Auto => serializer.serialize_str("auto"),
455 ToolChoice::Required => serializer.serialize_str("required"),
456 ToolChoice::Object(obj) => obj.serialize(serializer),
457 }
458 }
459}
460
461impl<'de> Deserialize<'de> for ToolChoice {
462 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
463 where
464 D: serde::Deserializer<'de>,
465 {
466 struct ToolChoiceVisitor;
467
468 impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
469 type Value = ToolChoice;
470
471 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
472 formatter.write_str("string or object")
473 }
474
475 fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
476 where
477 E: serde::de::Error,
478 {
479 match value {
480 "auto" => Ok(ToolChoice::Auto),
481 "required" => Ok(ToolChoice::Required),
482 _ => Err(serde::de::Error::unknown_variant(
483 value,
484 &["auto", "required"],
485 )),
486 }
487 }
488
489 fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
490 where
491 M: serde::de::MapAccess<'de>,
492 {
493 let obj = ToolChoiceObject::deserialize(
494 serde::de::value::MapAccessDeserializer::new(map),
495 )?;
496 Ok(ToolChoice::Object(obj))
497 }
498 }
499
500 deserializer.deserialize_any(ToolChoiceVisitor)
501 }
502}
503
504#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
505pub struct ToolChoiceObject {
506 pub r#type: String,
507 pub function: FunctionChoice,
508}
509
510#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
511pub struct FunctionChoice {
512 pub name: String,
513}
514
515#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
517pub struct ToolCall {
518 pub id: String,
519 pub r#type: String,
520 pub function: FunctionCall,
521 #[serde(skip_serializing_if = "Option::is_none")]
523 pub metadata: Option<serde_json::Value>,
524}
525
526#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
528pub struct FunctionCall {
529 pub name: String,
530 pub arguments: String,
531}
532
533#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
535pub enum ToolCallResultStatus {
536 Success,
537 Error,
538 Cancelled,
539}
540
541#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
543pub struct ToolCallResult {
544 pub call: ToolCall,
545 pub result: String,
546 pub status: ToolCallResultStatus,
547}
548
549#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
551pub struct ToolCallStreamInfo {
552 pub name: String,
554 pub args_tokens: usize,
556 #[serde(skip_serializing_if = "Option::is_none")]
558 pub description: Option<String>,
559}
560
561#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
563pub struct ToolCallResultProgress {
564 pub id: Uuid,
565 pub message: String,
566 #[serde(skip_serializing_if = "Option::is_none")]
568 pub progress_type: Option<ProgressType>,
569 #[serde(skip_serializing_if = "Option::is_none")]
571 pub task_updates: Option<Vec<TaskUpdate>>,
572 #[serde(skip_serializing_if = "Option::is_none")]
574 pub progress: Option<f64>,
575}
576
577#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
579pub enum ProgressType {
580 CommandOutput,
582 TaskWait,
584 Generic,
586}
587
588#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
590pub struct TaskUpdate {
591 pub task_id: String,
592 pub status: String,
593 #[serde(skip_serializing_if = "Option::is_none")]
594 pub description: Option<String>,
595 #[serde(skip_serializing_if = "Option::is_none")]
596 pub duration_secs: Option<f64>,
597 #[serde(skip_serializing_if = "Option::is_none")]
598 pub output_preview: Option<String>,
599 #[serde(default)]
601 pub is_target: bool,
602 #[serde(skip_serializing_if = "Option::is_none")]
604 pub pause_info: Option<TaskPauseInfo>,
605}
606
607#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
609pub struct TaskPauseInfo {
610 #[serde(skip_serializing_if = "Option::is_none")]
612 pub agent_message: Option<String>,
613 #[serde(skip_serializing_if = "Option::is_none")]
615 pub pending_tool_calls: Option<Vec<crate::models::async_manifest::PendingToolCall>>,
616}
617
618#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
624pub struct ChatCompletionRequest {
625 pub model: String,
626 pub messages: Vec<ChatMessage>,
627 #[serde(skip_serializing_if = "Option::is_none")]
628 pub frequency_penalty: Option<f32>,
629 #[serde(skip_serializing_if = "Option::is_none")]
630 pub logit_bias: Option<serde_json::Value>,
631 #[serde(skip_serializing_if = "Option::is_none")]
632 pub logprobs: Option<bool>,
633 #[serde(skip_serializing_if = "Option::is_none")]
634 pub max_tokens: Option<u32>,
635 #[serde(skip_serializing_if = "Option::is_none")]
636 pub n: Option<u32>,
637 #[serde(skip_serializing_if = "Option::is_none")]
638 pub presence_penalty: Option<f32>,
639 #[serde(skip_serializing_if = "Option::is_none")]
640 pub response_format: Option<ResponseFormat>,
641 #[serde(skip_serializing_if = "Option::is_none")]
642 pub seed: Option<i64>,
643 #[serde(skip_serializing_if = "Option::is_none")]
644 pub stop: Option<StopSequence>,
645 #[serde(skip_serializing_if = "Option::is_none")]
646 pub stream: Option<bool>,
647 #[serde(skip_serializing_if = "Option::is_none")]
648 pub temperature: Option<f32>,
649 #[serde(skip_serializing_if = "Option::is_none")]
650 pub top_p: Option<f32>,
651 #[serde(skip_serializing_if = "Option::is_none")]
652 pub tools: Option<Vec<Tool>>,
653 #[serde(skip_serializing_if = "Option::is_none")]
654 pub tool_choice: Option<ToolChoice>,
655 #[serde(skip_serializing_if = "Option::is_none")]
656 pub user: Option<String>,
657 #[serde(skip_serializing_if = "Option::is_none")]
658 pub context: Option<ChatCompletionContext>,
659}
660
661impl ChatCompletionRequest {
662 pub fn new(
663 model: String,
664 messages: Vec<ChatMessage>,
665 tools: Option<Vec<Tool>>,
666 stream: Option<bool>,
667 ) -> Self {
668 Self {
669 model,
670 messages,
671 frequency_penalty: None,
672 logit_bias: None,
673 logprobs: None,
674 max_tokens: None,
675 n: None,
676 presence_penalty: None,
677 response_format: None,
678 seed: None,
679 stop: None,
680 stream,
681 temperature: None,
682 top_p: None,
683 tools,
684 tool_choice: None,
685 user: None,
686 context: None,
687 }
688 }
689}
690
691#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
692pub struct ChatCompletionContext {
693 pub scratchpad: Option<Value>,
694}
695
696#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
697pub struct ResponseFormat {
698 pub r#type: String,
699}
700
701#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
702#[serde(untagged)]
703pub enum StopSequence {
704 String(String),
705 Array(Vec<String>),
706}
707
708#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
710pub struct ChatCompletionResponse {
711 pub id: String,
712 pub object: String,
713 pub created: u64,
714 pub model: String,
715 pub choices: Vec<ChatCompletionChoice>,
716 pub usage: LLMTokenUsage,
717 #[serde(skip_serializing_if = "Option::is_none")]
718 pub system_fingerprint: Option<String>,
719 pub metadata: Option<serde_json::Value>,
720}
721
722#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
723pub struct ChatCompletionChoice {
724 pub index: usize,
725 pub message: ChatMessage,
726 pub logprobs: Option<LogProbs>,
727 pub finish_reason: FinishReason,
728}
729
730#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
731#[serde(rename_all = "snake_case")]
732pub enum FinishReason {
733 Stop,
734 Length,
735 ContentFilter,
736 ToolCalls,
737}
738
739#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
740pub struct LogProbs {
741 pub content: Option<Vec<LogProbContent>>,
742}
743
744#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
745pub struct LogProbContent {
746 pub token: String,
747 pub logprob: f32,
748 pub bytes: Option<Vec<u8>>,
749 pub top_logprobs: Option<Vec<TokenLogprob>>,
750}
751
752#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
753pub struct TokenLogprob {
754 pub token: String,
755 pub logprob: f32,
756 pub bytes: Option<Vec<u8>>,
757}
758
759#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
764pub struct ChatCompletionStreamResponse {
765 pub id: String,
766 pub object: String,
767 pub created: u64,
768 pub model: String,
769 pub choices: Vec<ChatCompletionStreamChoice>,
770 pub usage: Option<LLMTokenUsage>,
771 pub metadata: Option<serde_json::Value>,
772}
773
774#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
775pub struct ChatCompletionStreamChoice {
776 pub index: usize,
777 pub delta: ChatMessageDelta,
778 pub finish_reason: Option<FinishReason>,
779}
780
781#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
782pub struct ChatMessageDelta {
783 #[serde(skip_serializing_if = "Option::is_none")]
784 pub role: Option<Role>,
785 #[serde(skip_serializing_if = "Option::is_none")]
786 pub content: Option<String>,
787 #[serde(skip_serializing_if = "Option::is_none")]
788 pub tool_calls: Option<Vec<ToolCallDelta>>,
789}
790
791#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
792pub struct ToolCallDelta {
793 pub index: usize,
794 pub id: Option<String>,
795 pub r#type: Option<String>,
796 pub function: Option<FunctionCallDelta>,
797 #[serde(skip_serializing_if = "Option::is_none")]
799 pub metadata: Option<serde_json::Value>,
800}
801
802#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
803pub struct FunctionCallDelta {
804 pub name: Option<String>,
805 pub arguments: Option<String>,
806}
807
808impl From<LLMMessage> for ChatMessage {
813 fn from(llm_message: LLMMessage) -> Self {
814 let role = match llm_message.role.as_str() {
815 "system" => Role::System,
816 "user" => Role::User,
817 "assistant" => Role::Assistant,
818 "tool" => Role::Tool,
819 "developer" => Role::Developer,
820 _ => Role::User,
821 };
822
823 let (content, tool_calls, tool_call_id) = match llm_message.content {
824 LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None, None),
825 LLMMessageContent::List(items) => {
826 let mut text_parts = Vec::new();
827 let mut tool_call_parts = Vec::new();
828 let mut tool_result_id: Option<String> = None;
829
830 for item in items {
831 match item {
832 LLMMessageTypedContent::Text { text } => {
833 text_parts.push(ContentPart {
834 r#type: "text".to_string(),
835 text: Some(text),
836 image_url: None,
837 });
838 }
839 LLMMessageTypedContent::ToolCall {
840 id,
841 name,
842 args,
843 metadata,
844 } => {
845 tool_call_parts.push(ToolCall {
846 id,
847 r#type: "function".to_string(),
848 function: FunctionCall {
849 name,
850 arguments: args.to_string(),
851 },
852 metadata,
853 });
854 }
855 LLMMessageTypedContent::ToolResult {
856 tool_use_id,
857 content,
858 } => {
859 if tool_result_id.is_none() {
860 tool_result_id = Some(tool_use_id);
861 }
862 text_parts.push(ContentPart {
863 r#type: "text".to_string(),
864 text: Some(content),
865 image_url: None,
866 });
867 }
868 LLMMessageTypedContent::Image { source } => {
869 text_parts.push(ContentPart {
870 r#type: "image_url".to_string(),
871 text: None,
872 image_url: Some(ImageUrl {
873 url: format!(
874 "data:{};base64,{}",
875 source.media_type, source.data
876 ),
877 detail: None,
878 }),
879 });
880 }
881 }
882 }
883
884 let content = if !text_parts.is_empty() {
885 Some(MessageContent::Array(text_parts))
886 } else {
887 None
888 };
889
890 let tool_calls = if !tool_call_parts.is_empty() {
891 Some(tool_call_parts)
892 } else {
893 None
894 };
895
896 (content, tool_calls, tool_result_id)
897 }
898 };
899
900 ChatMessage {
901 role,
902 content,
903 name: None,
904 tool_calls,
905 tool_call_id,
906 usage: None,
907 ..Default::default()
908 }
909 }
910}
911
912impl From<ChatMessage> for LLMMessage {
913 fn from(chat_message: ChatMessage) -> Self {
914 let mut content_parts = Vec::new();
915
916 match chat_message.content {
917 Some(MessageContent::String(s)) => {
918 if !s.is_empty() {
919 content_parts.push(LLMMessageTypedContent::Text { text: s });
920 }
921 }
922 Some(MessageContent::Array(parts)) => {
923 for part in parts {
924 if let Some(text) = part.text {
925 content_parts.push(LLMMessageTypedContent::Text { text });
926 } else if let Some(image_url) = part.image_url {
927 let (media_type, data) = if image_url.url.starts_with("data:") {
928 let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
929 if parts.len() == 2 {
930 let meta = parts[0];
931 let data = parts[1];
932 let media_type = meta
933 .trim_start_matches("data:")
934 .trim_end_matches(";base64")
935 .to_string();
936 (media_type, data.to_string())
937 } else {
938 ("image/jpeg".to_string(), image_url.url)
939 }
940 } else {
941 ("image/jpeg".to_string(), image_url.url)
942 };
943
944 content_parts.push(LLMMessageTypedContent::Image {
945 source: LLMMessageImageSource {
946 r#type: "base64".to_string(),
947 media_type,
948 data,
949 },
950 });
951 }
952 }
953 }
954 None => {}
955 }
956
957 if let Some(tool_calls) = chat_message.tool_calls {
958 for tool_call in tool_calls {
959 let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
960 content_parts.push(LLMMessageTypedContent::ToolCall {
961 id: tool_call.id,
962 name: tool_call.function.name,
963 args,
964 metadata: tool_call.metadata,
965 });
966 }
967 }
968
969 if chat_message.role == Role::Tool
974 && let Some(tool_call_id) = chat_message.tool_call_id
975 {
976 let content_str = content_parts
978 .iter()
979 .filter_map(|p| match p {
980 LLMMessageTypedContent::Text { text } => Some(text.clone()),
981 _ => None,
982 })
983 .collect::<Vec<_>>()
984 .join("\n");
985
986 content_parts = vec![LLMMessageTypedContent::ToolResult {
988 tool_use_id: tool_call_id,
989 content: content_str,
990 }];
991 }
992
993 LLMMessage {
994 role: chat_message.role.to_string(),
995 content: if content_parts.is_empty() {
996 LLMMessageContent::String(String::new())
997 } else if content_parts.len() == 1 {
998 match &content_parts[0] {
999 LLMMessageTypedContent::Text { text } => {
1000 LLMMessageContent::String(text.clone())
1001 }
1002 _ => LLMMessageContent::List(content_parts),
1003 }
1004 } else {
1005 LLMMessageContent::List(content_parts)
1006 },
1007 }
1008 }
1009}
1010
1011impl From<GenerationDelta> for ChatMessageDelta {
1012 fn from(delta: GenerationDelta) -> Self {
1013 match delta {
1014 GenerationDelta::Content { content } => ChatMessageDelta {
1015 role: Some(Role::Assistant),
1016 content: Some(content),
1017 tool_calls: None,
1018 },
1019 GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1020 role: Some(Role::Assistant),
1021 content: None,
1022 tool_calls: None,
1023 },
1024 GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1025 role: Some(Role::Assistant),
1026 content: None,
1027 tool_calls: Some(vec![ToolCallDelta {
1028 index: tool_use.index,
1029 id: tool_use.id,
1030 r#type: Some("function".to_string()),
1031 function: Some(FunctionCallDelta {
1032 name: tool_use.name,
1033 arguments: tool_use.input,
1034 }),
1035 metadata: tool_use.metadata,
1036 }]),
1037 },
1038 _ => ChatMessageDelta {
1039 role: Some(Role::Assistant),
1040 content: None,
1041 tool_calls: None,
1042 },
1043 }
1044 }
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049 use super::*;
1050
1051 #[test]
1052 fn test_serialize_basic_request() {
1053 let request = ChatCompletionRequest {
1054 model: "gpt-4".to_string(),
1055 messages: vec![
1056 ChatMessage {
1057 role: Role::System,
1058 content: Some(MessageContent::String(
1059 "You are a helpful assistant.".to_string(),
1060 )),
1061 name: None,
1062 tool_calls: None,
1063 tool_call_id: None,
1064 usage: None,
1065 ..Default::default()
1066 },
1067 ChatMessage {
1068 role: Role::User,
1069 content: Some(MessageContent::String("Hello!".to_string())),
1070 name: None,
1071 tool_calls: None,
1072 tool_call_id: None,
1073 usage: None,
1074 ..Default::default()
1075 },
1076 ],
1077 frequency_penalty: None,
1078 logit_bias: None,
1079 logprobs: None,
1080 max_tokens: Some(100),
1081 n: None,
1082 presence_penalty: None,
1083 response_format: None,
1084 seed: None,
1085 stop: None,
1086 stream: None,
1087 temperature: Some(0.7),
1088 top_p: None,
1089 tools: None,
1090 tool_choice: None,
1091 user: None,
1092 context: None,
1093 };
1094
1095 let json = serde_json::to_string(&request).unwrap();
1096 assert!(json.contains("\"model\":\"gpt-4\""));
1097 assert!(json.contains("\"messages\":["));
1098 assert!(json.contains("\"role\":\"system\""));
1099 }
1100
1101 #[test]
1102 fn test_llm_message_to_chat_message() {
1103 let llm_message = LLMMessage {
1104 role: "user".to_string(),
1105 content: LLMMessageContent::String("Hello, world!".to_string()),
1106 };
1107
1108 let chat_message = ChatMessage::from(llm_message);
1109 assert_eq!(chat_message.role, Role::User);
1110 match &chat_message.content {
1111 Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1112 _ => panic!("Expected string content"),
1113 }
1114 }
1115
1116 #[test]
1117 fn test_llm_message_to_chat_message_tool_result_preserves_tool_call_id() {
1118 let llm_message = LLMMessage {
1119 role: "tool".to_string(),
1120 content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolResult {
1121 tool_use_id: "toolu_01Abc123".to_string(),
1122 content: "Tool execution result".to_string(),
1123 }]),
1124 };
1125
1126 let chat_message = ChatMessage::from(llm_message);
1127 assert_eq!(chat_message.role, Role::Tool);
1128 assert_eq!(chat_message.tool_call_id.as_deref(), Some("toolu_01Abc123"));
1129 assert_eq!(
1130 chat_message.content,
1131 Some(MessageContent::Array(vec![ContentPart {
1132 r#type: "text".to_string(),
1133 text: Some("Tool execution result".to_string()),
1134 image_url: None,
1135 }]))
1136 );
1137 }
1138
1139 #[test]
1140 fn test_chat_message_to_llm_message_tool_result() {
1141 let chat_message = ChatMessage {
1145 role: Role::Tool,
1146 content: Some(MessageContent::String("Tool execution result".to_string())),
1147 name: None,
1148 tool_calls: None,
1149 tool_call_id: Some("toolu_01Abc123".to_string()),
1150 usage: None,
1151 ..Default::default()
1152 };
1153
1154 let llm_message: LLMMessage = chat_message.into();
1155
1156 assert_eq!(llm_message.role, "tool");
1158
1159 match &llm_message.content {
1161 LLMMessageContent::List(parts) => {
1162 assert_eq!(parts.len(), 1, "Should have exactly one content part");
1163 match &parts[0] {
1164 LLMMessageTypedContent::ToolResult {
1165 tool_use_id,
1166 content,
1167 } => {
1168 assert_eq!(tool_use_id, "toolu_01Abc123");
1169 assert_eq!(content, "Tool execution result");
1170 }
1171 _ => panic!("Expected ToolResult content part, got {:?}", parts[0]),
1172 }
1173 }
1174 _ => panic!(
1175 "Expected List content with ToolResult, got {:?}",
1176 llm_message.content
1177 ),
1178 }
1179 }
1180
1181 #[test]
1182 fn test_chat_message_to_llm_message_tool_result_empty_content() {
1183 let chat_message = ChatMessage {
1185 role: Role::Tool,
1186 content: None,
1187 name: None,
1188 tool_calls: None,
1189 tool_call_id: Some("toolu_02Xyz789".to_string()),
1190 usage: None,
1191 ..Default::default()
1192 };
1193
1194 let llm_message: LLMMessage = chat_message.into();
1195
1196 assert_eq!(llm_message.role, "tool");
1197 match &llm_message.content {
1198 LLMMessageContent::List(parts) => {
1199 assert_eq!(parts.len(), 1);
1200 match &parts[0] {
1201 LLMMessageTypedContent::ToolResult {
1202 tool_use_id,
1203 content,
1204 } => {
1205 assert_eq!(tool_use_id, "toolu_02Xyz789");
1206 assert_eq!(content, ""); }
1208 _ => panic!("Expected ToolResult content part"),
1209 }
1210 }
1211 _ => panic!("Expected List content with ToolResult"),
1212 }
1213 }
1214
1215 #[test]
1216 fn test_chat_message_to_llm_message_assistant_with_tool_calls() {
1217 let chat_message = ChatMessage {
1219 role: Role::Assistant,
1220 content: Some(MessageContent::String(
1221 "I'll help you with that.".to_string(),
1222 )),
1223 name: None,
1224 tool_calls: Some(vec![ToolCall {
1225 id: "call_abc123".to_string(),
1226 r#type: "function".to_string(),
1227 function: FunctionCall {
1228 name: "get_weather".to_string(),
1229 arguments: r#"{"location": "Paris"}"#.to_string(),
1230 },
1231 metadata: None,
1232 }]),
1233 tool_call_id: None,
1234 usage: None,
1235 ..Default::default()
1236 };
1237
1238 let llm_message: LLMMessage = chat_message.into();
1239
1240 assert_eq!(llm_message.role, "assistant");
1241 match &llm_message.content {
1242 LLMMessageContent::List(parts) => {
1243 assert_eq!(parts.len(), 2, "Should have text and tool call");
1244
1245 match &parts[0] {
1247 LLMMessageTypedContent::Text { text } => {
1248 assert_eq!(text, "I'll help you with that.");
1249 }
1250 _ => panic!("Expected Text content part first"),
1251 }
1252
1253 match &parts[1] {
1255 LLMMessageTypedContent::ToolCall { id, name, args, .. } => {
1256 assert_eq!(id, "call_abc123");
1257 assert_eq!(name, "get_weather");
1258 assert_eq!(args["location"], "Paris");
1259 }
1260 _ => panic!("Expected ToolCall content part second"),
1261 }
1262 }
1263 _ => panic!("Expected List content"),
1264 }
1265 }
1266}