1use crate::models::error::{AgentError, BadRequestErrorMessage};
2use crate::models::llm::{
3 GenerationDelta, GenerationDeltaToolUse, LLMChoice, LLMCompletionResponse, LLMMessage,
4 LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent, LLMTokenUsage, LLMTool,
5};
6use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
7use futures_util::StreamExt;
8use reqwest_middleware::ClientBuilder;
9use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use serde_json::json;
13use uuid::Uuid;
14
15const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1/chat/completions";
16
17#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
18pub struct OpenAIConfig {
19 pub api_endpoint: Option<String>,
20 pub api_key: Option<String>,
21}
22
23#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
24pub enum OpenAIModel {
25 #[serde(rename = "o3-2025-04-16")]
27 O3,
28 #[serde(rename = "o4-mini-2025-04-16")]
29 O4Mini,
30
31 #[default]
32 #[serde(rename = "gpt-5-2025-08-07")]
33 GPT5,
34 #[serde(rename = "gpt-5.1-2025-11-13")]
35 GPT51,
36 #[serde(rename = "gpt-5-mini-2025-08-07")]
37 GPT5Mini,
38 #[serde(rename = "gpt-5-nano-2025-08-07")]
39 GPT5Nano,
40
41 Custom(String),
42}
43
44impl OpenAIModel {
45 pub fn from_string(s: &str) -> Result<Self, String> {
46 serde_json::from_value(serde_json::Value::String(s.to_string()))
47 .map_err(|_| "Failed to deserialize OpenAI model".to_string())
48 }
49
50 pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
52
53 pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
55
56 pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
58
59 pub fn default_smart_model() -> String {
61 Self::DEFAULT_SMART_MODEL.to_string()
62 }
63
64 pub fn default_eco_model() -> String {
66 Self::DEFAULT_ECO_MODEL.to_string()
67 }
68
69 pub fn default_recovery_model() -> String {
71 Self::DEFAULT_RECOVERY_MODEL.to_string()
72 }
73}
74
75impl ContextAware for OpenAIModel {
76 fn context_info(&self) -> ModelContextInfo {
77 let model_name = self.to_string();
78
79 if model_name.starts_with("o3") {
80 return ModelContextInfo {
81 max_tokens: 200_000,
82 pricing_tiers: vec![ContextPricingTier {
83 label: "Standard".to_string(),
84 input_cost_per_million: 2.0,
85 output_cost_per_million: 8.0,
86 upper_bound: None,
87 }],
88 approach_warning_threshold: 0.8,
89 };
90 }
91
92 if model_name.starts_with("o4-mini") {
93 return ModelContextInfo {
94 max_tokens: 200_000,
95 pricing_tiers: vec![ContextPricingTier {
96 label: "Standard".to_string(),
97 input_cost_per_million: 1.10,
98 output_cost_per_million: 4.40,
99 upper_bound: None,
100 }],
101 approach_warning_threshold: 0.8,
102 };
103 }
104
105 if model_name.starts_with("gpt-5-mini") {
106 return ModelContextInfo {
107 max_tokens: 400_000,
108 pricing_tiers: vec![ContextPricingTier {
109 label: "Standard".to_string(),
110 input_cost_per_million: 0.25,
111 output_cost_per_million: 2.0,
112 upper_bound: None,
113 }],
114 approach_warning_threshold: 0.8,
115 };
116 }
117
118 if model_name.starts_with("gpt-5-nano") {
119 return ModelContextInfo {
120 max_tokens: 400_000,
121 pricing_tiers: vec![ContextPricingTier {
122 label: "Standard".to_string(),
123 input_cost_per_million: 0.05,
124 output_cost_per_million: 0.40,
125 upper_bound: None,
126 }],
127 approach_warning_threshold: 0.8,
128 };
129 }
130
131 if model_name.starts_with("gpt-5") {
132 return ModelContextInfo {
133 max_tokens: 400_000,
134 pricing_tiers: vec![ContextPricingTier {
135 label: "Standard".to_string(),
136 input_cost_per_million: 1.25,
137 output_cost_per_million: 10.0,
138 upper_bound: None,
139 }],
140 approach_warning_threshold: 0.8,
141 };
142 }
143
144 ModelContextInfo::default()
145 }
146
147 fn model_name(&self) -> String {
148 match self {
149 OpenAIModel::O3 => "O3".to_string(),
150 OpenAIModel::O4Mini => "O4-mini".to_string(),
151 OpenAIModel::GPT5 => "GPT-5".to_string(),
152 OpenAIModel::GPT51 => "GPT-5.1".to_string(),
153 OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
154 OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
155 OpenAIModel::Custom(name) => format!("Custom ({})", name),
156 }
157 }
158}
159
160impl std::fmt::Display for OpenAIModel {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 match self {
163 OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
164
165 OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
166
167 OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
168 OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
169 OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
170 OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
171
172 OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
173 }
174 }
175}
176
177#[derive(Serialize, Deserialize, Debug)]
178pub struct OpenAIInput {
179 pub model: OpenAIModel,
180 pub messages: Vec<LLMMessage>,
181 pub max_tokens: u32,
182
183 #[serde(skip_serializing_if = "Option::is_none")]
184 pub json: Option<serde_json::Value>,
185
186 #[serde(skip_serializing_if = "Option::is_none")]
187 pub tools: Option<Vec<LLMTool>>,
188
189 #[serde(skip_serializing_if = "Option::is_none")]
190 pub reasoning_effort: Option<OpenAIReasoningEffort>,
191}
192
193impl OpenAIInput {
194 pub fn is_reasoning_model(&self) -> bool {
195 matches!(self.model, |OpenAIModel::O3| OpenAIModel::O4Mini
196 | OpenAIModel::GPT5
197 | OpenAIModel::GPT5Mini
198 | OpenAIModel::GPT5Nano)
199 }
200
201 pub fn is_standard_model(&self) -> bool {
202 !self.is_reasoning_model()
203 }
204}
205
206#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
207pub enum OpenAIReasoningEffort {
208 #[serde(rename = "minimal")]
209 Minimal,
210 #[serde(rename = "low")]
211 Low,
212 #[default]
213 #[serde(rename = "medium")]
214 Medium,
215 #[serde(rename = "high")]
216 High,
217}
218
219#[derive(Serialize, Deserialize, Debug)]
220pub struct OpenAITool {
221 pub r#type: String,
222 pub function: OpenAIToolFunction,
223}
224
225#[derive(Serialize, Deserialize, Debug)]
226pub struct OpenAIToolFunction {
227 pub name: String,
228 pub description: String,
229 pub parameters: serde_json::Value,
230}
231
232impl From<LLMTool> for OpenAITool {
233 fn from(tool: LLMTool) -> Self {
234 OpenAITool {
235 r#type: "function".to_string(),
236 function: OpenAIToolFunction {
237 name: tool.name,
238 description: tool.description,
239 parameters: tool.input_schema,
240 },
241 }
242 }
243}
244
245#[derive(Serialize, Deserialize, Debug, Clone)]
246pub struct OpenAIOutput {
247 pub model: String,
248 pub object: String,
249 pub choices: Vec<OpenAILLMChoice>,
250 pub created: u64,
251 pub usage: Option<LLMTokenUsage>,
252 pub id: String,
253}
254
255impl From<OpenAIOutput> for LLMCompletionResponse {
256 fn from(val: OpenAIOutput) -> Self {
257 LLMCompletionResponse {
258 model: val.model,
259 object: val.object,
260 choices: val.choices.into_iter().map(OpenAILLMChoice::into).collect(),
261 created: val.created,
262 usage: val.usage,
263 id: val.id,
264 }
265 }
266}
267
268#[derive(Serialize, Deserialize, Debug, Clone)]
269pub struct OpenAILLMChoice {
270 pub finish_reason: Option<String>,
271 pub index: u32,
272 pub message: OpenAILLMMessage,
273}
274
275impl From<OpenAILLMChoice> for LLMChoice {
276 fn from(val: OpenAILLMChoice) -> Self {
277 LLMChoice {
278 finish_reason: val.finish_reason,
279 index: val.index,
280 message: val.message.into(),
281 }
282 }
283}
284
285#[derive(Serialize, Deserialize, Debug, Clone)]
286pub struct OpenAILLMMessage {
287 pub role: String,
288 pub content: Option<String>,
289 pub tool_calls: Option<Vec<OpenAILLMMessageToolCall>>,
290}
291impl From<OpenAILLMMessage> for LLMMessage {
292 fn from(val: OpenAILLMMessage) -> Self {
293 LLMMessage {
294 role: val.role,
295 content: match val.tool_calls {
296 None => LLMMessageContent::String(val.content.unwrap_or_default()),
297 Some(tool_calls) => LLMMessageContent::List(
298 std::iter::once(LLMMessageTypedContent::Text {
299 text: val.content.unwrap_or_default(),
300 })
301 .chain(tool_calls.into_iter().map(|tool_call| {
302 LLMMessageTypedContent::ToolCall {
303 id: tool_call.id,
304 name: tool_call.function.name,
305 args: match serde_json::from_str(&tool_call.function.arguments) {
306 Ok(args) => args,
307 Err(_) => {
308 return LLMMessageTypedContent::Text {
309 text: String::from("Error parsing tool call arguments"),
310 };
311 }
312 },
313 }
314 }))
315 .collect(),
316 ),
317 },
318 }
319 }
320}
321
322#[derive(Serialize, Deserialize, Debug, Clone)]
323pub struct OpenAILLMMessageToolCall {
324 pub id: String,
325 pub r#type: String,
326 pub function: OpenAILLMMessageToolCallFunction,
327}
328#[derive(Serialize, Deserialize, Debug, Clone)]
329pub struct OpenAILLMMessageToolCallFunction {
330 pub arguments: String,
331 pub name: String,
332}
333
334pub struct OpenAI {}
335
336impl OpenAI {
337 pub async fn chat(
338 config: &OpenAIConfig,
339 input: OpenAIInput,
340 ) -> Result<LLMCompletionResponse, AgentError> {
341 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
342 let client = ClientBuilder::new(reqwest::Client::new())
343 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
344 .build();
345
346 let is_reasoning_model = input.is_reasoning_model();
347
348 let mut payload = json!({
349 "model": input.model.to_string(),
350 "messages": input.messages.into_iter().map(ChatMessage::from).collect::<Vec<ChatMessage>>(),
351 "max_completion_tokens": input.max_tokens,
352 "stream": false,
353 });
354
355 if is_reasoning_model {
356 if let Some(reasoning_effort) = input.reasoning_effort {
357 payload["reasoning_effort"] = json!(reasoning_effort);
358 } else {
359 payload["reasoning_effort"] = json!(OpenAIReasoningEffort::Medium);
360 }
361 } else {
362 payload["temperature"] = json!(0);
363 }
364
365 if let Some(tools) = input.tools {
366 let openai_tools: Vec<OpenAITool> = tools.into_iter().map(|t| t.into()).collect();
367 payload["tools"] = json!(openai_tools);
368 }
369
370 if let Some(schema) = input.json {
371 payload["response_format"] = json!({
372 "type": "json_schema",
373 "json_schema": {
374 "strict": true,
375 "schema": schema,
376 "name": "my-schema"
377 }
378 });
379 }
380
381 let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
382 let api_key = config.api_key.as_ref().map_or("", |v| v);
383
384 let final_payload_str = serde_json::to_string(&payload)
385 .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))?;
386
387 let response = client
388 .post(api_endpoint)
389 .header("Authorization", format!("Bearer {}", api_key))
390 .header("Accept", "application/json")
391 .header("Content-Type", "application/json")
392 .body(final_payload_str)
393 .send()
394 .await;
395
396 let response = match response {
397 Ok(resp) => resp,
398 Err(e) => {
399 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
400 e.to_string(),
401 )));
402 }
403 };
404
405 match response.json::<Value>().await {
406 Ok(json) => {
407 if let Some(error_obj) = json.get("error") {
409 let error_message =
410 if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
411 message.to_string()
412 } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
413 format!("API error: {}", code)
414 } else {
415 serde_json::to_string(error_obj)
416 .unwrap_or_else(|_| "Unknown API error".to_string())
417 };
418 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
419 error_message,
420 )));
421 }
422
423 match serde_json::from_value::<OpenAIOutput>(json.clone()) {
425 Ok(json_response) => Ok(json_response.into()),
426 Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
427 e.to_string(),
428 ))),
429 }
430 }
431 Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
432 e.to_string(),
433 ))),
434 }
435 }
436
437 pub async fn chat_stream(
438 config: &OpenAIConfig,
439 stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
440 input: OpenAIInput,
441 ) -> Result<LLMCompletionResponse, AgentError> {
442 let is_reasoning_model = input.is_reasoning_model();
443
444 let mut payload = json!({
445 "model": input.model.to_string(),
446 "messages": input.messages.into_iter().map(ChatMessage::from).collect::<Vec<ChatMessage>>(),
447 "max_completion_tokens": input.max_tokens,
448 "stream": true,
449 "stream_options":{
450 "include_usage": true
451 }
452 });
453
454 if is_reasoning_model {
455 if let Some(reasoning_effort) = input.reasoning_effort {
456 payload["reasoning_effort"] = json!(reasoning_effort);
457 } else {
458 payload["reasoning_effort"] = json!(OpenAIReasoningEffort::Medium);
459 }
460 } else {
461 payload["temperature"] = json!(0);
462 }
463
464 if let Some(tools) = input.tools {
465 let openai_tools: Vec<OpenAITool> = tools.into_iter().map(|t| t.into()).collect();
466 payload["tools"] = json!(openai_tools);
467 }
468
469 if let Some(schema) = input.json {
470 payload["response_format"] = json!({
471 "type": "json_schema",
472 "json_schema": {
473 "strict": true,
474 "schema": schema,
475 "name": "my-schema"
476 }
477 });
478 }
479
480 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
481 let client = ClientBuilder::new(reqwest::Client::new())
482 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
483 .build();
484
485 let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
486 let api_key = config.api_key.as_ref().map_or("", |v| v);
487
488 let final_payload_str = serde_json::to_string(&payload)
489 .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))?;
490
491 let response = client
492 .post(api_endpoint)
493 .header("Authorization", format!("Bearer {}", api_key))
494 .header("Accept", "application/json")
495 .header("Content-Type", "application/json")
496 .body(final_payload_str)
497 .send()
498 .await;
499
500 let response = match response {
501 Ok(resp) => resp,
502 Err(e) => {
503 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
504 e.to_string(),
505 )));
506 }
507 };
508
509 if !response.status().is_success() {
510 let error_text = response.text().await.unwrap_or_default();
511
512 let error_message = if let Ok(json) = serde_json::from_str::<Value>(&error_text) {
514 if let Some(error_obj) = json.get("error") {
515 if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
516 message.to_string()
517 } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
518 format!("API error: {}", code)
519 } else {
520 error_text
521 }
522 } else {
523 error_text
524 }
525 } else {
526 error_text
527 };
528
529 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
530 error_message,
531 )));
532 }
533
534 process_openai_stream(response, input.model.to_string(), stream_channel_tx)
535 .await
536 .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))
537 }
538}
539
540async fn process_openai_stream(
542 response: reqwest::Response,
543 model: String,
544 stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
545) -> Result<LLMCompletionResponse, AgentError> {
546 let mut completion_response = LLMCompletionResponse {
547 id: "".to_string(),
548 model: model.clone(),
549 object: "chat.completion".to_string(),
550 choices: vec![],
551 created: chrono::Utc::now().timestamp_millis() as u64,
552 usage: None,
553 };
554
555 let mut stream = response.bytes_stream();
556 let mut unparsed_data = String::new();
557 let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
558 std::collections::HashMap::new();
559 let mut accumulated_content = String::new();
560 let mut finish_reason: Option<String> = None;
561
562 while let Some(chunk) = stream.next().await {
563 let chunk = chunk.map_err(|e| {
564 let error_message = format!("Failed to read stream chunk from OpenAI API: {e}");
565 AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
566 })?;
567
568 let text = std::str::from_utf8(&chunk).map_err(|e| {
569 let error_message = format!("Failed to parse UTF-8 from OpenAI response: {e}");
570 AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
571 })?;
572
573 unparsed_data.push_str(text);
574
575 while let Some(line_end) = unparsed_data.find('\n') {
576 let line = unparsed_data[..line_end].to_string();
577 unparsed_data = unparsed_data[line_end + 1..].to_string();
578
579 if line.trim().is_empty() {
580 continue;
581 }
582
583 if !line.starts_with("data: ") {
584 continue;
585 }
586
587 let json_str = &line[6..];
588 if json_str == "[DONE]" {
589 continue;
590 }
591
592 match serde_json::from_str::<ChatCompletionStreamResponse>(json_str) {
593 Ok(stream_response) => {
594 if completion_response.id.is_empty() {
596 completion_response.id = stream_response.id.clone();
597 completion_response.model = stream_response.model.clone();
598 completion_response.object = stream_response.object.clone();
599 completion_response.created = stream_response.created;
600 }
601
602 for choice in &stream_response.choices {
604 if let Some(content) = &choice.delta.content {
605 stream_channel_tx
607 .send(GenerationDelta::Content {
608 content: content.clone(),
609 })
610 .await
611 .map_err(|e| {
612 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
613 e.to_string(),
614 ))
615 })?;
616 accumulated_content.push_str(content);
617 }
618
619 if let Some(tool_calls) = &choice.delta.tool_calls {
621 for tool_call in tool_calls {
622 let index = tool_call.index;
623
624 let entry = current_tool_calls.entry(index).or_insert((
626 String::new(),
627 String::new(),
628 String::new(),
629 ));
630
631 if let Some(id) = &tool_call.id {
632 entry.0 = id.clone();
633 }
634 if let Some(function) = &tool_call.function {
635 if let Some(name) = &function.name {
636 entry.1 = name.clone();
637 }
638 if let Some(args) = &function.arguments {
639 entry.2.push_str(args);
640 }
641 }
642
643 stream_channel_tx
645 .send(GenerationDelta::ToolUse {
646 tool_use: GenerationDeltaToolUse {
647 id: tool_call.id.clone(),
648 name: tool_call
649 .function
650 .as_ref()
651 .and_then(|f| f.name.clone())
652 .and_then(|n| {
653 if n.is_empty() { None } else { Some(n) }
654 }),
655 input: tool_call
656 .function
657 .as_ref()
658 .and_then(|f| f.arguments.clone()),
659 index,
660 },
661 })
662 .await
663 .map_err(|e| {
664 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
665 e.to_string(),
666 ))
667 })?;
668 }
669 }
670
671 if let Some(reason) = &choice.finish_reason {
672 finish_reason = Some(match reason {
673 FinishReason::Stop => "stop".to_string(),
674 FinishReason::Length => "length".to_string(),
675 FinishReason::ContentFilter => "content_filter".to_string(),
676 FinishReason::ToolCalls => "tool_calls".to_string(),
677 });
678 }
679 }
680
681 if let Some(usage) = &stream_response.usage {
683 stream_channel_tx
684 .send(GenerationDelta::Usage {
685 usage: usage.clone(),
686 })
687 .await
688 .map_err(|e| {
689 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
690 e.to_string(),
691 ))
692 })?;
693 completion_response.usage = Some(usage.clone());
694 }
695 }
696 Err(e) => {
697 eprintln!("Error parsing response: {}", e);
698 }
699 }
700 }
701 }
702
703 let mut message_content = vec![];
705
706 if !accumulated_content.is_empty() {
707 message_content.push(LLMMessageTypedContent::Text {
708 text: accumulated_content,
709 });
710 }
711
712 for (_, (id, name, args)) in current_tool_calls {
713 if let Ok(parsed_args) = serde_json::from_str(&args) {
714 message_content.push(LLMMessageTypedContent::ToolCall {
715 id,
716 name,
717 args: parsed_args,
718 });
719 }
720 }
721
722 completion_response.choices = vec![LLMChoice {
723 finish_reason,
724 index: 0,
725 message: LLMMessage {
726 role: "assistant".to_string(),
727 content: if message_content.is_empty() {
728 LLMMessageContent::String(String::new())
729 } else if message_content.len() == 1
730 && matches!(&message_content[0], LLMMessageTypedContent::Text { .. })
731 {
732 if let LLMMessageTypedContent::Text { text } = &message_content[0] {
733 LLMMessageContent::String(text.clone())
734 } else {
735 LLMMessageContent::List(message_content)
736 }
737 } else {
738 LLMMessageContent::List(message_content)
739 },
740 },
741 }];
742
743 Ok(completion_response)
744}
745
746#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
747#[serde(rename_all = "lowercase")]
748pub enum Role {
749 System,
750 Developer,
751 User,
752 #[default]
753 Assistant,
754 Tool,
755 }
757
758impl std::fmt::Display for Role {
759 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
760 match self {
761 Role::System => write!(f, "system"),
762 Role::Developer => write!(f, "developer"),
763 Role::User => write!(f, "user"),
764 Role::Assistant => write!(f, "assistant"),
765 Role::Tool => write!(f, "tool"),
766 }
767 }
768}
769
770#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
771pub struct ChatCompletionRequest {
772 pub model: String,
773 pub messages: Vec<ChatMessage>,
774 #[serde(skip_serializing_if = "Option::is_none")]
775 pub frequency_penalty: Option<f32>,
776 #[serde(skip_serializing_if = "Option::is_none")]
777 pub logit_bias: Option<serde_json::Value>,
778 #[serde(skip_serializing_if = "Option::is_none")]
779 pub logprobs: Option<bool>,
780 #[serde(skip_serializing_if = "Option::is_none")]
781 pub max_tokens: Option<u32>,
782 #[serde(skip_serializing_if = "Option::is_none")]
783 pub n: Option<u32>,
784 #[serde(skip_serializing_if = "Option::is_none")]
785 pub presence_penalty: Option<f32>,
786 #[serde(skip_serializing_if = "Option::is_none")]
787 pub response_format: Option<ResponseFormat>,
788 #[serde(skip_serializing_if = "Option::is_none")]
789 pub seed: Option<i64>,
790 #[serde(skip_serializing_if = "Option::is_none")]
791 pub stop: Option<StopSequence>,
792 #[serde(skip_serializing_if = "Option::is_none")]
793 pub stream: Option<bool>,
794 #[serde(skip_serializing_if = "Option::is_none")]
795 pub temperature: Option<f32>,
796 #[serde(skip_serializing_if = "Option::is_none")]
797 pub top_p: Option<f32>,
798 #[serde(skip_serializing_if = "Option::is_none")]
799 pub tools: Option<Vec<Tool>>,
800 #[serde(skip_serializing_if = "Option::is_none")]
801 pub tool_choice: Option<ToolChoice>,
802 #[serde(skip_serializing_if = "Option::is_none")]
803 pub user: Option<String>,
804
805 #[serde(skip_serializing_if = "Option::is_none")]
806 pub context: Option<ChatCompletionContext>,
807}
808
809#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
810pub struct ChatCompletionContext {
811 pub scratchpad: Option<Value>,
812}
813
814#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
815pub struct ChatMessage {
816 pub role: Role,
817 pub content: Option<MessageContent>,
818 #[serde(skip_serializing_if = "Option::is_none")]
819 pub name: Option<String>,
820 #[serde(skip_serializing_if = "Option::is_none")]
821 pub tool_calls: Option<Vec<ToolCall>>,
822 #[serde(skip_serializing_if = "Option::is_none")]
823 pub tool_call_id: Option<String>,
824
825 #[serde(skip_serializing_if = "Option::is_none")]
827 pub usage: Option<LLMTokenUsage>,
828}
829
830impl ChatMessage {
831 pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
832 messages
833 .iter()
834 .rev()
835 .find(|message| message.role != Role::User && message.role != Role::Tool)
836 }
837
838 pub fn to_xml(&self) -> String {
839 match &self.content {
840 Some(MessageContent::String(s)) => {
841 format!("<message role=\"{}\">{}</message>", self.role, s)
842 }
843 Some(MessageContent::Array(parts)) => parts
844 .iter()
845 .map(|part| {
846 format!(
847 "<message role=\"{}\" type=\"{}\">{}</message>",
848 self.role,
849 part.r#type,
850 part.text.clone().unwrap_or_default()
851 )
852 })
853 .collect::<Vec<String>>()
854 .join("\n"),
855 None => String::new(),
856 }
857 }
858}
859
860#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
861#[serde(untagged)]
862pub enum MessageContent {
863 String(String),
864 Array(Vec<ContentPart>),
865}
866
867impl MessageContent {
868 pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
869 match self {
870 MessageContent::String(s) => MessageContent::String(format!(
871 "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
872 )),
873 MessageContent::Array(parts) => MessageContent::Array(
874 std::iter::once(ContentPart {
875 r#type: "text".to_string(),
876 text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
877 image_url: None,
878 })
879 .chain(parts.iter().cloned())
880 .collect(),
881 ),
882 }
883 }
884
885 pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
886 match self {
887 MessageContent::String(s) => s
888 .rfind("<checkpoint_id>")
889 .and_then(|start| {
890 s[start..]
891 .find("</checkpoint_id>")
892 .map(|end| (start + "<checkpoint_id>".len(), start + end))
893 })
894 .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
895 MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
896 part.text.as_deref().and_then(|text| {
897 text.rfind("<checkpoint_id>")
898 .and_then(|start| {
899 text[start..]
900 .find("</checkpoint_id>")
901 .map(|end| (start + "<checkpoint_id>".len(), start + end))
902 })
903 .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
904 })
905 }),
906 }
907 }
908}
909
910impl std::fmt::Display for MessageContent {
911 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
912 match self {
913 MessageContent::String(s) => write!(f, "{s}"),
914 MessageContent::Array(parts) => {
915 let text_parts: Vec<String> =
916 parts.iter().filter_map(|part| part.text.clone()).collect();
917 write!(f, "{}", text_parts.join("\n"))
918 }
919 }
920 }
921}
922impl Default for MessageContent {
923 fn default() -> Self {
924 MessageContent::String(String::new())
925 }
926}
927
928#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
929pub struct ContentPart {
930 pub r#type: String,
931 #[serde(skip_serializing_if = "Option::is_none")]
932 pub text: Option<String>,
933 #[serde(skip_serializing_if = "Option::is_none")]
934 pub image_url: Option<ImageUrl>,
935}
936
937#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
938pub struct ImageUrl {
939 pub url: String,
940 #[serde(skip_serializing_if = "Option::is_none")]
941 pub detail: Option<String>,
942}
943
944#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
945pub struct ResponseFormat {
946 pub r#type: String,
947}
948
949#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
950#[serde(untagged)]
951pub enum StopSequence {
952 String(String),
953 Array(Vec<String>),
954}
955
956#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
957pub struct Tool {
958 pub r#type: String,
959 pub function: FunctionDefinition,
960}
961
962#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
963pub struct FunctionDefinition {
964 pub name: String,
965 pub description: Option<String>,
966 pub parameters: serde_json::Value,
967}
968
969#[derive(Debug, Clone, PartialEq)]
970pub enum ToolChoice {
971 Auto,
972 Required,
973 Object(ToolChoiceObject),
974}
975
976impl Serialize for ToolChoice {
977 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
978 where
979 S: serde::Serializer,
980 {
981 match self {
982 ToolChoice::Auto => serializer.serialize_str("auto"),
983 ToolChoice::Required => serializer.serialize_str("required"),
984 ToolChoice::Object(obj) => obj.serialize(serializer),
985 }
986 }
987}
988
989impl<'de> Deserialize<'de> for ToolChoice {
990 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
991 where
992 D: serde::Deserializer<'de>,
993 {
994 struct ToolChoiceVisitor;
995
996 impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
997 type Value = ToolChoice;
998
999 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
1000 formatter.write_str("string or object")
1001 }
1002
1003 fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
1004 where
1005 E: serde::de::Error,
1006 {
1007 match value {
1008 "auto" => Ok(ToolChoice::Auto),
1009 "required" => Ok(ToolChoice::Required),
1010 _ => Err(serde::de::Error::unknown_variant(
1011 value,
1012 &["auto", "required"],
1013 )),
1014 }
1015 }
1016
1017 fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
1018 where
1019 M: serde::de::MapAccess<'de>,
1020 {
1021 let obj = ToolChoiceObject::deserialize(
1022 serde::de::value::MapAccessDeserializer::new(map),
1023 )?;
1024 Ok(ToolChoice::Object(obj))
1025 }
1026 }
1027
1028 deserializer.deserialize_any(ToolChoiceVisitor)
1029 }
1030}
1031
1032#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1033pub struct ToolChoiceObject {
1034 pub r#type: String,
1035 pub function: FunctionChoice,
1036}
1037
1038#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1039pub struct FunctionChoice {
1040 pub name: String,
1041}
1042
1043#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1044pub struct ToolCall {
1045 pub id: String,
1046 pub r#type: String,
1047 pub function: FunctionCall,
1048}
1049
1050#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1051pub struct FunctionCall {
1052 pub name: String,
1053 pub arguments: String,
1054}
1055
1056#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1057pub struct ChatCompletionResponse {
1058 pub id: String,
1059 pub object: String,
1060 pub created: u64,
1061 pub model: String,
1062 pub choices: Vec<ChatCompletionChoice>,
1063 pub usage: LLMTokenUsage,
1064 #[serde(skip_serializing_if = "Option::is_none")]
1065 pub system_fingerprint: Option<String>,
1066 pub metadata: Option<serde_json::Value>,
1067}
1068
1069#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1070pub struct ChatCompletionChoice {
1071 pub index: usize,
1072 pub message: ChatMessage,
1073 pub logprobs: Option<LogProbs>,
1074 pub finish_reason: FinishReason,
1075}
1076
1077#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1078#[serde(rename_all = "snake_case")]
1079pub enum FinishReason {
1080 Stop,
1081 Length,
1082 ContentFilter,
1083 ToolCalls,
1084}
1085
1086#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1087pub struct LogProbs {
1088 pub content: Option<Vec<LogProbContent>>,
1089}
1090
1091#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1092pub struct LogProbContent {
1093 pub token: String,
1094 pub logprob: f32,
1095 pub bytes: Option<Vec<u8>>,
1096 pub top_logprobs: Option<Vec<TokenLogprob>>,
1097}
1098
1099#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1100pub struct TokenLogprob {
1101 pub token: String,
1102 pub logprob: f32,
1103 pub bytes: Option<Vec<u8>>,
1104}
1105
1106#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1107pub struct ChatCompletionStreamResponse {
1108 pub id: String,
1109 pub object: String,
1110 pub created: u64,
1111 pub model: String,
1112 pub choices: Vec<ChatCompletionStreamChoice>,
1113 pub usage: Option<LLMTokenUsage>,
1114 pub metadata: Option<serde_json::Value>,
1115}
1116#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1117pub struct ChatCompletionStreamChoice {
1118 pub index: usize,
1119 pub delta: ChatMessageDelta,
1120 pub finish_reason: Option<FinishReason>,
1121}
1122#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1123pub struct ChatMessageDelta {
1124 #[serde(skip_serializing_if = "Option::is_none")]
1125 pub role: Option<Role>,
1126 #[serde(skip_serializing_if = "Option::is_none")]
1127 pub content: Option<String>,
1128 #[serde(skip_serializing_if = "Option::is_none")]
1129 pub tool_calls: Option<Vec<ToolCallDelta>>,
1130}
1131
1132#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1133pub struct ToolCallDelta {
1134 pub index: usize,
1135 pub id: Option<String>,
1136 pub r#type: Option<String>,
1137 pub function: Option<FunctionCallDelta>,
1138}
1139
1140#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1141pub struct FunctionCallDelta {
1142 pub name: Option<String>,
1143 pub arguments: Option<String>,
1144}
1145
1146impl From<LLMMessage> for ChatMessage {
1147 fn from(llm_message: LLMMessage) -> Self {
1148 let role = match llm_message.role.as_str() {
1149 "system" => Role::System,
1150 "user" => Role::User,
1151 "assistant" => Role::Assistant,
1152 "tool" => Role::Tool,
1153 "developer" => Role::Developer,
1155 _ => Role::User, };
1157
1158 let (content, tool_calls) = match llm_message.content {
1159 LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None),
1160 LLMMessageContent::List(items) => {
1161 let mut text_parts = Vec::new();
1162 let mut tool_call_parts = Vec::new();
1163
1164 for item in items {
1165 match item {
1166 LLMMessageTypedContent::Text { text } => {
1167 text_parts.push(ContentPart {
1168 r#type: "text".to_string(),
1169 text: Some(text),
1170 image_url: None,
1171 });
1172 }
1173 LLMMessageTypedContent::ToolCall { id, name, args } => {
1174 tool_call_parts.push(ToolCall {
1175 id,
1176 r#type: "function".to_string(),
1177 function: FunctionCall {
1178 name,
1179 arguments: args.to_string(),
1180 },
1181 });
1182 }
1183 LLMMessageTypedContent::ToolResult { content, .. } => {
1184 text_parts.push(ContentPart {
1185 r#type: "text".to_string(),
1186 text: Some(content),
1187 image_url: None,
1188 });
1189 }
1190 LLMMessageTypedContent::Image { source } => {
1191 text_parts.push(ContentPart {
1192 r#type: "image_url".to_string(),
1193 text: None,
1194 image_url: Some(ImageUrl {
1195 url: format!(
1196 "data:{};base64,{}",
1197 source.media_type, source.data
1198 ),
1199 detail: None,
1200 }),
1201 });
1202 }
1203 }
1204 }
1205
1206 let content = if !text_parts.is_empty() {
1207 Some(MessageContent::Array(text_parts))
1208 } else {
1209 None
1210 };
1211
1212 let tool_calls = if !tool_call_parts.is_empty() {
1213 Some(tool_call_parts)
1214 } else {
1215 None
1216 };
1217
1218 (content, tool_calls)
1219 }
1220 };
1221
1222 ChatMessage {
1223 role,
1224 content,
1225 name: None, tool_calls,
1227 tool_call_id: None, usage: None,
1229 }
1230 }
1231}
1232
1233impl From<ChatMessage> for LLMMessage {
1234 fn from(chat_message: ChatMessage) -> Self {
1235 let mut content_parts = Vec::new();
1236
1237 match chat_message.content {
1239 Some(MessageContent::String(s)) => {
1240 if !s.is_empty() {
1241 content_parts.push(LLMMessageTypedContent::Text { text: s });
1242 }
1243 }
1244 Some(MessageContent::Array(parts)) => {
1245 for part in parts {
1246 if let Some(text) = part.text {
1247 content_parts.push(LLMMessageTypedContent::Text { text });
1248 } else if let Some(image_url) = part.image_url {
1249 let (media_type, data) = if image_url.url.starts_with("data:") {
1250 let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
1251 if parts.len() == 2 {
1252 let meta = parts[0];
1253 let data = parts[1];
1254 let media_type = meta
1255 .trim_start_matches("data:")
1256 .trim_end_matches(";base64")
1257 .to_string();
1258 (media_type, data.to_string())
1259 } else {
1260 ("image/jpeg".to_string(), image_url.url)
1261 }
1262 } else {
1263 ("image/jpeg".to_string(), image_url.url)
1264 };
1265
1266 content_parts.push(LLMMessageTypedContent::Image {
1267 source: LLMMessageImageSource {
1268 r#type: "base64".to_string(),
1269 media_type,
1270 data,
1271 },
1272 });
1273 }
1274 }
1275 }
1276 None => {}
1277 }
1278
1279 if let Some(tool_calls) = chat_message.tool_calls {
1281 for tool_call in tool_calls {
1282 let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
1283 content_parts.push(LLMMessageTypedContent::ToolCall {
1284 id: tool_call.id,
1285 name: tool_call.function.name,
1286 args,
1287 });
1288 }
1289 }
1290
1291 LLMMessage {
1292 role: chat_message.role.to_string(),
1293 content: if content_parts.is_empty() {
1294 LLMMessageContent::String(String::new())
1295 } else if content_parts.len() == 1 {
1296 match &content_parts[0] {
1297 LLMMessageTypedContent::Text { text } => {
1298 LLMMessageContent::String(text.clone())
1299 }
1300 _ => LLMMessageContent::List(content_parts),
1301 }
1302 } else {
1303 LLMMessageContent::List(content_parts)
1304 },
1305 }
1306 }
1307}
1308
1309#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
1310pub enum AgentModel {
1311 #[serde(rename = "smart")]
1312 #[default]
1313 Smart,
1314 #[serde(rename = "eco")]
1315 Eco,
1316 #[serde(rename = "recovery")]
1317 Recovery,
1318}
1319
1320impl std::fmt::Display for AgentModel {
1321 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1322 match self {
1323 AgentModel::Smart => write!(f, "smart"),
1324 AgentModel::Eco => write!(f, "eco"),
1325 AgentModel::Recovery => write!(f, "recovery"),
1326 }
1327 }
1328}
1329
1330impl From<String> for AgentModel {
1331 fn from(value: String) -> Self {
1332 match value.as_str() {
1333 "eco" => AgentModel::Eco,
1334 "recovery" => AgentModel::Recovery,
1335 _ => AgentModel::Smart,
1336 }
1337 }
1338}
1339
1340impl ChatCompletionRequest {
1341 pub fn new(
1342 model: String,
1343 messages: Vec<ChatMessage>,
1344 tools: Option<Vec<Tool>>,
1345 stream: Option<bool>,
1346 ) -> Self {
1347 Self {
1348 model,
1349 messages,
1350 frequency_penalty: None,
1351 logit_bias: None,
1352 logprobs: None,
1353 max_tokens: None,
1354 n: None,
1355 presence_penalty: None,
1356 response_format: None,
1357 seed: None,
1358 stop: None,
1359 stream,
1360 temperature: None,
1361 top_p: None,
1362 tools,
1363 tool_choice: None,
1364 user: None,
1365 context: None,
1366 }
1367 }
1368}
1369
1370impl From<Tool> for LLMTool {
1371 fn from(tool: Tool) -> Self {
1372 LLMTool {
1373 name: tool.function.name,
1374 description: tool.function.description.unwrap_or_default(),
1375 input_schema: tool.function.parameters,
1376 }
1377 }
1378}
1379
1380#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1381pub enum ToolCallResultStatus {
1382 Success,
1383 Error,
1384 Cancelled,
1385}
1386
1387#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1388pub struct ToolCallResult {
1389 pub call: ToolCall,
1390 pub result: String,
1391 pub status: ToolCallResultStatus,
1392}
1393
1394#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1395pub struct ToolCallResultProgress {
1396 pub id: Uuid,
1397 pub message: String,
1398}
1399
1400impl From<GenerationDelta> for ChatMessageDelta {
1401 fn from(delta: GenerationDelta) -> Self {
1402 match delta {
1403 GenerationDelta::Content { content } => ChatMessageDelta {
1404 role: Some(Role::Assistant),
1405 content: Some(content),
1406 tool_calls: None,
1407 },
1408 GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1409 role: Some(Role::Assistant),
1410 content: None,
1411 tool_calls: None,
1412 },
1413 GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1414 role: Some(Role::Assistant),
1415 content: None,
1416 tool_calls: Some(vec![ToolCallDelta {
1417 index: tool_use.index,
1418 id: tool_use.id,
1419 r#type: Some("function".to_string()),
1420 function: Some(FunctionCallDelta {
1421 name: tool_use.name,
1422 arguments: tool_use.input,
1423 }),
1424 }]),
1425 },
1426 _ => ChatMessageDelta {
1427 role: Some(Role::Assistant),
1428 content: None,
1429 tool_calls: None,
1430 },
1431 }
1432 }
1433}
1434
1435#[cfg(test)]
1436mod tests {
1437 use super::*;
1438
1439 #[test]
1440 fn test_serialize_basic_request() {
1441 let request = ChatCompletionRequest {
1442 model: AgentModel::Smart.to_string(),
1443 messages: vec![
1444 ChatMessage {
1445 role: Role::System,
1446 content: Some(MessageContent::String(
1447 "You are a helpful assistant.".to_string(),
1448 )),
1449 name: None,
1450 tool_calls: None,
1451 tool_call_id: None,
1452 usage: None,
1453 },
1454 ChatMessage {
1455 role: Role::User,
1456 content: Some(MessageContent::String("Hello!".to_string())),
1457 name: None,
1458 tool_calls: None,
1459 tool_call_id: None,
1460 usage: None,
1461 },
1462 ],
1463 frequency_penalty: None,
1464 logit_bias: None,
1465 logprobs: None,
1466 max_tokens: Some(100),
1467 n: None,
1468 presence_penalty: None,
1469 response_format: None,
1470 seed: None,
1471 stop: None,
1472 stream: None,
1473 temperature: Some(0.7),
1474 top_p: None,
1475 tools: None,
1476 tool_choice: None,
1477 user: None,
1478 context: None,
1479 };
1480
1481 let json = serde_json::to_string(&request).unwrap();
1482 assert!(json.contains("\"model\":\"smart\""));
1483 assert!(json.contains("\"messages\":["));
1484 assert!(json.contains("\"role\":\"system\""));
1485 assert!(json.contains("\"content\":\"You are a helpful assistant.\""));
1486 assert!(json.contains("\"role\":\"user\""));
1487 assert!(json.contains("\"content\":\"Hello!\""));
1488 assert!(json.contains("\"max_tokens\":100"));
1489 assert!(json.contains("\"temperature\":0.7"));
1490 }
1491
1492 #[test]
1493 fn test_deserialize_response() {
1494 let json = r#"{
1495 "id": "chatcmpl-123",
1496 "object": "chat.completion",
1497 "created": 1677652288,
1498 "model": "smart",
1499 "system_fingerprint": "fp_123abc",
1500 "choices": [{
1501 "index": 0,
1502 "message": {
1503 "role": "assistant",
1504 "content": "Hello! How can I help you today?"
1505 },
1506 "finish_reason": "stop"
1507 }],
1508 "usage": {
1509 "prompt_tokens": 9,
1510 "completion_tokens": 12,
1511 "total_tokens": 21
1512 }
1513 }"#;
1514
1515 let response: ChatCompletionResponse = serde_json::from_str(json).unwrap();
1516 assert_eq!(response.id, "chatcmpl-123");
1517 assert_eq!(response.object, "chat.completion");
1518 assert_eq!(response.created, 1677652288);
1519 assert_eq!(response.model, AgentModel::Smart.to_string());
1520 assert_eq!(response.system_fingerprint, Some("fp_123abc".to_string()));
1521
1522 assert_eq!(response.choices.len(), 1);
1523 assert_eq!(response.choices[0].index, 0);
1524 assert_eq!(response.choices[0].message.role, Role::Assistant);
1525
1526 match &response.choices[0].message.content {
1527 Some(MessageContent::String(content)) => {
1528 assert_eq!(content, "Hello! How can I help you today?");
1529 }
1530 _ => panic!("Expected string content"),
1531 }
1532
1533 assert_eq!(response.choices[0].finish_reason, FinishReason::Stop);
1534 assert_eq!(response.usage.prompt_tokens, 9);
1535 assert_eq!(response.usage.completion_tokens, 12);
1536 assert_eq!(response.usage.total_tokens, 21);
1537 }
1538
1539 #[test]
1540 fn test_tool_calls_request_response() {
1541 let tools_request = ChatCompletionRequest {
1543 model: AgentModel::Smart.to_string(),
1544 messages: vec![ChatMessage {
1545 role: Role::User,
1546 content: Some(MessageContent::String(
1547 "What's the weather in San Francisco?".to_string(),
1548 )),
1549 name: None,
1550 tool_calls: None,
1551 tool_call_id: None,
1552 usage: None,
1553 }],
1554 tools: Some(vec![Tool {
1555 r#type: "function".to_string(),
1556 function: FunctionDefinition {
1557 name: "get_weather".to_string(),
1558 description: Some("Get the current weather in a given location".to_string()),
1559 parameters: serde_json::json!({
1560 "type": "object",
1561 "properties": {
1562 "location": {
1563 "type": "string",
1564 "description": "The city and state, e.g. San Francisco, CA"
1565 }
1566 },
1567 "required": ["location"]
1568 }),
1569 },
1570 }]),
1571 tool_choice: Some(ToolChoice::Auto),
1572 max_tokens: Some(100),
1573 temperature: Some(0.7),
1574 frequency_penalty: None,
1575 logit_bias: None,
1576 logprobs: None,
1577 n: None,
1578 presence_penalty: None,
1579 response_format: None,
1580 seed: None,
1581 stop: None,
1582 stream: None,
1583 top_p: None,
1584 user: None,
1585 context: None,
1586 };
1587
1588 let json = serde_json::to_string(&tools_request).unwrap();
1589 println!("Tool request JSON: {}", json);
1590
1591 assert!(json.contains("\"tools\":["));
1592 assert!(json.contains("\"type\":\"function\""));
1593 assert!(json.contains("\"name\":\"get_weather\""));
1594 assert!(json.contains("\"tool_choice\":\"auto\""));
1596
1597 let tool_response_json = r#"{
1599 "id": "chatcmpl-123",
1600 "object": "chat.completion",
1601 "created": 1677652288,
1602 "model": "eco",
1603 "choices": [{
1604 "index": 0,
1605 "message": {
1606 "role": "assistant",
1607 "content": null,
1608 "tool_calls": [
1609 {
1610 "id": "call_abc123",
1611 "type": "function",
1612 "function": {
1613 "name": "get_weather",
1614 "arguments": "{\"location\":\"San Francisco, CA\"}"
1615 }
1616 }
1617 ]
1618 },
1619 "finish_reason": "tool_calls"
1620 }],
1621 "usage": {
1622 "prompt_tokens": 82,
1623 "completion_tokens": 17,
1624 "total_tokens": 99
1625 }
1626 }"#;
1627
1628 let tool_response: ChatCompletionResponse =
1629 serde_json::from_str(tool_response_json).unwrap();
1630 assert_eq!(tool_response.choices[0].message.role, Role::Assistant);
1631 assert_eq!(tool_response.choices[0].message.content, None);
1632 assert!(tool_response.choices[0].message.tool_calls.is_some());
1633
1634 let tool_calls = tool_response.choices[0]
1635 .message
1636 .tool_calls
1637 .as_ref()
1638 .unwrap();
1639 assert_eq!(tool_calls.len(), 1);
1640 assert_eq!(tool_calls[0].id, "call_abc123");
1641 assert_eq!(tool_calls[0].r#type, "function");
1642 assert_eq!(tool_calls[0].function.name, "get_weather");
1643 assert_eq!(
1644 tool_calls[0].function.arguments,
1645 "{\"location\":\"San Francisco, CA\"}"
1646 );
1647
1648 assert_eq!(
1649 tool_response.choices[0].finish_reason,
1650 FinishReason::ToolCalls,
1651 );
1652 }
1653
1654 #[test]
1655 fn test_content_with_image() {
1656 let message_with_image = ChatMessage {
1657 role: Role::User,
1658 content: Some(MessageContent::Array(vec![
1659 ContentPart {
1660 r#type: "text".to_string(),
1661 text: Some("What's in this image?".to_string()),
1662 image_url: None,
1663 },
1664 ContentPart {
1665 r#type: "input_image".to_string(),
1666 text: None,
1667 image_url: Some(ImageUrl {
1668 url: "...".to_string(),
1669 detail: Some("low".to_string()),
1670 }),
1671 },
1672 ])),
1673 name: None,
1674 tool_calls: None,
1675 tool_call_id: None,
1676 usage: None,
1677 };
1678
1679 let json = serde_json::to_string(&message_with_image).unwrap();
1680 println!("Serialized JSON: {}", json);
1681
1682 assert!(json.contains("\"role\":\"user\""));
1683 assert!(json.contains("\"type\":\"text\""));
1684 assert!(json.contains("\"text\":\"What's in this image?\""));
1685 assert!(json.contains("\"type\":\"input_image\""));
1686 assert!(json.contains("\"url\":\"...\""));
1687 assert!(json.contains("\"detail\":\"low\""));
1688 }
1689
1690 #[test]
1691 fn test_response_format() {
1692 let json_format_request = ChatCompletionRequest {
1693 model: AgentModel::Smart.to_string(),
1694 messages: vec![ChatMessage {
1695 role: Role::User,
1696 content: Some(MessageContent::String(
1697 "Generate a JSON object with name and age fields".to_string(),
1698 )),
1699 name: None,
1700 tool_calls: None,
1701 tool_call_id: None,
1702 usage: None,
1703 }],
1704 response_format: Some(ResponseFormat {
1705 r#type: "json_object".to_string(),
1706 }),
1707 max_tokens: Some(100),
1708 temperature: None,
1709 frequency_penalty: None,
1710 logit_bias: None,
1711 logprobs: None,
1712 n: None,
1713 presence_penalty: None,
1714 seed: None,
1715 stop: None,
1716 stream: None,
1717 top_p: None,
1718 tools: None,
1719 tool_choice: None,
1720 user: None,
1721 context: None,
1722 };
1723
1724 let json = serde_json::to_string(&json_format_request).unwrap();
1725 assert!(json.contains("\"response_format\":{\"type\":\"json_object\"}"));
1726 }
1727
1728 #[test]
1729 fn test_llm_message_to_chat_message() {
1730 let llm_message = LLMMessage {
1732 role: "user".to_string(),
1733 content: LLMMessageContent::String("Hello, world!".to_string()),
1734 };
1735
1736 let chat_message = ChatMessage::from(llm_message);
1737 assert_eq!(chat_message.role, Role::User);
1738 match &chat_message.content {
1739 Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1740 _ => panic!("Expected string content"),
1741 }
1742 assert_eq!(chat_message.tool_calls, None);
1743
1744 let llm_message_with_tool = LLMMessage {
1746 role: "assistant".to_string(),
1747 content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolCall {
1748 id: "call_123".to_string(),
1749 name: "get_weather".to_string(),
1750 args: serde_json::json!({"location": "San Francisco"}),
1751 }]),
1752 };
1753
1754 let chat_message = ChatMessage::from(llm_message_with_tool);
1755 assert_eq!(chat_message.role, Role::Assistant);
1756 assert_eq!(chat_message.content, None); assert!(chat_message.tool_calls.is_some());
1758
1759 let tool_calls = chat_message.tool_calls.unwrap();
1760 assert_eq!(tool_calls.len(), 1);
1761 assert_eq!(tool_calls[0].id, "call_123");
1762 assert_eq!(tool_calls[0].function.name, "get_weather");
1763 assert!(tool_calls[0].function.arguments.contains("San Francisco"));
1764
1765 let llm_message_mixed = LLMMessage {
1767 role: "assistant".to_string(),
1768 content: LLMMessageContent::List(vec![
1769 LLMMessageTypedContent::Text {
1770 text: "The weather is:".to_string(),
1771 },
1772 LLMMessageTypedContent::ToolCall {
1773 id: "call_456".to_string(),
1774 name: "get_weather".to_string(),
1775 args: serde_json::json!({"location": "New York"}),
1776 },
1777 ]),
1778 };
1779
1780 let chat_message = ChatMessage::from(llm_message_mixed);
1781 assert_eq!(chat_message.role, Role::Assistant);
1782
1783 match &chat_message.content {
1784 Some(MessageContent::Array(parts)) => {
1785 assert_eq!(parts.len(), 1);
1786 assert_eq!(parts[0].r#type, "text");
1787 assert_eq!(parts[0].text, Some("The weather is:".to_string()));
1788 }
1789 _ => panic!("Expected array content"),
1790 }
1791
1792 let tool_calls = chat_message.tool_calls.unwrap();
1793 assert_eq!(tool_calls.len(), 1);
1794 assert_eq!(tool_calls[0].id, "call_456");
1795 assert_eq!(tool_calls[0].function.name, "get_weather");
1796 assert!(tool_calls[0].function.arguments.contains("New York"));
1797 }
1798}