1use crate::error::{Result, ValidationError};
6use crate::message::Message;
7use crate::tool::{ToolChoice, ToolDefinition};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum ResponseFormat {
15 Text,
17 JsonObject,
19 #[serde(rename = "json_schema")]
21 JsonSchema {
22 json_schema: JsonSchemaFormat,
24 },
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29pub struct JsonSchemaFormat {
30 pub name: String,
32 pub schema: Value,
34 #[serde(skip_serializing_if = "Option::is_none")]
36 pub strict: Option<bool>,
37}
38
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub struct CompletionRequest {
42 pub messages: Vec<Message>,
44 pub model: String,
46 #[serde(skip_serializing_if = "Option::is_none")]
48 pub max_tokens: Option<u32>,
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub temperature: Option<f32>,
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub top_p: Option<f32>,
55 #[serde(skip_serializing_if = "Option::is_none")]
57 pub stream: Option<bool>,
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub n: Option<u32>,
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub stop: Option<Vec<String>>,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub presence_penalty: Option<f32>,
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub frequency_penalty: Option<f32>,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub user: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub response_format: Option<ResponseFormat>,
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub tools: Option<Vec<ToolDefinition>>,
79 #[serde(skip_serializing_if = "Option::is_none")]
81 pub tool_choice: Option<ToolChoice>,
82}
83
84impl CompletionRequest {
85 pub fn builder() -> CompletionRequestBuilder {
99 CompletionRequestBuilder::default()
100 }
101
102 pub fn validate(&self) -> Result<()> {
111 if self.messages.is_empty() {
113 return Err(ValidationError::Empty {
114 field: "messages".to_string(),
115 }
116 .into());
117 }
118
119 if self.messages.len() > 1000 {
120 return Err(ValidationError::TooLong {
121 field: "messages".to_string(),
122 max: 1000,
123 }
124 .into());
125 }
126
127 const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
129 for (i, msg) in self.messages.iter().enumerate() {
130 if msg.content.len() > MAX_MESSAGE_SIZE {
131 return Err(ValidationError::TooLong {
132 field: format!("messages[{}].content", i),
133 max: MAX_MESSAGE_SIZE,
134 }
135 .into());
136 }
137
138 if msg.content.contains('\0') {
140 return Err(ValidationError::InvalidFormat {
141 field: format!("messages[{}].content", i),
142 reason: "contains null bytes".to_string(),
143 }
144 .into());
145 }
146 }
147
148 const MAX_TOTAL_REQUEST_SIZE: usize = 10 * 1024 * 1024;
150 let total_size: usize = self.messages.iter().map(|m| m.content.len()).sum();
151 if total_size > MAX_TOTAL_REQUEST_SIZE {
152 return Err(ValidationError::TooLong {
153 field: "total_request_size".to_string(),
154 max: MAX_TOTAL_REQUEST_SIZE,
155 }
156 .into());
157 }
158
159 if self.model.is_empty() {
161 return Err(ValidationError::Empty {
162 field: "model".to_string(),
163 }
164 .into());
165 }
166
167 if !self
169 .model
170 .chars()
171 .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '/')
172 {
173 return Err(ValidationError::InvalidFormat {
174 field: "model".to_string(),
175 reason: "must be alphanumeric with -_./ only".to_string(),
176 }
177 .into());
178 }
179
180 if let Some(temp) = self.temperature {
182 if !(0.0..=2.0).contains(&temp) {
183 return Err(ValidationError::OutOfRange {
184 field: "temperature".to_string(),
185 min: 0.0,
186 max: 2.0,
187 }
188 .into());
189 }
190 }
191
192 if let Some(top_p) = self.top_p {
194 if !(0.0..=1.0).contains(&top_p) {
195 return Err(ValidationError::OutOfRange {
196 field: "top_p".to_string(),
197 min: 0.0,
198 max: 1.0,
199 }
200 .into());
201 }
202 }
203
204 if let Some(penalty) = self.presence_penalty {
206 if !(-2.0..=2.0).contains(&penalty) {
207 return Err(ValidationError::OutOfRange {
208 field: "presence_penalty".to_string(),
209 min: -2.0,
210 max: 2.0,
211 }
212 .into());
213 }
214 }
215
216 if let Some(penalty) = self.frequency_penalty {
218 if !(-2.0..=2.0).contains(&penalty) {
219 return Err(ValidationError::OutOfRange {
220 field: "frequency_penalty".to_string(),
221 min: -2.0,
222 max: 2.0,
223 }
224 .into());
225 }
226 }
227
228 Ok(())
229 }
230}
231
232#[derive(Debug, Default, Clone)]
234pub struct CompletionRequestBuilder {
235 messages: Vec<Message>,
236 model: Option<String>,
237 max_tokens: Option<u32>,
238 temperature: Option<f32>,
239 top_p: Option<f32>,
240 stream: Option<bool>,
241 n: Option<u32>,
242 stop: Option<Vec<String>>,
243 presence_penalty: Option<f32>,
244 frequency_penalty: Option<f32>,
245 user: Option<String>,
246 response_format: Option<ResponseFormat>,
247 tools: Option<Vec<ToolDefinition>>,
248 tool_choice: Option<ToolChoice>,
249}
250
251impl CompletionRequestBuilder {
252 pub fn model(mut self, model: impl Into<String>) -> Self {
254 self.model = Some(model.into());
255 self
256 }
257
258 pub fn message(mut self, message: Message) -> Self {
260 self.messages.push(message);
261 self
262 }
263
264 pub fn messages(mut self, messages: Vec<Message>) -> Self {
266 self.messages = messages;
267 self
268 }
269
270 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
272 self.max_tokens = Some(max_tokens);
273 self
274 }
275
276 pub fn temperature(mut self, temperature: f32) -> Self {
278 self.temperature = Some(temperature);
279 self
280 }
281
282 pub fn top_p(mut self, top_p: f32) -> Self {
284 self.top_p = Some(top_p);
285 self
286 }
287
288 pub fn stream(mut self, stream: bool) -> Self {
290 self.stream = Some(stream);
291 self
292 }
293
294 pub fn n(mut self, n: u32) -> Self {
296 self.n = Some(n);
297 self
298 }
299
300 pub fn stop(mut self, stop: Vec<String>) -> Self {
302 self.stop = Some(stop);
303 self
304 }
305
306 pub fn presence_penalty(mut self, penalty: f32) -> Self {
308 self.presence_penalty = Some(penalty);
309 self
310 }
311
312 pub fn frequency_penalty(mut self, penalty: f32) -> Self {
314 self.frequency_penalty = Some(penalty);
315 self
316 }
317
318 pub fn user(mut self, user: impl Into<String>) -> Self {
320 self.user = Some(user.into());
321 self
322 }
323
324 pub fn response_format(mut self, format: ResponseFormat) -> Self {
326 self.response_format = Some(format);
327 self
328 }
329
330 pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
332 self.tools = Some(tools);
333 self
334 }
335
336 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
338 self.tool_choice = Some(tool_choice);
339 self
340 }
341
342 pub fn json_mode(mut self) -> Self {
344 self.response_format = Some(ResponseFormat::JsonObject);
345 self
346 }
347
348 pub fn json_schema(mut self, name: impl Into<String>, schema: Value) -> Self {
350 self.response_format = Some(ResponseFormat::JsonSchema {
351 json_schema: JsonSchemaFormat {
352 name: name.into(),
353 schema,
354 strict: Some(true),
355 },
356 });
357 self
358 }
359
360 pub fn build(self) -> Result<CompletionRequest> {
362 let model = self.model.ok_or_else(|| ValidationError::Empty {
363 field: "model".to_string(),
364 })?;
365
366 let request = CompletionRequest {
367 messages: self.messages,
368 model,
369 max_tokens: self.max_tokens,
370 temperature: self.temperature,
371 top_p: self.top_p,
372 stream: self.stream,
373 n: self.n,
374 stop: self.stop,
375 presence_penalty: self.presence_penalty,
376 frequency_penalty: self.frequency_penalty,
377 user: self.user,
378 response_format: self.response_format,
379 tools: self.tools,
380 tool_choice: self.tool_choice,
381 };
382
383 request.validate()?;
384 Ok(request)
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_builder_basic() {
394 let request = CompletionRequest::builder()
395 .model("gpt-4")
396 .message(Message::user("Hello"))
397 .build()
398 .unwrap();
399
400 assert_eq!(request.model, "gpt-4");
401 assert_eq!(request.messages.len(), 1);
402 assert_eq!(request.messages[0].content, "Hello");
403 }
404
405 #[test]
406 fn test_builder_all_fields() {
407 let request = CompletionRequest::builder()
408 .model("gpt-4")
409 .message(Message::user("Hello"))
410 .max_tokens(100)
411 .temperature(0.7)
412 .top_p(0.9)
413 .stream(true)
414 .n(1)
415 .stop(vec!["END".to_string()])
416 .presence_penalty(0.5)
417 .frequency_penalty(0.5)
418 .user("test-user")
419 .build()
420 .unwrap();
421
422 assert_eq!(request.max_tokens, Some(100));
423 assert_eq!(request.temperature, Some(0.7));
424 assert_eq!(request.top_p, Some(0.9));
425 assert_eq!(request.stream, Some(true));
426 assert_eq!(request.n, Some(1));
427 assert_eq!(request.stop, Some(vec!["END".to_string()]));
428 assert_eq!(request.presence_penalty, Some(0.5));
429 assert_eq!(request.frequency_penalty, Some(0.5));
430 assert_eq!(request.user, Some("test-user".to_string()));
431 }
432
433 #[test]
434 fn test_builder_missing_model() {
435 let result = CompletionRequest::builder()
436 .message(Message::user("Hello"))
437 .build();
438 assert!(result.is_err());
439 }
440
441 #[test]
442 fn test_validation_empty_messages() {
443 let result = CompletionRequest::builder().model("gpt-4").build();
444 assert!(result.is_err());
445 }
446
447 #[test]
448 fn test_validation_invalid_temperature() {
449 let result = CompletionRequest::builder()
450 .model("gpt-4")
451 .message(Message::user("Hello"))
452 .temperature(3.0)
453 .build();
454 assert!(result.is_err());
455 }
456
457 #[test]
458 fn test_validation_invalid_top_p() {
459 let result = CompletionRequest::builder()
460 .model("gpt-4")
461 .message(Message::user("Hello"))
462 .top_p(1.5)
463 .build();
464 assert!(result.is_err());
465 }
466
467 #[test]
468 fn test_validation_invalid_model_chars() {
469 let result = CompletionRequest::builder()
470 .model("gpt-4!")
471 .message(Message::user("Hello"))
472 .build();
473 assert!(result.is_err());
474 }
475
476 #[test]
477 fn test_serialization() {
478 let request = CompletionRequest::builder()
479 .model("gpt-4")
480 .message(Message::user("Hello"))
481 .temperature(0.7)
482 .build()
483 .unwrap();
484
485 let json = serde_json::to_string(&request).unwrap();
486 let parsed: CompletionRequest = serde_json::from_str(&json).unwrap();
487 assert_eq!(request, parsed);
488 }
489
490 #[test]
491 fn test_optional_fields_not_serialized() {
492 let request = CompletionRequest::builder()
493 .model("gpt-4")
494 .message(Message::user("Hello"))
495 .build()
496 .unwrap();
497
498 let json = serde_json::to_value(&request).unwrap();
499 assert!(json.get("max_tokens").is_none());
500 assert!(json.get("temperature").is_none());
501 }
502
503 #[test]
504 fn test_validation_total_request_size_limit() {
505 let large_content = "x".repeat(2 * 1024 * 1024); let result = CompletionRequest::builder()
508 .model("gpt-4")
509 .message(Message::user(large_content.clone()))
510 .message(Message::user(large_content.clone()))
511 .message(Message::user(large_content.clone()))
512 .message(Message::user(large_content.clone()))
513 .message(Message::user(large_content.clone()))
514 .message(Message::user(large_content.clone())) .build();
516
517 assert!(result.is_err());
518 assert!(matches!(
519 result.unwrap_err(),
520 crate::error::SimpleAgentsError::Validation(ValidationError::TooLong { .. })
521 ));
522 }
523
524 #[test]
525 fn test_validation_total_request_size_within_limit() {
526 let content = "x".repeat(1024 * 1024); let result = CompletionRequest::builder()
529 .model("gpt-4")
530 .message(Message::user(content.clone()))
531 .message(Message::user(content.clone()))
532 .message(Message::user(content.clone())) .build();
534
535 assert!(result.is_ok());
536 }
537}