1use crate::models::error::{AgentError, BadRequestErrorMessage};
2use crate::models::llm::{
3 GenerationDelta, GenerationDeltaToolUse, LLMChoice, LLMCompletionResponse, LLMMessage,
4 LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent, LLMTokenUsage, LLMTool,
5};
6use futures_util::StreamExt;
7use reqwest_middleware::ClientBuilder;
8use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use serde_json::json;
12use uuid::Uuid;
13
14const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1/chat/completions";
15
16#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
17pub struct OpenAIConfig {
18 pub api_endpoint: Option<String>,
19 pub api_key: Option<String>,
20}
21
22#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
23pub enum OpenAIModel {
24 #[serde(rename = "o3-2025-04-16")]
26 O3,
27 #[serde(rename = "o4-mini-2025-04-16")]
28 O4Mini,
29
30 #[default]
31 #[serde(rename = "gpt-5-2025-08-07")]
32 GPT5,
33 #[serde(rename = "gpt-5-mini-2025-08-07")]
34 GPT5Mini,
35 #[serde(rename = "gpt-5-nano-2025-08-07")]
36 GPT5Nano,
37
38 Custom(String),
39}
40
41impl OpenAIModel {
42 pub fn from_string(s: &str) -> Result<Self, String> {
43 serde_json::from_value(serde_json::Value::String(s.to_string()))
44 .map_err(|_| "Failed to deserialize OpenAI model".to_string())
45 }
46
47 pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
49
50 pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
52
53 pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
55
56 pub fn default_smart_model() -> String {
58 Self::DEFAULT_SMART_MODEL.to_string()
59 }
60
61 pub fn default_eco_model() -> String {
63 Self::DEFAULT_ECO_MODEL.to_string()
64 }
65
66 pub fn default_recovery_model() -> String {
68 Self::DEFAULT_RECOVERY_MODEL.to_string()
69 }
70}
71
72impl std::fmt::Display for OpenAIModel {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 match self {
75 OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
76
77 OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
78
79 OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
80 OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
81 OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
82
83 OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
84 }
85 }
86}
87
88#[derive(Serialize, Deserialize, Debug)]
89pub struct OpenAIInput {
90 pub model: OpenAIModel,
91 pub messages: Vec<LLMMessage>,
92 pub max_tokens: u32,
93
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub json: Option<serde_json::Value>,
96
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub tools: Option<Vec<LLMTool>>,
99
100 #[serde(skip_serializing_if = "Option::is_none")]
101 pub reasoning_effort: Option<OpenAIReasoningEffort>,
102}
103
104impl OpenAIInput {
105 pub fn is_reasoning_model(&self) -> bool {
106 matches!(self.model, |OpenAIModel::O3| OpenAIModel::O4Mini
107 | OpenAIModel::GPT5
108 | OpenAIModel::GPT5Mini
109 | OpenAIModel::GPT5Nano)
110 }
111
112 pub fn is_standard_model(&self) -> bool {
113 !self.is_reasoning_model()
114 }
115}
116
117#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
118pub enum OpenAIReasoningEffort {
119 #[serde(rename = "minimal")]
120 Minimal,
121 #[serde(rename = "low")]
122 Low,
123 #[default]
124 #[serde(rename = "medium")]
125 Medium,
126 #[serde(rename = "high")]
127 High,
128}
129
130#[derive(Serialize, Deserialize, Debug)]
131pub struct OpenAITool {
132 pub r#type: String,
133 pub function: OpenAIToolFunction,
134}
135
136#[derive(Serialize, Deserialize, Debug)]
137pub struct OpenAIToolFunction {
138 pub name: String,
139 pub description: String,
140 pub parameters: serde_json::Value,
141}
142
143impl From<LLMTool> for OpenAITool {
144 fn from(tool: LLMTool) -> Self {
145 OpenAITool {
146 r#type: "function".to_string(),
147 function: OpenAIToolFunction {
148 name: tool.name,
149 description: tool.description,
150 parameters: tool.input_schema,
151 },
152 }
153 }
154}
155
156#[derive(Serialize, Deserialize, Debug, Clone)]
157pub struct OpenAIOutput {
158 pub model: String,
159 pub object: String,
160 pub choices: Vec<OpenAILLMChoice>,
161 pub created: u64,
162 pub usage: Option<LLMTokenUsage>,
163 pub id: String,
164}
165
166impl From<OpenAIOutput> for LLMCompletionResponse {
167 fn from(val: OpenAIOutput) -> Self {
168 LLMCompletionResponse {
169 model: val.model,
170 object: val.object,
171 choices: val.choices.into_iter().map(OpenAILLMChoice::into).collect(),
172 created: val.created,
173 usage: val.usage,
174 id: val.id,
175 }
176 }
177}
178
179#[derive(Serialize, Deserialize, Debug, Clone)]
180pub struct OpenAILLMChoice {
181 pub finish_reason: Option<String>,
182 pub index: u32,
183 pub message: OpenAILLMMessage,
184}
185
186impl From<OpenAILLMChoice> for LLMChoice {
187 fn from(val: OpenAILLMChoice) -> Self {
188 LLMChoice {
189 finish_reason: val.finish_reason,
190 index: val.index,
191 message: val.message.into(),
192 }
193 }
194}
195
196#[derive(Serialize, Deserialize, Debug, Clone)]
197pub struct OpenAILLMMessage {
198 pub role: String,
199 pub content: Option<String>,
200 pub tool_calls: Option<Vec<OpenAILLMMessageToolCall>>,
201}
202impl From<OpenAILLMMessage> for LLMMessage {
203 fn from(val: OpenAILLMMessage) -> Self {
204 LLMMessage {
205 role: val.role,
206 content: match val.tool_calls {
207 None => LLMMessageContent::String(val.content.unwrap_or_default()),
208 Some(tool_calls) => LLMMessageContent::List(
209 std::iter::once(LLMMessageTypedContent::Text {
210 text: val.content.unwrap_or_default(),
211 })
212 .chain(tool_calls.into_iter().map(|tool_call| {
213 LLMMessageTypedContent::ToolCall {
214 id: tool_call.id,
215 name: tool_call.function.name,
216 args: match serde_json::from_str(&tool_call.function.arguments) {
217 Ok(args) => args,
218 Err(_) => {
219 return LLMMessageTypedContent::Text {
220 text: String::from("Error parsing tool call arguments"),
221 };
222 }
223 },
224 }
225 }))
226 .collect(),
227 ),
228 },
229 }
230 }
231}
232
233#[derive(Serialize, Deserialize, Debug, Clone)]
234pub struct OpenAILLMMessageToolCall {
235 pub id: String,
236 pub r#type: String,
237 pub function: OpenAILLMMessageToolCallFunction,
238}
239#[derive(Serialize, Deserialize, Debug, Clone)]
240pub struct OpenAILLMMessageToolCallFunction {
241 pub arguments: String,
242 pub name: String,
243}
244
245pub struct OpenAI {}
246
247impl OpenAI {
248 pub async fn chat(
249 config: &OpenAIConfig,
250 input: OpenAIInput,
251 ) -> Result<LLMCompletionResponse, AgentError> {
252 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
253 let client = ClientBuilder::new(reqwest::Client::new())
254 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
255 .build();
256
257 let is_reasoning_model = input.is_reasoning_model();
258
259 let mut payload = json!({
260 "model": input.model.to_string(),
261 "messages": input.messages.into_iter().map(ChatMessage::from).collect::<Vec<ChatMessage>>(),
262 "max_completion_tokens": input.max_tokens,
263 "stream": false,
264 });
265
266 if is_reasoning_model {
267 if let Some(reasoning_effort) = input.reasoning_effort {
268 payload["reasoning_effort"] = json!(reasoning_effort);
269 } else {
270 payload["reasoning_effort"] = json!(OpenAIReasoningEffort::Medium);
271 }
272 } else {
273 payload["temperature"] = json!(0);
274 }
275
276 if let Some(tools) = input.tools {
277 let openai_tools: Vec<OpenAITool> = tools.into_iter().map(|t| t.into()).collect();
278 payload["tools"] = json!(openai_tools);
279 }
280
281 if let Some(schema) = input.json {
282 payload["response_format"] = json!({
283 "type": "json_schema",
284 "json_schema": {
285 "strict": true,
286 "schema": schema,
287 "name": "my-schema"
288 }
289 });
290 }
291
292 let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
293 let api_key = config.api_key.as_ref().map_or("", |v| v);
294
295 let response = client
296 .post(api_endpoint)
297 .header("Authorization", format!("Bearer {}", api_key))
298 .header("Accept", "application/json")
299 .header("Content-Type", "application/json")
300 .body(serde_json::to_string(&payload).map_err(|e| {
301 AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string()))
302 })?)
303 .send()
304 .await;
305
306 let response = match response {
307 Ok(resp) => resp,
308 Err(e) => {
309 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
310 e.to_string(),
311 )));
312 }
313 };
314
315 match response.json::<Value>().await {
316 Ok(json) => match serde_json::from_value::<OpenAIOutput>(json.clone()) {
317 Ok(json_response) => Ok(json_response.into()),
318 Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
319 e.to_string(),
320 ))),
321 },
322 Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
323 e.to_string(),
324 ))),
325 }
326 }
327
328 pub async fn chat_stream(
329 config: &OpenAIConfig,
330 stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
331 input: OpenAIInput,
332 ) -> Result<LLMCompletionResponse, AgentError> {
333 let is_reasoning_model = input.is_reasoning_model();
334
335 let mut payload = json!({
336 "model": input.model.to_string(),
337 "messages": input.messages.into_iter().map(ChatMessage::from).collect::<Vec<ChatMessage>>(),
338 "max_completion_tokens": input.max_tokens,
339 "stream": true,
340 "stream_options":{
341 "include_usage": true
342 }
343 });
344
345 if is_reasoning_model {
346 if let Some(reasoning_effort) = input.reasoning_effort {
347 payload["reasoning_effort"] = json!(reasoning_effort);
348 } else {
349 payload["reasoning_effort"] = json!(OpenAIReasoningEffort::Medium);
350 }
351 } else {
352 payload["temperature"] = json!(0);
353 }
354
355 if let Some(tools) = input.tools {
356 let openai_tools: Vec<OpenAITool> = tools.into_iter().map(|t| t.into()).collect();
357 payload["tools"] = json!(openai_tools);
358 }
359
360 if let Some(schema) = input.json {
361 payload["response_format"] = json!({
362 "type": "json_schema",
363 "json_schema": {
364 "strict": true,
365 "schema": schema,
366 "name": "my-schema"
367 }
368 });
369 }
370
371 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
372 let client = ClientBuilder::new(reqwest::Client::new())
373 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
374 .build();
375
376 let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
377 let api_key = config.api_key.as_ref().map_or("", |v| v);
378
379 let response = client
381 .post(api_endpoint)
382 .header("Authorization", format!("Bearer {}", api_key))
383 .header("Accept", "application/json")
384 .header("Content-Type", "application/json")
385 .json(&payload)
386 .send()
387 .await;
388
389 let response = match response {
390 Ok(resp) => resp,
391 Err(e) => {
392 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
393 e.to_string(),
394 )));
395 }
396 };
397
398 if !response.status().is_success() {
399 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
400 format!(
401 "{}: {}",
402 response.status(),
403 response.text().await.unwrap_or_default(),
404 ),
405 )));
406 }
407
408 process_openai_stream(response, input.model.to_string(), stream_channel_tx)
409 .await
410 .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))
411 }
412}
413
414async fn process_openai_stream(
416 response: reqwest::Response,
417 model: String,
418 stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
419) -> Result<LLMCompletionResponse, AgentError> {
420 let mut completion_response = LLMCompletionResponse {
421 id: "".to_string(),
422 model: model.clone(),
423 object: "chat.completion".to_string(),
424 choices: vec![],
425 created: chrono::Utc::now().timestamp_millis() as u64,
426 usage: None,
427 };
428
429 let mut stream = response.bytes_stream();
430 let mut unparsed_data = String::new();
431 let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
432 std::collections::HashMap::new();
433 let mut accumulated_content = String::new();
434 let mut finish_reason: Option<String> = None;
435
436 while let Some(chunk) = stream.next().await {
437 let chunk = chunk.map_err(|e| {
438 let error_message = format!("Failed to read stream chunk from OpenAI API: {e}");
439 AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
440 })?;
441
442 let text = std::str::from_utf8(&chunk).map_err(|e| {
443 let error_message = format!("Failed to parse UTF-8 from OpenAI response: {e}");
444 AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
445 })?;
446
447 unparsed_data.push_str(text);
448
449 while let Some(line_end) = unparsed_data.find('\n') {
450 let line = unparsed_data[..line_end].to_string();
451 unparsed_data = unparsed_data[line_end + 1..].to_string();
452
453 if line.trim().is_empty() {
454 continue;
455 }
456
457 if !line.starts_with("data: ") {
458 continue;
459 }
460
461 let json_str = &line[6..];
462 if json_str == "[DONE]" {
463 continue;
464 }
465
466 match serde_json::from_str::<ChatCompletionStreamResponse>(json_str) {
467 Ok(stream_response) => {
468 if completion_response.id.is_empty() {
470 completion_response.id = stream_response.id.clone();
471 completion_response.model = stream_response.model.clone();
472 completion_response.object = stream_response.object.clone();
473 completion_response.created = stream_response.created;
474 }
475
476 for choice in &stream_response.choices {
478 if let Some(content) = &choice.delta.content {
479 stream_channel_tx
481 .send(GenerationDelta::Content {
482 content: content.clone(),
483 })
484 .await
485 .map_err(|e| {
486 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
487 e.to_string(),
488 ))
489 })?;
490 accumulated_content.push_str(content);
491 }
492
493 if let Some(tool_calls) = &choice.delta.tool_calls {
495 for tool_call in tool_calls {
496 let index = tool_call.index;
497
498 let entry = current_tool_calls.entry(index).or_insert((
500 String::new(),
501 String::new(),
502 String::new(),
503 ));
504
505 if let Some(id) = &tool_call.id {
506 entry.0 = id.clone();
507 }
508 if let Some(function) = &tool_call.function {
509 if let Some(name) = &function.name {
510 entry.1 = name.clone();
511 }
512 if let Some(args) = &function.arguments {
513 entry.2.push_str(args);
514 }
515 }
516
517 stream_channel_tx
519 .send(GenerationDelta::ToolUse {
520 tool_use: GenerationDeltaToolUse {
521 id: tool_call.id.clone(),
522 name: tool_call
523 .function
524 .as_ref()
525 .and_then(|f| f.name.clone())
526 .and_then(|n| {
527 if n.is_empty() { None } else { Some(n) }
528 }),
529 input: tool_call
530 .function
531 .as_ref()
532 .and_then(|f| f.arguments.clone()),
533 index,
534 },
535 })
536 .await
537 .map_err(|e| {
538 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
539 e.to_string(),
540 ))
541 })?;
542 }
543 }
544
545 if let Some(reason) = &choice.finish_reason {
546 finish_reason = Some(match reason {
547 FinishReason::Stop => "stop".to_string(),
548 FinishReason::Length => "length".to_string(),
549 FinishReason::ContentFilter => "content_filter".to_string(),
550 FinishReason::ToolCalls => "tool_calls".to_string(),
551 });
552 }
553 }
554
555 if let Some(usage) = &stream_response.usage {
557 stream_channel_tx
558 .send(GenerationDelta::Usage {
559 usage: usage.clone(),
560 })
561 .await
562 .map_err(|e| {
563 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
564 e.to_string(),
565 ))
566 })?;
567 completion_response.usage = Some(usage.clone());
568 }
569 }
570 Err(e) => {
571 eprintln!("Error parsing response: {}", e);
572 }
573 }
574 }
575 }
576
577 let mut message_content = vec![];
579
580 if !accumulated_content.is_empty() {
581 message_content.push(LLMMessageTypedContent::Text {
582 text: accumulated_content,
583 });
584 }
585
586 for (_, (id, name, args)) in current_tool_calls {
587 if let Ok(parsed_args) = serde_json::from_str(&args) {
588 message_content.push(LLMMessageTypedContent::ToolCall {
589 id,
590 name,
591 args: parsed_args,
592 });
593 }
594 }
595
596 completion_response.choices = vec![LLMChoice {
597 finish_reason,
598 index: 0,
599 message: LLMMessage {
600 role: "assistant".to_string(),
601 content: if message_content.is_empty() {
602 LLMMessageContent::String(String::new())
603 } else if message_content.len() == 1
604 && matches!(&message_content[0], LLMMessageTypedContent::Text { .. })
605 {
606 if let LLMMessageTypedContent::Text { text } = &message_content[0] {
607 LLMMessageContent::String(text.clone())
608 } else {
609 LLMMessageContent::List(message_content)
610 }
611 } else {
612 LLMMessageContent::List(message_content)
613 },
614 },
615 }];
616
617 Ok(completion_response)
618}
619
620#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
621#[serde(rename_all = "lowercase")]
622pub enum Role {
623 System,
624 Developer,
625 User,
626 #[default]
627 Assistant,
628 Tool,
629 }
631
632impl std::fmt::Display for Role {
633 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
634 match self {
635 Role::System => write!(f, "system"),
636 Role::Developer => write!(f, "developer"),
637 Role::User => write!(f, "user"),
638 Role::Assistant => write!(f, "assistant"),
639 Role::Tool => write!(f, "tool"),
640 }
641 }
642}
643
644#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
645pub struct ChatCompletionRequest {
646 pub model: String,
647 pub messages: Vec<ChatMessage>,
648 #[serde(skip_serializing_if = "Option::is_none")]
649 pub frequency_penalty: Option<f32>,
650 #[serde(skip_serializing_if = "Option::is_none")]
651 pub logit_bias: Option<serde_json::Value>,
652 #[serde(skip_serializing_if = "Option::is_none")]
653 pub logprobs: Option<bool>,
654 #[serde(skip_serializing_if = "Option::is_none")]
655 pub max_tokens: Option<u32>,
656 #[serde(skip_serializing_if = "Option::is_none")]
657 pub n: Option<u32>,
658 #[serde(skip_serializing_if = "Option::is_none")]
659 pub presence_penalty: Option<f32>,
660 #[serde(skip_serializing_if = "Option::is_none")]
661 pub response_format: Option<ResponseFormat>,
662 #[serde(skip_serializing_if = "Option::is_none")]
663 pub seed: Option<i64>,
664 #[serde(skip_serializing_if = "Option::is_none")]
665 pub stop: Option<StopSequence>,
666 #[serde(skip_serializing_if = "Option::is_none")]
667 pub stream: Option<bool>,
668 #[serde(skip_serializing_if = "Option::is_none")]
669 pub temperature: Option<f32>,
670 #[serde(skip_serializing_if = "Option::is_none")]
671 pub top_p: Option<f32>,
672 #[serde(skip_serializing_if = "Option::is_none")]
673 pub tools: Option<Vec<Tool>>,
674 #[serde(skip_serializing_if = "Option::is_none")]
675 pub tool_choice: Option<ToolChoice>,
676 #[serde(skip_serializing_if = "Option::is_none")]
677 pub user: Option<String>,
678
679 #[serde(skip_serializing_if = "Option::is_none")]
680 pub context: Option<ChatCompletionContext>,
681}
682
683#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
684pub struct ChatCompletionContext {
685 pub scratchpad: Option<Value>,
686}
687
688#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
689pub struct ChatMessage {
690 pub role: Role,
691 pub content: Option<MessageContent>,
692 #[serde(skip_serializing_if = "Option::is_none")]
693 pub name: Option<String>,
694 #[serde(skip_serializing_if = "Option::is_none")]
695 pub tool_calls: Option<Vec<ToolCall>>,
696 #[serde(skip_serializing_if = "Option::is_none")]
697 pub tool_call_id: Option<String>,
698
699 #[serde(skip_serializing_if = "Option::is_none")]
701 pub usage: Option<LLMTokenUsage>,
702}
703
704impl ChatMessage {
705 pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
706 messages
707 .iter()
708 .rev()
709 .find(|message| message.role != Role::User && message.role != Role::Tool)
710 }
711
712 pub fn to_xml(&self) -> String {
713 match &self.content {
714 Some(MessageContent::String(s)) => {
715 format!("<message role=\"{}\">{}</message>", self.role, s)
716 }
717 Some(MessageContent::Array(parts)) => parts
718 .iter()
719 .map(|part| {
720 format!(
721 "<message role=\"{}\" type=\"{}\">{}</message>",
722 self.role,
723 part.r#type,
724 part.text.clone().unwrap_or_default()
725 )
726 })
727 .collect::<Vec<String>>()
728 .join("\n"),
729 None => String::new(),
730 }
731 }
732}
733
734#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
735#[serde(untagged)]
736pub enum MessageContent {
737 String(String),
738 Array(Vec<ContentPart>),
739}
740
741impl MessageContent {
742 pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
743 match self {
744 MessageContent::String(s) => MessageContent::String(format!(
745 "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
746 )),
747 MessageContent::Array(parts) => MessageContent::Array(
748 std::iter::once(ContentPart {
749 r#type: "text".to_string(),
750 text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
751 image_url: None,
752 })
753 .chain(parts.iter().cloned())
754 .collect(),
755 ),
756 }
757 }
758
759 pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
760 match self {
761 MessageContent::String(s) => s
762 .rfind("<checkpoint_id>")
763 .and_then(|start| {
764 s[start..]
765 .find("</checkpoint_id>")
766 .map(|end| (start + "<checkpoint_id>".len(), start + end))
767 })
768 .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
769 MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
770 part.text.as_deref().and_then(|text| {
771 text.rfind("<checkpoint_id>")
772 .and_then(|start| {
773 text[start..]
774 .find("</checkpoint_id>")
775 .map(|end| (start + "<checkpoint_id>".len(), start + end))
776 })
777 .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
778 })
779 }),
780 }
781 }
782}
783
784impl std::fmt::Display for MessageContent {
785 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
786 match self {
787 MessageContent::String(s) => write!(f, "{s}"),
788 MessageContent::Array(parts) => {
789 let text_parts: Vec<String> =
790 parts.iter().filter_map(|part| part.text.clone()).collect();
791 write!(f, "{}", text_parts.join("\n"))
792 }
793 }
794 }
795}
796impl Default for MessageContent {
797 fn default() -> Self {
798 MessageContent::String(String::new())
799 }
800}
801
802#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
803pub struct ContentPart {
804 pub r#type: String,
805 #[serde(skip_serializing_if = "Option::is_none")]
806 pub text: Option<String>,
807 #[serde(skip_serializing_if = "Option::is_none")]
808 pub image_url: Option<ImageUrl>,
809}
810
811#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
812pub struct ImageUrl {
813 pub url: String,
814 #[serde(skip_serializing_if = "Option::is_none")]
815 pub detail: Option<String>,
816}
817
818#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
819pub struct ResponseFormat {
820 pub r#type: String,
821}
822
823#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
824#[serde(untagged)]
825pub enum StopSequence {
826 String(String),
827 Array(Vec<String>),
828}
829
830#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
831pub struct Tool {
832 pub r#type: String,
833 pub function: FunctionDefinition,
834}
835
836#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
837pub struct FunctionDefinition {
838 pub name: String,
839 pub description: Option<String>,
840 pub parameters: serde_json::Value,
841}
842
843#[derive(Debug, Clone, PartialEq)]
844pub enum ToolChoice {
845 Auto,
846 Required,
847 Object(ToolChoiceObject),
848}
849
850impl Serialize for ToolChoice {
851 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
852 where
853 S: serde::Serializer,
854 {
855 match self {
856 ToolChoice::Auto => serializer.serialize_str("auto"),
857 ToolChoice::Required => serializer.serialize_str("required"),
858 ToolChoice::Object(obj) => obj.serialize(serializer),
859 }
860 }
861}
862
863impl<'de> Deserialize<'de> for ToolChoice {
864 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
865 where
866 D: serde::Deserializer<'de>,
867 {
868 struct ToolChoiceVisitor;
869
870 impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
871 type Value = ToolChoice;
872
873 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
874 formatter.write_str("string or object")
875 }
876
877 fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
878 where
879 E: serde::de::Error,
880 {
881 match value {
882 "auto" => Ok(ToolChoice::Auto),
883 "required" => Ok(ToolChoice::Required),
884 _ => Err(serde::de::Error::unknown_variant(
885 value,
886 &["auto", "required"],
887 )),
888 }
889 }
890
891 fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
892 where
893 M: serde::de::MapAccess<'de>,
894 {
895 let obj = ToolChoiceObject::deserialize(
896 serde::de::value::MapAccessDeserializer::new(map),
897 )?;
898 Ok(ToolChoice::Object(obj))
899 }
900 }
901
902 deserializer.deserialize_any(ToolChoiceVisitor)
903 }
904}
905
906#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
907pub struct ToolChoiceObject {
908 pub r#type: String,
909 pub function: FunctionChoice,
910}
911
912#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
913pub struct FunctionChoice {
914 pub name: String,
915}
916
917#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
918pub struct ToolCall {
919 pub id: String,
920 pub r#type: String,
921 pub function: FunctionCall,
922}
923
924#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
925pub struct FunctionCall {
926 pub name: String,
927 pub arguments: String,
928}
929
930#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
931pub struct ChatCompletionResponse {
932 pub id: String,
933 pub object: String,
934 pub created: u64,
935 pub model: String,
936 pub choices: Vec<ChatCompletionChoice>,
937 pub usage: LLMTokenUsage,
938 #[serde(skip_serializing_if = "Option::is_none")]
939 pub system_fingerprint: Option<String>,
940 pub metadata: Option<serde_json::Value>,
941}
942
943#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
944pub struct ChatCompletionChoice {
945 pub index: usize,
946 pub message: ChatMessage,
947 pub logprobs: Option<LogProbs>,
948 pub finish_reason: FinishReason,
949}
950
951#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
952#[serde(rename_all = "snake_case")]
953pub enum FinishReason {
954 Stop,
955 Length,
956 ContentFilter,
957 ToolCalls,
958}
959
960#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
961pub struct LogProbs {
962 pub content: Option<Vec<LogProbContent>>,
963}
964
965#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
966pub struct LogProbContent {
967 pub token: String,
968 pub logprob: f32,
969 pub bytes: Option<Vec<u8>>,
970 pub top_logprobs: Option<Vec<TokenLogprob>>,
971}
972
973#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
974pub struct TokenLogprob {
975 pub token: String,
976 pub logprob: f32,
977 pub bytes: Option<Vec<u8>>,
978}
979
980#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
981pub struct ChatCompletionStreamResponse {
982 pub id: String,
983 pub object: String,
984 pub created: u64,
985 pub model: String,
986 pub choices: Vec<ChatCompletionStreamChoice>,
987 pub usage: Option<LLMTokenUsage>,
988 pub metadata: Option<serde_json::Value>,
989}
990#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
991pub struct ChatCompletionStreamChoice {
992 pub index: usize,
993 pub delta: ChatMessageDelta,
994 pub finish_reason: Option<FinishReason>,
995}
996#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
997pub struct ChatMessageDelta {
998 #[serde(skip_serializing_if = "Option::is_none")]
999 pub role: Option<Role>,
1000 #[serde(skip_serializing_if = "Option::is_none")]
1001 pub content: Option<String>,
1002 #[serde(skip_serializing_if = "Option::is_none")]
1003 pub tool_calls: Option<Vec<ToolCallDelta>>,
1004}
1005
1006#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1007pub struct ToolCallDelta {
1008 pub index: usize,
1009 pub id: Option<String>,
1010 pub r#type: Option<String>,
1011 pub function: Option<FunctionCallDelta>,
1012}
1013
1014#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1015pub struct FunctionCallDelta {
1016 pub name: Option<String>,
1017 pub arguments: Option<String>,
1018}
1019
1020impl From<LLMMessage> for ChatMessage {
1021 fn from(llm_message: LLMMessage) -> Self {
1022 let role = match llm_message.role.as_str() {
1023 "system" => Role::System,
1024 "user" => Role::User,
1025 "assistant" => Role::Assistant,
1026 "tool" => Role::Tool,
1027 "developer" => Role::Developer,
1029 _ => Role::User, };
1031
1032 let (content, tool_calls) = match llm_message.content {
1033 LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None),
1034 LLMMessageContent::List(items) => {
1035 let mut text_parts = Vec::new();
1036 let mut tool_call_parts = Vec::new();
1037
1038 for item in items {
1039 match item {
1040 LLMMessageTypedContent::Text { text } => {
1041 text_parts.push(ContentPart {
1042 r#type: "text".to_string(),
1043 text: Some(text),
1044 image_url: None,
1045 });
1046 }
1047 LLMMessageTypedContent::ToolCall { id, name, args } => {
1048 tool_call_parts.push(ToolCall {
1049 id,
1050 r#type: "function".to_string(),
1051 function: FunctionCall {
1052 name,
1053 arguments: args.to_string(),
1054 },
1055 });
1056 }
1057 LLMMessageTypedContent::ToolResult { content, .. } => {
1058 text_parts.push(ContentPart {
1059 r#type: "text".to_string(),
1060 text: Some(content),
1061 image_url: None,
1062 });
1063 }
1064 LLMMessageTypedContent::Image { source } => {
1065 text_parts.push(ContentPart {
1066 r#type: "image_url".to_string(),
1067 text: None,
1068 image_url: Some(ImageUrl {
1069 url: format!(
1070 "data:{};base64,{}",
1071 source.media_type, source.data
1072 ),
1073 detail: None,
1074 }),
1075 });
1076 }
1077 }
1078 }
1079
1080 let content = if !text_parts.is_empty() {
1081 Some(MessageContent::Array(text_parts))
1082 } else {
1083 None
1084 };
1085
1086 let tool_calls = if !tool_call_parts.is_empty() {
1087 Some(tool_call_parts)
1088 } else {
1089 None
1090 };
1091
1092 (content, tool_calls)
1093 }
1094 };
1095
1096 ChatMessage {
1097 role,
1098 content,
1099 name: None, tool_calls,
1101 tool_call_id: None, usage: None,
1103 }
1104 }
1105}
1106
1107impl From<ChatMessage> for LLMMessage {
1108 fn from(chat_message: ChatMessage) -> Self {
1109 let mut content_parts = Vec::new();
1110
1111 match chat_message.content {
1113 Some(MessageContent::String(s)) => {
1114 if !s.is_empty() {
1115 content_parts.push(LLMMessageTypedContent::Text { text: s });
1116 }
1117 }
1118 Some(MessageContent::Array(parts)) => {
1119 for part in parts {
1120 if let Some(text) = part.text {
1121 content_parts.push(LLMMessageTypedContent::Text { text });
1122 } else if let Some(image_url) = part.image_url {
1123 let (media_type, data) = if image_url.url.starts_with("data:") {
1124 let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
1125 if parts.len() == 2 {
1126 let meta = parts[0];
1127 let data = parts[1];
1128 let media_type = meta
1129 .trim_start_matches("data:")
1130 .trim_end_matches(";base64")
1131 .to_string();
1132 (media_type, data.to_string())
1133 } else {
1134 ("image/jpeg".to_string(), image_url.url)
1135 }
1136 } else {
1137 ("image/jpeg".to_string(), image_url.url)
1138 };
1139
1140 content_parts.push(LLMMessageTypedContent::Image {
1141 source: LLMMessageImageSource {
1142 r#type: "base64".to_string(),
1143 media_type,
1144 data,
1145 },
1146 });
1147 }
1148 }
1149 }
1150 None => {}
1151 }
1152
1153 if let Some(tool_calls) = chat_message.tool_calls {
1155 for tool_call in tool_calls {
1156 let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
1157 content_parts.push(LLMMessageTypedContent::ToolCall {
1158 id: tool_call.id,
1159 name: tool_call.function.name,
1160 args,
1161 });
1162 }
1163 }
1164
1165 LLMMessage {
1166 role: chat_message.role.to_string(),
1167 content: if content_parts.is_empty() {
1168 LLMMessageContent::String(String::new())
1169 } else if content_parts.len() == 1 {
1170 match &content_parts[0] {
1171 LLMMessageTypedContent::Text { text } => {
1172 LLMMessageContent::String(text.clone())
1173 }
1174 _ => LLMMessageContent::List(content_parts),
1175 }
1176 } else {
1177 LLMMessageContent::List(content_parts)
1178 },
1179 }
1180 }
1181}
1182
1183#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
1184pub enum AgentModel {
1185 #[serde(rename = "smart")]
1186 #[default]
1187 Smart,
1188 #[serde(rename = "eco")]
1189 Eco,
1190 #[serde(rename = "recovery")]
1191 Recovery,
1192}
1193
1194impl std::fmt::Display for AgentModel {
1195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1196 match self {
1197 AgentModel::Smart => write!(f, "smart"),
1198 AgentModel::Eco => write!(f, "eco"),
1199 AgentModel::Recovery => write!(f, "recovery"),
1200 }
1201 }
1202}
1203
1204impl From<String> for AgentModel {
1205 fn from(value: String) -> Self {
1206 match value.as_str() {
1207 "eco" => AgentModel::Eco,
1208 "recovery" => AgentModel::Recovery,
1209 _ => AgentModel::Smart,
1210 }
1211 }
1212}
1213
1214impl ChatCompletionRequest {
1215 pub fn new(
1216 model: String,
1217 messages: Vec<ChatMessage>,
1218 tools: Option<Vec<Tool>>,
1219 stream: Option<bool>,
1220 ) -> Self {
1221 Self {
1222 model,
1223 messages,
1224 frequency_penalty: None,
1225 logit_bias: None,
1226 logprobs: None,
1227 max_tokens: None,
1228 n: None,
1229 presence_penalty: None,
1230 response_format: None,
1231 seed: None,
1232 stop: None,
1233 stream,
1234 temperature: None,
1235 top_p: None,
1236 tools,
1237 tool_choice: None,
1238 user: None,
1239 context: None,
1240 }
1241 }
1242}
1243
1244impl From<Tool> for LLMTool {
1245 fn from(tool: Tool) -> Self {
1246 LLMTool {
1247 name: tool.function.name,
1248 description: tool.function.description.unwrap_or_default(),
1249 input_schema: tool.function.parameters,
1250 }
1251 }
1252}
1253
1254#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1255pub enum ToolCallResultStatus {
1256 Success,
1257 Error,
1258 Cancelled,
1259}
1260
1261#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1262pub struct ToolCallResult {
1263 pub call: ToolCall,
1264 pub result: String,
1265 pub status: ToolCallResultStatus,
1266}
1267
1268#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1269pub struct ToolCallResultProgress {
1270 pub id: Uuid,
1271 pub message: String,
1272}
1273
1274impl From<GenerationDelta> for ChatMessageDelta {
1275 fn from(delta: GenerationDelta) -> Self {
1276 match delta {
1277 GenerationDelta::Content { content } => ChatMessageDelta {
1278 role: Some(Role::Assistant),
1279 content: Some(content),
1280 tool_calls: None,
1281 },
1282 GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1283 role: Some(Role::Assistant),
1284 content: None,
1285 tool_calls: None,
1286 },
1287 GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1288 role: Some(Role::Assistant),
1289 content: None,
1290 tool_calls: Some(vec![ToolCallDelta {
1291 index: tool_use.index,
1292 id: tool_use.id,
1293 r#type: Some("function".to_string()),
1294 function: Some(FunctionCallDelta {
1295 name: tool_use.name,
1296 arguments: tool_use.input,
1297 }),
1298 }]),
1299 },
1300 _ => ChatMessageDelta {
1301 role: Some(Role::Assistant),
1302 content: None,
1303 tool_calls: None,
1304 },
1305 }
1306 }
1307}
1308
1309#[cfg(test)]
1310mod tests {
1311 use super::*;
1312
1313 #[test]
1314 fn test_serialize_basic_request() {
1315 let request = ChatCompletionRequest {
1316 model: AgentModel::Smart.to_string(),
1317 messages: vec![
1318 ChatMessage {
1319 role: Role::System,
1320 content: Some(MessageContent::String(
1321 "You are a helpful assistant.".to_string(),
1322 )),
1323 name: None,
1324 tool_calls: None,
1325 tool_call_id: None,
1326 usage: None,
1327 },
1328 ChatMessage {
1329 role: Role::User,
1330 content: Some(MessageContent::String("Hello!".to_string())),
1331 name: None,
1332 tool_calls: None,
1333 tool_call_id: None,
1334 usage: None,
1335 },
1336 ],
1337 frequency_penalty: None,
1338 logit_bias: None,
1339 logprobs: None,
1340 max_tokens: Some(100),
1341 n: None,
1342 presence_penalty: None,
1343 response_format: None,
1344 seed: None,
1345 stop: None,
1346 stream: None,
1347 temperature: Some(0.7),
1348 top_p: None,
1349 tools: None,
1350 tool_choice: None,
1351 user: None,
1352 context: None,
1353 };
1354
1355 let json = serde_json::to_string(&request).unwrap();
1356 assert!(json.contains("\"model\":\"smart\""));
1357 assert!(json.contains("\"messages\":["));
1358 assert!(json.contains("\"role\":\"system\""));
1359 assert!(json.contains("\"content\":\"You are a helpful assistant.\""));
1360 assert!(json.contains("\"role\":\"user\""));
1361 assert!(json.contains("\"content\":\"Hello!\""));
1362 assert!(json.contains("\"max_tokens\":100"));
1363 assert!(json.contains("\"temperature\":0.7"));
1364 }
1365
1366 #[test]
1367 fn test_deserialize_response() {
1368 let json = r#"{
1369 "id": "chatcmpl-123",
1370 "object": "chat.completion",
1371 "created": 1677652288,
1372 "model": "smart",
1373 "system_fingerprint": "fp_123abc",
1374 "choices": [{
1375 "index": 0,
1376 "message": {
1377 "role": "assistant",
1378 "content": "Hello! How can I help you today?"
1379 },
1380 "finish_reason": "stop"
1381 }],
1382 "usage": {
1383 "prompt_tokens": 9,
1384 "completion_tokens": 12,
1385 "total_tokens": 21
1386 }
1387 }"#;
1388
1389 let response: ChatCompletionResponse = serde_json::from_str(json).unwrap();
1390 assert_eq!(response.id, "chatcmpl-123");
1391 assert_eq!(response.object, "chat.completion");
1392 assert_eq!(response.created, 1677652288);
1393 assert_eq!(response.model, AgentModel::Smart.to_string());
1394 assert_eq!(response.system_fingerprint, Some("fp_123abc".to_string()));
1395
1396 assert_eq!(response.choices.len(), 1);
1397 assert_eq!(response.choices[0].index, 0);
1398 assert_eq!(response.choices[0].message.role, Role::Assistant);
1399
1400 match &response.choices[0].message.content {
1401 Some(MessageContent::String(content)) => {
1402 assert_eq!(content, "Hello! How can I help you today?");
1403 }
1404 _ => panic!("Expected string content"),
1405 }
1406
1407 assert_eq!(response.choices[0].finish_reason, FinishReason::Stop);
1408 assert_eq!(response.usage.prompt_tokens, 9);
1409 assert_eq!(response.usage.completion_tokens, 12);
1410 assert_eq!(response.usage.total_tokens, 21);
1411 }
1412
1413 #[test]
1414 fn test_tool_calls_request_response() {
1415 let tools_request = ChatCompletionRequest {
1417 model: AgentModel::Smart.to_string(),
1418 messages: vec![ChatMessage {
1419 role: Role::User,
1420 content: Some(MessageContent::String(
1421 "What's the weather in San Francisco?".to_string(),
1422 )),
1423 name: None,
1424 tool_calls: None,
1425 tool_call_id: None,
1426 usage: None,
1427 }],
1428 tools: Some(vec![Tool {
1429 r#type: "function".to_string(),
1430 function: FunctionDefinition {
1431 name: "get_weather".to_string(),
1432 description: Some("Get the current weather in a given location".to_string()),
1433 parameters: serde_json::json!({
1434 "type": "object",
1435 "properties": {
1436 "location": {
1437 "type": "string",
1438 "description": "The city and state, e.g. San Francisco, CA"
1439 }
1440 },
1441 "required": ["location"]
1442 }),
1443 },
1444 }]),
1445 tool_choice: Some(ToolChoice::Auto),
1446 max_tokens: Some(100),
1447 temperature: Some(0.7),
1448 frequency_penalty: None,
1449 logit_bias: None,
1450 logprobs: None,
1451 n: None,
1452 presence_penalty: None,
1453 response_format: None,
1454 seed: None,
1455 stop: None,
1456 stream: None,
1457 top_p: None,
1458 user: None,
1459 context: None,
1460 };
1461
1462 let json = serde_json::to_string(&tools_request).unwrap();
1463 println!("Tool request JSON: {}", json);
1464
1465 assert!(json.contains("\"tools\":["));
1466 assert!(json.contains("\"type\":\"function\""));
1467 assert!(json.contains("\"name\":\"get_weather\""));
1468 assert!(json.contains("\"tool_choice\":\"auto\""));
1470
1471 let tool_response_json = r#"{
1473 "id": "chatcmpl-123",
1474 "object": "chat.completion",
1475 "created": 1677652288,
1476 "model": "eco",
1477 "choices": [{
1478 "index": 0,
1479 "message": {
1480 "role": "assistant",
1481 "content": null,
1482 "tool_calls": [
1483 {
1484 "id": "call_abc123",
1485 "type": "function",
1486 "function": {
1487 "name": "get_weather",
1488 "arguments": "{\"location\":\"San Francisco, CA\"}"
1489 }
1490 }
1491 ]
1492 },
1493 "finish_reason": "tool_calls"
1494 }],
1495 "usage": {
1496 "prompt_tokens": 82,
1497 "completion_tokens": 17,
1498 "total_tokens": 99
1499 }
1500 }"#;
1501
1502 let tool_response: ChatCompletionResponse =
1503 serde_json::from_str(tool_response_json).unwrap();
1504 assert_eq!(tool_response.choices[0].message.role, Role::Assistant);
1505 assert_eq!(tool_response.choices[0].message.content, None);
1506 assert!(tool_response.choices[0].message.tool_calls.is_some());
1507
1508 let tool_calls = tool_response.choices[0]
1509 .message
1510 .tool_calls
1511 .as_ref()
1512 .unwrap();
1513 assert_eq!(tool_calls.len(), 1);
1514 assert_eq!(tool_calls[0].id, "call_abc123");
1515 assert_eq!(tool_calls[0].r#type, "function");
1516 assert_eq!(tool_calls[0].function.name, "get_weather");
1517 assert_eq!(
1518 tool_calls[0].function.arguments,
1519 "{\"location\":\"San Francisco, CA\"}"
1520 );
1521
1522 assert_eq!(
1523 tool_response.choices[0].finish_reason,
1524 FinishReason::ToolCalls,
1525 );
1526 }
1527
1528 #[test]
1529 fn test_content_with_image() {
1530 let message_with_image = ChatMessage {
1531 role: Role::User,
1532 content: Some(MessageContent::Array(vec![
1533 ContentPart {
1534 r#type: "text".to_string(),
1535 text: Some("What's in this image?".to_string()),
1536 image_url: None,
1537 },
1538 ContentPart {
1539 r#type: "input_image".to_string(),
1540 text: None,
1541 image_url: Some(ImageUrl {
1542 url: "...".to_string(),
1543 detail: Some("low".to_string()),
1544 }),
1545 },
1546 ])),
1547 name: None,
1548 tool_calls: None,
1549 tool_call_id: None,
1550 usage: None,
1551 };
1552
1553 let json = serde_json::to_string(&message_with_image).unwrap();
1554 println!("Serialized JSON: {}", json);
1555
1556 assert!(json.contains("\"role\":\"user\""));
1557 assert!(json.contains("\"type\":\"text\""));
1558 assert!(json.contains("\"text\":\"What's in this image?\""));
1559 assert!(json.contains("\"type\":\"input_image\""));
1560 assert!(json.contains("\"url\":\"...\""));
1561 assert!(json.contains("\"detail\":\"low\""));
1562 }
1563
1564 #[test]
1565 fn test_response_format() {
1566 let json_format_request = ChatCompletionRequest {
1567 model: AgentModel::Smart.to_string(),
1568 messages: vec![ChatMessage {
1569 role: Role::User,
1570 content: Some(MessageContent::String(
1571 "Generate a JSON object with name and age fields".to_string(),
1572 )),
1573 name: None,
1574 tool_calls: None,
1575 tool_call_id: None,
1576 usage: None,
1577 }],
1578 response_format: Some(ResponseFormat {
1579 r#type: "json_object".to_string(),
1580 }),
1581 max_tokens: Some(100),
1582 temperature: None,
1583 frequency_penalty: None,
1584 logit_bias: None,
1585 logprobs: None,
1586 n: None,
1587 presence_penalty: None,
1588 seed: None,
1589 stop: None,
1590 stream: None,
1591 top_p: None,
1592 tools: None,
1593 tool_choice: None,
1594 user: None,
1595 context: None,
1596 };
1597
1598 let json = serde_json::to_string(&json_format_request).unwrap();
1599 assert!(json.contains("\"response_format\":{\"type\":\"json_object\"}"));
1600 }
1601
1602 #[test]
1603 fn test_llm_message_to_chat_message() {
1604 let llm_message = LLMMessage {
1606 role: "user".to_string(),
1607 content: LLMMessageContent::String("Hello, world!".to_string()),
1608 };
1609
1610 let chat_message = ChatMessage::from(llm_message);
1611 assert_eq!(chat_message.role, Role::User);
1612 match &chat_message.content {
1613 Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1614 _ => panic!("Expected string content"),
1615 }
1616 assert_eq!(chat_message.tool_calls, None);
1617
1618 let llm_message_with_tool = LLMMessage {
1620 role: "assistant".to_string(),
1621 content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolCall {
1622 id: "call_123".to_string(),
1623 name: "get_weather".to_string(),
1624 args: serde_json::json!({"location": "San Francisco"}),
1625 }]),
1626 };
1627
1628 let chat_message = ChatMessage::from(llm_message_with_tool);
1629 assert_eq!(chat_message.role, Role::Assistant);
1630 assert_eq!(chat_message.content, None); assert!(chat_message.tool_calls.is_some());
1632
1633 let tool_calls = chat_message.tool_calls.unwrap();
1634 assert_eq!(tool_calls.len(), 1);
1635 assert_eq!(tool_calls[0].id, "call_123");
1636 assert_eq!(tool_calls[0].function.name, "get_weather");
1637 assert!(tool_calls[0].function.arguments.contains("San Francisco"));
1638
1639 let llm_message_mixed = LLMMessage {
1641 role: "assistant".to_string(),
1642 content: LLMMessageContent::List(vec![
1643 LLMMessageTypedContent::Text {
1644 text: "The weather is:".to_string(),
1645 },
1646 LLMMessageTypedContent::ToolCall {
1647 id: "call_456".to_string(),
1648 name: "get_weather".to_string(),
1649 args: serde_json::json!({"location": "New York"}),
1650 },
1651 ]),
1652 };
1653
1654 let chat_message = ChatMessage::from(llm_message_mixed);
1655 assert_eq!(chat_message.role, Role::Assistant);
1656
1657 match &chat_message.content {
1658 Some(MessageContent::Array(parts)) => {
1659 assert_eq!(parts.len(), 1);
1660 assert_eq!(parts[0].r#type, "text");
1661 assert_eq!(parts[0].text, Some("The weather is:".to_string()));
1662 }
1663 _ => panic!("Expected array content"),
1664 }
1665
1666 let tool_calls = chat_message.tool_calls.unwrap();
1667 assert_eq!(tool_calls.len(), 1);
1668 assert_eq!(tool_calls[0].id, "call_456");
1669 assert_eq!(tool_calls[0].function.name, "get_weather");
1670 assert!(tool_calls[0].function.arguments.contains("New York"));
1671 }
1672}