Skip to main content

simple_agent_type/
request.rs

1//! Request types for LLM completions.
2//!
3//! Provides OpenAI-compatible request structures with validation.
4
5use crate::error::{Result, ValidationError};
6use crate::message::Message;
7use crate::tool::{ToolChoice, ToolDefinition};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11/// Response format for structured outputs
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum ResponseFormat {
15    /// Default text response
16    Text,
17    /// JSON object mode (no schema validation)
18    JsonObject,
19    /// Structured output with JSON schema (OpenAI only)
20    #[serde(rename = "json_schema")]
21    JsonSchema {
22        /// The JSON schema definition
23        json_schema: JsonSchemaFormat,
24    },
25}
26
27/// JSON schema format for structured outputs
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29pub struct JsonSchemaFormat {
30    /// Name of the schema
31    pub name: String,
32    /// The JSON schema
33    pub schema: Value,
34    /// Whether to use strict mode (default: true)
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub strict: Option<bool>,
37}
38
39/// A completion request to an LLM provider.
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub struct CompletionRequest {
42    /// List of messages in the conversation
43    pub messages: Vec<Message>,
44    /// Model identifier (e.g., "gpt-4", "claude-3-opus")
45    pub model: String,
46    /// Maximum tokens to generate
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub max_tokens: Option<u32>,
49    /// Sampling temperature (0.0-2.0)
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub temperature: Option<f32>,
52    /// Nucleus sampling threshold (0.0-1.0)
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub top_p: Option<f32>,
55    /// Enable streaming responses
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub stream: Option<bool>,
58    /// Number of completions to generate
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub n: Option<u32>,
61    /// Stop sequences
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub stop: Option<Vec<String>>,
64    /// Presence penalty (-2.0 to 2.0)
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub presence_penalty: Option<f32>,
67    /// Frequency penalty (-2.0 to 2.0)
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub frequency_penalty: Option<f32>,
70    /// User identifier (for abuse detection)
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub user: Option<String>,
73    /// Response format for structured outputs (OpenAI only)
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub response_format: Option<ResponseFormat>,
76    /// Tool definitions for tool calling.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub tools: Option<Vec<ToolDefinition>>,
79    /// Tool choice configuration.
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub tool_choice: Option<ToolChoice>,
82}
83
84impl CompletionRequest {
85    /// Create a new builder.
86    ///
87    /// # Example
88    /// ```
89    /// use simple_agent_type::request::CompletionRequest;
90    /// use simple_agent_type::message::Message;
91    ///
92    /// let request = CompletionRequest::builder()
93    ///     .model("gpt-4")
94    ///     .message(Message::user("Hello!"))
95    ///     .build()
96    ///     .unwrap();
97    /// ```
98    pub fn builder() -> CompletionRequestBuilder {
99        CompletionRequestBuilder::default()
100    }
101
102    /// Validate the request.
103    ///
104    /// # Validation Rules
105    /// - Messages: 1-1000 items, each < 1MB
106    /// - Model: alphanumeric + `-_./` only
107    /// - Temperature: 0.0-2.0
108    /// - Top_p: 0.0-1.0
109    /// - No null bytes (security)
110    pub fn validate(&self) -> Result<()> {
111        // Validate messages
112        if self.messages.is_empty() {
113            return Err(ValidationError::Empty {
114                field: "messages".to_string(),
115            }
116            .into());
117        }
118
119        if self.messages.len() > 1000 {
120            return Err(ValidationError::TooLong {
121                field: "messages".to_string(),
122                max: 1000,
123            }
124            .into());
125        }
126
127        // Validate each message content size (max 1MB)
128        const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
129        for (i, msg) in self.messages.iter().enumerate() {
130            if msg.content.len() > MAX_MESSAGE_SIZE {
131                return Err(ValidationError::TooLong {
132                    field: format!("messages[{}].content", i),
133                    max: MAX_MESSAGE_SIZE,
134                }
135                .into());
136            }
137
138            // Security: no null bytes
139            if msg.content.contains('\0') {
140                return Err(ValidationError::InvalidFormat {
141                    field: format!("messages[{}].content", i),
142                    reason: "contains null bytes".to_string(),
143                }
144                .into());
145            }
146        }
147
148        // Validate total request size (max 10MB)
149        const MAX_TOTAL_REQUEST_SIZE: usize = 10 * 1024 * 1024;
150        let total_size: usize = self.messages.iter().map(|m| m.content.len()).sum();
151        if total_size > MAX_TOTAL_REQUEST_SIZE {
152            return Err(ValidationError::TooLong {
153                field: "total_request_size".to_string(),
154                max: MAX_TOTAL_REQUEST_SIZE,
155            }
156            .into());
157        }
158
159        // Validate model
160        if self.model.is_empty() {
161            return Err(ValidationError::Empty {
162                field: "model".to_string(),
163            }
164            .into());
165        }
166
167        // Model must be alphanumeric + `-_./`
168        if !self
169            .model
170            .chars()
171            .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '/')
172        {
173            return Err(ValidationError::InvalidFormat {
174                field: "model".to_string(),
175                reason: "must be alphanumeric with -_./ only".to_string(),
176            }
177            .into());
178        }
179
180        // Validate temperature
181        if let Some(temp) = self.temperature {
182            if !(0.0..=2.0).contains(&temp) {
183                return Err(ValidationError::OutOfRange {
184                    field: "temperature".to_string(),
185                    min: 0.0,
186                    max: 2.0,
187                }
188                .into());
189            }
190        }
191
192        // Validate top_p
193        if let Some(top_p) = self.top_p {
194            if !(0.0..=1.0).contains(&top_p) {
195                return Err(ValidationError::OutOfRange {
196                    field: "top_p".to_string(),
197                    min: 0.0,
198                    max: 1.0,
199                }
200                .into());
201            }
202        }
203
204        // Validate presence_penalty
205        if let Some(penalty) = self.presence_penalty {
206            if !(-2.0..=2.0).contains(&penalty) {
207                return Err(ValidationError::OutOfRange {
208                    field: "presence_penalty".to_string(),
209                    min: -2.0,
210                    max: 2.0,
211                }
212                .into());
213            }
214        }
215
216        // Validate frequency_penalty
217        if let Some(penalty) = self.frequency_penalty {
218            if !(-2.0..=2.0).contains(&penalty) {
219                return Err(ValidationError::OutOfRange {
220                    field: "frequency_penalty".to_string(),
221                    min: -2.0,
222                    max: 2.0,
223                }
224                .into());
225            }
226        }
227
228        Ok(())
229    }
230}
231
232/// Builder for CompletionRequest.
233#[derive(Debug, Default, Clone)]
234pub struct CompletionRequestBuilder {
235    messages: Vec<Message>,
236    model: Option<String>,
237    max_tokens: Option<u32>,
238    temperature: Option<f32>,
239    top_p: Option<f32>,
240    stream: Option<bool>,
241    n: Option<u32>,
242    stop: Option<Vec<String>>,
243    presence_penalty: Option<f32>,
244    frequency_penalty: Option<f32>,
245    user: Option<String>,
246    response_format: Option<ResponseFormat>,
247    tools: Option<Vec<ToolDefinition>>,
248    tool_choice: Option<ToolChoice>,
249}
250
251impl CompletionRequestBuilder {
252    /// Set the model.
253    pub fn model(mut self, model: impl Into<String>) -> Self {
254        self.model = Some(model.into());
255        self
256    }
257
258    /// Add a message.
259    pub fn message(mut self, message: Message) -> Self {
260        self.messages.push(message);
261        self
262    }
263
264    /// Set all messages at once.
265    pub fn messages(mut self, messages: Vec<Message>) -> Self {
266        self.messages = messages;
267        self
268    }
269
270    /// Set max_tokens.
271    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
272        self.max_tokens = Some(max_tokens);
273        self
274    }
275
276    /// Set temperature.
277    pub fn temperature(mut self, temperature: f32) -> Self {
278        self.temperature = Some(temperature);
279        self
280    }
281
282    /// Set top_p.
283    pub fn top_p(mut self, top_p: f32) -> Self {
284        self.top_p = Some(top_p);
285        self
286    }
287
288    /// Enable streaming.
289    pub fn stream(mut self, stream: bool) -> Self {
290        self.stream = Some(stream);
291        self
292    }
293
294    /// Set number of completions.
295    pub fn n(mut self, n: u32) -> Self {
296        self.n = Some(n);
297        self
298    }
299
300    /// Set stop sequences.
301    pub fn stop(mut self, stop: Vec<String>) -> Self {
302        self.stop = Some(stop);
303        self
304    }
305
306    /// Set presence penalty.
307    pub fn presence_penalty(mut self, penalty: f32) -> Self {
308        self.presence_penalty = Some(penalty);
309        self
310    }
311
312    /// Set frequency penalty.
313    pub fn frequency_penalty(mut self, penalty: f32) -> Self {
314        self.frequency_penalty = Some(penalty);
315        self
316    }
317
318    /// Set user identifier.
319    pub fn user(mut self, user: impl Into<String>) -> Self {
320        self.user = Some(user.into());
321        self
322    }
323
324    /// Set response format.
325    pub fn response_format(mut self, format: ResponseFormat) -> Self {
326        self.response_format = Some(format);
327        self
328    }
329
330    /// Set tool definitions for tool calling.
331    pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
332        self.tools = Some(tools);
333        self
334    }
335
336    /// Set tool choice configuration.
337    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
338        self.tool_choice = Some(tool_choice);
339        self
340    }
341
342    /// Enable JSON object mode (no schema validation).
343    pub fn json_mode(mut self) -> Self {
344        self.response_format = Some(ResponseFormat::JsonObject);
345        self
346    }
347
348    /// Enable structured output with JSON schema.
349    pub fn json_schema(mut self, name: impl Into<String>, schema: Value) -> Self {
350        self.response_format = Some(ResponseFormat::JsonSchema {
351            json_schema: JsonSchemaFormat {
352                name: name.into(),
353                schema,
354                strict: Some(true),
355            },
356        });
357        self
358    }
359
360    /// Build and validate the request.
361    pub fn build(self) -> Result<CompletionRequest> {
362        let model = self.model.ok_or_else(|| ValidationError::Empty {
363            field: "model".to_string(),
364        })?;
365
366        let request = CompletionRequest {
367            messages: self.messages,
368            model,
369            max_tokens: self.max_tokens,
370            temperature: self.temperature,
371            top_p: self.top_p,
372            stream: self.stream,
373            n: self.n,
374            stop: self.stop,
375            presence_penalty: self.presence_penalty,
376            frequency_penalty: self.frequency_penalty,
377            user: self.user,
378            response_format: self.response_format,
379            tools: self.tools,
380            tool_choice: self.tool_choice,
381        };
382
383        request.validate()?;
384        Ok(request)
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_builder_basic() {
394        let request = CompletionRequest::builder()
395            .model("gpt-4")
396            .message(Message::user("Hello"))
397            .build()
398            .unwrap();
399
400        assert_eq!(request.model, "gpt-4");
401        assert_eq!(request.messages.len(), 1);
402        assert_eq!(request.messages[0].content, "Hello");
403    }
404
405    #[test]
406    fn test_builder_all_fields() {
407        let request = CompletionRequest::builder()
408            .model("gpt-4")
409            .message(Message::user("Hello"))
410            .max_tokens(100)
411            .temperature(0.7)
412            .top_p(0.9)
413            .stream(true)
414            .n(1)
415            .stop(vec!["END".to_string()])
416            .presence_penalty(0.5)
417            .frequency_penalty(0.5)
418            .user("test-user")
419            .build()
420            .unwrap();
421
422        assert_eq!(request.max_tokens, Some(100));
423        assert_eq!(request.temperature, Some(0.7));
424        assert_eq!(request.top_p, Some(0.9));
425        assert_eq!(request.stream, Some(true));
426        assert_eq!(request.n, Some(1));
427        assert_eq!(request.stop, Some(vec!["END".to_string()]));
428        assert_eq!(request.presence_penalty, Some(0.5));
429        assert_eq!(request.frequency_penalty, Some(0.5));
430        assert_eq!(request.user, Some("test-user".to_string()));
431    }
432
433    #[test]
434    fn test_builder_missing_model() {
435        let result = CompletionRequest::builder()
436            .message(Message::user("Hello"))
437            .build();
438        assert!(result.is_err());
439    }
440
441    #[test]
442    fn test_validation_empty_messages() {
443        let result = CompletionRequest::builder().model("gpt-4").build();
444        assert!(result.is_err());
445    }
446
447    #[test]
448    fn test_validation_invalid_temperature() {
449        let result = CompletionRequest::builder()
450            .model("gpt-4")
451            .message(Message::user("Hello"))
452            .temperature(3.0)
453            .build();
454        assert!(result.is_err());
455    }
456
457    #[test]
458    fn test_validation_invalid_top_p() {
459        let result = CompletionRequest::builder()
460            .model("gpt-4")
461            .message(Message::user("Hello"))
462            .top_p(1.5)
463            .build();
464        assert!(result.is_err());
465    }
466
467    #[test]
468    fn test_validation_invalid_model_chars() {
469        let result = CompletionRequest::builder()
470            .model("gpt-4!")
471            .message(Message::user("Hello"))
472            .build();
473        assert!(result.is_err());
474    }
475
476    #[test]
477    fn test_serialization() {
478        let request = CompletionRequest::builder()
479            .model("gpt-4")
480            .message(Message::user("Hello"))
481            .temperature(0.7)
482            .build()
483            .unwrap();
484
485        let json = serde_json::to_string(&request).unwrap();
486        let parsed: CompletionRequest = serde_json::from_str(&json).unwrap();
487        assert_eq!(request, parsed);
488    }
489
490    #[test]
491    fn test_optional_fields_not_serialized() {
492        let request = CompletionRequest::builder()
493            .model("gpt-4")
494            .message(Message::user("Hello"))
495            .build()
496            .unwrap();
497
498        let json = serde_json::to_value(&request).unwrap();
499        assert!(json.get("max_tokens").is_none());
500        assert!(json.get("temperature").is_none());
501    }
502
503    #[test]
504    fn test_validation_total_request_size_limit() {
505        // Create a request that exceeds 10MB total
506        let large_content = "x".repeat(2 * 1024 * 1024); // 2MB per message
507        let result = CompletionRequest::builder()
508            .model("gpt-4")
509            .message(Message::user(large_content.clone()))
510            .message(Message::user(large_content.clone()))
511            .message(Message::user(large_content.clone()))
512            .message(Message::user(large_content.clone()))
513            .message(Message::user(large_content.clone()))
514            .message(Message::user(large_content.clone())) // 6 * 2MB = 12MB > 10MB
515            .build();
516
517        assert!(result.is_err());
518        assert!(matches!(
519            result.unwrap_err(),
520            crate::error::SimpleAgentsError::Validation(ValidationError::TooLong { .. })
521        ));
522    }
523
524    #[test]
525    fn test_validation_total_request_size_within_limit() {
526        // Create a request that's under 10MB total
527        let content = "x".repeat(1024 * 1024); // 1MB per message
528        let result = CompletionRequest::builder()
529            .model("gpt-4")
530            .message(Message::user(content.clone()))
531            .message(Message::user(content.clone()))
532            .message(Message::user(content.clone())) // 3MB < 10MB
533            .build();
534
535        assert!(result.is_ok());
536    }
537}