1use crate::models::llm::{
12 GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13 LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27 pub api_endpoint: Option<String>,
28 pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32 pub const OPENAI_CODEX_BASE_URL: &'static str = "https://chatgpt.com/backend-api/codex";
33 const OPENAI_AUTH_CLAIM: &'static str = "https://api.openai.com/auth";
34
35 pub fn with_api_key(api_key: impl Into<String>) -> Self {
37 Self {
38 api_key: Some(api_key.into()),
39 api_endpoint: None,
40 }
41 }
42
43 pub fn extract_chatgpt_account_id(access_token: &str) -> Option<String> {
49 let claims = crate::jwt::decode_jwt_payload_unverified(access_token)?;
50 let auth_claim = claims.get(Self::OPENAI_AUTH_CLAIM)?;
51
52 match auth_claim {
53 Value::Object(map) => map
54 .get("chatgpt_account_id")
55 .and_then(Value::as_str)
56 .map(ToString::to_string),
57 Value::String(raw_json) => {
58 serde_json::from_str::<Value>(raw_json)
59 .ok()
60 .and_then(|value| {
61 value
62 .get("chatgpt_account_id")
63 .and_then(Value::as_str)
64 .map(ToString::to_string)
65 })
66 }
67 _ => None,
68 }
69 }
70}
71
72#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
78pub enum OpenAIModel {
79 #[serde(rename = "o3-2025-04-16")]
81 O3,
82 #[serde(rename = "o4-mini-2025-04-16")]
83 O4Mini,
84
85 #[default]
86 #[serde(rename = "gpt-5-2025-08-07")]
87 GPT5,
88 #[serde(rename = "gpt-5.1-2025-11-13")]
89 GPT51,
90 #[serde(rename = "gpt-5-mini-2025-08-07")]
91 GPT5Mini,
92 #[serde(rename = "gpt-5-nano-2025-08-07")]
93 GPT5Nano,
94
95 Custom(String),
96}
97
98impl OpenAIModel {
99 pub fn from_string(s: &str) -> Result<Self, String> {
100 serde_json::from_value(serde_json::Value::String(s.to_string()))
101 .map_err(|_| "Failed to deserialize OpenAI model".to_string())
102 }
103}
104
105impl ContextAware for OpenAIModel {
106 fn context_info(&self) -> ModelContextInfo {
107 let model_name = self.to_string();
108
109 if model_name.starts_with("o3") {
110 return ModelContextInfo {
111 max_tokens: 200_000,
112 pricing_tiers: vec![ContextPricingTier {
113 label: "Standard".to_string(),
114 input_cost_per_million: 2.0,
115 output_cost_per_million: 8.0,
116 upper_bound: None,
117 }],
118 approach_warning_threshold: 0.8,
119 };
120 }
121
122 if model_name.starts_with("o4-mini") {
123 return ModelContextInfo {
124 max_tokens: 200_000,
125 pricing_tiers: vec![ContextPricingTier {
126 label: "Standard".to_string(),
127 input_cost_per_million: 1.10,
128 output_cost_per_million: 4.40,
129 upper_bound: None,
130 }],
131 approach_warning_threshold: 0.8,
132 };
133 }
134
135 if model_name.starts_with("gpt-5-mini") {
136 return ModelContextInfo {
137 max_tokens: 400_000,
138 pricing_tiers: vec![ContextPricingTier {
139 label: "Standard".to_string(),
140 input_cost_per_million: 0.25,
141 output_cost_per_million: 2.0,
142 upper_bound: None,
143 }],
144 approach_warning_threshold: 0.8,
145 };
146 }
147
148 if model_name.starts_with("gpt-5-nano") {
149 return ModelContextInfo {
150 max_tokens: 400_000,
151 pricing_tiers: vec![ContextPricingTier {
152 label: "Standard".to_string(),
153 input_cost_per_million: 0.05,
154 output_cost_per_million: 0.40,
155 upper_bound: None,
156 }],
157 approach_warning_threshold: 0.8,
158 };
159 }
160
161 if model_name.starts_with("gpt-5") {
162 return ModelContextInfo {
163 max_tokens: 400_000,
164 pricing_tiers: vec![ContextPricingTier {
165 label: "Standard".to_string(),
166 input_cost_per_million: 1.25,
167 output_cost_per_million: 10.0,
168 upper_bound: None,
169 }],
170 approach_warning_threshold: 0.8,
171 };
172 }
173
174 ModelContextInfo::default()
175 }
176
177 fn model_name(&self) -> String {
178 match self {
179 OpenAIModel::O3 => "O3".to_string(),
180 OpenAIModel::O4Mini => "O4-mini".to_string(),
181 OpenAIModel::GPT5 => "GPT-5".to_string(),
182 OpenAIModel::GPT51 => "GPT-5.1".to_string(),
183 OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
184 OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
185 OpenAIModel::Custom(name) => format!("Custom ({})", name),
186 }
187 }
188}
189
190impl std::fmt::Display for OpenAIModel {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 match self {
193 OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
194 OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
195 OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
196 OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
197 OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
198 OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
199 OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
200 }
201 }
202}
203
204#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
210#[serde(rename_all = "lowercase")]
211pub enum Role {
212 System,
213 Developer,
214 User,
215 #[default]
216 Assistant,
217 Tool,
218}
219
220impl std::fmt::Display for Role {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 match self {
223 Role::System => write!(f, "system"),
224 Role::Developer => write!(f, "developer"),
225 Role::User => write!(f, "user"),
226 Role::Assistant => write!(f, "assistant"),
227 Role::Tool => write!(f, "tool"),
228 }
229 }
230}
231
232#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
234pub struct ModelInfo {
235 pub provider: String,
237 pub id: String,
239}
240
241#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
243pub struct ChatMessage {
244 pub role: Role,
245 pub content: Option<MessageContent>,
246 #[serde(skip_serializing_if = "Option::is_none")]
247 pub name: Option<String>,
248 #[serde(skip_serializing_if = "Option::is_none")]
249 pub tool_calls: Option<Vec<ToolCall>>,
250 #[serde(skip_serializing_if = "Option::is_none")]
251 pub tool_call_id: Option<String>,
252 #[serde(skip_serializing_if = "Option::is_none")]
253 pub usage: Option<LLMTokenUsage>,
254
255 #[serde(skip_serializing_if = "Option::is_none")]
258 pub id: Option<String>,
259 #[serde(skip_serializing_if = "Option::is_none")]
261 pub model: Option<ModelInfo>,
262 #[serde(skip_serializing_if = "Option::is_none")]
264 pub cost: Option<f64>,
265 #[serde(skip_serializing_if = "Option::is_none")]
267 pub finish_reason: Option<String>,
268 #[serde(skip_serializing_if = "Option::is_none")]
270 pub created_at: Option<i64>,
271 #[serde(skip_serializing_if = "Option::is_none")]
273 pub completed_at: Option<i64>,
274 #[serde(skip_serializing_if = "Option::is_none")]
276 pub metadata: Option<serde_json::Value>,
277}
278
279impl ChatMessage {
280 pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
281 messages
282 .iter()
283 .rev()
284 .find(|message| message.role != Role::User && message.role != Role::Tool)
285 }
286
287 pub fn to_xml(&self) -> String {
288 match &self.content {
289 Some(MessageContent::String(s)) => {
290 format!("<message role=\"{}\">{}</message>", self.role, s)
291 }
292 Some(MessageContent::Array(parts)) => parts
293 .iter()
294 .map(|part| {
295 format!(
296 "<message role=\"{}\" type=\"{}\">{}</message>",
297 self.role,
298 part.r#type,
299 part.text.clone().unwrap_or_default()
300 )
301 })
302 .collect::<Vec<String>>()
303 .join("\n"),
304 None => String::new(),
305 }
306 }
307}
308
309#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
311#[serde(untagged)]
312pub enum MessageContent {
313 String(String),
314 Array(Vec<ContentPart>),
315}
316
317impl MessageContent {
318 pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
319 match self {
320 MessageContent::String(s) => MessageContent::String(format!(
321 "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
322 )),
323 MessageContent::Array(parts) => MessageContent::Array(
324 std::iter::once(ContentPart {
325 r#type: "text".to_string(),
326 text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
327 image_url: None,
328 })
329 .chain(parts.iter().cloned())
330 .collect(),
331 ),
332 }
333 }
334
335 #[allow(clippy::string_slice)]
337 pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
338 match self {
339 MessageContent::String(s) => s
340 .rfind("<checkpoint_id>")
341 .and_then(|start| {
342 s[start..]
343 .find("</checkpoint_id>")
344 .map(|end| (start + "<checkpoint_id>".len(), start + end))
345 })
346 .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
347 MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
348 part.text.as_deref().and_then(|text| {
349 text.rfind("<checkpoint_id>")
350 .and_then(|start| {
351 text[start..]
352 .find("</checkpoint_id>")
353 .map(|end| (start + "<checkpoint_id>".len(), start + end))
354 })
355 .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
356 })
357 }),
358 }
359 }
360}
361
362impl std::fmt::Display for MessageContent {
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 match self {
365 MessageContent::String(s) => write!(f, "{s}"),
366 MessageContent::Array(parts) => {
367 let text_parts: Vec<String> =
368 parts.iter().filter_map(|part| part.text.clone()).collect();
369 write!(f, "{}", text_parts.join("\n"))
370 }
371 }
372 }
373}
374
375impl Default for MessageContent {
376 fn default() -> Self {
377 MessageContent::String(String::new())
378 }
379}
380
381#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
383pub struct ContentPart {
384 pub r#type: String,
385 #[serde(skip_serializing_if = "Option::is_none")]
386 pub text: Option<String>,
387 #[serde(skip_serializing_if = "Option::is_none")]
388 pub image_url: Option<ImageUrl>,
389}
390
391#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
393pub struct ImageUrl {
394 pub url: String,
395 #[serde(skip_serializing_if = "Option::is_none")]
396 pub detail: Option<String>,
397}
398
399#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
405pub struct Tool {
406 pub r#type: String,
407 pub function: FunctionDefinition,
408}
409
410#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
412pub struct FunctionDefinition {
413 pub name: String,
414 pub description: Option<String>,
415 pub parameters: serde_json::Value,
416}
417
418impl From<Tool> for LLMTool {
419 fn from(tool: Tool) -> Self {
420 LLMTool {
421 name: tool.function.name,
422 description: tool.function.description.unwrap_or_default(),
423 input_schema: tool.function.parameters,
424 }
425 }
426}
427
428#[derive(Debug, Clone, PartialEq)]
430pub enum ToolChoice {
431 Auto,
432 Required,
433 Object(ToolChoiceObject),
434}
435
436impl Serialize for ToolChoice {
437 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
438 where
439 S: serde::Serializer,
440 {
441 match self {
442 ToolChoice::Auto => serializer.serialize_str("auto"),
443 ToolChoice::Required => serializer.serialize_str("required"),
444 ToolChoice::Object(obj) => obj.serialize(serializer),
445 }
446 }
447}
448
449impl<'de> Deserialize<'de> for ToolChoice {
450 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
451 where
452 D: serde::Deserializer<'de>,
453 {
454 struct ToolChoiceVisitor;
455
456 impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
457 type Value = ToolChoice;
458
459 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
460 formatter.write_str("string or object")
461 }
462
463 fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
464 where
465 E: serde::de::Error,
466 {
467 match value {
468 "auto" => Ok(ToolChoice::Auto),
469 "required" => Ok(ToolChoice::Required),
470 _ => Err(serde::de::Error::unknown_variant(
471 value,
472 &["auto", "required"],
473 )),
474 }
475 }
476
477 fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
478 where
479 M: serde::de::MapAccess<'de>,
480 {
481 let obj = ToolChoiceObject::deserialize(
482 serde::de::value::MapAccessDeserializer::new(map),
483 )?;
484 Ok(ToolChoice::Object(obj))
485 }
486 }
487
488 deserializer.deserialize_any(ToolChoiceVisitor)
489 }
490}
491
492#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
493pub struct ToolChoiceObject {
494 pub r#type: String,
495 pub function: FunctionChoice,
496}
497
498#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
499pub struct FunctionChoice {
500 pub name: String,
501}
502
503#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
505pub struct ToolCall {
506 pub id: String,
507 pub r#type: String,
508 pub function: FunctionCall,
509 #[serde(skip_serializing_if = "Option::is_none")]
511 pub metadata: Option<serde_json::Value>,
512}
513
514#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
516pub struct FunctionCall {
517 pub name: String,
518 pub arguments: String,
519}
520
521#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
523pub enum ToolCallResultStatus {
524 Success,
525 Error,
526 Cancelled,
527}
528
529#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
531pub struct ToolCallResult {
532 pub call: ToolCall,
533 pub result: String,
534 pub status: ToolCallResultStatus,
535}
536
537#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
539pub struct ToolCallStreamInfo {
540 pub name: String,
542 pub args_tokens: usize,
544 #[serde(skip_serializing_if = "Option::is_none")]
546 pub description: Option<String>,
547}
548
549#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
551pub struct ToolCallResultProgress {
552 pub id: Uuid,
553 pub message: String,
554 #[serde(skip_serializing_if = "Option::is_none")]
556 pub progress_type: Option<ProgressType>,
557 #[serde(skip_serializing_if = "Option::is_none")]
559 pub task_updates: Option<Vec<TaskUpdate>>,
560 #[serde(skip_serializing_if = "Option::is_none")]
562 pub progress: Option<f64>,
563}
564
565#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
567pub enum ProgressType {
568 CommandOutput,
570 TaskWait,
572 Generic,
574}
575
576#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
578pub struct TaskUpdate {
579 pub task_id: String,
580 pub status: String,
581 #[serde(skip_serializing_if = "Option::is_none")]
582 pub description: Option<String>,
583 #[serde(skip_serializing_if = "Option::is_none")]
584 pub duration_secs: Option<f64>,
585 #[serde(skip_serializing_if = "Option::is_none")]
586 pub output_preview: Option<String>,
587 #[serde(default)]
589 pub is_target: bool,
590 #[serde(skip_serializing_if = "Option::is_none")]
592 pub pause_info: Option<TaskPauseInfo>,
593}
594
595#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
597pub struct TaskPauseInfo {
598 #[serde(skip_serializing_if = "Option::is_none")]
600 pub agent_message: Option<String>,
601 #[serde(skip_serializing_if = "Option::is_none")]
603 pub pending_tool_calls: Option<Vec<crate::models::async_manifest::PendingToolCall>>,
604}
605
606pub use crate::models::tools::ask_user::{
607 AskUserAnswer, AskUserOption, AskUserQuestion, AskUserRequest, AskUserResult,
608};
609
610#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
612pub struct ChatCompletionRequest {
613 pub model: String,
614 pub messages: Vec<ChatMessage>,
615 #[serde(skip_serializing_if = "Option::is_none")]
616 pub frequency_penalty: Option<f32>,
617 #[serde(skip_serializing_if = "Option::is_none")]
618 pub logit_bias: Option<serde_json::Value>,
619 #[serde(skip_serializing_if = "Option::is_none")]
620 pub logprobs: Option<bool>,
621 #[serde(skip_serializing_if = "Option::is_none")]
622 pub max_tokens: Option<u32>,
623 #[serde(skip_serializing_if = "Option::is_none")]
624 pub n: Option<u32>,
625 #[serde(skip_serializing_if = "Option::is_none")]
626 pub presence_penalty: Option<f32>,
627 #[serde(skip_serializing_if = "Option::is_none")]
628 pub response_format: Option<ResponseFormat>,
629 #[serde(skip_serializing_if = "Option::is_none")]
630 pub seed: Option<i64>,
631 #[serde(skip_serializing_if = "Option::is_none")]
632 pub stop: Option<StopSequence>,
633 #[serde(skip_serializing_if = "Option::is_none")]
634 pub stream: Option<bool>,
635 #[serde(skip_serializing_if = "Option::is_none")]
636 pub temperature: Option<f32>,
637 #[serde(skip_serializing_if = "Option::is_none")]
638 pub top_p: Option<f32>,
639 #[serde(skip_serializing_if = "Option::is_none")]
640 pub tools: Option<Vec<Tool>>,
641 #[serde(skip_serializing_if = "Option::is_none")]
642 pub tool_choice: Option<ToolChoice>,
643 #[serde(skip_serializing_if = "Option::is_none")]
644 pub user: Option<String>,
645 #[serde(skip_serializing_if = "Option::is_none")]
646 pub context: Option<ChatCompletionContext>,
647}
648
649impl ChatCompletionRequest {
650 pub fn new(
651 model: String,
652 messages: Vec<ChatMessage>,
653 tools: Option<Vec<Tool>>,
654 stream: Option<bool>,
655 ) -> Self {
656 Self {
657 model,
658 messages,
659 frequency_penalty: None,
660 logit_bias: None,
661 logprobs: None,
662 max_tokens: None,
663 n: None,
664 presence_penalty: None,
665 response_format: None,
666 seed: None,
667 stop: None,
668 stream,
669 temperature: None,
670 top_p: None,
671 tools,
672 tool_choice: None,
673 user: None,
674 context: None,
675 }
676 }
677}
678
679#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
680pub struct ChatCompletionContext {
681 pub scratchpad: Option<Value>,
682}
683
684#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
685pub struct ResponseFormat {
686 pub r#type: String,
687}
688
689#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
690#[serde(untagged)]
691pub enum StopSequence {
692 String(String),
693 Array(Vec<String>),
694}
695
696#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
698pub struct ChatCompletionResponse {
699 pub id: String,
700 pub object: String,
701 pub created: u64,
702 pub model: String,
703 pub choices: Vec<ChatCompletionChoice>,
704 pub usage: LLMTokenUsage,
705 #[serde(skip_serializing_if = "Option::is_none")]
706 pub system_fingerprint: Option<String>,
707 pub metadata: Option<serde_json::Value>,
708}
709
710#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
711pub struct ChatCompletionChoice {
712 pub index: usize,
713 pub message: ChatMessage,
714 pub logprobs: Option<LogProbs>,
715 pub finish_reason: FinishReason,
716}
717
718#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
719#[serde(rename_all = "snake_case")]
720pub enum FinishReason {
721 Stop,
722 Length,
723 ContentFilter,
724 ToolCalls,
725}
726
727#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
728pub struct LogProbs {
729 pub content: Option<Vec<LogProbContent>>,
730}
731
732#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
733pub struct LogProbContent {
734 pub token: String,
735 pub logprob: f32,
736 pub bytes: Option<Vec<u8>>,
737 pub top_logprobs: Option<Vec<TokenLogprob>>,
738}
739
740#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
741pub struct TokenLogprob {
742 pub token: String,
743 pub logprob: f32,
744 pub bytes: Option<Vec<u8>>,
745}
746
747#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
752pub struct ChatCompletionStreamResponse {
753 pub id: String,
754 pub object: String,
755 pub created: u64,
756 pub model: String,
757 pub choices: Vec<ChatCompletionStreamChoice>,
758 pub usage: Option<LLMTokenUsage>,
759 pub metadata: Option<serde_json::Value>,
760}
761
762#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
763pub struct ChatCompletionStreamChoice {
764 pub index: usize,
765 pub delta: ChatMessageDelta,
766 pub finish_reason: Option<FinishReason>,
767}
768
769#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
770pub struct ChatMessageDelta {
771 #[serde(skip_serializing_if = "Option::is_none")]
772 pub role: Option<Role>,
773 #[serde(skip_serializing_if = "Option::is_none")]
774 pub content: Option<String>,
775 #[serde(skip_serializing_if = "Option::is_none")]
776 pub tool_calls: Option<Vec<ToolCallDelta>>,
777}
778
779#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
780pub struct ToolCallDelta {
781 pub index: usize,
782 pub id: Option<String>,
783 pub r#type: Option<String>,
784 pub function: Option<FunctionCallDelta>,
785 #[serde(skip_serializing_if = "Option::is_none")]
787 pub metadata: Option<serde_json::Value>,
788}
789
790#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
791pub struct FunctionCallDelta {
792 pub name: Option<String>,
793 pub arguments: Option<String>,
794}
795
796impl From<LLMMessage> for ChatMessage {
801 fn from(llm_message: LLMMessage) -> Self {
802 let role = match llm_message.role.as_str() {
803 "system" => Role::System,
804 "user" => Role::User,
805 "assistant" => Role::Assistant,
806 "tool" => Role::Tool,
807 "developer" => Role::Developer,
808 _ => Role::User,
809 };
810
811 let (content, tool_calls, tool_call_id) = match llm_message.content {
812 LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None, None),
813 LLMMessageContent::List(items) => {
814 let mut text_parts = Vec::new();
815 let mut tool_call_parts = Vec::new();
816 let mut tool_result_id: Option<String> = None;
817
818 for item in items {
819 match item {
820 LLMMessageTypedContent::Text { text } => {
821 text_parts.push(ContentPart {
822 r#type: "text".to_string(),
823 text: Some(text),
824 image_url: None,
825 });
826 }
827 LLMMessageTypedContent::ToolCall {
828 id,
829 name,
830 args,
831 metadata,
832 } => {
833 tool_call_parts.push(ToolCall {
834 id,
835 r#type: "function".to_string(),
836 function: FunctionCall {
837 name,
838 arguments: args.to_string(),
839 },
840 metadata,
841 });
842 }
843 LLMMessageTypedContent::ToolResult {
844 tool_use_id,
845 content,
846 } => {
847 if tool_result_id.is_none() {
848 tool_result_id = Some(tool_use_id);
849 }
850 text_parts.push(ContentPart {
851 r#type: "text".to_string(),
852 text: Some(content),
853 image_url: None,
854 });
855 }
856 LLMMessageTypedContent::Image { source } => {
857 text_parts.push(ContentPart {
858 r#type: "image_url".to_string(),
859 text: None,
860 image_url: Some(ImageUrl {
861 url: format!(
862 "data:{};base64,{}",
863 source.media_type, source.data
864 ),
865 detail: None,
866 }),
867 });
868 }
869 }
870 }
871
872 let content = if !text_parts.is_empty() {
873 Some(MessageContent::Array(text_parts))
874 } else {
875 None
876 };
877
878 let tool_calls = if !tool_call_parts.is_empty() {
879 Some(tool_call_parts)
880 } else {
881 None
882 };
883
884 (content, tool_calls, tool_result_id)
885 }
886 };
887
888 ChatMessage {
889 role,
890 content,
891 name: None,
892 tool_calls,
893 tool_call_id,
894 usage: None,
895 ..Default::default()
896 }
897 }
898}
899
900impl From<ChatMessage> for LLMMessage {
901 fn from(chat_message: ChatMessage) -> Self {
902 let mut content_parts = Vec::new();
903
904 match chat_message.content {
905 Some(MessageContent::String(s)) => {
906 if !s.is_empty() {
907 content_parts.push(LLMMessageTypedContent::Text { text: s });
908 }
909 }
910 Some(MessageContent::Array(parts)) => {
911 for part in parts {
912 if let Some(text) = part.text {
913 content_parts.push(LLMMessageTypedContent::Text { text });
914 } else if let Some(image_url) = part.image_url {
915 let (media_type, data) = if image_url.url.starts_with("data:") {
916 let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
917 if parts.len() == 2 {
918 let meta = parts[0];
919 let data = parts[1];
920 let media_type = meta
921 .trim_start_matches("data:")
922 .trim_end_matches(";base64")
923 .to_string();
924 (media_type, data.to_string())
925 } else {
926 ("image/jpeg".to_string(), image_url.url)
927 }
928 } else {
929 ("image/jpeg".to_string(), image_url.url)
930 };
931
932 content_parts.push(LLMMessageTypedContent::Image {
933 source: LLMMessageImageSource {
934 r#type: "base64".to_string(),
935 media_type,
936 data,
937 },
938 });
939 }
940 }
941 }
942 None => {}
943 }
944
945 if let Some(tool_calls) = chat_message.tool_calls {
946 for tool_call in tool_calls {
947 let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
948 content_parts.push(LLMMessageTypedContent::ToolCall {
949 id: tool_call.id,
950 name: tool_call.function.name,
951 args,
952 metadata: tool_call.metadata,
953 });
954 }
955 }
956
957 if chat_message.role == Role::Tool
962 && let Some(tool_call_id) = chat_message.tool_call_id
963 {
964 let content_str = content_parts
966 .iter()
967 .filter_map(|p| match p {
968 LLMMessageTypedContent::Text { text } => Some(text.clone()),
969 _ => None,
970 })
971 .collect::<Vec<_>>()
972 .join("\n");
973
974 content_parts = vec![LLMMessageTypedContent::ToolResult {
976 tool_use_id: tool_call_id,
977 content: content_str,
978 }];
979 }
980
981 LLMMessage {
982 role: chat_message.role.to_string(),
983 content: if content_parts.is_empty() {
984 LLMMessageContent::String(String::new())
985 } else if content_parts.len() == 1 {
986 match &content_parts[0] {
987 LLMMessageTypedContent::Text { text } => {
988 LLMMessageContent::String(text.clone())
989 }
990 _ => LLMMessageContent::List(content_parts),
991 }
992 } else {
993 LLMMessageContent::List(content_parts)
994 },
995 }
996 }
997}
998
999impl From<GenerationDelta> for ChatMessageDelta {
1000 fn from(delta: GenerationDelta) -> Self {
1001 match delta {
1002 GenerationDelta::Content { content } => ChatMessageDelta {
1003 role: Some(Role::Assistant),
1004 content: Some(content),
1005 tool_calls: None,
1006 },
1007 GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1008 role: Some(Role::Assistant),
1009 content: None,
1010 tool_calls: None,
1011 },
1012 GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1013 role: Some(Role::Assistant),
1014 content: None,
1015 tool_calls: Some(vec![ToolCallDelta {
1016 index: tool_use.index,
1017 id: tool_use.id,
1018 r#type: Some("function".to_string()),
1019 function: Some(FunctionCallDelta {
1020 name: tool_use.name,
1021 arguments: tool_use.input,
1022 }),
1023 metadata: tool_use.metadata,
1024 }]),
1025 },
1026 _ => ChatMessageDelta {
1027 role: Some(Role::Assistant),
1028 content: None,
1029 tool_calls: None,
1030 },
1031 }
1032 }
1033}
1034
1035#[cfg(test)]
1036mod tests {
1037 use super::*;
1038
1039 #[test]
1040 fn test_serialize_basic_request() {
1041 let request = ChatCompletionRequest {
1042 model: "gpt-4".to_string(),
1043 messages: vec![
1044 ChatMessage {
1045 role: Role::System,
1046 content: Some(MessageContent::String(
1047 "You are a helpful assistant.".to_string(),
1048 )),
1049 name: None,
1050 tool_calls: None,
1051 tool_call_id: None,
1052 usage: None,
1053 ..Default::default()
1054 },
1055 ChatMessage {
1056 role: Role::User,
1057 content: Some(MessageContent::String("Hello!".to_string())),
1058 name: None,
1059 tool_calls: None,
1060 tool_call_id: None,
1061 usage: None,
1062 ..Default::default()
1063 },
1064 ],
1065 frequency_penalty: None,
1066 logit_bias: None,
1067 logprobs: None,
1068 max_tokens: Some(100),
1069 n: None,
1070 presence_penalty: None,
1071 response_format: None,
1072 seed: None,
1073 stop: None,
1074 stream: None,
1075 temperature: Some(0.7),
1076 top_p: None,
1077 tools: None,
1078 tool_choice: None,
1079 user: None,
1080 context: None,
1081 };
1082
1083 let json = serde_json::to_string(&request).unwrap();
1084 assert!(json.contains("\"model\":\"gpt-4\""));
1085 assert!(json.contains("\"messages\":["));
1086 assert!(json.contains("\"role\":\"system\""));
1087 }
1088
1089 #[test]
1090 fn test_llm_message_to_chat_message() {
1091 let llm_message = LLMMessage {
1092 role: "user".to_string(),
1093 content: LLMMessageContent::String("Hello, world!".to_string()),
1094 };
1095
1096 let chat_message = ChatMessage::from(llm_message);
1097 assert_eq!(chat_message.role, Role::User);
1098 match &chat_message.content {
1099 Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1100 _ => panic!("Expected string content"),
1101 }
1102 }
1103
1104 #[test]
1105 fn test_llm_message_to_chat_message_tool_result_preserves_tool_call_id() {
1106 let llm_message = LLMMessage {
1107 role: "tool".to_string(),
1108 content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolResult {
1109 tool_use_id: "toolu_01Abc123".to_string(),
1110 content: "Tool execution result".to_string(),
1111 }]),
1112 };
1113
1114 let chat_message = ChatMessage::from(llm_message);
1115 assert_eq!(chat_message.role, Role::Tool);
1116 assert_eq!(chat_message.tool_call_id.as_deref(), Some("toolu_01Abc123"));
1117 assert_eq!(
1118 chat_message.content,
1119 Some(MessageContent::Array(vec![ContentPart {
1120 r#type: "text".to_string(),
1121 text: Some("Tool execution result".to_string()),
1122 image_url: None,
1123 }]))
1124 );
1125 }
1126
1127 #[test]
1128 fn test_chat_message_to_llm_message_tool_result() {
1129 let chat_message = ChatMessage {
1133 role: Role::Tool,
1134 content: Some(MessageContent::String("Tool execution result".to_string())),
1135 name: None,
1136 tool_calls: None,
1137 tool_call_id: Some("toolu_01Abc123".to_string()),
1138 usage: None,
1139 ..Default::default()
1140 };
1141
1142 let llm_message: LLMMessage = chat_message.into();
1143
1144 assert_eq!(llm_message.role, "tool");
1146
1147 match &llm_message.content {
1149 LLMMessageContent::List(parts) => {
1150 assert_eq!(parts.len(), 1, "Should have exactly one content part");
1151 match &parts[0] {
1152 LLMMessageTypedContent::ToolResult {
1153 tool_use_id,
1154 content,
1155 } => {
1156 assert_eq!(tool_use_id, "toolu_01Abc123");
1157 assert_eq!(content, "Tool execution result");
1158 }
1159 _ => panic!("Expected ToolResult content part, got {:?}", parts[0]),
1160 }
1161 }
1162 _ => panic!(
1163 "Expected List content with ToolResult, got {:?}",
1164 llm_message.content
1165 ),
1166 }
1167 }
1168
1169 #[test]
1170 fn test_chat_message_to_llm_message_tool_result_empty_content() {
1171 let chat_message = ChatMessage {
1173 role: Role::Tool,
1174 content: None,
1175 name: None,
1176 tool_calls: None,
1177 tool_call_id: Some("toolu_02Xyz789".to_string()),
1178 usage: None,
1179 ..Default::default()
1180 };
1181
1182 let llm_message: LLMMessage = chat_message.into();
1183
1184 assert_eq!(llm_message.role, "tool");
1185 match &llm_message.content {
1186 LLMMessageContent::List(parts) => {
1187 assert_eq!(parts.len(), 1);
1188 match &parts[0] {
1189 LLMMessageTypedContent::ToolResult {
1190 tool_use_id,
1191 content,
1192 } => {
1193 assert_eq!(tool_use_id, "toolu_02Xyz789");
1194 assert_eq!(content, ""); }
1196 _ => panic!("Expected ToolResult content part"),
1197 }
1198 }
1199 _ => panic!("Expected List content with ToolResult"),
1200 }
1201 }
1202
1203 #[test]
1204 fn test_chat_message_to_llm_message_assistant_with_tool_calls() {
1205 let chat_message = ChatMessage {
1207 role: Role::Assistant,
1208 content: Some(MessageContent::String(
1209 "I'll help you with that.".to_string(),
1210 )),
1211 name: None,
1212 tool_calls: Some(vec![ToolCall {
1213 id: "call_abc123".to_string(),
1214 r#type: "function".to_string(),
1215 function: FunctionCall {
1216 name: "get_weather".to_string(),
1217 arguments: r#"{"location": "Paris"}"#.to_string(),
1218 },
1219 metadata: None,
1220 }]),
1221 tool_call_id: None,
1222 usage: None,
1223 ..Default::default()
1224 };
1225
1226 let llm_message: LLMMessage = chat_message.into();
1227
1228 assert_eq!(llm_message.role, "assistant");
1229 match &llm_message.content {
1230 LLMMessageContent::List(parts) => {
1231 assert_eq!(parts.len(), 2, "Should have text and tool call");
1232
1233 match &parts[0] {
1235 LLMMessageTypedContent::Text { text } => {
1236 assert_eq!(text, "I'll help you with that.");
1237 }
1238 _ => panic!("Expected Text content part first"),
1239 }
1240
1241 match &parts[1] {
1243 LLMMessageTypedContent::ToolCall { id, name, args, .. } => {
1244 assert_eq!(id, "call_abc123");
1245 assert_eq!(name, "get_weather");
1246 assert_eq!(args["location"], "Paris");
1247 }
1248 _ => panic!("Expected ToolCall content part second"),
1249 }
1250 }
1251 _ => panic!("Expected List content"),
1252 }
1253 }
1254
1255 #[test]
1256 fn test_extract_chatgpt_account_id_from_access_token() {
1257 use base64::Engine;
1258
1259 let claim = json!({
1260 "chatgpt_account_id": "acct_test_123"
1261 });
1262 let payload = json!({
1263 "https://api.openai.com/auth": claim
1264 });
1265 let encoded_payload =
1266 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1267 let access_token = format!("header.{}.signature", encoded_payload);
1268
1269 assert_eq!(
1270 OpenAIConfig::extract_chatgpt_account_id(&access_token),
1271 Some("acct_test_123".to_string())
1272 );
1273 }
1274
1275 #[test]
1276 fn test_extract_chatgpt_account_id_returns_none_for_missing_claim() {
1277 use base64::Engine;
1278
1279 let payload = json!({
1280 "sub": "user_123"
1281 });
1282 let encoded_payload =
1283 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1284 let access_token = format!("header.{}.signature", encoded_payload);
1285
1286 assert_eq!(
1287 OpenAIConfig::extract_chatgpt_account_id(&access_token),
1288 None
1289 );
1290 }
1291
1292 #[test]
1293 fn test_extract_chatgpt_account_id_returns_none_for_invalid_token_shape() {
1294 assert_eq!(OpenAIConfig::extract_chatgpt_account_id("not-a-jwt"), None);
1295 }
1296
1297 #[test]
1298 fn test_extract_chatgpt_account_id_returns_none_for_invalid_claim_json() {
1299 use base64::Engine;
1300
1301 let payload = json!({
1302 "https://api.openai.com/auth": "{not-json}"
1303 });
1304 let encoded_payload =
1305 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1306 let access_token = format!("header.{}.signature", encoded_payload);
1307
1308 assert_eq!(
1309 OpenAIConfig::extract_chatgpt_account_id(&access_token),
1310 None
1311 );
1312 }
1313}