Skip to main content

simple_agent_type/
request.rs

1//! Request types for LLM completions.
2//!
3//! Provides OpenAI-compatible request structures with validation.
4
5use crate::error::{Result, ValidationError};
6use crate::message::Message;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10/// Response format for structured outputs
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12#[serde(tag = "type", rename_all = "snake_case")]
13pub enum ResponseFormat {
14    /// Default text response
15    Text,
16    /// JSON object mode (no schema validation)
17    JsonObject,
18    /// Structured output with JSON schema (OpenAI only)
19    #[serde(rename = "json_schema")]
20    JsonSchema {
21        /// The JSON schema definition
22        json_schema: JsonSchemaFormat,
23    },
24}
25
26/// JSON schema format for structured outputs
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
28pub struct JsonSchemaFormat {
29    /// Name of the schema
30    pub name: String,
31    /// The JSON schema
32    pub schema: Value,
33    /// Whether to use strict mode (default: true)
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub strict: Option<bool>,
36}
37
38/// A completion request to an LLM provider.
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
40pub struct CompletionRequest {
41    /// List of messages in the conversation
42    pub messages: Vec<Message>,
43    /// Model identifier (e.g., "gpt-4", "claude-3-opus")
44    pub model: String,
45    /// Maximum tokens to generate
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub max_tokens: Option<u32>,
48    /// Sampling temperature (0.0-2.0)
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub temperature: Option<f32>,
51    /// Nucleus sampling threshold (0.0-1.0)
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub top_p: Option<f32>,
54    /// Enable streaming responses
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub stream: Option<bool>,
57    /// Number of completions to generate
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub n: Option<u32>,
60    /// Stop sequences
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub stop: Option<Vec<String>>,
63    /// Presence penalty (-2.0 to 2.0)
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub presence_penalty: Option<f32>,
66    /// Frequency penalty (-2.0 to 2.0)
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub frequency_penalty: Option<f32>,
69    /// User identifier (for abuse detection)
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub user: Option<String>,
72    /// Response format for structured outputs (OpenAI only)
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub response_format: Option<ResponseFormat>,
75}
76
77impl CompletionRequest {
78    /// Create a new builder.
79    ///
80    /// # Example
81    /// ```
82    /// use simple_agent_type::request::CompletionRequest;
83    /// use simple_agent_type::message::Message;
84    ///
85    /// let request = CompletionRequest::builder()
86    ///     .model("gpt-4")
87    ///     .message(Message::user("Hello!"))
88    ///     .build()
89    ///     .unwrap();
90    /// ```
91    pub fn builder() -> CompletionRequestBuilder {
92        CompletionRequestBuilder::default()
93    }
94
95    /// Validate the request.
96    ///
97    /// # Validation Rules
98    /// - Messages: 1-1000 items, each < 1MB
99    /// - Model: alphanumeric + `-_./` only
100    /// - Temperature: 0.0-2.0
101    /// - Top_p: 0.0-1.0
102    /// - No null bytes (security)
103    pub fn validate(&self) -> Result<()> {
104        // Validate messages
105        if self.messages.is_empty() {
106            return Err(ValidationError::Empty {
107                field: "messages".to_string(),
108            }
109            .into());
110        }
111
112        if self.messages.len() > 1000 {
113            return Err(ValidationError::TooLong {
114                field: "messages".to_string(),
115                max: 1000,
116            }
117            .into());
118        }
119
120        // Validate each message content size (max 1MB)
121        const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
122        for (i, msg) in self.messages.iter().enumerate() {
123            if msg.content.len() > MAX_MESSAGE_SIZE {
124                return Err(ValidationError::TooLong {
125                    field: format!("messages[{}].content", i),
126                    max: MAX_MESSAGE_SIZE,
127                }
128                .into());
129            }
130
131            // Security: no null bytes
132            if msg.content.contains('\0') {
133                return Err(ValidationError::InvalidFormat {
134                    field: format!("messages[{}].content", i),
135                    reason: "contains null bytes".to_string(),
136                }
137                .into());
138            }
139        }
140
141        // Validate total request size (max 10MB)
142        const MAX_TOTAL_REQUEST_SIZE: usize = 10 * 1024 * 1024;
143        let total_size: usize = self.messages.iter().map(|m| m.content.len()).sum();
144        if total_size > MAX_TOTAL_REQUEST_SIZE {
145            return Err(ValidationError::TooLong {
146                field: "total_request_size".to_string(),
147                max: MAX_TOTAL_REQUEST_SIZE,
148            }
149            .into());
150        }
151
152        // Validate model
153        if self.model.is_empty() {
154            return Err(ValidationError::Empty {
155                field: "model".to_string(),
156            }
157            .into());
158        }
159
160        // Model must be alphanumeric + `-_./`
161        if !self
162            .model
163            .chars()
164            .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '/')
165        {
166            return Err(ValidationError::InvalidFormat {
167                field: "model".to_string(),
168                reason: "must be alphanumeric with -_./ only".to_string(),
169            }
170            .into());
171        }
172
173        // Validate temperature
174        if let Some(temp) = self.temperature {
175            if !(0.0..=2.0).contains(&temp) {
176                return Err(ValidationError::OutOfRange {
177                    field: "temperature".to_string(),
178                    min: 0.0,
179                    max: 2.0,
180                }
181                .into());
182            }
183        }
184
185        // Validate top_p
186        if let Some(top_p) = self.top_p {
187            if !(0.0..=1.0).contains(&top_p) {
188                return Err(ValidationError::OutOfRange {
189                    field: "top_p".to_string(),
190                    min: 0.0,
191                    max: 1.0,
192                }
193                .into());
194            }
195        }
196
197        // Validate presence_penalty
198        if let Some(penalty) = self.presence_penalty {
199            if !(-2.0..=2.0).contains(&penalty) {
200                return Err(ValidationError::OutOfRange {
201                    field: "presence_penalty".to_string(),
202                    min: -2.0,
203                    max: 2.0,
204                }
205                .into());
206            }
207        }
208
209        // Validate frequency_penalty
210        if let Some(penalty) = self.frequency_penalty {
211            if !(-2.0..=2.0).contains(&penalty) {
212                return Err(ValidationError::OutOfRange {
213                    field: "frequency_penalty".to_string(),
214                    min: -2.0,
215                    max: 2.0,
216                }
217                .into());
218            }
219        }
220
221        Ok(())
222    }
223}
224
225/// Builder for CompletionRequest.
226#[derive(Debug, Default, Clone)]
227pub struct CompletionRequestBuilder {
228    messages: Vec<Message>,
229    model: Option<String>,
230    max_tokens: Option<u32>,
231    temperature: Option<f32>,
232    top_p: Option<f32>,
233    stream: Option<bool>,
234    n: Option<u32>,
235    stop: Option<Vec<String>>,
236    presence_penalty: Option<f32>,
237    frequency_penalty: Option<f32>,
238    user: Option<String>,
239    response_format: Option<ResponseFormat>,
240}
241
242impl CompletionRequestBuilder {
243    /// Set the model.
244    pub fn model(mut self, model: impl Into<String>) -> Self {
245        self.model = Some(model.into());
246        self
247    }
248
249    /// Add a message.
250    pub fn message(mut self, message: Message) -> Self {
251        self.messages.push(message);
252        self
253    }
254
255    /// Set all messages at once.
256    pub fn messages(mut self, messages: Vec<Message>) -> Self {
257        self.messages = messages;
258        self
259    }
260
261    /// Set max_tokens.
262    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
263        self.max_tokens = Some(max_tokens);
264        self
265    }
266
267    /// Set temperature.
268    pub fn temperature(mut self, temperature: f32) -> Self {
269        self.temperature = Some(temperature);
270        self
271    }
272
273    /// Set top_p.
274    pub fn top_p(mut self, top_p: f32) -> Self {
275        self.top_p = Some(top_p);
276        self
277    }
278
279    /// Enable streaming.
280    pub fn stream(mut self, stream: bool) -> Self {
281        self.stream = Some(stream);
282        self
283    }
284
285    /// Set number of completions.
286    pub fn n(mut self, n: u32) -> Self {
287        self.n = Some(n);
288        self
289    }
290
291    /// Set stop sequences.
292    pub fn stop(mut self, stop: Vec<String>) -> Self {
293        self.stop = Some(stop);
294        self
295    }
296
297    /// Set presence penalty.
298    pub fn presence_penalty(mut self, penalty: f32) -> Self {
299        self.presence_penalty = Some(penalty);
300        self
301    }
302
303    /// Set frequency penalty.
304    pub fn frequency_penalty(mut self, penalty: f32) -> Self {
305        self.frequency_penalty = Some(penalty);
306        self
307    }
308
309    /// Set user identifier.
310    pub fn user(mut self, user: impl Into<String>) -> Self {
311        self.user = Some(user.into());
312        self
313    }
314
315    /// Set response format.
316    pub fn response_format(mut self, format: ResponseFormat) -> Self {
317        self.response_format = Some(format);
318        self
319    }
320
321    /// Enable JSON object mode (no schema validation).
322    pub fn json_mode(mut self) -> Self {
323        self.response_format = Some(ResponseFormat::JsonObject);
324        self
325    }
326
327    /// Enable structured output with JSON schema.
328    pub fn json_schema(mut self, name: impl Into<String>, schema: Value) -> Self {
329        self.response_format = Some(ResponseFormat::JsonSchema {
330            json_schema: JsonSchemaFormat {
331                name: name.into(),
332                schema,
333                strict: Some(true),
334            },
335        });
336        self
337    }
338
339    /// Build and validate the request.
340    pub fn build(self) -> Result<CompletionRequest> {
341        let model = self.model.ok_or_else(|| ValidationError::Empty {
342            field: "model".to_string(),
343        })?;
344
345        let request = CompletionRequest {
346            messages: self.messages,
347            model,
348            max_tokens: self.max_tokens,
349            temperature: self.temperature,
350            top_p: self.top_p,
351            stream: self.stream,
352            n: self.n,
353            stop: self.stop,
354            presence_penalty: self.presence_penalty,
355            frequency_penalty: self.frequency_penalty,
356            user: self.user,
357            response_format: self.response_format,
358        };
359
360        request.validate()?;
361        Ok(request)
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_builder_basic() {
371        let request = CompletionRequest::builder()
372            .model("gpt-4")
373            .message(Message::user("Hello"))
374            .build()
375            .unwrap();
376
377        assert_eq!(request.model, "gpt-4");
378        assert_eq!(request.messages.len(), 1);
379        assert_eq!(request.messages[0].content, "Hello");
380    }
381
382    #[test]
383    fn test_builder_all_fields() {
384        let request = CompletionRequest::builder()
385            .model("gpt-4")
386            .message(Message::user("Hello"))
387            .max_tokens(100)
388            .temperature(0.7)
389            .top_p(0.9)
390            .stream(true)
391            .n(1)
392            .stop(vec!["END".to_string()])
393            .presence_penalty(0.5)
394            .frequency_penalty(0.5)
395            .user("test-user")
396            .build()
397            .unwrap();
398
399        assert_eq!(request.max_tokens, Some(100));
400        assert_eq!(request.temperature, Some(0.7));
401        assert_eq!(request.top_p, Some(0.9));
402        assert_eq!(request.stream, Some(true));
403        assert_eq!(request.n, Some(1));
404        assert_eq!(request.stop, Some(vec!["END".to_string()]));
405        assert_eq!(request.presence_penalty, Some(0.5));
406        assert_eq!(request.frequency_penalty, Some(0.5));
407        assert_eq!(request.user, Some("test-user".to_string()));
408    }
409
410    #[test]
411    fn test_builder_missing_model() {
412        let result = CompletionRequest::builder()
413            .message(Message::user("Hello"))
414            .build();
415        assert!(result.is_err());
416    }
417
418    #[test]
419    fn test_validation_empty_messages() {
420        let result = CompletionRequest::builder().model("gpt-4").build();
421        assert!(result.is_err());
422    }
423
424    #[test]
425    fn test_validation_invalid_temperature() {
426        let result = CompletionRequest::builder()
427            .model("gpt-4")
428            .message(Message::user("Hello"))
429            .temperature(3.0)
430            .build();
431        assert!(result.is_err());
432    }
433
434    #[test]
435    fn test_validation_invalid_top_p() {
436        let result = CompletionRequest::builder()
437            .model("gpt-4")
438            .message(Message::user("Hello"))
439            .top_p(1.5)
440            .build();
441        assert!(result.is_err());
442    }
443
444    #[test]
445    fn test_validation_invalid_model_chars() {
446        let result = CompletionRequest::builder()
447            .model("gpt-4!")
448            .message(Message::user("Hello"))
449            .build();
450        assert!(result.is_err());
451    }
452
453    #[test]
454    fn test_serialization() {
455        let request = CompletionRequest::builder()
456            .model("gpt-4")
457            .message(Message::user("Hello"))
458            .temperature(0.7)
459            .build()
460            .unwrap();
461
462        let json = serde_json::to_string(&request).unwrap();
463        let parsed: CompletionRequest = serde_json::from_str(&json).unwrap();
464        assert_eq!(request, parsed);
465    }
466
467    #[test]
468    fn test_optional_fields_not_serialized() {
469        let request = CompletionRequest::builder()
470            .model("gpt-4")
471            .message(Message::user("Hello"))
472            .build()
473            .unwrap();
474
475        let json = serde_json::to_value(&request).unwrap();
476        assert!(json.get("max_tokens").is_none());
477        assert!(json.get("temperature").is_none());
478    }
479
480    #[test]
481    fn test_validation_total_request_size_limit() {
482        // Create a request that exceeds 10MB total
483        let large_content = "x".repeat(2 * 1024 * 1024); // 2MB per message
484        let result = CompletionRequest::builder()
485            .model("gpt-4")
486            .message(Message::user(large_content.clone()))
487            .message(Message::user(large_content.clone()))
488            .message(Message::user(large_content.clone()))
489            .message(Message::user(large_content.clone()))
490            .message(Message::user(large_content.clone()))
491            .message(Message::user(large_content.clone())) // 6 * 2MB = 12MB > 10MB
492            .build();
493
494        assert!(result.is_err());
495        assert!(matches!(
496            result.unwrap_err(),
497            crate::error::SimpleAgentsError::Validation(ValidationError::TooLong { .. })
498        ));
499    }
500
501    #[test]
502    fn test_validation_total_request_size_within_limit() {
503        // Create a request that's under 10MB total
504        let content = "x".repeat(1024 * 1024); // 1MB per message
505        let result = CompletionRequest::builder()
506            .model("gpt-4")
507            .message(Message::user(content.clone()))
508            .message(Message::user(content.clone()))
509            .message(Message::user(content.clone())) // 3MB < 10MB
510            .build();
511
512        assert!(result.is_ok());
513    }
514}