1use crate::error::{Result, ValidationError};
6use crate::message::Message;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12#[serde(tag = "type", rename_all = "snake_case")]
13pub enum ResponseFormat {
14 Text,
16 JsonObject,
18 #[serde(rename = "json_schema")]
20 JsonSchema {
21 json_schema: JsonSchemaFormat,
23 },
24}
25
26#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
28pub struct JsonSchemaFormat {
29 pub name: String,
31 pub schema: Value,
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub strict: Option<bool>,
36}
37
38#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
40pub struct CompletionRequest {
41 pub messages: Vec<Message>,
43 pub model: String,
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub max_tokens: Option<u32>,
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub temperature: Option<f32>,
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub top_p: Option<f32>,
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub stream: Option<bool>,
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub n: Option<u32>,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub stop: Option<Vec<String>>,
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub presence_penalty: Option<f32>,
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub frequency_penalty: Option<f32>,
69 #[serde(skip_serializing_if = "Option::is_none")]
71 pub user: Option<String>,
72 #[serde(skip_serializing_if = "Option::is_none")]
74 pub response_format: Option<ResponseFormat>,
75}
76
77impl CompletionRequest {
78 pub fn builder() -> CompletionRequestBuilder {
92 CompletionRequestBuilder::default()
93 }
94
95 pub fn validate(&self) -> Result<()> {
104 if self.messages.is_empty() {
106 return Err(ValidationError::Empty {
107 field: "messages".to_string(),
108 }
109 .into());
110 }
111
112 if self.messages.len() > 1000 {
113 return Err(ValidationError::TooLong {
114 field: "messages".to_string(),
115 max: 1000,
116 }
117 .into());
118 }
119
120 const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
122 for (i, msg) in self.messages.iter().enumerate() {
123 if msg.content.len() > MAX_MESSAGE_SIZE {
124 return Err(ValidationError::TooLong {
125 field: format!("messages[{}].content", i),
126 max: MAX_MESSAGE_SIZE,
127 }
128 .into());
129 }
130
131 if msg.content.contains('\0') {
133 return Err(ValidationError::InvalidFormat {
134 field: format!("messages[{}].content", i),
135 reason: "contains null bytes".to_string(),
136 }
137 .into());
138 }
139 }
140
141 const MAX_TOTAL_REQUEST_SIZE: usize = 10 * 1024 * 1024;
143 let total_size: usize = self.messages.iter().map(|m| m.content.len()).sum();
144 if total_size > MAX_TOTAL_REQUEST_SIZE {
145 return Err(ValidationError::TooLong {
146 field: "total_request_size".to_string(),
147 max: MAX_TOTAL_REQUEST_SIZE,
148 }
149 .into());
150 }
151
152 if self.model.is_empty() {
154 return Err(ValidationError::Empty {
155 field: "model".to_string(),
156 }
157 .into());
158 }
159
160 if !self
162 .model
163 .chars()
164 .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '/')
165 {
166 return Err(ValidationError::InvalidFormat {
167 field: "model".to_string(),
168 reason: "must be alphanumeric with -_./ only".to_string(),
169 }
170 .into());
171 }
172
173 if let Some(temp) = self.temperature {
175 if !(0.0..=2.0).contains(&temp) {
176 return Err(ValidationError::OutOfRange {
177 field: "temperature".to_string(),
178 min: 0.0,
179 max: 2.0,
180 }
181 .into());
182 }
183 }
184
185 if let Some(top_p) = self.top_p {
187 if !(0.0..=1.0).contains(&top_p) {
188 return Err(ValidationError::OutOfRange {
189 field: "top_p".to_string(),
190 min: 0.0,
191 max: 1.0,
192 }
193 .into());
194 }
195 }
196
197 if let Some(penalty) = self.presence_penalty {
199 if !(-2.0..=2.0).contains(&penalty) {
200 return Err(ValidationError::OutOfRange {
201 field: "presence_penalty".to_string(),
202 min: -2.0,
203 max: 2.0,
204 }
205 .into());
206 }
207 }
208
209 if let Some(penalty) = self.frequency_penalty {
211 if !(-2.0..=2.0).contains(&penalty) {
212 return Err(ValidationError::OutOfRange {
213 field: "frequency_penalty".to_string(),
214 min: -2.0,
215 max: 2.0,
216 }
217 .into());
218 }
219 }
220
221 Ok(())
222 }
223}
224
225#[derive(Debug, Default, Clone)]
227pub struct CompletionRequestBuilder {
228 messages: Vec<Message>,
229 model: Option<String>,
230 max_tokens: Option<u32>,
231 temperature: Option<f32>,
232 top_p: Option<f32>,
233 stream: Option<bool>,
234 n: Option<u32>,
235 stop: Option<Vec<String>>,
236 presence_penalty: Option<f32>,
237 frequency_penalty: Option<f32>,
238 user: Option<String>,
239 response_format: Option<ResponseFormat>,
240}
241
242impl CompletionRequestBuilder {
243 pub fn model(mut self, model: impl Into<String>) -> Self {
245 self.model = Some(model.into());
246 self
247 }
248
249 pub fn message(mut self, message: Message) -> Self {
251 self.messages.push(message);
252 self
253 }
254
255 pub fn messages(mut self, messages: Vec<Message>) -> Self {
257 self.messages = messages;
258 self
259 }
260
261 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
263 self.max_tokens = Some(max_tokens);
264 self
265 }
266
267 pub fn temperature(mut self, temperature: f32) -> Self {
269 self.temperature = Some(temperature);
270 self
271 }
272
273 pub fn top_p(mut self, top_p: f32) -> Self {
275 self.top_p = Some(top_p);
276 self
277 }
278
279 pub fn stream(mut self, stream: bool) -> Self {
281 self.stream = Some(stream);
282 self
283 }
284
285 pub fn n(mut self, n: u32) -> Self {
287 self.n = Some(n);
288 self
289 }
290
291 pub fn stop(mut self, stop: Vec<String>) -> Self {
293 self.stop = Some(stop);
294 self
295 }
296
297 pub fn presence_penalty(mut self, penalty: f32) -> Self {
299 self.presence_penalty = Some(penalty);
300 self
301 }
302
303 pub fn frequency_penalty(mut self, penalty: f32) -> Self {
305 self.frequency_penalty = Some(penalty);
306 self
307 }
308
309 pub fn user(mut self, user: impl Into<String>) -> Self {
311 self.user = Some(user.into());
312 self
313 }
314
315 pub fn response_format(mut self, format: ResponseFormat) -> Self {
317 self.response_format = Some(format);
318 self
319 }
320
321 pub fn json_mode(mut self) -> Self {
323 self.response_format = Some(ResponseFormat::JsonObject);
324 self
325 }
326
327 pub fn json_schema(mut self, name: impl Into<String>, schema: Value) -> Self {
329 self.response_format = Some(ResponseFormat::JsonSchema {
330 json_schema: JsonSchemaFormat {
331 name: name.into(),
332 schema,
333 strict: Some(true),
334 },
335 });
336 self
337 }
338
339 pub fn build(self) -> Result<CompletionRequest> {
341 let model = self.model.ok_or_else(|| ValidationError::Empty {
342 field: "model".to_string(),
343 })?;
344
345 let request = CompletionRequest {
346 messages: self.messages,
347 model,
348 max_tokens: self.max_tokens,
349 temperature: self.temperature,
350 top_p: self.top_p,
351 stream: self.stream,
352 n: self.n,
353 stop: self.stop,
354 presence_penalty: self.presence_penalty,
355 frequency_penalty: self.frequency_penalty,
356 user: self.user,
357 response_format: self.response_format,
358 };
359
360 request.validate()?;
361 Ok(request)
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_builder_basic() {
371 let request = CompletionRequest::builder()
372 .model("gpt-4")
373 .message(Message::user("Hello"))
374 .build()
375 .unwrap();
376
377 assert_eq!(request.model, "gpt-4");
378 assert_eq!(request.messages.len(), 1);
379 assert_eq!(request.messages[0].content, "Hello");
380 }
381
382 #[test]
383 fn test_builder_all_fields() {
384 let request = CompletionRequest::builder()
385 .model("gpt-4")
386 .message(Message::user("Hello"))
387 .max_tokens(100)
388 .temperature(0.7)
389 .top_p(0.9)
390 .stream(true)
391 .n(1)
392 .stop(vec!["END".to_string()])
393 .presence_penalty(0.5)
394 .frequency_penalty(0.5)
395 .user("test-user")
396 .build()
397 .unwrap();
398
399 assert_eq!(request.max_tokens, Some(100));
400 assert_eq!(request.temperature, Some(0.7));
401 assert_eq!(request.top_p, Some(0.9));
402 assert_eq!(request.stream, Some(true));
403 assert_eq!(request.n, Some(1));
404 assert_eq!(request.stop, Some(vec!["END".to_string()]));
405 assert_eq!(request.presence_penalty, Some(0.5));
406 assert_eq!(request.frequency_penalty, Some(0.5));
407 assert_eq!(request.user, Some("test-user".to_string()));
408 }
409
410 #[test]
411 fn test_builder_missing_model() {
412 let result = CompletionRequest::builder()
413 .message(Message::user("Hello"))
414 .build();
415 assert!(result.is_err());
416 }
417
418 #[test]
419 fn test_validation_empty_messages() {
420 let result = CompletionRequest::builder().model("gpt-4").build();
421 assert!(result.is_err());
422 }
423
424 #[test]
425 fn test_validation_invalid_temperature() {
426 let result = CompletionRequest::builder()
427 .model("gpt-4")
428 .message(Message::user("Hello"))
429 .temperature(3.0)
430 .build();
431 assert!(result.is_err());
432 }
433
434 #[test]
435 fn test_validation_invalid_top_p() {
436 let result = CompletionRequest::builder()
437 .model("gpt-4")
438 .message(Message::user("Hello"))
439 .top_p(1.5)
440 .build();
441 assert!(result.is_err());
442 }
443
444 #[test]
445 fn test_validation_invalid_model_chars() {
446 let result = CompletionRequest::builder()
447 .model("gpt-4!")
448 .message(Message::user("Hello"))
449 .build();
450 assert!(result.is_err());
451 }
452
453 #[test]
454 fn test_serialization() {
455 let request = CompletionRequest::builder()
456 .model("gpt-4")
457 .message(Message::user("Hello"))
458 .temperature(0.7)
459 .build()
460 .unwrap();
461
462 let json = serde_json::to_string(&request).unwrap();
463 let parsed: CompletionRequest = serde_json::from_str(&json).unwrap();
464 assert_eq!(request, parsed);
465 }
466
467 #[test]
468 fn test_optional_fields_not_serialized() {
469 let request = CompletionRequest::builder()
470 .model("gpt-4")
471 .message(Message::user("Hello"))
472 .build()
473 .unwrap();
474
475 let json = serde_json::to_value(&request).unwrap();
476 assert!(json.get("max_tokens").is_none());
477 assert!(json.get("temperature").is_none());
478 }
479
480 #[test]
481 fn test_validation_total_request_size_limit() {
482 let large_content = "x".repeat(2 * 1024 * 1024); let result = CompletionRequest::builder()
485 .model("gpt-4")
486 .message(Message::user(large_content.clone()))
487 .message(Message::user(large_content.clone()))
488 .message(Message::user(large_content.clone()))
489 .message(Message::user(large_content.clone()))
490 .message(Message::user(large_content.clone()))
491 .message(Message::user(large_content.clone())) .build();
493
494 assert!(result.is_err());
495 assert!(matches!(
496 result.unwrap_err(),
497 crate::error::SimpleAgentsError::Validation(ValidationError::TooLong { .. })
498 ));
499 }
500
501 #[test]
502 fn test_validation_total_request_size_within_limit() {
503 let content = "x".repeat(1024 * 1024); let result = CompletionRequest::builder()
506 .model("gpt-4")
507 .message(Message::user(content.clone()))
508 .message(Message::user(content.clone()))
509 .message(Message::user(content.clone())) .build();
511
512 assert!(result.is_ok());
513 }
514}