1use crate::models::llm::{
12 GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13 LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27 pub api_endpoint: Option<String>,
28 pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32 pub const OPENAI_CODEX_BASE_URL: &'static str = "https://chatgpt.com/backend-api/codex";
33 const OPENAI_AUTH_CLAIM: &'static str = "https://api.openai.com/auth";
34
35 pub fn with_api_key(api_key: impl Into<String>) -> Self {
37 Self {
38 api_key: Some(api_key.into()),
39 api_endpoint: None,
40 }
41 }
42
43 pub fn extract_chatgpt_account_id(access_token: &str) -> Option<String> {
49 let claims = crate::jwt::decode_jwt_payload_unverified(access_token)?;
50 let auth_claim = claims.get(Self::OPENAI_AUTH_CLAIM)?;
51
52 match auth_claim {
53 Value::Object(map) => map
54 .get("chatgpt_account_id")
55 .and_then(Value::as_str)
56 .map(ToString::to_string),
57 Value::String(raw_json) => {
58 serde_json::from_str::<Value>(raw_json)
59 .ok()
60 .and_then(|value| {
61 value
62 .get("chatgpt_account_id")
63 .and_then(Value::as_str)
64 .map(ToString::to_string)
65 })
66 }
67 _ => None,
68 }
69 }
70}
71
72#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
78pub enum OpenAIModel {
79 #[serde(rename = "o3-2025-04-16")]
81 O3,
82 #[serde(rename = "o4-mini-2025-04-16")]
83 O4Mini,
84
85 #[default]
86 #[serde(rename = "gpt-5-2025-08-07")]
87 GPT5,
88 #[serde(rename = "gpt-5.1-2025-11-13")]
89 GPT51,
90 #[serde(rename = "gpt-5-mini-2025-08-07")]
91 GPT5Mini,
92 #[serde(rename = "gpt-5-nano-2025-08-07")]
93 GPT5Nano,
94
95 Custom(String),
96}
97
98impl OpenAIModel {
99 pub fn from_string(s: &str) -> Result<Self, String> {
100 serde_json::from_value(serde_json::Value::String(s.to_string()))
101 .map_err(|_| "Failed to deserialize OpenAI model".to_string())
102 }
103
104 pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
106
107 pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
109
110 pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
112
113 pub fn default_smart_model() -> String {
115 Self::DEFAULT_SMART_MODEL.to_string()
116 }
117
118 pub fn default_eco_model() -> String {
120 Self::DEFAULT_ECO_MODEL.to_string()
121 }
122
123 pub fn default_recovery_model() -> String {
125 Self::DEFAULT_RECOVERY_MODEL.to_string()
126 }
127}
128
129impl ContextAware for OpenAIModel {
130 fn context_info(&self) -> ModelContextInfo {
131 let model_name = self.to_string();
132
133 if model_name.starts_with("o3") {
134 return ModelContextInfo {
135 max_tokens: 200_000,
136 pricing_tiers: vec![ContextPricingTier {
137 label: "Standard".to_string(),
138 input_cost_per_million: 2.0,
139 output_cost_per_million: 8.0,
140 upper_bound: None,
141 }],
142 approach_warning_threshold: 0.8,
143 };
144 }
145
146 if model_name.starts_with("o4-mini") {
147 return ModelContextInfo {
148 max_tokens: 200_000,
149 pricing_tiers: vec![ContextPricingTier {
150 label: "Standard".to_string(),
151 input_cost_per_million: 1.10,
152 output_cost_per_million: 4.40,
153 upper_bound: None,
154 }],
155 approach_warning_threshold: 0.8,
156 };
157 }
158
159 if model_name.starts_with("gpt-5-mini") {
160 return ModelContextInfo {
161 max_tokens: 400_000,
162 pricing_tiers: vec![ContextPricingTier {
163 label: "Standard".to_string(),
164 input_cost_per_million: 0.25,
165 output_cost_per_million: 2.0,
166 upper_bound: None,
167 }],
168 approach_warning_threshold: 0.8,
169 };
170 }
171
172 if model_name.starts_with("gpt-5-nano") {
173 return ModelContextInfo {
174 max_tokens: 400_000,
175 pricing_tiers: vec![ContextPricingTier {
176 label: "Standard".to_string(),
177 input_cost_per_million: 0.05,
178 output_cost_per_million: 0.40,
179 upper_bound: None,
180 }],
181 approach_warning_threshold: 0.8,
182 };
183 }
184
185 if model_name.starts_with("gpt-5") {
186 return ModelContextInfo {
187 max_tokens: 400_000,
188 pricing_tiers: vec![ContextPricingTier {
189 label: "Standard".to_string(),
190 input_cost_per_million: 1.25,
191 output_cost_per_million: 10.0,
192 upper_bound: None,
193 }],
194 approach_warning_threshold: 0.8,
195 };
196 }
197
198 ModelContextInfo::default()
199 }
200
201 fn model_name(&self) -> String {
202 match self {
203 OpenAIModel::O3 => "O3".to_string(),
204 OpenAIModel::O4Mini => "O4-mini".to_string(),
205 OpenAIModel::GPT5 => "GPT-5".to_string(),
206 OpenAIModel::GPT51 => "GPT-5.1".to_string(),
207 OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
208 OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
209 OpenAIModel::Custom(name) => format!("Custom ({})", name),
210 }
211 }
212}
213
214impl std::fmt::Display for OpenAIModel {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 match self {
217 OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
218 OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
219 OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
220 OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
221 OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
222 OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
223 OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
224 }
225 }
226}
227
228#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
234#[serde(rename_all = "lowercase")]
235pub enum Role {
236 System,
237 Developer,
238 User,
239 #[default]
240 Assistant,
241 Tool,
242}
243
244impl std::fmt::Display for Role {
245 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 match self {
247 Role::System => write!(f, "system"),
248 Role::Developer => write!(f, "developer"),
249 Role::User => write!(f, "user"),
250 Role::Assistant => write!(f, "assistant"),
251 Role::Tool => write!(f, "tool"),
252 }
253 }
254}
255
256#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
258pub struct ModelInfo {
259 pub provider: String,
261 pub id: String,
263}
264
265#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
267pub struct ChatMessage {
268 pub role: Role,
269 pub content: Option<MessageContent>,
270 #[serde(skip_serializing_if = "Option::is_none")]
271 pub name: Option<String>,
272 #[serde(skip_serializing_if = "Option::is_none")]
273 pub tool_calls: Option<Vec<ToolCall>>,
274 #[serde(skip_serializing_if = "Option::is_none")]
275 pub tool_call_id: Option<String>,
276 #[serde(skip_serializing_if = "Option::is_none")]
277 pub usage: Option<LLMTokenUsage>,
278
279 #[serde(skip_serializing_if = "Option::is_none")]
282 pub id: Option<String>,
283 #[serde(skip_serializing_if = "Option::is_none")]
285 pub model: Option<ModelInfo>,
286 #[serde(skip_serializing_if = "Option::is_none")]
288 pub cost: Option<f64>,
289 #[serde(skip_serializing_if = "Option::is_none")]
291 pub finish_reason: Option<String>,
292 #[serde(skip_serializing_if = "Option::is_none")]
294 pub created_at: Option<i64>,
295 #[serde(skip_serializing_if = "Option::is_none")]
297 pub completed_at: Option<i64>,
298 #[serde(skip_serializing_if = "Option::is_none")]
300 pub metadata: Option<serde_json::Value>,
301}
302
303impl ChatMessage {
304 pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
305 messages
306 .iter()
307 .rev()
308 .find(|message| message.role != Role::User && message.role != Role::Tool)
309 }
310
311 pub fn to_xml(&self) -> String {
312 match &self.content {
313 Some(MessageContent::String(s)) => {
314 format!("<message role=\"{}\">{}</message>", self.role, s)
315 }
316 Some(MessageContent::Array(parts)) => parts
317 .iter()
318 .map(|part| {
319 format!(
320 "<message role=\"{}\" type=\"{}\">{}</message>",
321 self.role,
322 part.r#type,
323 part.text.clone().unwrap_or_default()
324 )
325 })
326 .collect::<Vec<String>>()
327 .join("\n"),
328 None => String::new(),
329 }
330 }
331}
332
333#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
335#[serde(untagged)]
336pub enum MessageContent {
337 String(String),
338 Array(Vec<ContentPart>),
339}
340
341impl MessageContent {
342 pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
343 match self {
344 MessageContent::String(s) => MessageContent::String(format!(
345 "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
346 )),
347 MessageContent::Array(parts) => MessageContent::Array(
348 std::iter::once(ContentPart {
349 r#type: "text".to_string(),
350 text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
351 image_url: None,
352 })
353 .chain(parts.iter().cloned())
354 .collect(),
355 ),
356 }
357 }
358
359 #[allow(clippy::string_slice)]
361 pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
362 match self {
363 MessageContent::String(s) => s
364 .rfind("<checkpoint_id>")
365 .and_then(|start| {
366 s[start..]
367 .find("</checkpoint_id>")
368 .map(|end| (start + "<checkpoint_id>".len(), start + end))
369 })
370 .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
371 MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
372 part.text.as_deref().and_then(|text| {
373 text.rfind("<checkpoint_id>")
374 .and_then(|start| {
375 text[start..]
376 .find("</checkpoint_id>")
377 .map(|end| (start + "<checkpoint_id>".len(), start + end))
378 })
379 .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
380 })
381 }),
382 }
383 }
384}
385
386impl std::fmt::Display for MessageContent {
387 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 match self {
389 MessageContent::String(s) => write!(f, "{s}"),
390 MessageContent::Array(parts) => {
391 let text_parts: Vec<String> =
392 parts.iter().filter_map(|part| part.text.clone()).collect();
393 write!(f, "{}", text_parts.join("\n"))
394 }
395 }
396 }
397}
398
399impl Default for MessageContent {
400 fn default() -> Self {
401 MessageContent::String(String::new())
402 }
403}
404
405#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
407pub struct ContentPart {
408 pub r#type: String,
409 #[serde(skip_serializing_if = "Option::is_none")]
410 pub text: Option<String>,
411 #[serde(skip_serializing_if = "Option::is_none")]
412 pub image_url: Option<ImageUrl>,
413}
414
415#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
417pub struct ImageUrl {
418 pub url: String,
419 #[serde(skip_serializing_if = "Option::is_none")]
420 pub detail: Option<String>,
421}
422
423#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
429pub struct Tool {
430 pub r#type: String,
431 pub function: FunctionDefinition,
432}
433
434#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
436pub struct FunctionDefinition {
437 pub name: String,
438 pub description: Option<String>,
439 pub parameters: serde_json::Value,
440}
441
442impl From<Tool> for LLMTool {
443 fn from(tool: Tool) -> Self {
444 LLMTool {
445 name: tool.function.name,
446 description: tool.function.description.unwrap_or_default(),
447 input_schema: tool.function.parameters,
448 }
449 }
450}
451
452#[derive(Debug, Clone, PartialEq)]
454pub enum ToolChoice {
455 Auto,
456 Required,
457 Object(ToolChoiceObject),
458}
459
460impl Serialize for ToolChoice {
461 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
462 where
463 S: serde::Serializer,
464 {
465 match self {
466 ToolChoice::Auto => serializer.serialize_str("auto"),
467 ToolChoice::Required => serializer.serialize_str("required"),
468 ToolChoice::Object(obj) => obj.serialize(serializer),
469 }
470 }
471}
472
473impl<'de> Deserialize<'de> for ToolChoice {
474 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
475 where
476 D: serde::Deserializer<'de>,
477 {
478 struct ToolChoiceVisitor;
479
480 impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
481 type Value = ToolChoice;
482
483 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
484 formatter.write_str("string or object")
485 }
486
487 fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
488 where
489 E: serde::de::Error,
490 {
491 match value {
492 "auto" => Ok(ToolChoice::Auto),
493 "required" => Ok(ToolChoice::Required),
494 _ => Err(serde::de::Error::unknown_variant(
495 value,
496 &["auto", "required"],
497 )),
498 }
499 }
500
501 fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
502 where
503 M: serde::de::MapAccess<'de>,
504 {
505 let obj = ToolChoiceObject::deserialize(
506 serde::de::value::MapAccessDeserializer::new(map),
507 )?;
508 Ok(ToolChoice::Object(obj))
509 }
510 }
511
512 deserializer.deserialize_any(ToolChoiceVisitor)
513 }
514}
515
516#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
517pub struct ToolChoiceObject {
518 pub r#type: String,
519 pub function: FunctionChoice,
520}
521
522#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
523pub struct FunctionChoice {
524 pub name: String,
525}
526
527#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
529pub struct ToolCall {
530 pub id: String,
531 pub r#type: String,
532 pub function: FunctionCall,
533 #[serde(skip_serializing_if = "Option::is_none")]
535 pub metadata: Option<serde_json::Value>,
536}
537
538#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
540pub struct FunctionCall {
541 pub name: String,
542 pub arguments: String,
543}
544
545#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
547pub enum ToolCallResultStatus {
548 Success,
549 Error,
550 Cancelled,
551}
552
553#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
555pub struct ToolCallResult {
556 pub call: ToolCall,
557 pub result: String,
558 pub status: ToolCallResultStatus,
559}
560
561#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
563pub struct ToolCallStreamInfo {
564 pub name: String,
566 pub args_tokens: usize,
568 #[serde(skip_serializing_if = "Option::is_none")]
570 pub description: Option<String>,
571}
572
573#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
575pub struct ToolCallResultProgress {
576 pub id: Uuid,
577 pub message: String,
578 #[serde(skip_serializing_if = "Option::is_none")]
580 pub progress_type: Option<ProgressType>,
581 #[serde(skip_serializing_if = "Option::is_none")]
583 pub task_updates: Option<Vec<TaskUpdate>>,
584 #[serde(skip_serializing_if = "Option::is_none")]
586 pub progress: Option<f64>,
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
591pub enum ProgressType {
592 CommandOutput,
594 TaskWait,
596 Generic,
598}
599
600#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
602pub struct TaskUpdate {
603 pub task_id: String,
604 pub status: String,
605 #[serde(skip_serializing_if = "Option::is_none")]
606 pub description: Option<String>,
607 #[serde(skip_serializing_if = "Option::is_none")]
608 pub duration_secs: Option<f64>,
609 #[serde(skip_serializing_if = "Option::is_none")]
610 pub output_preview: Option<String>,
611 #[serde(default)]
613 pub is_target: bool,
614 #[serde(skip_serializing_if = "Option::is_none")]
616 pub pause_info: Option<TaskPauseInfo>,
617}
618
619#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
621pub struct TaskPauseInfo {
622 #[serde(skip_serializing_if = "Option::is_none")]
624 pub agent_message: Option<String>,
625 #[serde(skip_serializing_if = "Option::is_none")]
627 pub pending_tool_calls: Option<Vec<crate::models::async_manifest::PendingToolCall>>,
628}
629
630pub use crate::models::tools::ask_user::{
631 AskUserAnswer, AskUserOption, AskUserQuestion, AskUserRequest, AskUserResult,
632};
633
634#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
636pub struct ChatCompletionRequest {
637 pub model: String,
638 pub messages: Vec<ChatMessage>,
639 #[serde(skip_serializing_if = "Option::is_none")]
640 pub frequency_penalty: Option<f32>,
641 #[serde(skip_serializing_if = "Option::is_none")]
642 pub logit_bias: Option<serde_json::Value>,
643 #[serde(skip_serializing_if = "Option::is_none")]
644 pub logprobs: Option<bool>,
645 #[serde(skip_serializing_if = "Option::is_none")]
646 pub max_tokens: Option<u32>,
647 #[serde(skip_serializing_if = "Option::is_none")]
648 pub n: Option<u32>,
649 #[serde(skip_serializing_if = "Option::is_none")]
650 pub presence_penalty: Option<f32>,
651 #[serde(skip_serializing_if = "Option::is_none")]
652 pub response_format: Option<ResponseFormat>,
653 #[serde(skip_serializing_if = "Option::is_none")]
654 pub seed: Option<i64>,
655 #[serde(skip_serializing_if = "Option::is_none")]
656 pub stop: Option<StopSequence>,
657 #[serde(skip_serializing_if = "Option::is_none")]
658 pub stream: Option<bool>,
659 #[serde(skip_serializing_if = "Option::is_none")]
660 pub temperature: Option<f32>,
661 #[serde(skip_serializing_if = "Option::is_none")]
662 pub top_p: Option<f32>,
663 #[serde(skip_serializing_if = "Option::is_none")]
664 pub tools: Option<Vec<Tool>>,
665 #[serde(skip_serializing_if = "Option::is_none")]
666 pub tool_choice: Option<ToolChoice>,
667 #[serde(skip_serializing_if = "Option::is_none")]
668 pub user: Option<String>,
669 #[serde(skip_serializing_if = "Option::is_none")]
670 pub context: Option<ChatCompletionContext>,
671}
672
673impl ChatCompletionRequest {
674 pub fn new(
675 model: String,
676 messages: Vec<ChatMessage>,
677 tools: Option<Vec<Tool>>,
678 stream: Option<bool>,
679 ) -> Self {
680 Self {
681 model,
682 messages,
683 frequency_penalty: None,
684 logit_bias: None,
685 logprobs: None,
686 max_tokens: None,
687 n: None,
688 presence_penalty: None,
689 response_format: None,
690 seed: None,
691 stop: None,
692 stream,
693 temperature: None,
694 top_p: None,
695 tools,
696 tool_choice: None,
697 user: None,
698 context: None,
699 }
700 }
701}
702
703#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
704pub struct ChatCompletionContext {
705 pub scratchpad: Option<Value>,
706}
707
708#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
709pub struct ResponseFormat {
710 pub r#type: String,
711}
712
713#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
714#[serde(untagged)]
715pub enum StopSequence {
716 String(String),
717 Array(Vec<String>),
718}
719
720#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
722pub struct ChatCompletionResponse {
723 pub id: String,
724 pub object: String,
725 pub created: u64,
726 pub model: String,
727 pub choices: Vec<ChatCompletionChoice>,
728 pub usage: LLMTokenUsage,
729 #[serde(skip_serializing_if = "Option::is_none")]
730 pub system_fingerprint: Option<String>,
731 pub metadata: Option<serde_json::Value>,
732}
733
734#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
735pub struct ChatCompletionChoice {
736 pub index: usize,
737 pub message: ChatMessage,
738 pub logprobs: Option<LogProbs>,
739 pub finish_reason: FinishReason,
740}
741
742#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
743#[serde(rename_all = "snake_case")]
744pub enum FinishReason {
745 Stop,
746 Length,
747 ContentFilter,
748 ToolCalls,
749}
750
751#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
752pub struct LogProbs {
753 pub content: Option<Vec<LogProbContent>>,
754}
755
756#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
757pub struct LogProbContent {
758 pub token: String,
759 pub logprob: f32,
760 pub bytes: Option<Vec<u8>>,
761 pub top_logprobs: Option<Vec<TokenLogprob>>,
762}
763
764#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
765pub struct TokenLogprob {
766 pub token: String,
767 pub logprob: f32,
768 pub bytes: Option<Vec<u8>>,
769}
770
771#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
776pub struct ChatCompletionStreamResponse {
777 pub id: String,
778 pub object: String,
779 pub created: u64,
780 pub model: String,
781 pub choices: Vec<ChatCompletionStreamChoice>,
782 pub usage: Option<LLMTokenUsage>,
783 pub metadata: Option<serde_json::Value>,
784}
785
786#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
787pub struct ChatCompletionStreamChoice {
788 pub index: usize,
789 pub delta: ChatMessageDelta,
790 pub finish_reason: Option<FinishReason>,
791}
792
793#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
794pub struct ChatMessageDelta {
795 #[serde(skip_serializing_if = "Option::is_none")]
796 pub role: Option<Role>,
797 #[serde(skip_serializing_if = "Option::is_none")]
798 pub content: Option<String>,
799 #[serde(skip_serializing_if = "Option::is_none")]
800 pub tool_calls: Option<Vec<ToolCallDelta>>,
801}
802
803#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
804pub struct ToolCallDelta {
805 pub index: usize,
806 pub id: Option<String>,
807 pub r#type: Option<String>,
808 pub function: Option<FunctionCallDelta>,
809 #[serde(skip_serializing_if = "Option::is_none")]
811 pub metadata: Option<serde_json::Value>,
812}
813
814#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
815pub struct FunctionCallDelta {
816 pub name: Option<String>,
817 pub arguments: Option<String>,
818}
819
820impl From<LLMMessage> for ChatMessage {
825 fn from(llm_message: LLMMessage) -> Self {
826 let role = match llm_message.role.as_str() {
827 "system" => Role::System,
828 "user" => Role::User,
829 "assistant" => Role::Assistant,
830 "tool" => Role::Tool,
831 "developer" => Role::Developer,
832 _ => Role::User,
833 };
834
835 let (content, tool_calls, tool_call_id) = match llm_message.content {
836 LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None, None),
837 LLMMessageContent::List(items) => {
838 let mut text_parts = Vec::new();
839 let mut tool_call_parts = Vec::new();
840 let mut tool_result_id: Option<String> = None;
841
842 for item in items {
843 match item {
844 LLMMessageTypedContent::Text { text } => {
845 text_parts.push(ContentPart {
846 r#type: "text".to_string(),
847 text: Some(text),
848 image_url: None,
849 });
850 }
851 LLMMessageTypedContent::ToolCall {
852 id,
853 name,
854 args,
855 metadata,
856 } => {
857 tool_call_parts.push(ToolCall {
858 id,
859 r#type: "function".to_string(),
860 function: FunctionCall {
861 name,
862 arguments: args.to_string(),
863 },
864 metadata,
865 });
866 }
867 LLMMessageTypedContent::ToolResult {
868 tool_use_id,
869 content,
870 } => {
871 if tool_result_id.is_none() {
872 tool_result_id = Some(tool_use_id);
873 }
874 text_parts.push(ContentPart {
875 r#type: "text".to_string(),
876 text: Some(content),
877 image_url: None,
878 });
879 }
880 LLMMessageTypedContent::Image { source } => {
881 text_parts.push(ContentPart {
882 r#type: "image_url".to_string(),
883 text: None,
884 image_url: Some(ImageUrl {
885 url: format!(
886 "data:{};base64,{}",
887 source.media_type, source.data
888 ),
889 detail: None,
890 }),
891 });
892 }
893 }
894 }
895
896 let content = if !text_parts.is_empty() {
897 Some(MessageContent::Array(text_parts))
898 } else {
899 None
900 };
901
902 let tool_calls = if !tool_call_parts.is_empty() {
903 Some(tool_call_parts)
904 } else {
905 None
906 };
907
908 (content, tool_calls, tool_result_id)
909 }
910 };
911
912 ChatMessage {
913 role,
914 content,
915 name: None,
916 tool_calls,
917 tool_call_id,
918 usage: None,
919 ..Default::default()
920 }
921 }
922}
923
924impl From<ChatMessage> for LLMMessage {
925 fn from(chat_message: ChatMessage) -> Self {
926 let mut content_parts = Vec::new();
927
928 match chat_message.content {
929 Some(MessageContent::String(s)) => {
930 if !s.is_empty() {
931 content_parts.push(LLMMessageTypedContent::Text { text: s });
932 }
933 }
934 Some(MessageContent::Array(parts)) => {
935 for part in parts {
936 if let Some(text) = part.text {
937 content_parts.push(LLMMessageTypedContent::Text { text });
938 } else if let Some(image_url) = part.image_url {
939 let (media_type, data) = if image_url.url.starts_with("data:") {
940 let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
941 if parts.len() == 2 {
942 let meta = parts[0];
943 let data = parts[1];
944 let media_type = meta
945 .trim_start_matches("data:")
946 .trim_end_matches(";base64")
947 .to_string();
948 (media_type, data.to_string())
949 } else {
950 ("image/jpeg".to_string(), image_url.url)
951 }
952 } else {
953 ("image/jpeg".to_string(), image_url.url)
954 };
955
956 content_parts.push(LLMMessageTypedContent::Image {
957 source: LLMMessageImageSource {
958 r#type: "base64".to_string(),
959 media_type,
960 data,
961 },
962 });
963 }
964 }
965 }
966 None => {}
967 }
968
969 if let Some(tool_calls) = chat_message.tool_calls {
970 for tool_call in tool_calls {
971 let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
972 content_parts.push(LLMMessageTypedContent::ToolCall {
973 id: tool_call.id,
974 name: tool_call.function.name,
975 args,
976 metadata: tool_call.metadata,
977 });
978 }
979 }
980
981 if chat_message.role == Role::Tool
986 && let Some(tool_call_id) = chat_message.tool_call_id
987 {
988 let content_str = content_parts
990 .iter()
991 .filter_map(|p| match p {
992 LLMMessageTypedContent::Text { text } => Some(text.clone()),
993 _ => None,
994 })
995 .collect::<Vec<_>>()
996 .join("\n");
997
998 content_parts = vec![LLMMessageTypedContent::ToolResult {
1000 tool_use_id: tool_call_id,
1001 content: content_str,
1002 }];
1003 }
1004
1005 LLMMessage {
1006 role: chat_message.role.to_string(),
1007 content: if content_parts.is_empty() {
1008 LLMMessageContent::String(String::new())
1009 } else if content_parts.len() == 1 {
1010 match &content_parts[0] {
1011 LLMMessageTypedContent::Text { text } => {
1012 LLMMessageContent::String(text.clone())
1013 }
1014 _ => LLMMessageContent::List(content_parts),
1015 }
1016 } else {
1017 LLMMessageContent::List(content_parts)
1018 },
1019 }
1020 }
1021}
1022
1023impl From<GenerationDelta> for ChatMessageDelta {
1024 fn from(delta: GenerationDelta) -> Self {
1025 match delta {
1026 GenerationDelta::Content { content } => ChatMessageDelta {
1027 role: Some(Role::Assistant),
1028 content: Some(content),
1029 tool_calls: None,
1030 },
1031 GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1032 role: Some(Role::Assistant),
1033 content: None,
1034 tool_calls: None,
1035 },
1036 GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1037 role: Some(Role::Assistant),
1038 content: None,
1039 tool_calls: Some(vec![ToolCallDelta {
1040 index: tool_use.index,
1041 id: tool_use.id,
1042 r#type: Some("function".to_string()),
1043 function: Some(FunctionCallDelta {
1044 name: tool_use.name,
1045 arguments: tool_use.input,
1046 }),
1047 metadata: tool_use.metadata,
1048 }]),
1049 },
1050 _ => ChatMessageDelta {
1051 role: Some(Role::Assistant),
1052 content: None,
1053 tool_calls: None,
1054 },
1055 }
1056 }
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061 use super::*;
1062
1063 #[test]
1064 fn test_serialize_basic_request() {
1065 let request = ChatCompletionRequest {
1066 model: "gpt-4".to_string(),
1067 messages: vec![
1068 ChatMessage {
1069 role: Role::System,
1070 content: Some(MessageContent::String(
1071 "You are a helpful assistant.".to_string(),
1072 )),
1073 name: None,
1074 tool_calls: None,
1075 tool_call_id: None,
1076 usage: None,
1077 ..Default::default()
1078 },
1079 ChatMessage {
1080 role: Role::User,
1081 content: Some(MessageContent::String("Hello!".to_string())),
1082 name: None,
1083 tool_calls: None,
1084 tool_call_id: None,
1085 usage: None,
1086 ..Default::default()
1087 },
1088 ],
1089 frequency_penalty: None,
1090 logit_bias: None,
1091 logprobs: None,
1092 max_tokens: Some(100),
1093 n: None,
1094 presence_penalty: None,
1095 response_format: None,
1096 seed: None,
1097 stop: None,
1098 stream: None,
1099 temperature: Some(0.7),
1100 top_p: None,
1101 tools: None,
1102 tool_choice: None,
1103 user: None,
1104 context: None,
1105 };
1106
1107 let json = serde_json::to_string(&request).unwrap();
1108 assert!(json.contains("\"model\":\"gpt-4\""));
1109 assert!(json.contains("\"messages\":["));
1110 assert!(json.contains("\"role\":\"system\""));
1111 }
1112
1113 #[test]
1114 fn test_llm_message_to_chat_message() {
1115 let llm_message = LLMMessage {
1116 role: "user".to_string(),
1117 content: LLMMessageContent::String("Hello, world!".to_string()),
1118 };
1119
1120 let chat_message = ChatMessage::from(llm_message);
1121 assert_eq!(chat_message.role, Role::User);
1122 match &chat_message.content {
1123 Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1124 _ => panic!("Expected string content"),
1125 }
1126 }
1127
1128 #[test]
1129 fn test_llm_message_to_chat_message_tool_result_preserves_tool_call_id() {
1130 let llm_message = LLMMessage {
1131 role: "tool".to_string(),
1132 content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolResult {
1133 tool_use_id: "toolu_01Abc123".to_string(),
1134 content: "Tool execution result".to_string(),
1135 }]),
1136 };
1137
1138 let chat_message = ChatMessage::from(llm_message);
1139 assert_eq!(chat_message.role, Role::Tool);
1140 assert_eq!(chat_message.tool_call_id.as_deref(), Some("toolu_01Abc123"));
1141 assert_eq!(
1142 chat_message.content,
1143 Some(MessageContent::Array(vec![ContentPart {
1144 r#type: "text".to_string(),
1145 text: Some("Tool execution result".to_string()),
1146 image_url: None,
1147 }]))
1148 );
1149 }
1150
1151 #[test]
1152 fn test_chat_message_to_llm_message_tool_result() {
1153 let chat_message = ChatMessage {
1157 role: Role::Tool,
1158 content: Some(MessageContent::String("Tool execution result".to_string())),
1159 name: None,
1160 tool_calls: None,
1161 tool_call_id: Some("toolu_01Abc123".to_string()),
1162 usage: None,
1163 ..Default::default()
1164 };
1165
1166 let llm_message: LLMMessage = chat_message.into();
1167
1168 assert_eq!(llm_message.role, "tool");
1170
1171 match &llm_message.content {
1173 LLMMessageContent::List(parts) => {
1174 assert_eq!(parts.len(), 1, "Should have exactly one content part");
1175 match &parts[0] {
1176 LLMMessageTypedContent::ToolResult {
1177 tool_use_id,
1178 content,
1179 } => {
1180 assert_eq!(tool_use_id, "toolu_01Abc123");
1181 assert_eq!(content, "Tool execution result");
1182 }
1183 _ => panic!("Expected ToolResult content part, got {:?}", parts[0]),
1184 }
1185 }
1186 _ => panic!(
1187 "Expected List content with ToolResult, got {:?}",
1188 llm_message.content
1189 ),
1190 }
1191 }
1192
1193 #[test]
1194 fn test_chat_message_to_llm_message_tool_result_empty_content() {
1195 let chat_message = ChatMessage {
1197 role: Role::Tool,
1198 content: None,
1199 name: None,
1200 tool_calls: None,
1201 tool_call_id: Some("toolu_02Xyz789".to_string()),
1202 usage: None,
1203 ..Default::default()
1204 };
1205
1206 let llm_message: LLMMessage = chat_message.into();
1207
1208 assert_eq!(llm_message.role, "tool");
1209 match &llm_message.content {
1210 LLMMessageContent::List(parts) => {
1211 assert_eq!(parts.len(), 1);
1212 match &parts[0] {
1213 LLMMessageTypedContent::ToolResult {
1214 tool_use_id,
1215 content,
1216 } => {
1217 assert_eq!(tool_use_id, "toolu_02Xyz789");
1218 assert_eq!(content, ""); }
1220 _ => panic!("Expected ToolResult content part"),
1221 }
1222 }
1223 _ => panic!("Expected List content with ToolResult"),
1224 }
1225 }
1226
1227 #[test]
1228 fn test_chat_message_to_llm_message_assistant_with_tool_calls() {
1229 let chat_message = ChatMessage {
1231 role: Role::Assistant,
1232 content: Some(MessageContent::String(
1233 "I'll help you with that.".to_string(),
1234 )),
1235 name: None,
1236 tool_calls: Some(vec![ToolCall {
1237 id: "call_abc123".to_string(),
1238 r#type: "function".to_string(),
1239 function: FunctionCall {
1240 name: "get_weather".to_string(),
1241 arguments: r#"{"location": "Paris"}"#.to_string(),
1242 },
1243 metadata: None,
1244 }]),
1245 tool_call_id: None,
1246 usage: None,
1247 ..Default::default()
1248 };
1249
1250 let llm_message: LLMMessage = chat_message.into();
1251
1252 assert_eq!(llm_message.role, "assistant");
1253 match &llm_message.content {
1254 LLMMessageContent::List(parts) => {
1255 assert_eq!(parts.len(), 2, "Should have text and tool call");
1256
1257 match &parts[0] {
1259 LLMMessageTypedContent::Text { text } => {
1260 assert_eq!(text, "I'll help you with that.");
1261 }
1262 _ => panic!("Expected Text content part first"),
1263 }
1264
1265 match &parts[1] {
1267 LLMMessageTypedContent::ToolCall { id, name, args, .. } => {
1268 assert_eq!(id, "call_abc123");
1269 assert_eq!(name, "get_weather");
1270 assert_eq!(args["location"], "Paris");
1271 }
1272 _ => panic!("Expected ToolCall content part second"),
1273 }
1274 }
1275 _ => panic!("Expected List content"),
1276 }
1277 }
1278
1279 #[test]
1280 fn test_extract_chatgpt_account_id_from_access_token() {
1281 use base64::Engine;
1282
1283 let claim = json!({
1284 "chatgpt_account_id": "acct_test_123"
1285 });
1286 let payload = json!({
1287 "https://api.openai.com/auth": claim
1288 });
1289 let encoded_payload =
1290 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1291 let access_token = format!("header.{}.signature", encoded_payload);
1292
1293 assert_eq!(
1294 OpenAIConfig::extract_chatgpt_account_id(&access_token),
1295 Some("acct_test_123".to_string())
1296 );
1297 }
1298
1299 #[test]
1300 fn test_extract_chatgpt_account_id_returns_none_for_missing_claim() {
1301 use base64::Engine;
1302
1303 let payload = json!({
1304 "sub": "user_123"
1305 });
1306 let encoded_payload =
1307 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1308 let access_token = format!("header.{}.signature", encoded_payload);
1309
1310 assert_eq!(
1311 OpenAIConfig::extract_chatgpt_account_id(&access_token),
1312 None
1313 );
1314 }
1315
1316 #[test]
1317 fn test_extract_chatgpt_account_id_returns_none_for_invalid_token_shape() {
1318 assert_eq!(OpenAIConfig::extract_chatgpt_account_id("not-a-jwt"), None);
1319 }
1320
1321 #[test]
1322 fn test_extract_chatgpt_account_id_returns_none_for_invalid_claim_json() {
1323 use base64::Engine;
1324
1325 let payload = json!({
1326 "https://api.openai.com/auth": "{not-json}"
1327 });
1328 let encoded_payload =
1329 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1330 let access_token = format!("header.{}.signature", encoded_payload);
1331
1332 assert_eq!(
1333 OpenAIConfig::extract_chatgpt_account_id(&access_token),
1334 None
1335 );
1336 }
1337}