1use crate::error::{Result, ValidationError};
6use crate::message::Message;
7use crate::tool::{ToolChoice, ToolDefinition};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum ResponseFormat {
15 Text,
17 JsonObject,
19 #[serde(rename = "json_schema")]
21 JsonSchema {
22 json_schema: JsonSchemaFormat,
24 },
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29pub struct JsonSchemaFormat {
30 pub name: String,
32 pub schema: Value,
34 #[serde(skip_serializing_if = "Option::is_none")]
36 pub strict: Option<bool>,
37}
38
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub struct CompletionRequest {
42 pub messages: Vec<Message>,
44 pub model: String,
46 #[serde(skip_serializing_if = "Option::is_none")]
48 pub max_tokens: Option<u32>,
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub temperature: Option<f32>,
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub top_p: Option<f32>,
55 #[serde(skip_serializing_if = "Option::is_none")]
57 pub stream: Option<bool>,
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub n: Option<u32>,
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub stop: Option<Vec<String>>,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub presence_penalty: Option<f32>,
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub frequency_penalty: Option<f32>,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub user: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub response_format: Option<ResponseFormat>,
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub tools: Option<Vec<ToolDefinition>>,
79 #[serde(skip_serializing_if = "Option::is_none")]
81 pub tool_choice: Option<ToolChoice>,
82 #[serde(skip_serializing_if = "Option::is_none")]
84 pub instructions: Option<String>,
85 #[serde(skip_serializing_if = "Option::is_none")]
87 pub previous_response_id: Option<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub store: Option<bool>,
91}
92
93impl CompletionRequest {
94 pub fn new(model: impl Into<String>) -> Self {
96 Self {
97 messages: Vec::new(),
98 model: model.into(),
99 max_tokens: None,
100 temperature: None,
101 top_p: None,
102 stream: None,
103 n: None,
104 stop: None,
105 presence_penalty: None,
106 frequency_penalty: None,
107 user: None,
108 response_format: None,
109 tools: None,
110 tool_choice: None,
111 instructions: None,
112 previous_response_id: None,
113 store: None,
114 }
115 }
116
117 pub fn messages(mut self, messages: Vec<Message>) -> Self {
119 self.messages = messages;
120 self
121 }
122
123 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
125 self.instructions = Some(instructions.into());
126 self
127 }
128
129 pub fn previous_response_id(mut self, id: impl Into<String>) -> Self {
131 self.previous_response_id = Some(id.into());
132 self
133 }
134
135 pub fn store(mut self, store: bool) -> Self {
137 self.store = Some(store);
138 self
139 }
140
141 pub fn builder() -> CompletionRequestBuilder {
155 CompletionRequestBuilder::default()
156 }
157
158 pub fn validate(&self) -> Result<()> {
167 if self.messages.is_empty() {
169 return Err(ValidationError::Empty {
170 field: "messages".to_string(),
171 }
172 .into());
173 }
174
175 if self.messages.len() > 1000 {
176 return Err(ValidationError::TooLong {
177 field: "messages".to_string(),
178 max: 1000,
179 }
180 .into());
181 }
182
183 const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
185 for (i, msg) in self.messages.iter().enumerate() {
186 if msg.content.text_len() > MAX_MESSAGE_SIZE {
187 return Err(ValidationError::TooLong {
188 field: format!("messages[{}].content", i),
189 max: MAX_MESSAGE_SIZE,
190 }
191 .into());
192 }
193
194 if msg.content.contains_null() {
195 return Err(ValidationError::InvalidFormat {
196 field: format!("messages[{}].content", i),
197 reason: "contains null bytes".to_string(),
198 }
199 .into());
200 }
201 }
202
203 const MAX_TOTAL_REQUEST_SIZE: usize = 10 * 1024 * 1024;
205 let total_size: usize = self.messages.iter().map(|m| m.content.text_len()).sum();
206 if total_size > MAX_TOTAL_REQUEST_SIZE {
207 return Err(ValidationError::TooLong {
208 field: "total_request_size".to_string(),
209 max: MAX_TOTAL_REQUEST_SIZE,
210 }
211 .into());
212 }
213
214 if self.model.is_empty() {
216 return Err(ValidationError::Empty {
217 field: "model".to_string(),
218 }
219 .into());
220 }
221
222 if !self
224 .model
225 .chars()
226 .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '/')
227 {
228 return Err(ValidationError::InvalidFormat {
229 field: "model".to_string(),
230 reason: "must be alphanumeric with -_./ only".to_string(),
231 }
232 .into());
233 }
234
235 if let Some(temp) = self.temperature {
237 if !(0.0..=2.0).contains(&temp) {
238 return Err(ValidationError::OutOfRange {
239 field: "temperature".to_string(),
240 min: 0.0,
241 max: 2.0,
242 }
243 .into());
244 }
245 }
246
247 if let Some(top_p) = self.top_p {
249 if !(0.0..=1.0).contains(&top_p) {
250 return Err(ValidationError::OutOfRange {
251 field: "top_p".to_string(),
252 min: 0.0,
253 max: 1.0,
254 }
255 .into());
256 }
257 }
258
259 if let Some(penalty) = self.presence_penalty {
261 if !(-2.0..=2.0).contains(&penalty) {
262 return Err(ValidationError::OutOfRange {
263 field: "presence_penalty".to_string(),
264 min: -2.0,
265 max: 2.0,
266 }
267 .into());
268 }
269 }
270
271 if let Some(penalty) = self.frequency_penalty {
273 if !(-2.0..=2.0).contains(&penalty) {
274 return Err(ValidationError::OutOfRange {
275 field: "frequency_penalty".to_string(),
276 min: -2.0,
277 max: 2.0,
278 }
279 .into());
280 }
281 }
282
283 Ok(())
284 }
285}
286
287#[derive(Debug, Default, Clone)]
289pub struct CompletionRequestBuilder {
290 messages: Vec<Message>,
291 model: Option<String>,
292 max_tokens: Option<u32>,
293 temperature: Option<f32>,
294 top_p: Option<f32>,
295 stream: Option<bool>,
296 n: Option<u32>,
297 stop: Option<Vec<String>>,
298 presence_penalty: Option<f32>,
299 frequency_penalty: Option<f32>,
300 user: Option<String>,
301 response_format: Option<ResponseFormat>,
302 tools: Option<Vec<ToolDefinition>>,
303 tool_choice: Option<ToolChoice>,
304 instructions: Option<String>,
305 previous_response_id: Option<String>,
306 store: Option<bool>,
307}
308
309impl CompletionRequestBuilder {
310 pub fn model(mut self, model: impl Into<String>) -> Self {
312 self.model = Some(model.into());
313 self
314 }
315
316 pub fn message(mut self, message: Message) -> Self {
318 self.messages.push(message);
319 self
320 }
321
322 pub fn messages(mut self, messages: Vec<Message>) -> Self {
324 self.messages = messages;
325 self
326 }
327
328 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
330 self.max_tokens = Some(max_tokens);
331 self
332 }
333
334 pub fn temperature(mut self, temperature: f32) -> Self {
336 self.temperature = Some(temperature);
337 self
338 }
339
340 pub fn top_p(mut self, top_p: f32) -> Self {
342 self.top_p = Some(top_p);
343 self
344 }
345
346 pub fn stream(mut self, stream: bool) -> Self {
348 self.stream = Some(stream);
349 self
350 }
351
352 pub fn n(mut self, n: u32) -> Self {
354 self.n = Some(n);
355 self
356 }
357
358 pub fn stop(mut self, stop: Vec<String>) -> Self {
360 self.stop = Some(stop);
361 self
362 }
363
364 pub fn presence_penalty(mut self, penalty: f32) -> Self {
366 self.presence_penalty = Some(penalty);
367 self
368 }
369
370 pub fn frequency_penalty(mut self, penalty: f32) -> Self {
372 self.frequency_penalty = Some(penalty);
373 self
374 }
375
376 pub fn user(mut self, user: impl Into<String>) -> Self {
378 self.user = Some(user.into());
379 self
380 }
381
382 pub fn response_format(mut self, format: ResponseFormat) -> Self {
384 self.response_format = Some(format);
385 self
386 }
387
388 pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
390 self.tools = Some(tools);
391 self
392 }
393
394 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
396 self.tool_choice = Some(tool_choice);
397 self
398 }
399
400 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
402 self.instructions = Some(instructions.into());
403 self
404 }
405
406 pub fn previous_response_id(mut self, id: impl Into<String>) -> Self {
408 self.previous_response_id = Some(id.into());
409 self
410 }
411
412 pub fn store(mut self, store: bool) -> Self {
414 self.store = Some(store);
415 self
416 }
417
418 pub fn json_mode(mut self) -> Self {
420 self.response_format = Some(ResponseFormat::JsonObject);
421 self
422 }
423
424 pub fn json_schema(mut self, name: impl Into<String>, schema: Value) -> Self {
426 self.response_format = Some(ResponseFormat::JsonSchema {
427 json_schema: JsonSchemaFormat {
428 name: name.into(),
429 schema,
430 strict: Some(true),
431 },
432 });
433 self
434 }
435
436 pub fn build(self) -> Result<CompletionRequest> {
438 let model = self.model.ok_or_else(|| ValidationError::Empty {
439 field: "model".to_string(),
440 })?;
441
442 let request = CompletionRequest {
443 messages: self.messages,
444 model,
445 max_tokens: self.max_tokens,
446 temperature: self.temperature,
447 top_p: self.top_p,
448 stream: self.stream,
449 n: self.n,
450 stop: self.stop,
451 presence_penalty: self.presence_penalty,
452 frequency_penalty: self.frequency_penalty,
453 user: self.user,
454 response_format: self.response_format,
455 tools: self.tools,
456 tool_choice: self.tool_choice,
457 instructions: self.instructions,
458 previous_response_id: self.previous_response_id,
459 store: self.store,
460 };
461
462 request.validate()?;
463 Ok(request)
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470 use crate::message::MessageContent;
471
472 #[test]
473 fn test_builder_basic() {
474 let request = CompletionRequest::builder()
475 .model("gpt-4")
476 .message(Message::user("Hello"))
477 .build()
478 .unwrap();
479
480 assert_eq!(request.model, "gpt-4");
481 assert_eq!(request.messages.len(), 1);
482 assert_eq!(
483 request.messages[0].content,
484 MessageContent::Text("Hello".to_string())
485 );
486 }
487
488 #[test]
489 fn test_builder_all_fields() {
490 let request = CompletionRequest::builder()
491 .model("gpt-4")
492 .message(Message::user("Hello"))
493 .max_tokens(100)
494 .temperature(0.7)
495 .top_p(0.9)
496 .stream(true)
497 .n(1)
498 .stop(vec!["END".to_string()])
499 .presence_penalty(0.5)
500 .frequency_penalty(0.5)
501 .user("test-user")
502 .build()
503 .unwrap();
504
505 assert_eq!(request.max_tokens, Some(100));
506 assert_eq!(request.temperature, Some(0.7));
507 assert_eq!(request.top_p, Some(0.9));
508 assert_eq!(request.stream, Some(true));
509 assert_eq!(request.n, Some(1));
510 assert_eq!(request.stop, Some(vec!["END".to_string()]));
511 assert_eq!(request.presence_penalty, Some(0.5));
512 assert_eq!(request.frequency_penalty, Some(0.5));
513 assert_eq!(request.user, Some("test-user".to_string()));
514 }
515
516 #[test]
517 fn test_builder_missing_model() {
518 let result = CompletionRequest::builder()
519 .message(Message::user("Hello"))
520 .build();
521 assert!(result.is_err());
522 }
523
524 #[test]
525 fn test_validation_empty_messages() {
526 let result = CompletionRequest::builder().model("gpt-4").build();
527 assert!(result.is_err());
528 }
529
530 #[test]
531 fn test_validation_invalid_temperature() {
532 let result = CompletionRequest::builder()
533 .model("gpt-4")
534 .message(Message::user("Hello"))
535 .temperature(3.0)
536 .build();
537 assert!(result.is_err());
538 }
539
540 #[test]
541 fn test_validation_invalid_top_p() {
542 let result = CompletionRequest::builder()
543 .model("gpt-4")
544 .message(Message::user("Hello"))
545 .top_p(1.5)
546 .build();
547 assert!(result.is_err());
548 }
549
550 #[test]
551 fn test_validation_invalid_model_chars() {
552 let result = CompletionRequest::builder()
553 .model("gpt-4!")
554 .message(Message::user("Hello"))
555 .build();
556 assert!(result.is_err());
557 }
558
559 #[test]
560 fn test_serialization() {
561 let request = CompletionRequest::builder()
562 .model("gpt-4")
563 .message(Message::user("Hello"))
564 .temperature(0.7)
565 .build()
566 .unwrap();
567
568 let json = serde_json::to_string(&request).unwrap();
569 let parsed: CompletionRequest = serde_json::from_str(&json).unwrap();
570 assert_eq!(request, parsed);
571 }
572
573 #[test]
574 fn test_optional_fields_not_serialized() {
575 let request = CompletionRequest::builder()
576 .model("gpt-4")
577 .message(Message::user("Hello"))
578 .build()
579 .unwrap();
580
581 let json = serde_json::to_value(&request).unwrap();
582 assert!(json.get("max_tokens").is_none());
583 assert!(json.get("temperature").is_none());
584 }
585
586 #[test]
587 fn test_validation_total_request_size_limit() {
588 let large_content = "x".repeat(2 * 1024 * 1024); let result = CompletionRequest::builder()
591 .model("gpt-4")
592 .message(Message::user(large_content.clone()))
593 .message(Message::user(large_content.clone()))
594 .message(Message::user(large_content.clone()))
595 .message(Message::user(large_content.clone()))
596 .message(Message::user(large_content.clone()))
597 .message(Message::user(large_content.clone())) .build();
599
600 assert!(result.is_err());
601 assert!(matches!(
602 result.unwrap_err(),
603 crate::error::SimpleAgentsError::Validation(ValidationError::TooLong { .. })
604 ));
605 }
606
607 #[test]
608 fn test_responses_api_fields() {
609 let req = CompletionRequest::new("gpt-4o")
610 .messages(vec![Message::user("hello")])
611 .instructions("You are helpful")
612 .store(true)
613 .previous_response_id("resp_abc");
614 assert_eq!(req.instructions.as_deref(), Some("You are helpful"));
615 assert_eq!(req.store, Some(true));
616 assert_eq!(req.previous_response_id.as_deref(), Some("resp_abc"));
617 }
618
619 #[test]
620 fn test_validation_total_request_size_within_limit() {
621 let content = "x".repeat(1024 * 1024); let result = CompletionRequest::builder()
624 .model("gpt-4")
625 .message(Message::user(content.clone()))
626 .message(Message::user(content.clone()))
627 .message(Message::user(content.clone())) .build();
629
630 assert!(result.is_ok());
631 }
632}