1use crate::models::llm::{
2 GenerationDelta, LLMChoice, LLMCompletionResponse, LLMMessage, LLMMessageContent,
3 LLMMessageImageSource, LLMMessageTypedContent, LLMTokenUsage, LLMTool,
4};
5use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
6use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8use uuid::Uuid;
9
10#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
11pub struct OpenAIConfig {
12 pub api_endpoint: Option<String>,
13 pub api_key: Option<String>,
14}
15
16#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
17pub enum OpenAIModel {
18 #[serde(rename = "o3-2025-04-16")]
20 O3,
21 #[serde(rename = "o4-mini-2025-04-16")]
22 O4Mini,
23
24 #[default]
25 #[serde(rename = "gpt-5-2025-08-07")]
26 GPT5,
27 #[serde(rename = "gpt-5.1-2025-11-13")]
28 GPT51,
29 #[serde(rename = "gpt-5-mini-2025-08-07")]
30 GPT5Mini,
31 #[serde(rename = "gpt-5-nano-2025-08-07")]
32 GPT5Nano,
33
34 Custom(String),
35}
36
37impl OpenAIModel {
38 pub fn from_string(s: &str) -> Result<Self, String> {
39 serde_json::from_value(serde_json::Value::String(s.to_string()))
40 .map_err(|_| "Failed to deserialize OpenAI model".to_string())
41 }
42
43 pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
45
46 pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
48
49 pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
51
52 pub fn default_smart_model() -> String {
54 Self::DEFAULT_SMART_MODEL.to_string()
55 }
56
57 pub fn default_eco_model() -> String {
59 Self::DEFAULT_ECO_MODEL.to_string()
60 }
61
62 pub fn default_recovery_model() -> String {
64 Self::DEFAULT_RECOVERY_MODEL.to_string()
65 }
66}
67
68impl ContextAware for OpenAIModel {
69 fn context_info(&self) -> ModelContextInfo {
70 let model_name = self.to_string();
71
72 if model_name.starts_with("o3") {
73 return ModelContextInfo {
74 max_tokens: 200_000,
75 pricing_tiers: vec![ContextPricingTier {
76 label: "Standard".to_string(),
77 input_cost_per_million: 2.0,
78 output_cost_per_million: 8.0,
79 upper_bound: None,
80 }],
81 approach_warning_threshold: 0.8,
82 };
83 }
84
85 if model_name.starts_with("o4-mini") {
86 return ModelContextInfo {
87 max_tokens: 200_000,
88 pricing_tiers: vec![ContextPricingTier {
89 label: "Standard".to_string(),
90 input_cost_per_million: 1.10,
91 output_cost_per_million: 4.40,
92 upper_bound: None,
93 }],
94 approach_warning_threshold: 0.8,
95 };
96 }
97
98 if model_name.starts_with("gpt-5-mini") {
99 return ModelContextInfo {
100 max_tokens: 400_000,
101 pricing_tiers: vec![ContextPricingTier {
102 label: "Standard".to_string(),
103 input_cost_per_million: 0.25,
104 output_cost_per_million: 2.0,
105 upper_bound: None,
106 }],
107 approach_warning_threshold: 0.8,
108 };
109 }
110
111 if model_name.starts_with("gpt-5-nano") {
112 return ModelContextInfo {
113 max_tokens: 400_000,
114 pricing_tiers: vec![ContextPricingTier {
115 label: "Standard".to_string(),
116 input_cost_per_million: 0.05,
117 output_cost_per_million: 0.40,
118 upper_bound: None,
119 }],
120 approach_warning_threshold: 0.8,
121 };
122 }
123
124 if model_name.starts_with("gpt-5") {
125 return ModelContextInfo {
126 max_tokens: 400_000,
127 pricing_tiers: vec![ContextPricingTier {
128 label: "Standard".to_string(),
129 input_cost_per_million: 1.25,
130 output_cost_per_million: 10.0,
131 upper_bound: None,
132 }],
133 approach_warning_threshold: 0.8,
134 };
135 }
136
137 ModelContextInfo::default()
138 }
139
140 fn model_name(&self) -> String {
141 match self {
142 OpenAIModel::O3 => "O3".to_string(),
143 OpenAIModel::O4Mini => "O4-mini".to_string(),
144 OpenAIModel::GPT5 => "GPT-5".to_string(),
145 OpenAIModel::GPT51 => "GPT-5.1".to_string(),
146 OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
147 OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
148 OpenAIModel::Custom(name) => format!("Custom ({})", name),
149 }
150 }
151}
152
153impl std::fmt::Display for OpenAIModel {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 match self {
156 OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
157
158 OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
159
160 OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
161 OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
162 OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
163 OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
164
165 OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
166 }
167 }
168}
169
170#[derive(Serialize, Deserialize, Debug)]
171pub struct OpenAIInput {
172 pub model: OpenAIModel,
173 pub messages: Vec<LLMMessage>,
174 pub max_tokens: u32,
175
176 #[serde(skip_serializing_if = "Option::is_none")]
177 pub json: Option<serde_json::Value>,
178
179 #[serde(skip_serializing_if = "Option::is_none")]
180 pub tools: Option<Vec<LLMTool>>,
181
182 #[serde(skip_serializing_if = "Option::is_none")]
183 pub reasoning_effort: Option<OpenAIReasoningEffort>,
184}
185
186impl OpenAIInput {
187 pub fn is_reasoning_model(&self) -> bool {
188 matches!(self.model, |OpenAIModel::O3| OpenAIModel::O4Mini
189 | OpenAIModel::GPT5
190 | OpenAIModel::GPT5Mini
191 | OpenAIModel::GPT5Nano)
192 }
193
194 pub fn is_standard_model(&self) -> bool {
195 !self.is_reasoning_model()
196 }
197}
198
199#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
200pub enum OpenAIReasoningEffort {
201 #[serde(rename = "minimal")]
202 Minimal,
203 #[serde(rename = "low")]
204 Low,
205 #[default]
206 #[serde(rename = "medium")]
207 Medium,
208 #[serde(rename = "high")]
209 High,
210}
211
212#[derive(Serialize, Deserialize, Debug)]
213pub struct OpenAITool {
214 pub r#type: String,
215 pub function: OpenAIToolFunction,
216}
217
218#[derive(Serialize, Deserialize, Debug)]
219pub struct OpenAIToolFunction {
220 pub name: String,
221 pub description: String,
222 pub parameters: serde_json::Value,
223}
224
225impl From<LLMTool> for OpenAITool {
226 fn from(tool: LLMTool) -> Self {
227 OpenAITool {
228 r#type: "function".to_string(),
229 function: OpenAIToolFunction {
230 name: tool.name,
231 description: tool.description,
232 parameters: tool.input_schema,
233 },
234 }
235 }
236}
237
238#[derive(Serialize, Deserialize, Debug, Clone)]
239pub struct OpenAIOutput {
240 pub model: String,
241 pub object: String,
242 pub choices: Vec<OpenAILLMChoice>,
243 pub created: u64,
244 pub usage: Option<LLMTokenUsage>,
245 pub id: String,
246}
247
248impl From<OpenAIOutput> for LLMCompletionResponse {
249 fn from(val: OpenAIOutput) -> Self {
250 LLMCompletionResponse {
251 model: val.model,
252 object: val.object,
253 choices: val.choices.into_iter().map(OpenAILLMChoice::into).collect(),
254 created: val.created,
255 usage: val.usage,
256 id: val.id,
257 }
258 }
259}
260
261#[derive(Serialize, Deserialize, Debug, Clone)]
262pub struct OpenAILLMChoice {
263 pub finish_reason: Option<String>,
264 pub index: u32,
265 pub message: OpenAILLMMessage,
266}
267
268impl From<OpenAILLMChoice> for LLMChoice {
269 fn from(val: OpenAILLMChoice) -> Self {
270 LLMChoice {
271 finish_reason: val.finish_reason,
272 index: val.index,
273 message: val.message.into(),
274 }
275 }
276}
277
278#[derive(Serialize, Deserialize, Debug, Clone)]
279pub struct OpenAILLMMessage {
280 pub role: String,
281 pub content: Option<String>,
282 pub tool_calls: Option<Vec<OpenAILLMMessageToolCall>>,
283}
284impl From<OpenAILLMMessage> for LLMMessage {
285 fn from(val: OpenAILLMMessage) -> Self {
286 LLMMessage {
287 role: val.role,
288 content: match val.tool_calls {
289 None => LLMMessageContent::String(val.content.unwrap_or_default()),
290 Some(tool_calls) => LLMMessageContent::List(
291 std::iter::once(LLMMessageTypedContent::Text {
292 text: val.content.unwrap_or_default(),
293 })
294 .chain(tool_calls.into_iter().map(|tool_call| {
295 LLMMessageTypedContent::ToolCall {
296 id: tool_call.id,
297 name: tool_call.function.name,
298 args: match serde_json::from_str(&tool_call.function.arguments) {
299 Ok(args) => args,
300 Err(_) => {
301 return LLMMessageTypedContent::Text {
302 text: String::from("Error parsing tool call arguments"),
303 };
304 }
305 },
306 }
307 }))
308 .collect(),
309 ),
310 },
311 }
312 }
313}
314
315#[derive(Serialize, Deserialize, Debug, Clone)]
316pub struct OpenAILLMMessageToolCall {
317 pub id: String,
318 pub r#type: String,
319 pub function: OpenAILLMMessageToolCallFunction,
320}
321#[derive(Serialize, Deserialize, Debug, Clone)]
322pub struct OpenAILLMMessageToolCallFunction {
323 pub arguments: String,
324 pub name: String,
325}
326
327#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
328#[serde(rename_all = "lowercase")]
329pub enum Role {
330 System,
331 Developer,
332 User,
333 #[default]
334 Assistant,
335 Tool,
336 }
338
339impl std::fmt::Display for Role {
340 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 match self {
342 Role::System => write!(f, "system"),
343 Role::Developer => write!(f, "developer"),
344 Role::User => write!(f, "user"),
345 Role::Assistant => write!(f, "assistant"),
346 Role::Tool => write!(f, "tool"),
347 }
348 }
349}
350
351#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
352pub struct ChatCompletionRequest {
353 pub model: String,
354 pub messages: Vec<ChatMessage>,
355 #[serde(skip_serializing_if = "Option::is_none")]
356 pub frequency_penalty: Option<f32>,
357 #[serde(skip_serializing_if = "Option::is_none")]
358 pub logit_bias: Option<serde_json::Value>,
359 #[serde(skip_serializing_if = "Option::is_none")]
360 pub logprobs: Option<bool>,
361 #[serde(skip_serializing_if = "Option::is_none")]
362 pub max_tokens: Option<u32>,
363 #[serde(skip_serializing_if = "Option::is_none")]
364 pub n: Option<u32>,
365 #[serde(skip_serializing_if = "Option::is_none")]
366 pub presence_penalty: Option<f32>,
367 #[serde(skip_serializing_if = "Option::is_none")]
368 pub response_format: Option<ResponseFormat>,
369 #[serde(skip_serializing_if = "Option::is_none")]
370 pub seed: Option<i64>,
371 #[serde(skip_serializing_if = "Option::is_none")]
372 pub stop: Option<StopSequence>,
373 #[serde(skip_serializing_if = "Option::is_none")]
374 pub stream: Option<bool>,
375 #[serde(skip_serializing_if = "Option::is_none")]
376 pub temperature: Option<f32>,
377 #[serde(skip_serializing_if = "Option::is_none")]
378 pub top_p: Option<f32>,
379 #[serde(skip_serializing_if = "Option::is_none")]
380 pub tools: Option<Vec<Tool>>,
381 #[serde(skip_serializing_if = "Option::is_none")]
382 pub tool_choice: Option<ToolChoice>,
383 #[serde(skip_serializing_if = "Option::is_none")]
384 pub user: Option<String>,
385
386 #[serde(skip_serializing_if = "Option::is_none")]
387 pub context: Option<ChatCompletionContext>,
388}
389
390#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
391pub struct ChatCompletionContext {
392 pub scratchpad: Option<Value>,
393}
394
395#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
396pub struct ChatMessage {
397 pub role: Role,
398 pub content: Option<MessageContent>,
399 #[serde(skip_serializing_if = "Option::is_none")]
400 pub name: Option<String>,
401 #[serde(skip_serializing_if = "Option::is_none")]
402 pub tool_calls: Option<Vec<ToolCall>>,
403 #[serde(skip_serializing_if = "Option::is_none")]
404 pub tool_call_id: Option<String>,
405
406 #[serde(skip_serializing_if = "Option::is_none")]
408 pub usage: Option<LLMTokenUsage>,
409}
410
411impl ChatMessage {
412 pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
413 messages
414 .iter()
415 .rev()
416 .find(|message| message.role != Role::User && message.role != Role::Tool)
417 }
418
419 pub fn to_xml(&self) -> String {
420 match &self.content {
421 Some(MessageContent::String(s)) => {
422 format!("<message role=\"{}\">{}</message>", self.role, s)
423 }
424 Some(MessageContent::Array(parts)) => parts
425 .iter()
426 .map(|part| {
427 format!(
428 "<message role=\"{}\" type=\"{}\">{}</message>",
429 self.role,
430 part.r#type,
431 part.text.clone().unwrap_or_default()
432 )
433 })
434 .collect::<Vec<String>>()
435 .join("\n"),
436 None => String::new(),
437 }
438 }
439}
440
441#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
442#[serde(untagged)]
443pub enum MessageContent {
444 String(String),
445 Array(Vec<ContentPart>),
446}
447
448impl MessageContent {
449 pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
450 match self {
451 MessageContent::String(s) => MessageContent::String(format!(
452 "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
453 )),
454 MessageContent::Array(parts) => MessageContent::Array(
455 std::iter::once(ContentPart {
456 r#type: "text".to_string(),
457 text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
458 image_url: None,
459 })
460 .chain(parts.iter().cloned())
461 .collect(),
462 ),
463 }
464 }
465
466 pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
467 match self {
468 MessageContent::String(s) => s
469 .rfind("<checkpoint_id>")
470 .and_then(|start| {
471 s[start..]
472 .find("</checkpoint_id>")
473 .map(|end| (start + "<checkpoint_id>".len(), start + end))
474 })
475 .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
476 MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
477 part.text.as_deref().and_then(|text| {
478 text.rfind("<checkpoint_id>")
479 .and_then(|start| {
480 text[start..]
481 .find("</checkpoint_id>")
482 .map(|end| (start + "<checkpoint_id>".len(), start + end))
483 })
484 .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
485 })
486 }),
487 }
488 }
489}
490
491impl std::fmt::Display for MessageContent {
492 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
493 match self {
494 MessageContent::String(s) => write!(f, "{s}"),
495 MessageContent::Array(parts) => {
496 let text_parts: Vec<String> =
497 parts.iter().filter_map(|part| part.text.clone()).collect();
498 write!(f, "{}", text_parts.join("\n"))
499 }
500 }
501 }
502}
503impl Default for MessageContent {
504 fn default() -> Self {
505 MessageContent::String(String::new())
506 }
507}
508
509#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
510pub struct ContentPart {
511 pub r#type: String,
512 #[serde(skip_serializing_if = "Option::is_none")]
513 pub text: Option<String>,
514 #[serde(skip_serializing_if = "Option::is_none")]
515 pub image_url: Option<ImageUrl>,
516}
517
518#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
519pub struct ImageUrl {
520 pub url: String,
521 #[serde(skip_serializing_if = "Option::is_none")]
522 pub detail: Option<String>,
523}
524
525#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
526pub struct ResponseFormat {
527 pub r#type: String,
528}
529
530#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
531#[serde(untagged)]
532pub enum StopSequence {
533 String(String),
534 Array(Vec<String>),
535}
536
537#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
538pub struct Tool {
539 pub r#type: String,
540 pub function: FunctionDefinition,
541}
542
543#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
544pub struct FunctionDefinition {
545 pub name: String,
546 pub description: Option<String>,
547 pub parameters: serde_json::Value,
548}
549
550#[derive(Debug, Clone, PartialEq)]
551pub enum ToolChoice {
552 Auto,
553 Required,
554 Object(ToolChoiceObject),
555}
556
557impl Serialize for ToolChoice {
558 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
559 where
560 S: serde::Serializer,
561 {
562 match self {
563 ToolChoice::Auto => serializer.serialize_str("auto"),
564 ToolChoice::Required => serializer.serialize_str("required"),
565 ToolChoice::Object(obj) => obj.serialize(serializer),
566 }
567 }
568}
569
570impl<'de> Deserialize<'de> for ToolChoice {
571 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
572 where
573 D: serde::Deserializer<'de>,
574 {
575 struct ToolChoiceVisitor;
576
577 impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
578 type Value = ToolChoice;
579
580 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
581 formatter.write_str("string or object")
582 }
583
584 fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
585 where
586 E: serde::de::Error,
587 {
588 match value {
589 "auto" => Ok(ToolChoice::Auto),
590 "required" => Ok(ToolChoice::Required),
591 _ => Err(serde::de::Error::unknown_variant(
592 value,
593 &["auto", "required"],
594 )),
595 }
596 }
597
598 fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
599 where
600 M: serde::de::MapAccess<'de>,
601 {
602 let obj = ToolChoiceObject::deserialize(
603 serde::de::value::MapAccessDeserializer::new(map),
604 )?;
605 Ok(ToolChoice::Object(obj))
606 }
607 }
608
609 deserializer.deserialize_any(ToolChoiceVisitor)
610 }
611}
612
613#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
614pub struct ToolChoiceObject {
615 pub r#type: String,
616 pub function: FunctionChoice,
617}
618
619#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
620pub struct FunctionChoice {
621 pub name: String,
622}
623
624#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
625pub struct ToolCall {
626 pub id: String,
627 pub r#type: String,
628 pub function: FunctionCall,
629}
630
631#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
632pub struct FunctionCall {
633 pub name: String,
634 pub arguments: String,
635}
636
637#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
638pub struct ChatCompletionResponse {
639 pub id: String,
640 pub object: String,
641 pub created: u64,
642 pub model: String,
643 pub choices: Vec<ChatCompletionChoice>,
644 pub usage: LLMTokenUsage,
645 #[serde(skip_serializing_if = "Option::is_none")]
646 pub system_fingerprint: Option<String>,
647 pub metadata: Option<serde_json::Value>,
648}
649
650#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
651pub struct ChatCompletionChoice {
652 pub index: usize,
653 pub message: ChatMessage,
654 pub logprobs: Option<LogProbs>,
655 pub finish_reason: FinishReason,
656}
657
658#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
659#[serde(rename_all = "snake_case")]
660pub enum FinishReason {
661 Stop,
662 Length,
663 ContentFilter,
664 ToolCalls,
665}
666
667#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
668pub struct LogProbs {
669 pub content: Option<Vec<LogProbContent>>,
670}
671
672#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
673pub struct LogProbContent {
674 pub token: String,
675 pub logprob: f32,
676 pub bytes: Option<Vec<u8>>,
677 pub top_logprobs: Option<Vec<TokenLogprob>>,
678}
679
680#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
681pub struct TokenLogprob {
682 pub token: String,
683 pub logprob: f32,
684 pub bytes: Option<Vec<u8>>,
685}
686
687#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
688pub struct ChatCompletionStreamResponse {
689 pub id: String,
690 pub object: String,
691 pub created: u64,
692 pub model: String,
693 pub choices: Vec<ChatCompletionStreamChoice>,
694 pub usage: Option<LLMTokenUsage>,
695 pub metadata: Option<serde_json::Value>,
696}
697#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
698pub struct ChatCompletionStreamChoice {
699 pub index: usize,
700 pub delta: ChatMessageDelta,
701 pub finish_reason: Option<FinishReason>,
702}
703#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
704pub struct ChatMessageDelta {
705 #[serde(skip_serializing_if = "Option::is_none")]
706 pub role: Option<Role>,
707 #[serde(skip_serializing_if = "Option::is_none")]
708 pub content: Option<String>,
709 #[serde(skip_serializing_if = "Option::is_none")]
710 pub tool_calls: Option<Vec<ToolCallDelta>>,
711}
712
713#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
714pub struct ToolCallDelta {
715 pub index: usize,
716 pub id: Option<String>,
717 pub r#type: Option<String>,
718 pub function: Option<FunctionCallDelta>,
719}
720
721#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
722pub struct FunctionCallDelta {
723 pub name: Option<String>,
724 pub arguments: Option<String>,
725}
726
727impl From<LLMMessage> for ChatMessage {
728 fn from(llm_message: LLMMessage) -> Self {
729 let role = match llm_message.role.as_str() {
730 "system" => Role::System,
731 "user" => Role::User,
732 "assistant" => Role::Assistant,
733 "tool" => Role::Tool,
734 "developer" => Role::Developer,
736 _ => Role::User, };
738
739 let (content, tool_calls) = match llm_message.content {
740 LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None),
741 LLMMessageContent::List(items) => {
742 let mut text_parts = Vec::new();
743 let mut tool_call_parts = Vec::new();
744
745 for item in items {
746 match item {
747 LLMMessageTypedContent::Text { text } => {
748 text_parts.push(ContentPart {
749 r#type: "text".to_string(),
750 text: Some(text),
751 image_url: None,
752 });
753 }
754 LLMMessageTypedContent::ToolCall { id, name, args } => {
755 tool_call_parts.push(ToolCall {
756 id,
757 r#type: "function".to_string(),
758 function: FunctionCall {
759 name,
760 arguments: args.to_string(),
761 },
762 });
763 }
764 LLMMessageTypedContent::ToolResult { content, .. } => {
765 text_parts.push(ContentPart {
766 r#type: "text".to_string(),
767 text: Some(content),
768 image_url: None,
769 });
770 }
771 LLMMessageTypedContent::Image { source } => {
772 text_parts.push(ContentPart {
773 r#type: "image_url".to_string(),
774 text: None,
775 image_url: Some(ImageUrl {
776 url: format!(
777 "data:{};base64,{}",
778 source.media_type, source.data
779 ),
780 detail: None,
781 }),
782 });
783 }
784 }
785 }
786
787 let content = if !text_parts.is_empty() {
788 Some(MessageContent::Array(text_parts))
789 } else {
790 None
791 };
792
793 let tool_calls = if !tool_call_parts.is_empty() {
794 Some(tool_call_parts)
795 } else {
796 None
797 };
798
799 (content, tool_calls)
800 }
801 };
802
803 ChatMessage {
804 role,
805 content,
806 name: None, tool_calls,
808 tool_call_id: None, usage: None,
810 }
811 }
812}
813
814impl From<ChatMessage> for LLMMessage {
815 fn from(chat_message: ChatMessage) -> Self {
816 let mut content_parts = Vec::new();
817
818 match chat_message.content {
820 Some(MessageContent::String(s)) => {
821 if !s.is_empty() {
822 content_parts.push(LLMMessageTypedContent::Text { text: s });
823 }
824 }
825 Some(MessageContent::Array(parts)) => {
826 for part in parts {
827 if let Some(text) = part.text {
828 content_parts.push(LLMMessageTypedContent::Text { text });
829 } else if let Some(image_url) = part.image_url {
830 let (media_type, data) = if image_url.url.starts_with("data:") {
831 let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
832 if parts.len() == 2 {
833 let meta = parts[0];
834 let data = parts[1];
835 let media_type = meta
836 .trim_start_matches("data:")
837 .trim_end_matches(";base64")
838 .to_string();
839 (media_type, data.to_string())
840 } else {
841 ("image/jpeg".to_string(), image_url.url)
842 }
843 } else {
844 ("image/jpeg".to_string(), image_url.url)
845 };
846
847 content_parts.push(LLMMessageTypedContent::Image {
848 source: LLMMessageImageSource {
849 r#type: "base64".to_string(),
850 media_type,
851 data,
852 },
853 });
854 }
855 }
856 }
857 None => {}
858 }
859
860 if let Some(tool_calls) = chat_message.tool_calls {
862 for tool_call in tool_calls {
863 let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
864 content_parts.push(LLMMessageTypedContent::ToolCall {
865 id: tool_call.id,
866 name: tool_call.function.name,
867 args,
868 });
869 }
870 }
871
872 LLMMessage {
873 role: chat_message.role.to_string(),
874 content: if content_parts.is_empty() {
875 LLMMessageContent::String(String::new())
876 } else if content_parts.len() == 1 {
877 match &content_parts[0] {
878 LLMMessageTypedContent::Text { text } => {
879 LLMMessageContent::String(text.clone())
880 }
881 _ => LLMMessageContent::List(content_parts),
882 }
883 } else {
884 LLMMessageContent::List(content_parts)
885 },
886 }
887 }
888}
889
890#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
891pub enum AgentModel {
892 #[serde(rename = "smart")]
893 #[default]
894 Smart,
895 #[serde(rename = "eco")]
896 Eco,
897 #[serde(rename = "recovery")]
898 Recovery,
899}
900
901impl std::fmt::Display for AgentModel {
902 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
903 match self {
904 AgentModel::Smart => write!(f, "smart"),
905 AgentModel::Eco => write!(f, "eco"),
906 AgentModel::Recovery => write!(f, "recovery"),
907 }
908 }
909}
910
911impl From<String> for AgentModel {
912 fn from(value: String) -> Self {
913 match value.as_str() {
914 "eco" => AgentModel::Eco,
915 "recovery" => AgentModel::Recovery,
916 _ => AgentModel::Smart,
917 }
918 }
919}
920
921impl ChatCompletionRequest {
922 pub fn new(
923 model: String,
924 messages: Vec<ChatMessage>,
925 tools: Option<Vec<Tool>>,
926 stream: Option<bool>,
927 ) -> Self {
928 Self {
929 model,
930 messages,
931 frequency_penalty: None,
932 logit_bias: None,
933 logprobs: None,
934 max_tokens: None,
935 n: None,
936 presence_penalty: None,
937 response_format: None,
938 seed: None,
939 stop: None,
940 stream,
941 temperature: None,
942 top_p: None,
943 tools,
944 tool_choice: None,
945 user: None,
946 context: None,
947 }
948 }
949}
950
951impl From<Tool> for LLMTool {
952 fn from(tool: Tool) -> Self {
953 LLMTool {
954 name: tool.function.name,
955 description: tool.function.description.unwrap_or_default(),
956 input_schema: tool.function.parameters,
957 }
958 }
959}
960
961#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
962pub enum ToolCallResultStatus {
963 Success,
964 Error,
965 Cancelled,
966}
967
968#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
969pub struct ToolCallResult {
970 pub call: ToolCall,
971 pub result: String,
972 pub status: ToolCallResultStatus,
973}
974
975#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
976pub struct ToolCallResultProgress {
977 pub id: Uuid,
978 pub message: String,
979}
980
981impl From<GenerationDelta> for ChatMessageDelta {
982 fn from(delta: GenerationDelta) -> Self {
983 match delta {
984 GenerationDelta::Content { content } => ChatMessageDelta {
985 role: Some(Role::Assistant),
986 content: Some(content),
987 tool_calls: None,
988 },
989 GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
990 role: Some(Role::Assistant),
991 content: None,
992 tool_calls: None,
993 },
994 GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
995 role: Some(Role::Assistant),
996 content: None,
997 tool_calls: Some(vec![ToolCallDelta {
998 index: tool_use.index,
999 id: tool_use.id,
1000 r#type: Some("function".to_string()),
1001 function: Some(FunctionCallDelta {
1002 name: tool_use.name,
1003 arguments: tool_use.input,
1004 }),
1005 }]),
1006 },
1007 _ => ChatMessageDelta {
1008 role: Some(Role::Assistant),
1009 content: None,
1010 tool_calls: None,
1011 },
1012 }
1013 }
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018 use super::*;
1019
1020 #[test]
1021 fn test_serialize_basic_request() {
1022 let request = ChatCompletionRequest {
1023 model: AgentModel::Smart.to_string(),
1024 messages: vec![
1025 ChatMessage {
1026 role: Role::System,
1027 content: Some(MessageContent::String(
1028 "You are a helpful assistant.".to_string(),
1029 )),
1030 name: None,
1031 tool_calls: None,
1032 tool_call_id: None,
1033 usage: None,
1034 },
1035 ChatMessage {
1036 role: Role::User,
1037 content: Some(MessageContent::String("Hello!".to_string())),
1038 name: None,
1039 tool_calls: None,
1040 tool_call_id: None,
1041 usage: None,
1042 },
1043 ],
1044 frequency_penalty: None,
1045 logit_bias: None,
1046 logprobs: None,
1047 max_tokens: Some(100),
1048 n: None,
1049 presence_penalty: None,
1050 response_format: None,
1051 seed: None,
1052 stop: None,
1053 stream: None,
1054 temperature: Some(0.7),
1055 top_p: None,
1056 tools: None,
1057 tool_choice: None,
1058 user: None,
1059 context: None,
1060 };
1061
1062 let json = serde_json::to_string(&request).unwrap();
1063 assert!(json.contains("\"model\":\"smart\""));
1064 assert!(json.contains("\"messages\":["));
1065 assert!(json.contains("\"role\":\"system\""));
1066 assert!(json.contains("\"content\":\"You are a helpful assistant.\""));
1067 assert!(json.contains("\"role\":\"user\""));
1068 assert!(json.contains("\"content\":\"Hello!\""));
1069 assert!(json.contains("\"max_tokens\":100"));
1070 assert!(json.contains("\"temperature\":0.7"));
1071 }
1072
1073 #[test]
1074 fn test_deserialize_response() {
1075 let json = r#"{
1076 "id": "chatcmpl-123",
1077 "object": "chat.completion",
1078 "created": 1677652288,
1079 "model": "smart",
1080 "system_fingerprint": "fp_123abc",
1081 "choices": [{
1082 "index": 0,
1083 "message": {
1084 "role": "assistant",
1085 "content": "Hello! How can I help you today?"
1086 },
1087 "finish_reason": "stop"
1088 }],
1089 "usage": {
1090 "prompt_tokens": 9,
1091 "completion_tokens": 12,
1092 "total_tokens": 21
1093 }
1094 }"#;
1095
1096 let response: ChatCompletionResponse = serde_json::from_str(json).unwrap();
1097 assert_eq!(response.id, "chatcmpl-123");
1098 assert_eq!(response.object, "chat.completion");
1099 assert_eq!(response.created, 1677652288);
1100 assert_eq!(response.model, AgentModel::Smart.to_string());
1101 assert_eq!(response.system_fingerprint, Some("fp_123abc".to_string()));
1102
1103 assert_eq!(response.choices.len(), 1);
1104 assert_eq!(response.choices[0].index, 0);
1105 assert_eq!(response.choices[0].message.role, Role::Assistant);
1106
1107 match &response.choices[0].message.content {
1108 Some(MessageContent::String(content)) => {
1109 assert_eq!(content, "Hello! How can I help you today?");
1110 }
1111 _ => panic!("Expected string content"),
1112 }
1113
1114 assert_eq!(response.choices[0].finish_reason, FinishReason::Stop);
1115 assert_eq!(response.usage.prompt_tokens, 9);
1116 assert_eq!(response.usage.completion_tokens, 12);
1117 assert_eq!(response.usage.total_tokens, 21);
1118 }
1119
1120 #[test]
1121 fn test_tool_calls_request_response() {
1122 let tools_request = ChatCompletionRequest {
1124 model: AgentModel::Smart.to_string(),
1125 messages: vec![ChatMessage {
1126 role: Role::User,
1127 content: Some(MessageContent::String(
1128 "What's the weather in San Francisco?".to_string(),
1129 )),
1130 name: None,
1131 tool_calls: None,
1132 tool_call_id: None,
1133 usage: None,
1134 }],
1135 tools: Some(vec![Tool {
1136 r#type: "function".to_string(),
1137 function: FunctionDefinition {
1138 name: "get_weather".to_string(),
1139 description: Some("Get the current weather in a given location".to_string()),
1140 parameters: serde_json::json!({
1141 "type": "object",
1142 "properties": {
1143 "location": {
1144 "type": "string",
1145 "description": "The city and state, e.g. San Francisco, CA"
1146 }
1147 },
1148 "required": ["location"]
1149 }),
1150 },
1151 }]),
1152 tool_choice: Some(ToolChoice::Auto),
1153 max_tokens: Some(100),
1154 temperature: Some(0.7),
1155 frequency_penalty: None,
1156 logit_bias: None,
1157 logprobs: None,
1158 n: None,
1159 presence_penalty: None,
1160 response_format: None,
1161 seed: None,
1162 stop: None,
1163 stream: None,
1164 top_p: None,
1165 user: None,
1166 context: None,
1167 };
1168
1169 let json = serde_json::to_string(&tools_request).unwrap();
1170 println!("Tool request JSON: {}", json);
1171
1172 assert!(json.contains("\"tools\":["));
1173 assert!(json.contains("\"type\":\"function\""));
1174 assert!(json.contains("\"name\":\"get_weather\""));
1175 assert!(json.contains("\"tool_choice\":\"auto\""));
1177
1178 let tool_response_json = r#"{
1180 "id": "chatcmpl-123",
1181 "object": "chat.completion",
1182 "created": 1677652288,
1183 "model": "eco",
1184 "choices": [{
1185 "index": 0,
1186 "message": {
1187 "role": "assistant",
1188 "content": null,
1189 "tool_calls": [
1190 {
1191 "id": "call_abc123",
1192 "type": "function",
1193 "function": {
1194 "name": "get_weather",
1195 "arguments": "{\"location\":\"San Francisco, CA\"}"
1196 }
1197 }
1198 ]
1199 },
1200 "finish_reason": "tool_calls"
1201 }],
1202 "usage": {
1203 "prompt_tokens": 82,
1204 "completion_tokens": 17,
1205 "total_tokens": 99
1206 }
1207 }"#;
1208
1209 let tool_response: ChatCompletionResponse =
1210 serde_json::from_str(tool_response_json).unwrap();
1211 assert_eq!(tool_response.choices[0].message.role, Role::Assistant);
1212 assert_eq!(tool_response.choices[0].message.content, None);
1213 assert!(tool_response.choices[0].message.tool_calls.is_some());
1214
1215 let tool_calls = tool_response.choices[0]
1216 .message
1217 .tool_calls
1218 .as_ref()
1219 .unwrap();
1220 assert_eq!(tool_calls.len(), 1);
1221 assert_eq!(tool_calls[0].id, "call_abc123");
1222 assert_eq!(tool_calls[0].r#type, "function");
1223 assert_eq!(tool_calls[0].function.name, "get_weather");
1224 assert_eq!(
1225 tool_calls[0].function.arguments,
1226 "{\"location\":\"San Francisco, CA\"}"
1227 );
1228
1229 assert_eq!(
1230 tool_response.choices[0].finish_reason,
1231 FinishReason::ToolCalls,
1232 );
1233 }
1234
1235 #[test]
1236 fn test_content_with_image() {
1237 let message_with_image = ChatMessage {
1238 role: Role::User,
1239 content: Some(MessageContent::Array(vec![
1240 ContentPart {
1241 r#type: "text".to_string(),
1242 text: Some("What's in this image?".to_string()),
1243 image_url: None,
1244 },
1245 ContentPart {
1246 r#type: "input_image".to_string(),
1247 text: None,
1248 image_url: Some(ImageUrl {
1249 url: "...".to_string(),
1250 detail: Some("low".to_string()),
1251 }),
1252 },
1253 ])),
1254 name: None,
1255 tool_calls: None,
1256 tool_call_id: None,
1257 usage: None,
1258 };
1259
1260 let json = serde_json::to_string(&message_with_image).unwrap();
1261 println!("Serialized JSON: {}", json);
1262
1263 assert!(json.contains("\"role\":\"user\""));
1264 assert!(json.contains("\"type\":\"text\""));
1265 assert!(json.contains("\"text\":\"What's in this image?\""));
1266 assert!(json.contains("\"type\":\"input_image\""));
1267 assert!(json.contains("\"url\":\"...\""));
1268 assert!(json.contains("\"detail\":\"low\""));
1269 }
1270
1271 #[test]
1272 fn test_response_format() {
1273 let json_format_request = ChatCompletionRequest {
1274 model: AgentModel::Smart.to_string(),
1275 messages: vec![ChatMessage {
1276 role: Role::User,
1277 content: Some(MessageContent::String(
1278 "Generate a JSON object with name and age fields".to_string(),
1279 )),
1280 name: None,
1281 tool_calls: None,
1282 tool_call_id: None,
1283 usage: None,
1284 }],
1285 response_format: Some(ResponseFormat {
1286 r#type: "json_object".to_string(),
1287 }),
1288 max_tokens: Some(100),
1289 temperature: None,
1290 frequency_penalty: None,
1291 logit_bias: None,
1292 logprobs: None,
1293 n: None,
1294 presence_penalty: None,
1295 seed: None,
1296 stop: None,
1297 stream: None,
1298 top_p: None,
1299 tools: None,
1300 tool_choice: None,
1301 user: None,
1302 context: None,
1303 };
1304
1305 let json = serde_json::to_string(&json_format_request).unwrap();
1306 assert!(json.contains("\"response_format\":{\"type\":\"json_object\"}"));
1307 }
1308
1309 #[test]
1310 fn test_llm_message_to_chat_message() {
1311 let llm_message = LLMMessage {
1313 role: "user".to_string(),
1314 content: LLMMessageContent::String("Hello, world!".to_string()),
1315 };
1316
1317 let chat_message = ChatMessage::from(llm_message);
1318 assert_eq!(chat_message.role, Role::User);
1319 match &chat_message.content {
1320 Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1321 _ => panic!("Expected string content"),
1322 }
1323 assert_eq!(chat_message.tool_calls, None);
1324
1325 let llm_message_with_tool = LLMMessage {
1327 role: "assistant".to_string(),
1328 content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolCall {
1329 id: "call_123".to_string(),
1330 name: "get_weather".to_string(),
1331 args: serde_json::json!({"location": "San Francisco"}),
1332 }]),
1333 };
1334
1335 let chat_message = ChatMessage::from(llm_message_with_tool);
1336 assert_eq!(chat_message.role, Role::Assistant);
1337 assert_eq!(chat_message.content, None); assert!(chat_message.tool_calls.is_some());
1339
1340 let tool_calls = chat_message.tool_calls.unwrap();
1341 assert_eq!(tool_calls.len(), 1);
1342 assert_eq!(tool_calls[0].id, "call_123");
1343 assert_eq!(tool_calls[0].function.name, "get_weather");
1344 assert!(tool_calls[0].function.arguments.contains("San Francisco"));
1345
1346 let llm_message_mixed = LLMMessage {
1348 role: "assistant".to_string(),
1349 content: LLMMessageContent::List(vec![
1350 LLMMessageTypedContent::Text {
1351 text: "The weather is:".to_string(),
1352 },
1353 LLMMessageTypedContent::ToolCall {
1354 id: "call_456".to_string(),
1355 name: "get_weather".to_string(),
1356 args: serde_json::json!({"location": "New York"}),
1357 },
1358 ]),
1359 };
1360
1361 let chat_message = ChatMessage::from(llm_message_mixed);
1362 assert_eq!(chat_message.role, Role::Assistant);
1363
1364 match &chat_message.content {
1365 Some(MessageContent::Array(parts)) => {
1366 assert_eq!(parts.len(), 1);
1367 assert_eq!(parts[0].r#type, "text");
1368 assert_eq!(parts[0].text, Some("The weather is:".to_string()));
1369 }
1370 _ => panic!("Expected array content"),
1371 }
1372
1373 let tool_calls = chat_message.tool_calls.unwrap();
1374 assert_eq!(tool_calls.len(), 1);
1375 assert_eq!(tool_calls[0].id, "call_456");
1376 assert_eq!(tool_calls[0].function.name, "get_weather");
1377 assert!(tool_calls[0].function.arguments.contains("New York"));
1378 }
1379}