1use crate::models::llm::{
12 GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13 LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27 pub api_endpoint: Option<String>,
28 pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32 pub fn with_api_key(api_key: impl Into<String>) -> Self {
34 Self {
35 api_key: Some(api_key.into()),
36 api_endpoint: None,
37 }
38 }
39
40 pub fn from_provider_auth(auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
42 match auth {
43 crate::models::auth::ProviderAuth::Api { key } => Some(Self::with_api_key(key)),
44 crate::models::auth::ProviderAuth::OAuth { .. } => None, }
46 }
47
48 pub fn with_provider_auth(mut self, auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
50 match auth {
51 crate::models::auth::ProviderAuth::Api { key } => {
52 self.api_key = Some(key.clone());
53 Some(self)
54 }
55 crate::models::auth::ProviderAuth::OAuth { .. } => None, }
57 }
58}
59
60#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
66pub enum OpenAIModel {
67 #[serde(rename = "o3-2025-04-16")]
69 O3,
70 #[serde(rename = "o4-mini-2025-04-16")]
71 O4Mini,
72
73 #[default]
74 #[serde(rename = "gpt-5-2025-08-07")]
75 GPT5,
76 #[serde(rename = "gpt-5.1-2025-11-13")]
77 GPT51,
78 #[serde(rename = "gpt-5-mini-2025-08-07")]
79 GPT5Mini,
80 #[serde(rename = "gpt-5-nano-2025-08-07")]
81 GPT5Nano,
82
83 Custom(String),
84}
85
86impl OpenAIModel {
87 pub fn from_string(s: &str) -> Result<Self, String> {
88 serde_json::from_value(serde_json::Value::String(s.to_string()))
89 .map_err(|_| "Failed to deserialize OpenAI model".to_string())
90 }
91
92 pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
94
95 pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
97
98 pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
100
101 pub fn default_smart_model() -> String {
103 Self::DEFAULT_SMART_MODEL.to_string()
104 }
105
106 pub fn default_eco_model() -> String {
108 Self::DEFAULT_ECO_MODEL.to_string()
109 }
110
111 pub fn default_recovery_model() -> String {
113 Self::DEFAULT_RECOVERY_MODEL.to_string()
114 }
115}
116
117impl ContextAware for OpenAIModel {
118 fn context_info(&self) -> ModelContextInfo {
119 let model_name = self.to_string();
120
121 if model_name.starts_with("o3") {
122 return ModelContextInfo {
123 max_tokens: 200_000,
124 pricing_tiers: vec![ContextPricingTier {
125 label: "Standard".to_string(),
126 input_cost_per_million: 2.0,
127 output_cost_per_million: 8.0,
128 upper_bound: None,
129 }],
130 approach_warning_threshold: 0.8,
131 };
132 }
133
134 if model_name.starts_with("o4-mini") {
135 return ModelContextInfo {
136 max_tokens: 200_000,
137 pricing_tiers: vec![ContextPricingTier {
138 label: "Standard".to_string(),
139 input_cost_per_million: 1.10,
140 output_cost_per_million: 4.40,
141 upper_bound: None,
142 }],
143 approach_warning_threshold: 0.8,
144 };
145 }
146
147 if model_name.starts_with("gpt-5-mini") {
148 return ModelContextInfo {
149 max_tokens: 400_000,
150 pricing_tiers: vec![ContextPricingTier {
151 label: "Standard".to_string(),
152 input_cost_per_million: 0.25,
153 output_cost_per_million: 2.0,
154 upper_bound: None,
155 }],
156 approach_warning_threshold: 0.8,
157 };
158 }
159
160 if model_name.starts_with("gpt-5-nano") {
161 return ModelContextInfo {
162 max_tokens: 400_000,
163 pricing_tiers: vec![ContextPricingTier {
164 label: "Standard".to_string(),
165 input_cost_per_million: 0.05,
166 output_cost_per_million: 0.40,
167 upper_bound: None,
168 }],
169 approach_warning_threshold: 0.8,
170 };
171 }
172
173 if model_name.starts_with("gpt-5") {
174 return ModelContextInfo {
175 max_tokens: 400_000,
176 pricing_tiers: vec![ContextPricingTier {
177 label: "Standard".to_string(),
178 input_cost_per_million: 1.25,
179 output_cost_per_million: 10.0,
180 upper_bound: None,
181 }],
182 approach_warning_threshold: 0.8,
183 };
184 }
185
186 ModelContextInfo::default()
187 }
188
189 fn model_name(&self) -> String {
190 match self {
191 OpenAIModel::O3 => "O3".to_string(),
192 OpenAIModel::O4Mini => "O4-mini".to_string(),
193 OpenAIModel::GPT5 => "GPT-5".to_string(),
194 OpenAIModel::GPT51 => "GPT-5.1".to_string(),
195 OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
196 OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
197 OpenAIModel::Custom(name) => format!("Custom ({})", name),
198 }
199 }
200}
201
202impl std::fmt::Display for OpenAIModel {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 match self {
205 OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
206 OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
207 OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
208 OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
209 OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
210 OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
211 OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
212 }
213 }
214}
215
216#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
222#[serde(rename_all = "lowercase")]
223pub enum Role {
224 System,
225 Developer,
226 User,
227 #[default]
228 Assistant,
229 Tool,
230}
231
232impl std::fmt::Display for Role {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 match self {
235 Role::System => write!(f, "system"),
236 Role::Developer => write!(f, "developer"),
237 Role::User => write!(f, "user"),
238 Role::Assistant => write!(f, "assistant"),
239 Role::Tool => write!(f, "tool"),
240 }
241 }
242}
243
244#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
246pub struct ModelInfo {
247 pub provider: String,
249 pub id: String,
251}
252
253#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
255pub struct ChatMessage {
256 pub role: Role,
257 pub content: Option<MessageContent>,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 pub name: Option<String>,
260 #[serde(skip_serializing_if = "Option::is_none")]
261 pub tool_calls: Option<Vec<ToolCall>>,
262 #[serde(skip_serializing_if = "Option::is_none")]
263 pub tool_call_id: Option<String>,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 pub usage: Option<LLMTokenUsage>,
266
267 #[serde(skip_serializing_if = "Option::is_none")]
270 pub id: Option<String>,
271 #[serde(skip_serializing_if = "Option::is_none")]
273 pub model: Option<ModelInfo>,
274 #[serde(skip_serializing_if = "Option::is_none")]
276 pub cost: Option<f64>,
277 #[serde(skip_serializing_if = "Option::is_none")]
279 pub finish_reason: Option<String>,
280 #[serde(skip_serializing_if = "Option::is_none")]
282 pub created_at: Option<i64>,
283 #[serde(skip_serializing_if = "Option::is_none")]
285 pub completed_at: Option<i64>,
286 #[serde(skip_serializing_if = "Option::is_none")]
288 pub metadata: Option<serde_json::Value>,
289}
290
291impl ChatMessage {
292 pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
293 messages
294 .iter()
295 .rev()
296 .find(|message| message.role != Role::User && message.role != Role::Tool)
297 }
298
299 pub fn to_xml(&self) -> String {
300 match &self.content {
301 Some(MessageContent::String(s)) => {
302 format!("<message role=\"{}\">{}</message>", self.role, s)
303 }
304 Some(MessageContent::Array(parts)) => parts
305 .iter()
306 .map(|part| {
307 format!(
308 "<message role=\"{}\" type=\"{}\">{}</message>",
309 self.role,
310 part.r#type,
311 part.text.clone().unwrap_or_default()
312 )
313 })
314 .collect::<Vec<String>>()
315 .join("\n"),
316 None => String::new(),
317 }
318 }
319}
320
321#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
323#[serde(untagged)]
324pub enum MessageContent {
325 String(String),
326 Array(Vec<ContentPart>),
327}
328
329impl MessageContent {
330 pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
331 match self {
332 MessageContent::String(s) => MessageContent::String(format!(
333 "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
334 )),
335 MessageContent::Array(parts) => MessageContent::Array(
336 std::iter::once(ContentPart {
337 r#type: "text".to_string(),
338 text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
339 image_url: None,
340 })
341 .chain(parts.iter().cloned())
342 .collect(),
343 ),
344 }
345 }
346
347 pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
348 match self {
349 MessageContent::String(s) => s
350 .rfind("<checkpoint_id>")
351 .and_then(|start| {
352 s[start..]
353 .find("</checkpoint_id>")
354 .map(|end| (start + "<checkpoint_id>".len(), start + end))
355 })
356 .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
357 MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
358 part.text.as_deref().and_then(|text| {
359 text.rfind("<checkpoint_id>")
360 .and_then(|start| {
361 text[start..]
362 .find("</checkpoint_id>")
363 .map(|end| (start + "<checkpoint_id>".len(), start + end))
364 })
365 .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
366 })
367 }),
368 }
369 }
370}
371
372impl std::fmt::Display for MessageContent {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 match self {
375 MessageContent::String(s) => write!(f, "{s}"),
376 MessageContent::Array(parts) => {
377 let text_parts: Vec<String> =
378 parts.iter().filter_map(|part| part.text.clone()).collect();
379 write!(f, "{}", text_parts.join("\n"))
380 }
381 }
382 }
383}
384
385impl Default for MessageContent {
386 fn default() -> Self {
387 MessageContent::String(String::new())
388 }
389}
390
391#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
393pub struct ContentPart {
394 pub r#type: String,
395 #[serde(skip_serializing_if = "Option::is_none")]
396 pub text: Option<String>,
397 #[serde(skip_serializing_if = "Option::is_none")]
398 pub image_url: Option<ImageUrl>,
399}
400
401#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
403pub struct ImageUrl {
404 pub url: String,
405 #[serde(skip_serializing_if = "Option::is_none")]
406 pub detail: Option<String>,
407}
408
409#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
415pub struct Tool {
416 pub r#type: String,
417 pub function: FunctionDefinition,
418}
419
420#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
422pub struct FunctionDefinition {
423 pub name: String,
424 pub description: Option<String>,
425 pub parameters: serde_json::Value,
426}
427
428impl From<Tool> for LLMTool {
429 fn from(tool: Tool) -> Self {
430 LLMTool {
431 name: tool.function.name,
432 description: tool.function.description.unwrap_or_default(),
433 input_schema: tool.function.parameters,
434 }
435 }
436}
437
438#[derive(Debug, Clone, PartialEq)]
440pub enum ToolChoice {
441 Auto,
442 Required,
443 Object(ToolChoiceObject),
444}
445
446impl Serialize for ToolChoice {
447 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
448 where
449 S: serde::Serializer,
450 {
451 match self {
452 ToolChoice::Auto => serializer.serialize_str("auto"),
453 ToolChoice::Required => serializer.serialize_str("required"),
454 ToolChoice::Object(obj) => obj.serialize(serializer),
455 }
456 }
457}
458
459impl<'de> Deserialize<'de> for ToolChoice {
460 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
461 where
462 D: serde::Deserializer<'de>,
463 {
464 struct ToolChoiceVisitor;
465
466 impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
467 type Value = ToolChoice;
468
469 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
470 formatter.write_str("string or object")
471 }
472
473 fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
474 where
475 E: serde::de::Error,
476 {
477 match value {
478 "auto" => Ok(ToolChoice::Auto),
479 "required" => Ok(ToolChoice::Required),
480 _ => Err(serde::de::Error::unknown_variant(
481 value,
482 &["auto", "required"],
483 )),
484 }
485 }
486
487 fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
488 where
489 M: serde::de::MapAccess<'de>,
490 {
491 let obj = ToolChoiceObject::deserialize(
492 serde::de::value::MapAccessDeserializer::new(map),
493 )?;
494 Ok(ToolChoice::Object(obj))
495 }
496 }
497
498 deserializer.deserialize_any(ToolChoiceVisitor)
499 }
500}
501
502#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
503pub struct ToolChoiceObject {
504 pub r#type: String,
505 pub function: FunctionChoice,
506}
507
508#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
509pub struct FunctionChoice {
510 pub name: String,
511}
512
513#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
515pub struct ToolCall {
516 pub id: String,
517 pub r#type: String,
518 pub function: FunctionCall,
519 #[serde(skip_serializing_if = "Option::is_none")]
521 pub metadata: Option<serde_json::Value>,
522}
523
524#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
526pub struct FunctionCall {
527 pub name: String,
528 pub arguments: String,
529}
530
531#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
533pub enum ToolCallResultStatus {
534 Success,
535 Error,
536 Cancelled,
537}
538
539#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
541pub struct ToolCallResult {
542 pub call: ToolCall,
543 pub result: String,
544 pub status: ToolCallResultStatus,
545}
546
547#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
549pub struct ToolCallStreamInfo {
550 pub name: String,
552 pub args_tokens: usize,
554 #[serde(skip_serializing_if = "Option::is_none")]
556 pub description: Option<String>,
557}
558
559#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
561pub struct ToolCallResultProgress {
562 pub id: Uuid,
563 pub message: String,
564 #[serde(skip_serializing_if = "Option::is_none")]
566 pub progress_type: Option<ProgressType>,
567 #[serde(skip_serializing_if = "Option::is_none")]
569 pub task_updates: Option<Vec<TaskUpdate>>,
570 #[serde(skip_serializing_if = "Option::is_none")]
572 pub progress: Option<f64>,
573}
574
575#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
577pub enum ProgressType {
578 CommandOutput,
580 TaskWait,
582 Generic,
584}
585
586#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
588pub struct TaskUpdate {
589 pub task_id: String,
590 pub status: String,
591 #[serde(skip_serializing_if = "Option::is_none")]
592 pub description: Option<String>,
593 #[serde(skip_serializing_if = "Option::is_none")]
594 pub duration_secs: Option<f64>,
595 #[serde(skip_serializing_if = "Option::is_none")]
596 pub output_preview: Option<String>,
597 #[serde(default)]
599 pub is_target: bool,
600 #[serde(skip_serializing_if = "Option::is_none")]
602 pub pause_info: Option<TaskPauseInfo>,
603}
604
605#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
607pub struct TaskPauseInfo {
608 #[serde(skip_serializing_if = "Option::is_none")]
610 pub agent_message: Option<String>,
611 #[serde(skip_serializing_if = "Option::is_none")]
613 pub pending_tool_calls: Option<Vec<crate::models::async_manifest::PendingToolCall>>,
614}
615
616#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
622pub struct ChatCompletionRequest {
623 pub model: String,
624 pub messages: Vec<ChatMessage>,
625 #[serde(skip_serializing_if = "Option::is_none")]
626 pub frequency_penalty: Option<f32>,
627 #[serde(skip_serializing_if = "Option::is_none")]
628 pub logit_bias: Option<serde_json::Value>,
629 #[serde(skip_serializing_if = "Option::is_none")]
630 pub logprobs: Option<bool>,
631 #[serde(skip_serializing_if = "Option::is_none")]
632 pub max_tokens: Option<u32>,
633 #[serde(skip_serializing_if = "Option::is_none")]
634 pub n: Option<u32>,
635 #[serde(skip_serializing_if = "Option::is_none")]
636 pub presence_penalty: Option<f32>,
637 #[serde(skip_serializing_if = "Option::is_none")]
638 pub response_format: Option<ResponseFormat>,
639 #[serde(skip_serializing_if = "Option::is_none")]
640 pub seed: Option<i64>,
641 #[serde(skip_serializing_if = "Option::is_none")]
642 pub stop: Option<StopSequence>,
643 #[serde(skip_serializing_if = "Option::is_none")]
644 pub stream: Option<bool>,
645 #[serde(skip_serializing_if = "Option::is_none")]
646 pub temperature: Option<f32>,
647 #[serde(skip_serializing_if = "Option::is_none")]
648 pub top_p: Option<f32>,
649 #[serde(skip_serializing_if = "Option::is_none")]
650 pub tools: Option<Vec<Tool>>,
651 #[serde(skip_serializing_if = "Option::is_none")]
652 pub tool_choice: Option<ToolChoice>,
653 #[serde(skip_serializing_if = "Option::is_none")]
654 pub user: Option<String>,
655 #[serde(skip_serializing_if = "Option::is_none")]
656 pub context: Option<ChatCompletionContext>,
657}
658
659impl ChatCompletionRequest {
660 pub fn new(
661 model: String,
662 messages: Vec<ChatMessage>,
663 tools: Option<Vec<Tool>>,
664 stream: Option<bool>,
665 ) -> Self {
666 Self {
667 model,
668 messages,
669 frequency_penalty: None,
670 logit_bias: None,
671 logprobs: None,
672 max_tokens: None,
673 n: None,
674 presence_penalty: None,
675 response_format: None,
676 seed: None,
677 stop: None,
678 stream,
679 temperature: None,
680 top_p: None,
681 tools,
682 tool_choice: None,
683 user: None,
684 context: None,
685 }
686 }
687}
688
689#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
690pub struct ChatCompletionContext {
691 pub scratchpad: Option<Value>,
692}
693
694#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
695pub struct ResponseFormat {
696 pub r#type: String,
697}
698
699#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
700#[serde(untagged)]
701pub enum StopSequence {
702 String(String),
703 Array(Vec<String>),
704}
705
706#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
708pub struct ChatCompletionResponse {
709 pub id: String,
710 pub object: String,
711 pub created: u64,
712 pub model: String,
713 pub choices: Vec<ChatCompletionChoice>,
714 pub usage: LLMTokenUsage,
715 #[serde(skip_serializing_if = "Option::is_none")]
716 pub system_fingerprint: Option<String>,
717 pub metadata: Option<serde_json::Value>,
718}
719
720#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
721pub struct ChatCompletionChoice {
722 pub index: usize,
723 pub message: ChatMessage,
724 pub logprobs: Option<LogProbs>,
725 pub finish_reason: FinishReason,
726}
727
728#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
729#[serde(rename_all = "snake_case")]
730pub enum FinishReason {
731 Stop,
732 Length,
733 ContentFilter,
734 ToolCalls,
735}
736
737#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
738pub struct LogProbs {
739 pub content: Option<Vec<LogProbContent>>,
740}
741
742#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
743pub struct LogProbContent {
744 pub token: String,
745 pub logprob: f32,
746 pub bytes: Option<Vec<u8>>,
747 pub top_logprobs: Option<Vec<TokenLogprob>>,
748}
749
750#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
751pub struct TokenLogprob {
752 pub token: String,
753 pub logprob: f32,
754 pub bytes: Option<Vec<u8>>,
755}
756
757#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
762pub struct ChatCompletionStreamResponse {
763 pub id: String,
764 pub object: String,
765 pub created: u64,
766 pub model: String,
767 pub choices: Vec<ChatCompletionStreamChoice>,
768 pub usage: Option<LLMTokenUsage>,
769 pub metadata: Option<serde_json::Value>,
770}
771
772#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
773pub struct ChatCompletionStreamChoice {
774 pub index: usize,
775 pub delta: ChatMessageDelta,
776 pub finish_reason: Option<FinishReason>,
777}
778
779#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
780pub struct ChatMessageDelta {
781 #[serde(skip_serializing_if = "Option::is_none")]
782 pub role: Option<Role>,
783 #[serde(skip_serializing_if = "Option::is_none")]
784 pub content: Option<String>,
785 #[serde(skip_serializing_if = "Option::is_none")]
786 pub tool_calls: Option<Vec<ToolCallDelta>>,
787}
788
789#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
790pub struct ToolCallDelta {
791 pub index: usize,
792 pub id: Option<String>,
793 pub r#type: Option<String>,
794 pub function: Option<FunctionCallDelta>,
795 #[serde(skip_serializing_if = "Option::is_none")]
797 pub metadata: Option<serde_json::Value>,
798}
799
800#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
801pub struct FunctionCallDelta {
802 pub name: Option<String>,
803 pub arguments: Option<String>,
804}
805
806impl From<LLMMessage> for ChatMessage {
811 fn from(llm_message: LLMMessage) -> Self {
812 let role = match llm_message.role.as_str() {
813 "system" => Role::System,
814 "user" => Role::User,
815 "assistant" => Role::Assistant,
816 "tool" => Role::Tool,
817 "developer" => Role::Developer,
818 _ => Role::User,
819 };
820
821 let (content, tool_calls, tool_call_id) = match llm_message.content {
822 LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None, None),
823 LLMMessageContent::List(items) => {
824 let mut text_parts = Vec::new();
825 let mut tool_call_parts = Vec::new();
826 let mut tool_result_id: Option<String> = None;
827
828 for item in items {
829 match item {
830 LLMMessageTypedContent::Text { text } => {
831 text_parts.push(ContentPart {
832 r#type: "text".to_string(),
833 text: Some(text),
834 image_url: None,
835 });
836 }
837 LLMMessageTypedContent::ToolCall {
838 id,
839 name,
840 args,
841 metadata,
842 } => {
843 tool_call_parts.push(ToolCall {
844 id,
845 r#type: "function".to_string(),
846 function: FunctionCall {
847 name,
848 arguments: args.to_string(),
849 },
850 metadata,
851 });
852 }
853 LLMMessageTypedContent::ToolResult {
854 tool_use_id,
855 content,
856 } => {
857 if tool_result_id.is_none() {
858 tool_result_id = Some(tool_use_id);
859 }
860 text_parts.push(ContentPart {
861 r#type: "text".to_string(),
862 text: Some(content),
863 image_url: None,
864 });
865 }
866 LLMMessageTypedContent::Image { source } => {
867 text_parts.push(ContentPart {
868 r#type: "image_url".to_string(),
869 text: None,
870 image_url: Some(ImageUrl {
871 url: format!(
872 "data:{};base64,{}",
873 source.media_type, source.data
874 ),
875 detail: None,
876 }),
877 });
878 }
879 }
880 }
881
882 let content = if !text_parts.is_empty() {
883 Some(MessageContent::Array(text_parts))
884 } else {
885 None
886 };
887
888 let tool_calls = if !tool_call_parts.is_empty() {
889 Some(tool_call_parts)
890 } else {
891 None
892 };
893
894 (content, tool_calls, tool_result_id)
895 }
896 };
897
898 ChatMessage {
899 role,
900 content,
901 name: None,
902 tool_calls,
903 tool_call_id,
904 usage: None,
905 ..Default::default()
906 }
907 }
908}
909
910impl From<ChatMessage> for LLMMessage {
911 fn from(chat_message: ChatMessage) -> Self {
912 let mut content_parts = Vec::new();
913
914 match chat_message.content {
915 Some(MessageContent::String(s)) => {
916 if !s.is_empty() {
917 content_parts.push(LLMMessageTypedContent::Text { text: s });
918 }
919 }
920 Some(MessageContent::Array(parts)) => {
921 for part in parts {
922 if let Some(text) = part.text {
923 content_parts.push(LLMMessageTypedContent::Text { text });
924 } else if let Some(image_url) = part.image_url {
925 let (media_type, data) = if image_url.url.starts_with("data:") {
926 let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
927 if parts.len() == 2 {
928 let meta = parts[0];
929 let data = parts[1];
930 let media_type = meta
931 .trim_start_matches("data:")
932 .trim_end_matches(";base64")
933 .to_string();
934 (media_type, data.to_string())
935 } else {
936 ("image/jpeg".to_string(), image_url.url)
937 }
938 } else {
939 ("image/jpeg".to_string(), image_url.url)
940 };
941
942 content_parts.push(LLMMessageTypedContent::Image {
943 source: LLMMessageImageSource {
944 r#type: "base64".to_string(),
945 media_type,
946 data,
947 },
948 });
949 }
950 }
951 }
952 None => {}
953 }
954
955 if let Some(tool_calls) = chat_message.tool_calls {
956 for tool_call in tool_calls {
957 let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
958 content_parts.push(LLMMessageTypedContent::ToolCall {
959 id: tool_call.id,
960 name: tool_call.function.name,
961 args,
962 metadata: tool_call.metadata,
963 });
964 }
965 }
966
967 if chat_message.role == Role::Tool
972 && let Some(tool_call_id) = chat_message.tool_call_id
973 {
974 let content_str = content_parts
976 .iter()
977 .filter_map(|p| match p {
978 LLMMessageTypedContent::Text { text } => Some(text.clone()),
979 _ => None,
980 })
981 .collect::<Vec<_>>()
982 .join("\n");
983
984 content_parts = vec![LLMMessageTypedContent::ToolResult {
986 tool_use_id: tool_call_id,
987 content: content_str,
988 }];
989 }
990
991 LLMMessage {
992 role: chat_message.role.to_string(),
993 content: if content_parts.is_empty() {
994 LLMMessageContent::String(String::new())
995 } else if content_parts.len() == 1 {
996 match &content_parts[0] {
997 LLMMessageTypedContent::Text { text } => {
998 LLMMessageContent::String(text.clone())
999 }
1000 _ => LLMMessageContent::List(content_parts),
1001 }
1002 } else {
1003 LLMMessageContent::List(content_parts)
1004 },
1005 }
1006 }
1007}
1008
1009impl From<GenerationDelta> for ChatMessageDelta {
1010 fn from(delta: GenerationDelta) -> Self {
1011 match delta {
1012 GenerationDelta::Content { content } => ChatMessageDelta {
1013 role: Some(Role::Assistant),
1014 content: Some(content),
1015 tool_calls: None,
1016 },
1017 GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1018 role: Some(Role::Assistant),
1019 content: None,
1020 tool_calls: None,
1021 },
1022 GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1023 role: Some(Role::Assistant),
1024 content: None,
1025 tool_calls: Some(vec![ToolCallDelta {
1026 index: tool_use.index,
1027 id: tool_use.id,
1028 r#type: Some("function".to_string()),
1029 function: Some(FunctionCallDelta {
1030 name: tool_use.name,
1031 arguments: tool_use.input,
1032 }),
1033 metadata: tool_use.metadata,
1034 }]),
1035 },
1036 _ => ChatMessageDelta {
1037 role: Some(Role::Assistant),
1038 content: None,
1039 tool_calls: None,
1040 },
1041 }
1042 }
1043}
1044
1045#[cfg(test)]
1046mod tests {
1047 use super::*;
1048
1049 #[test]
1050 fn test_serialize_basic_request() {
1051 let request = ChatCompletionRequest {
1052 model: "gpt-4".to_string(),
1053 messages: vec![
1054 ChatMessage {
1055 role: Role::System,
1056 content: Some(MessageContent::String(
1057 "You are a helpful assistant.".to_string(),
1058 )),
1059 name: None,
1060 tool_calls: None,
1061 tool_call_id: None,
1062 usage: None,
1063 ..Default::default()
1064 },
1065 ChatMessage {
1066 role: Role::User,
1067 content: Some(MessageContent::String("Hello!".to_string())),
1068 name: None,
1069 tool_calls: None,
1070 tool_call_id: None,
1071 usage: None,
1072 ..Default::default()
1073 },
1074 ],
1075 frequency_penalty: None,
1076 logit_bias: None,
1077 logprobs: None,
1078 max_tokens: Some(100),
1079 n: None,
1080 presence_penalty: None,
1081 response_format: None,
1082 seed: None,
1083 stop: None,
1084 stream: None,
1085 temperature: Some(0.7),
1086 top_p: None,
1087 tools: None,
1088 tool_choice: None,
1089 user: None,
1090 context: None,
1091 };
1092
1093 let json = serde_json::to_string(&request).unwrap();
1094 assert!(json.contains("\"model\":\"gpt-4\""));
1095 assert!(json.contains("\"messages\":["));
1096 assert!(json.contains("\"role\":\"system\""));
1097 }
1098
1099 #[test]
1100 fn test_llm_message_to_chat_message() {
1101 let llm_message = LLMMessage {
1102 role: "user".to_string(),
1103 content: LLMMessageContent::String("Hello, world!".to_string()),
1104 };
1105
1106 let chat_message = ChatMessage::from(llm_message);
1107 assert_eq!(chat_message.role, Role::User);
1108 match &chat_message.content {
1109 Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1110 _ => panic!("Expected string content"),
1111 }
1112 }
1113
1114 #[test]
1115 fn test_llm_message_to_chat_message_tool_result_preserves_tool_call_id() {
1116 let llm_message = LLMMessage {
1117 role: "tool".to_string(),
1118 content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolResult {
1119 tool_use_id: "toolu_01Abc123".to_string(),
1120 content: "Tool execution result".to_string(),
1121 }]),
1122 };
1123
1124 let chat_message = ChatMessage::from(llm_message);
1125 assert_eq!(chat_message.role, Role::Tool);
1126 assert_eq!(chat_message.tool_call_id.as_deref(), Some("toolu_01Abc123"));
1127 assert_eq!(
1128 chat_message.content,
1129 Some(MessageContent::Array(vec![ContentPart {
1130 r#type: "text".to_string(),
1131 text: Some("Tool execution result".to_string()),
1132 image_url: None,
1133 }]))
1134 );
1135 }
1136
1137 #[test]
1138 fn test_chat_message_to_llm_message_tool_result() {
1139 let chat_message = ChatMessage {
1143 role: Role::Tool,
1144 content: Some(MessageContent::String("Tool execution result".to_string())),
1145 name: None,
1146 tool_calls: None,
1147 tool_call_id: Some("toolu_01Abc123".to_string()),
1148 usage: None,
1149 ..Default::default()
1150 };
1151
1152 let llm_message: LLMMessage = chat_message.into();
1153
1154 assert_eq!(llm_message.role, "tool");
1156
1157 match &llm_message.content {
1159 LLMMessageContent::List(parts) => {
1160 assert_eq!(parts.len(), 1, "Should have exactly one content part");
1161 match &parts[0] {
1162 LLMMessageTypedContent::ToolResult {
1163 tool_use_id,
1164 content,
1165 } => {
1166 assert_eq!(tool_use_id, "toolu_01Abc123");
1167 assert_eq!(content, "Tool execution result");
1168 }
1169 _ => panic!("Expected ToolResult content part, got {:?}", parts[0]),
1170 }
1171 }
1172 _ => panic!(
1173 "Expected List content with ToolResult, got {:?}",
1174 llm_message.content
1175 ),
1176 }
1177 }
1178
1179 #[test]
1180 fn test_chat_message_to_llm_message_tool_result_empty_content() {
1181 let chat_message = ChatMessage {
1183 role: Role::Tool,
1184 content: None,
1185 name: None,
1186 tool_calls: None,
1187 tool_call_id: Some("toolu_02Xyz789".to_string()),
1188 usage: None,
1189 ..Default::default()
1190 };
1191
1192 let llm_message: LLMMessage = chat_message.into();
1193
1194 assert_eq!(llm_message.role, "tool");
1195 match &llm_message.content {
1196 LLMMessageContent::List(parts) => {
1197 assert_eq!(parts.len(), 1);
1198 match &parts[0] {
1199 LLMMessageTypedContent::ToolResult {
1200 tool_use_id,
1201 content,
1202 } => {
1203 assert_eq!(tool_use_id, "toolu_02Xyz789");
1204 assert_eq!(content, ""); }
1206 _ => panic!("Expected ToolResult content part"),
1207 }
1208 }
1209 _ => panic!("Expected List content with ToolResult"),
1210 }
1211 }
1212
1213 #[test]
1214 fn test_chat_message_to_llm_message_assistant_with_tool_calls() {
1215 let chat_message = ChatMessage {
1217 role: Role::Assistant,
1218 content: Some(MessageContent::String(
1219 "I'll help you with that.".to_string(),
1220 )),
1221 name: None,
1222 tool_calls: Some(vec![ToolCall {
1223 id: "call_abc123".to_string(),
1224 r#type: "function".to_string(),
1225 function: FunctionCall {
1226 name: "get_weather".to_string(),
1227 arguments: r#"{"location": "Paris"}"#.to_string(),
1228 },
1229 metadata: None,
1230 }]),
1231 tool_call_id: None,
1232 usage: None,
1233 ..Default::default()
1234 };
1235
1236 let llm_message: LLMMessage = chat_message.into();
1237
1238 assert_eq!(llm_message.role, "assistant");
1239 match &llm_message.content {
1240 LLMMessageContent::List(parts) => {
1241 assert_eq!(parts.len(), 2, "Should have text and tool call");
1242
1243 match &parts[0] {
1245 LLMMessageTypedContent::Text { text } => {
1246 assert_eq!(text, "I'll help you with that.");
1247 }
1248 _ => panic!("Expected Text content part first"),
1249 }
1250
1251 match &parts[1] {
1253 LLMMessageTypedContent::ToolCall { id, name, args, .. } => {
1254 assert_eq!(id, "call_abc123");
1255 assert_eq!(name, "get_weather");
1256 assert_eq!(args["location"], "Paris");
1257 }
1258 _ => panic!("Expected ToolCall content part second"),
1259 }
1260 }
1261 _ => panic!("Expected List content"),
1262 }
1263 }
1264}