stakpak_shared/models/integrations/
openai.rs

1use crate::models::llm::{
2    GenerationDelta, LLMChoice, LLMCompletionResponse, LLMMessage, LLMMessageContent,
3    LLMMessageImageSource, LLMMessageTypedContent, LLMTokenUsage, LLMTool,
4};
5use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
6use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8use uuid::Uuid;
9
10#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
11pub struct OpenAIConfig {
12    pub api_endpoint: Option<String>,
13    pub api_key: Option<String>,
14}
15
16#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
17pub enum OpenAIModel {
18    // Reasoning Models
19    #[serde(rename = "o3-2025-04-16")]
20    O3,
21    #[serde(rename = "o4-mini-2025-04-16")]
22    O4Mini,
23
24    #[default]
25    #[serde(rename = "gpt-5-2025-08-07")]
26    GPT5,
27    #[serde(rename = "gpt-5.1-2025-11-13")]
28    GPT51,
29    #[serde(rename = "gpt-5-mini-2025-08-07")]
30    GPT5Mini,
31    #[serde(rename = "gpt-5-nano-2025-08-07")]
32    GPT5Nano,
33
34    Custom(String),
35}
36
37impl OpenAIModel {
38    pub fn from_string(s: &str) -> Result<Self, String> {
39        serde_json::from_value(serde_json::Value::String(s.to_string()))
40            .map_err(|_| "Failed to deserialize OpenAI model".to_string())
41    }
42
43    /// Default smart model for OpenAI
44    pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
45
46    /// Default eco model for OpenAI
47    pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
48
49    /// Default recovery model for OpenAI
50    pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
51
52    /// Get default smart model as string
53    pub fn default_smart_model() -> String {
54        Self::DEFAULT_SMART_MODEL.to_string()
55    }
56
57    /// Get default eco model as string
58    pub fn default_eco_model() -> String {
59        Self::DEFAULT_ECO_MODEL.to_string()
60    }
61
62    /// Get default recovery model as string
63    pub fn default_recovery_model() -> String {
64        Self::DEFAULT_RECOVERY_MODEL.to_string()
65    }
66}
67
68impl ContextAware for OpenAIModel {
69    fn context_info(&self) -> ModelContextInfo {
70        let model_name = self.to_string();
71
72        if model_name.starts_with("o3") {
73            return ModelContextInfo {
74                max_tokens: 200_000,
75                pricing_tiers: vec![ContextPricingTier {
76                    label: "Standard".to_string(),
77                    input_cost_per_million: 2.0,
78                    output_cost_per_million: 8.0,
79                    upper_bound: None,
80                }],
81                approach_warning_threshold: 0.8,
82            };
83        }
84
85        if model_name.starts_with("o4-mini") {
86            return ModelContextInfo {
87                max_tokens: 200_000,
88                pricing_tiers: vec![ContextPricingTier {
89                    label: "Standard".to_string(),
90                    input_cost_per_million: 1.10,
91                    output_cost_per_million: 4.40,
92                    upper_bound: None,
93                }],
94                approach_warning_threshold: 0.8,
95            };
96        }
97
98        if model_name.starts_with("gpt-5-mini") {
99            return ModelContextInfo {
100                max_tokens: 400_000,
101                pricing_tiers: vec![ContextPricingTier {
102                    label: "Standard".to_string(),
103                    input_cost_per_million: 0.25,
104                    output_cost_per_million: 2.0,
105                    upper_bound: None,
106                }],
107                approach_warning_threshold: 0.8,
108            };
109        }
110
111        if model_name.starts_with("gpt-5-nano") {
112            return ModelContextInfo {
113                max_tokens: 400_000,
114                pricing_tiers: vec![ContextPricingTier {
115                    label: "Standard".to_string(),
116                    input_cost_per_million: 0.05,
117                    output_cost_per_million: 0.40,
118                    upper_bound: None,
119                }],
120                approach_warning_threshold: 0.8,
121            };
122        }
123
124        if model_name.starts_with("gpt-5") {
125            return ModelContextInfo {
126                max_tokens: 400_000,
127                pricing_tiers: vec![ContextPricingTier {
128                    label: "Standard".to_string(),
129                    input_cost_per_million: 1.25,
130                    output_cost_per_million: 10.0,
131                    upper_bound: None,
132                }],
133                approach_warning_threshold: 0.8,
134            };
135        }
136
137        ModelContextInfo::default()
138    }
139
140    fn model_name(&self) -> String {
141        match self {
142            OpenAIModel::O3 => "O3".to_string(),
143            OpenAIModel::O4Mini => "O4-mini".to_string(),
144            OpenAIModel::GPT5 => "GPT-5".to_string(),
145            OpenAIModel::GPT51 => "GPT-5.1".to_string(),
146            OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
147            OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
148            OpenAIModel::Custom(name) => format!("Custom ({})", name),
149        }
150    }
151}
152
153impl std::fmt::Display for OpenAIModel {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        match self {
156            OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
157
158            OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
159
160            OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
161            OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
162            OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
163            OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
164
165            OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
166        }
167    }
168}
169
170#[derive(Serialize, Deserialize, Debug)]
171pub struct OpenAIInput {
172    pub model: OpenAIModel,
173    pub messages: Vec<LLMMessage>,
174    pub max_tokens: u32,
175
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub json: Option<serde_json::Value>,
178
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub tools: Option<Vec<LLMTool>>,
181
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub reasoning_effort: Option<OpenAIReasoningEffort>,
184}
185
186impl OpenAIInput {
187    pub fn is_reasoning_model(&self) -> bool {
188        matches!(self.model, |OpenAIModel::O3| OpenAIModel::O4Mini
189            | OpenAIModel::GPT5
190            | OpenAIModel::GPT5Mini
191            | OpenAIModel::GPT5Nano)
192    }
193
194    pub fn is_standard_model(&self) -> bool {
195        !self.is_reasoning_model()
196    }
197}
198
199#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
200pub enum OpenAIReasoningEffort {
201    #[serde(rename = "minimal")]
202    Minimal,
203    #[serde(rename = "low")]
204    Low,
205    #[default]
206    #[serde(rename = "medium")]
207    Medium,
208    #[serde(rename = "high")]
209    High,
210}
211
212#[derive(Serialize, Deserialize, Debug)]
213pub struct OpenAITool {
214    pub r#type: String,
215    pub function: OpenAIToolFunction,
216}
217
218#[derive(Serialize, Deserialize, Debug)]
219pub struct OpenAIToolFunction {
220    pub name: String,
221    pub description: String,
222    pub parameters: serde_json::Value,
223}
224
225impl From<LLMTool> for OpenAITool {
226    fn from(tool: LLMTool) -> Self {
227        OpenAITool {
228            r#type: "function".to_string(),
229            function: OpenAIToolFunction {
230                name: tool.name,
231                description: tool.description,
232                parameters: tool.input_schema,
233            },
234        }
235    }
236}
237
238#[derive(Serialize, Deserialize, Debug, Clone)]
239pub struct OpenAIOutput {
240    pub model: String,
241    pub object: String,
242    pub choices: Vec<OpenAILLMChoice>,
243    pub created: u64,
244    pub usage: Option<LLMTokenUsage>,
245    pub id: String,
246}
247
248impl From<OpenAIOutput> for LLMCompletionResponse {
249    fn from(val: OpenAIOutput) -> Self {
250        LLMCompletionResponse {
251            model: val.model,
252            object: val.object,
253            choices: val.choices.into_iter().map(OpenAILLMChoice::into).collect(),
254            created: val.created,
255            usage: val.usage,
256            id: val.id,
257        }
258    }
259}
260
261#[derive(Serialize, Deserialize, Debug, Clone)]
262pub struct OpenAILLMChoice {
263    pub finish_reason: Option<String>,
264    pub index: u32,
265    pub message: OpenAILLMMessage,
266}
267
268impl From<OpenAILLMChoice> for LLMChoice {
269    fn from(val: OpenAILLMChoice) -> Self {
270        LLMChoice {
271            finish_reason: val.finish_reason,
272            index: val.index,
273            message: val.message.into(),
274        }
275    }
276}
277
278#[derive(Serialize, Deserialize, Debug, Clone)]
279pub struct OpenAILLMMessage {
280    pub role: String,
281    pub content: Option<String>,
282    pub tool_calls: Option<Vec<OpenAILLMMessageToolCall>>,
283}
284impl From<OpenAILLMMessage> for LLMMessage {
285    fn from(val: OpenAILLMMessage) -> Self {
286        LLMMessage {
287            role: val.role,
288            content: match val.tool_calls {
289                None => LLMMessageContent::String(val.content.unwrap_or_default()),
290                Some(tool_calls) => LLMMessageContent::List(
291                    std::iter::once(LLMMessageTypedContent::Text {
292                        text: val.content.unwrap_or_default(),
293                    })
294                    .chain(tool_calls.into_iter().map(|tool_call| {
295                        LLMMessageTypedContent::ToolCall {
296                            id: tool_call.id,
297                            name: tool_call.function.name,
298                            args: match serde_json::from_str(&tool_call.function.arguments) {
299                                Ok(args) => args,
300                                Err(_) => {
301                                    return LLMMessageTypedContent::Text {
302                                        text: String::from("Error parsing tool call arguments"),
303                                    };
304                                }
305                            },
306                        }
307                    }))
308                    .collect(),
309                ),
310            },
311        }
312    }
313}
314
315#[derive(Serialize, Deserialize, Debug, Clone)]
316pub struct OpenAILLMMessageToolCall {
317    pub id: String,
318    pub r#type: String,
319    pub function: OpenAILLMMessageToolCallFunction,
320}
321#[derive(Serialize, Deserialize, Debug, Clone)]
322pub struct OpenAILLMMessageToolCallFunction {
323    pub arguments: String,
324    pub name: String,
325}
326
327#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
328#[serde(rename_all = "lowercase")]
329pub enum Role {
330    System,
331    Developer,
332    User,
333    #[default]
334    Assistant,
335    Tool,
336    // Function,
337}
338
339impl std::fmt::Display for Role {
340    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        match self {
342            Role::System => write!(f, "system"),
343            Role::Developer => write!(f, "developer"),
344            Role::User => write!(f, "user"),
345            Role::Assistant => write!(f, "assistant"),
346            Role::Tool => write!(f, "tool"),
347        }
348    }
349}
350
351#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
352pub struct ChatCompletionRequest {
353    pub model: String,
354    pub messages: Vec<ChatMessage>,
355    #[serde(skip_serializing_if = "Option::is_none")]
356    pub frequency_penalty: Option<f32>,
357    #[serde(skip_serializing_if = "Option::is_none")]
358    pub logit_bias: Option<serde_json::Value>,
359    #[serde(skip_serializing_if = "Option::is_none")]
360    pub logprobs: Option<bool>,
361    #[serde(skip_serializing_if = "Option::is_none")]
362    pub max_tokens: Option<u32>,
363    #[serde(skip_serializing_if = "Option::is_none")]
364    pub n: Option<u32>,
365    #[serde(skip_serializing_if = "Option::is_none")]
366    pub presence_penalty: Option<f32>,
367    #[serde(skip_serializing_if = "Option::is_none")]
368    pub response_format: Option<ResponseFormat>,
369    #[serde(skip_serializing_if = "Option::is_none")]
370    pub seed: Option<i64>,
371    #[serde(skip_serializing_if = "Option::is_none")]
372    pub stop: Option<StopSequence>,
373    #[serde(skip_serializing_if = "Option::is_none")]
374    pub stream: Option<bool>,
375    #[serde(skip_serializing_if = "Option::is_none")]
376    pub temperature: Option<f32>,
377    #[serde(skip_serializing_if = "Option::is_none")]
378    pub top_p: Option<f32>,
379    #[serde(skip_serializing_if = "Option::is_none")]
380    pub tools: Option<Vec<Tool>>,
381    #[serde(skip_serializing_if = "Option::is_none")]
382    pub tool_choice: Option<ToolChoice>,
383    #[serde(skip_serializing_if = "Option::is_none")]
384    pub user: Option<String>,
385
386    #[serde(skip_serializing_if = "Option::is_none")]
387    pub context: Option<ChatCompletionContext>,
388}
389
390#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
391pub struct ChatCompletionContext {
392    pub scratchpad: Option<Value>,
393}
394
395#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
396pub struct ChatMessage {
397    pub role: Role,
398    pub content: Option<MessageContent>,
399    #[serde(skip_serializing_if = "Option::is_none")]
400    pub name: Option<String>,
401    #[serde(skip_serializing_if = "Option::is_none")]
402    pub tool_calls: Option<Vec<ToolCall>>,
403    #[serde(skip_serializing_if = "Option::is_none")]
404    pub tool_call_id: Option<String>,
405
406    // TODO: check if this is needed
407    #[serde(skip_serializing_if = "Option::is_none")]
408    pub usage: Option<LLMTokenUsage>,
409}
410
411impl ChatMessage {
412    pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
413        messages
414            .iter()
415            .rev()
416            .find(|message| message.role != Role::User && message.role != Role::Tool)
417    }
418
419    pub fn to_xml(&self) -> String {
420        match &self.content {
421            Some(MessageContent::String(s)) => {
422                format!("<message role=\"{}\">{}</message>", self.role, s)
423            }
424            Some(MessageContent::Array(parts)) => parts
425                .iter()
426                .map(|part| {
427                    format!(
428                        "<message role=\"{}\" type=\"{}\">{}</message>",
429                        self.role,
430                        part.r#type,
431                        part.text.clone().unwrap_or_default()
432                    )
433                })
434                .collect::<Vec<String>>()
435                .join("\n"),
436            None => String::new(),
437        }
438    }
439}
440
441#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
442#[serde(untagged)]
443pub enum MessageContent {
444    String(String),
445    Array(Vec<ContentPart>),
446}
447
448impl MessageContent {
449    pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
450        match self {
451            MessageContent::String(s) => MessageContent::String(format!(
452                "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
453            )),
454            MessageContent::Array(parts) => MessageContent::Array(
455                std::iter::once(ContentPart {
456                    r#type: "text".to_string(),
457                    text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
458                    image_url: None,
459                })
460                .chain(parts.iter().cloned())
461                .collect(),
462            ),
463        }
464    }
465
466    pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
467        match self {
468            MessageContent::String(s) => s
469                .rfind("<checkpoint_id>")
470                .and_then(|start| {
471                    s[start..]
472                        .find("</checkpoint_id>")
473                        .map(|end| (start + "<checkpoint_id>".len(), start + end))
474                })
475                .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
476            MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
477                part.text.as_deref().and_then(|text| {
478                    text.rfind("<checkpoint_id>")
479                        .and_then(|start| {
480                            text[start..]
481                                .find("</checkpoint_id>")
482                                .map(|end| (start + "<checkpoint_id>".len(), start + end))
483                        })
484                        .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
485                })
486            }),
487        }
488    }
489}
490
491impl std::fmt::Display for MessageContent {
492    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
493        match self {
494            MessageContent::String(s) => write!(f, "{s}"),
495            MessageContent::Array(parts) => {
496                let text_parts: Vec<String> =
497                    parts.iter().filter_map(|part| part.text.clone()).collect();
498                write!(f, "{}", text_parts.join("\n"))
499            }
500        }
501    }
502}
503impl Default for MessageContent {
504    fn default() -> Self {
505        MessageContent::String(String::new())
506    }
507}
508
509#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
510pub struct ContentPart {
511    pub r#type: String,
512    #[serde(skip_serializing_if = "Option::is_none")]
513    pub text: Option<String>,
514    #[serde(skip_serializing_if = "Option::is_none")]
515    pub image_url: Option<ImageUrl>,
516}
517
518#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
519pub struct ImageUrl {
520    pub url: String,
521    #[serde(skip_serializing_if = "Option::is_none")]
522    pub detail: Option<String>,
523}
524
525#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
526pub struct ResponseFormat {
527    pub r#type: String,
528}
529
530#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
531#[serde(untagged)]
532pub enum StopSequence {
533    String(String),
534    Array(Vec<String>),
535}
536
537#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
538pub struct Tool {
539    pub r#type: String,
540    pub function: FunctionDefinition,
541}
542
543#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
544pub struct FunctionDefinition {
545    pub name: String,
546    pub description: Option<String>,
547    pub parameters: serde_json::Value,
548}
549
550#[derive(Debug, Clone, PartialEq)]
551pub enum ToolChoice {
552    Auto,
553    Required,
554    Object(ToolChoiceObject),
555}
556
557impl Serialize for ToolChoice {
558    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
559    where
560        S: serde::Serializer,
561    {
562        match self {
563            ToolChoice::Auto => serializer.serialize_str("auto"),
564            ToolChoice::Required => serializer.serialize_str("required"),
565            ToolChoice::Object(obj) => obj.serialize(serializer),
566        }
567    }
568}
569
570impl<'de> Deserialize<'de> for ToolChoice {
571    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
572    where
573        D: serde::Deserializer<'de>,
574    {
575        struct ToolChoiceVisitor;
576
577        impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
578            type Value = ToolChoice;
579
580            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
581                formatter.write_str("string or object")
582            }
583
584            fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
585            where
586                E: serde::de::Error,
587            {
588                match value {
589                    "auto" => Ok(ToolChoice::Auto),
590                    "required" => Ok(ToolChoice::Required),
591                    _ => Err(serde::de::Error::unknown_variant(
592                        value,
593                        &["auto", "required"],
594                    )),
595                }
596            }
597
598            fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
599            where
600                M: serde::de::MapAccess<'de>,
601            {
602                let obj = ToolChoiceObject::deserialize(
603                    serde::de::value::MapAccessDeserializer::new(map),
604                )?;
605                Ok(ToolChoice::Object(obj))
606            }
607        }
608
609        deserializer.deserialize_any(ToolChoiceVisitor)
610    }
611}
612
613#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
614pub struct ToolChoiceObject {
615    pub r#type: String,
616    pub function: FunctionChoice,
617}
618
619#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
620pub struct FunctionChoice {
621    pub name: String,
622}
623
624#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
625pub struct ToolCall {
626    pub id: String,
627    pub r#type: String,
628    pub function: FunctionCall,
629}
630
631#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
632pub struct FunctionCall {
633    pub name: String,
634    pub arguments: String,
635}
636
637#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
638pub struct ChatCompletionResponse {
639    pub id: String,
640    pub object: String,
641    pub created: u64,
642    pub model: String,
643    pub choices: Vec<ChatCompletionChoice>,
644    pub usage: LLMTokenUsage,
645    #[serde(skip_serializing_if = "Option::is_none")]
646    pub system_fingerprint: Option<String>,
647    pub metadata: Option<serde_json::Value>,
648}
649
650#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
651pub struct ChatCompletionChoice {
652    pub index: usize,
653    pub message: ChatMessage,
654    pub logprobs: Option<LogProbs>,
655    pub finish_reason: FinishReason,
656}
657
658#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
659#[serde(rename_all = "snake_case")]
660pub enum FinishReason {
661    Stop,
662    Length,
663    ContentFilter,
664    ToolCalls,
665}
666
667#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
668pub struct LogProbs {
669    pub content: Option<Vec<LogProbContent>>,
670}
671
672#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
673pub struct LogProbContent {
674    pub token: String,
675    pub logprob: f32,
676    pub bytes: Option<Vec<u8>>,
677    pub top_logprobs: Option<Vec<TokenLogprob>>,
678}
679
680#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
681pub struct TokenLogprob {
682    pub token: String,
683    pub logprob: f32,
684    pub bytes: Option<Vec<u8>>,
685}
686
687#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
688pub struct ChatCompletionStreamResponse {
689    pub id: String,
690    pub object: String,
691    pub created: u64,
692    pub model: String,
693    pub choices: Vec<ChatCompletionStreamChoice>,
694    pub usage: Option<LLMTokenUsage>,
695    pub metadata: Option<serde_json::Value>,
696}
697#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
698pub struct ChatCompletionStreamChoice {
699    pub index: usize,
700    pub delta: ChatMessageDelta,
701    pub finish_reason: Option<FinishReason>,
702}
703#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
704pub struct ChatMessageDelta {
705    #[serde(skip_serializing_if = "Option::is_none")]
706    pub role: Option<Role>,
707    #[serde(skip_serializing_if = "Option::is_none")]
708    pub content: Option<String>,
709    #[serde(skip_serializing_if = "Option::is_none")]
710    pub tool_calls: Option<Vec<ToolCallDelta>>,
711}
712
713#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
714pub struct ToolCallDelta {
715    pub index: usize,
716    pub id: Option<String>,
717    pub r#type: Option<String>,
718    pub function: Option<FunctionCallDelta>,
719}
720
721#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
722pub struct FunctionCallDelta {
723    pub name: Option<String>,
724    pub arguments: Option<String>,
725}
726
727impl From<LLMMessage> for ChatMessage {
728    fn from(llm_message: LLMMessage) -> Self {
729        let role = match llm_message.role.as_str() {
730            "system" => Role::System,
731            "user" => Role::User,
732            "assistant" => Role::Assistant,
733            "tool" => Role::Tool,
734            // "function" => Role::Function,
735            "developer" => Role::Developer,
736            _ => Role::User, // Default to user for unknown roles
737        };
738
739        let (content, tool_calls) = match llm_message.content {
740            LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None),
741            LLMMessageContent::List(items) => {
742                let mut text_parts = Vec::new();
743                let mut tool_call_parts = Vec::new();
744
745                for item in items {
746                    match item {
747                        LLMMessageTypedContent::Text { text } => {
748                            text_parts.push(ContentPart {
749                                r#type: "text".to_string(),
750                                text: Some(text),
751                                image_url: None,
752                            });
753                        }
754                        LLMMessageTypedContent::ToolCall { id, name, args } => {
755                            tool_call_parts.push(ToolCall {
756                                id,
757                                r#type: "function".to_string(),
758                                function: FunctionCall {
759                                    name,
760                                    arguments: args.to_string(),
761                                },
762                            });
763                        }
764                        LLMMessageTypedContent::ToolResult { content, .. } => {
765                            text_parts.push(ContentPart {
766                                r#type: "text".to_string(),
767                                text: Some(content),
768                                image_url: None,
769                            });
770                        }
771                        LLMMessageTypedContent::Image { source } => {
772                            text_parts.push(ContentPart {
773                                r#type: "image_url".to_string(),
774                                text: None,
775                                image_url: Some(ImageUrl {
776                                    url: format!(
777                                        "data:{};base64,{}",
778                                        source.media_type, source.data
779                                    ),
780                                    detail: None,
781                                }),
782                            });
783                        }
784                    }
785                }
786
787                let content = if !text_parts.is_empty() {
788                    Some(MessageContent::Array(text_parts))
789                } else {
790                    None
791                };
792
793                let tool_calls = if !tool_call_parts.is_empty() {
794                    Some(tool_call_parts)
795                } else {
796                    None
797                };
798
799                (content, tool_calls)
800            }
801        };
802
803        ChatMessage {
804            role,
805            content,
806            name: None, // LLMMessage doesn't have a name field
807            tool_calls,
808            tool_call_id: None, // LLMMessage doesn't have a tool_call_id field
809            usage: None,
810        }
811    }
812}
813
814impl From<ChatMessage> for LLMMessage {
815    fn from(chat_message: ChatMessage) -> Self {
816        let mut content_parts = Vec::new();
817
818        // Handle text content
819        match chat_message.content {
820            Some(MessageContent::String(s)) => {
821                if !s.is_empty() {
822                    content_parts.push(LLMMessageTypedContent::Text { text: s });
823                }
824            }
825            Some(MessageContent::Array(parts)) => {
826                for part in parts {
827                    if let Some(text) = part.text {
828                        content_parts.push(LLMMessageTypedContent::Text { text });
829                    } else if let Some(image_url) = part.image_url {
830                        let (media_type, data) = if image_url.url.starts_with("data:") {
831                            let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
832                            if parts.len() == 2 {
833                                let meta = parts[0];
834                                let data = parts[1];
835                                let media_type = meta
836                                    .trim_start_matches("data:")
837                                    .trim_end_matches(";base64")
838                                    .to_string();
839                                (media_type, data.to_string())
840                            } else {
841                                ("image/jpeg".to_string(), image_url.url)
842                            }
843                        } else {
844                            ("image/jpeg".to_string(), image_url.url)
845                        };
846
847                        content_parts.push(LLMMessageTypedContent::Image {
848                            source: LLMMessageImageSource {
849                                r#type: "base64".to_string(),
850                                media_type,
851                                data,
852                            },
853                        });
854                    }
855                }
856            }
857            None => {}
858        }
859
860        // Handle tool calls
861        if let Some(tool_calls) = chat_message.tool_calls {
862            for tool_call in tool_calls {
863                let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
864                content_parts.push(LLMMessageTypedContent::ToolCall {
865                    id: tool_call.id,
866                    name: tool_call.function.name,
867                    args,
868                });
869            }
870        }
871
872        LLMMessage {
873            role: chat_message.role.to_string(),
874            content: if content_parts.is_empty() {
875                LLMMessageContent::String(String::new())
876            } else if content_parts.len() == 1 {
877                match &content_parts[0] {
878                    LLMMessageTypedContent::Text { text } => {
879                        LLMMessageContent::String(text.clone())
880                    }
881                    _ => LLMMessageContent::List(content_parts),
882                }
883            } else {
884                LLMMessageContent::List(content_parts)
885            },
886        }
887    }
888}
889
890#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
891pub enum AgentModel {
892    #[serde(rename = "smart")]
893    #[default]
894    Smart,
895    #[serde(rename = "eco")]
896    Eco,
897    #[serde(rename = "recovery")]
898    Recovery,
899}
900
901impl std::fmt::Display for AgentModel {
902    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
903        match self {
904            AgentModel::Smart => write!(f, "smart"),
905            AgentModel::Eco => write!(f, "eco"),
906            AgentModel::Recovery => write!(f, "recovery"),
907        }
908    }
909}
910
911impl From<String> for AgentModel {
912    fn from(value: String) -> Self {
913        match value.as_str() {
914            "eco" => AgentModel::Eco,
915            "recovery" => AgentModel::Recovery,
916            _ => AgentModel::Smart,
917        }
918    }
919}
920
921impl ChatCompletionRequest {
922    pub fn new(
923        model: String,
924        messages: Vec<ChatMessage>,
925        tools: Option<Vec<Tool>>,
926        stream: Option<bool>,
927    ) -> Self {
928        Self {
929            model,
930            messages,
931            frequency_penalty: None,
932            logit_bias: None,
933            logprobs: None,
934            max_tokens: None,
935            n: None,
936            presence_penalty: None,
937            response_format: None,
938            seed: None,
939            stop: None,
940            stream,
941            temperature: None,
942            top_p: None,
943            tools,
944            tool_choice: None,
945            user: None,
946            context: None,
947        }
948    }
949}
950
951impl From<Tool> for LLMTool {
952    fn from(tool: Tool) -> Self {
953        LLMTool {
954            name: tool.function.name,
955            description: tool.function.description.unwrap_or_default(),
956            input_schema: tool.function.parameters,
957        }
958    }
959}
960
961#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
962pub enum ToolCallResultStatus {
963    Success,
964    Error,
965    Cancelled,
966}
967
968#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
969pub struct ToolCallResult {
970    pub call: ToolCall,
971    pub result: String,
972    pub status: ToolCallResultStatus,
973}
974
975#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
976pub struct ToolCallResultProgress {
977    pub id: Uuid,
978    pub message: String,
979}
980
981impl From<GenerationDelta> for ChatMessageDelta {
982    fn from(delta: GenerationDelta) -> Self {
983        match delta {
984            GenerationDelta::Content { content } => ChatMessageDelta {
985                role: Some(Role::Assistant),
986                content: Some(content),
987                tool_calls: None,
988            },
989            GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
990                role: Some(Role::Assistant),
991                content: None,
992                tool_calls: None,
993            },
994            GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
995                role: Some(Role::Assistant),
996                content: None,
997                tool_calls: Some(vec![ToolCallDelta {
998                    index: tool_use.index,
999                    id: tool_use.id,
1000                    r#type: Some("function".to_string()),
1001                    function: Some(FunctionCallDelta {
1002                        name: tool_use.name,
1003                        arguments: tool_use.input,
1004                    }),
1005                }]),
1006            },
1007            _ => ChatMessageDelta {
1008                role: Some(Role::Assistant),
1009                content: None,
1010                tool_calls: None,
1011            },
1012        }
1013    }
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018    use super::*;
1019
1020    #[test]
1021    fn test_serialize_basic_request() {
1022        let request = ChatCompletionRequest {
1023            model: AgentModel::Smart.to_string(),
1024            messages: vec![
1025                ChatMessage {
1026                    role: Role::System,
1027                    content: Some(MessageContent::String(
1028                        "You are a helpful assistant.".to_string(),
1029                    )),
1030                    name: None,
1031                    tool_calls: None,
1032                    tool_call_id: None,
1033                    usage: None,
1034                },
1035                ChatMessage {
1036                    role: Role::User,
1037                    content: Some(MessageContent::String("Hello!".to_string())),
1038                    name: None,
1039                    tool_calls: None,
1040                    tool_call_id: None,
1041                    usage: None,
1042                },
1043            ],
1044            frequency_penalty: None,
1045            logit_bias: None,
1046            logprobs: None,
1047            max_tokens: Some(100),
1048            n: None,
1049            presence_penalty: None,
1050            response_format: None,
1051            seed: None,
1052            stop: None,
1053            stream: None,
1054            temperature: Some(0.7),
1055            top_p: None,
1056            tools: None,
1057            tool_choice: None,
1058            user: None,
1059            context: None,
1060        };
1061
1062        let json = serde_json::to_string(&request).unwrap();
1063        assert!(json.contains("\"model\":\"smart\""));
1064        assert!(json.contains("\"messages\":["));
1065        assert!(json.contains("\"role\":\"system\""));
1066        assert!(json.contains("\"content\":\"You are a helpful assistant.\""));
1067        assert!(json.contains("\"role\":\"user\""));
1068        assert!(json.contains("\"content\":\"Hello!\""));
1069        assert!(json.contains("\"max_tokens\":100"));
1070        assert!(json.contains("\"temperature\":0.7"));
1071    }
1072
1073    #[test]
1074    fn test_deserialize_response() {
1075        let json = r#"{
1076            "id": "chatcmpl-123",
1077            "object": "chat.completion",
1078            "created": 1677652288,
1079            "model": "smart",
1080            "system_fingerprint": "fp_123abc",
1081            "choices": [{
1082                "index": 0,
1083                "message": {
1084                    "role": "assistant",
1085                    "content": "Hello! How can I help you today?"
1086                },
1087                "finish_reason": "stop"
1088            }],
1089            "usage": {
1090                "prompt_tokens": 9,
1091                "completion_tokens": 12,
1092                "total_tokens": 21
1093            }
1094        }"#;
1095
1096        let response: ChatCompletionResponse = serde_json::from_str(json).unwrap();
1097        assert_eq!(response.id, "chatcmpl-123");
1098        assert_eq!(response.object, "chat.completion");
1099        assert_eq!(response.created, 1677652288);
1100        assert_eq!(response.model, AgentModel::Smart.to_string());
1101        assert_eq!(response.system_fingerprint, Some("fp_123abc".to_string()));
1102
1103        assert_eq!(response.choices.len(), 1);
1104        assert_eq!(response.choices[0].index, 0);
1105        assert_eq!(response.choices[0].message.role, Role::Assistant);
1106
1107        match &response.choices[0].message.content {
1108            Some(MessageContent::String(content)) => {
1109                assert_eq!(content, "Hello! How can I help you today?");
1110            }
1111            _ => panic!("Expected string content"),
1112        }
1113
1114        assert_eq!(response.choices[0].finish_reason, FinishReason::Stop);
1115        assert_eq!(response.usage.prompt_tokens, 9);
1116        assert_eq!(response.usage.completion_tokens, 12);
1117        assert_eq!(response.usage.total_tokens, 21);
1118    }
1119
1120    #[test]
1121    fn test_tool_calls_request_response() {
1122        // Test a request with tools
1123        let tools_request = ChatCompletionRequest {
1124            model: AgentModel::Smart.to_string(),
1125            messages: vec![ChatMessage {
1126                role: Role::User,
1127                content: Some(MessageContent::String(
1128                    "What's the weather in San Francisco?".to_string(),
1129                )),
1130                name: None,
1131                tool_calls: None,
1132                tool_call_id: None,
1133                usage: None,
1134            }],
1135            tools: Some(vec![Tool {
1136                r#type: "function".to_string(),
1137                function: FunctionDefinition {
1138                    name: "get_weather".to_string(),
1139                    description: Some("Get the current weather in a given location".to_string()),
1140                    parameters: serde_json::json!({
1141                        "type": "object",
1142                        "properties": {
1143                            "location": {
1144                                "type": "string",
1145                                "description": "The city and state, e.g. San Francisco, CA"
1146                            }
1147                        },
1148                        "required": ["location"]
1149                    }),
1150                },
1151            }]),
1152            tool_choice: Some(ToolChoice::Auto),
1153            max_tokens: Some(100),
1154            temperature: Some(0.7),
1155            frequency_penalty: None,
1156            logit_bias: None,
1157            logprobs: None,
1158            n: None,
1159            presence_penalty: None,
1160            response_format: None,
1161            seed: None,
1162            stop: None,
1163            stream: None,
1164            top_p: None,
1165            user: None,
1166            context: None,
1167        };
1168
1169        let json = serde_json::to_string(&tools_request).unwrap();
1170        println!("Tool request JSON: {}", json);
1171
1172        assert!(json.contains("\"tools\":["));
1173        assert!(json.contains("\"type\":\"function\""));
1174        assert!(json.contains("\"name\":\"get_weather\""));
1175        // Auto should be serialized as "auto" (string)
1176        assert!(json.contains("\"tool_choice\":\"auto\""));
1177
1178        // Test response with tool calls
1179        let tool_response_json = r#"{
1180            "id": "chatcmpl-123",
1181            "object": "chat.completion",
1182            "created": 1677652288,
1183            "model": "eco",
1184            "choices": [{
1185                "index": 0,
1186                "message": {
1187                    "role": "assistant",
1188                    "content": null,
1189                    "tool_calls": [
1190                        {
1191                            "id": "call_abc123",
1192                            "type": "function",
1193                            "function": {
1194                                "name": "get_weather",
1195                                "arguments": "{\"location\":\"San Francisco, CA\"}"
1196                            }
1197                        }
1198                    ]
1199                },
1200                "finish_reason": "tool_calls"
1201            }],
1202            "usage": {
1203                "prompt_tokens": 82,
1204                "completion_tokens": 17,
1205                "total_tokens": 99
1206            }
1207        }"#;
1208
1209        let tool_response: ChatCompletionResponse =
1210            serde_json::from_str(tool_response_json).unwrap();
1211        assert_eq!(tool_response.choices[0].message.role, Role::Assistant);
1212        assert_eq!(tool_response.choices[0].message.content, None);
1213        assert!(tool_response.choices[0].message.tool_calls.is_some());
1214
1215        let tool_calls = tool_response.choices[0]
1216            .message
1217            .tool_calls
1218            .as_ref()
1219            .unwrap();
1220        assert_eq!(tool_calls.len(), 1);
1221        assert_eq!(tool_calls[0].id, "call_abc123");
1222        assert_eq!(tool_calls[0].r#type, "function");
1223        assert_eq!(tool_calls[0].function.name, "get_weather");
1224        assert_eq!(
1225            tool_calls[0].function.arguments,
1226            "{\"location\":\"San Francisco, CA\"}"
1227        );
1228
1229        assert_eq!(
1230            tool_response.choices[0].finish_reason,
1231            FinishReason::ToolCalls,
1232        );
1233    }
1234
1235    #[test]
1236    fn test_content_with_image() {
1237        let message_with_image = ChatMessage {
1238            role: Role::User,
1239            content: Some(MessageContent::Array(vec![
1240                ContentPart {
1241                    r#type: "text".to_string(),
1242                    text: Some("What's in this image?".to_string()),
1243                    image_url: None,
1244                },
1245                ContentPart {
1246                    r#type: "input_image".to_string(),
1247                    text: None,
1248                    image_url: Some(ImageUrl {
1249                        url: "...".to_string(),
1250                        detail: Some("low".to_string()),
1251                    }),
1252                },
1253            ])),
1254            name: None,
1255            tool_calls: None,
1256            tool_call_id: None,
1257            usage: None,
1258        };
1259
1260        let json = serde_json::to_string(&message_with_image).unwrap();
1261        println!("Serialized JSON: {}", json);
1262
1263        assert!(json.contains("\"role\":\"user\""));
1264        assert!(json.contains("\"type\":\"text\""));
1265        assert!(json.contains("\"text\":\"What's in this image?\""));
1266        assert!(json.contains("\"type\":\"input_image\""));
1267        assert!(json.contains("\"url\":\"...\""));
1268        assert!(json.contains("\"detail\":\"low\""));
1269    }
1270
1271    #[test]
1272    fn test_response_format() {
1273        let json_format_request = ChatCompletionRequest {
1274            model: AgentModel::Smart.to_string(),
1275            messages: vec![ChatMessage {
1276                role: Role::User,
1277                content: Some(MessageContent::String(
1278                    "Generate a JSON object with name and age fields".to_string(),
1279                )),
1280                name: None,
1281                tool_calls: None,
1282                tool_call_id: None,
1283                usage: None,
1284            }],
1285            response_format: Some(ResponseFormat {
1286                r#type: "json_object".to_string(),
1287            }),
1288            max_tokens: Some(100),
1289            temperature: None,
1290            frequency_penalty: None,
1291            logit_bias: None,
1292            logprobs: None,
1293            n: None,
1294            presence_penalty: None,
1295            seed: None,
1296            stop: None,
1297            stream: None,
1298            top_p: None,
1299            tools: None,
1300            tool_choice: None,
1301            user: None,
1302            context: None,
1303        };
1304
1305        let json = serde_json::to_string(&json_format_request).unwrap();
1306        assert!(json.contains("\"response_format\":{\"type\":\"json_object\"}"));
1307    }
1308
1309    #[test]
1310    fn test_llm_message_to_chat_message() {
1311        // Test simple string content
1312        let llm_message = LLMMessage {
1313            role: "user".to_string(),
1314            content: LLMMessageContent::String("Hello, world!".to_string()),
1315        };
1316
1317        let chat_message = ChatMessage::from(llm_message);
1318        assert_eq!(chat_message.role, Role::User);
1319        match &chat_message.content {
1320            Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1321            _ => panic!("Expected string content"),
1322        }
1323        assert_eq!(chat_message.tool_calls, None);
1324
1325        // Test tool call conversion
1326        let llm_message_with_tool = LLMMessage {
1327            role: "assistant".to_string(),
1328            content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolCall {
1329                id: "call_123".to_string(),
1330                name: "get_weather".to_string(),
1331                args: serde_json::json!({"location": "San Francisco"}),
1332            }]),
1333        };
1334
1335        let chat_message = ChatMessage::from(llm_message_with_tool);
1336        assert_eq!(chat_message.role, Role::Assistant);
1337        assert_eq!(chat_message.content, None); // No text content
1338        assert!(chat_message.tool_calls.is_some());
1339
1340        let tool_calls = chat_message.tool_calls.unwrap();
1341        assert_eq!(tool_calls.len(), 1);
1342        assert_eq!(tool_calls[0].id, "call_123");
1343        assert_eq!(tool_calls[0].function.name, "get_weather");
1344        assert!(tool_calls[0].function.arguments.contains("San Francisco"));
1345
1346        // Test mixed content
1347        let llm_message_mixed = LLMMessage {
1348            role: "assistant".to_string(),
1349            content: LLMMessageContent::List(vec![
1350                LLMMessageTypedContent::Text {
1351                    text: "The weather is:".to_string(),
1352                },
1353                LLMMessageTypedContent::ToolCall {
1354                    id: "call_456".to_string(),
1355                    name: "get_weather".to_string(),
1356                    args: serde_json::json!({"location": "New York"}),
1357                },
1358            ]),
1359        };
1360
1361        let chat_message = ChatMessage::from(llm_message_mixed);
1362        assert_eq!(chat_message.role, Role::Assistant);
1363
1364        match &chat_message.content {
1365            Some(MessageContent::Array(parts)) => {
1366                assert_eq!(parts.len(), 1);
1367                assert_eq!(parts[0].r#type, "text");
1368                assert_eq!(parts[0].text, Some("The weather is:".to_string()));
1369            }
1370            _ => panic!("Expected array content"),
1371        }
1372
1373        let tool_calls = chat_message.tool_calls.unwrap();
1374        assert_eq!(tool_calls.len(), 1);
1375        assert_eq!(tool_calls[0].id, "call_456");
1376        assert_eq!(tool_calls[0].function.name, "get_weather");
1377        assert!(tool_calls[0].function.arguments.contains("New York"));
1378    }
1379}