Skip to main content

simple_agent_type/
request.rs

1//! Request types for LLM completions.
2//!
3//! Provides OpenAI-compatible request structures with validation.
4
5use crate::error::{Result, ValidationError};
6use crate::message::Message;
7use crate::tool::{ToolChoice, ToolDefinition};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11/// Response format for structured outputs
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum ResponseFormat {
15    /// Default text response
16    Text,
17    /// JSON object mode (no schema validation)
18    JsonObject,
19    /// Structured output with JSON schema (OpenAI only)
20    #[serde(rename = "json_schema")]
21    JsonSchema {
22        /// The JSON schema definition
23        json_schema: JsonSchemaFormat,
24    },
25}
26
27/// JSON schema format for structured outputs
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29pub struct JsonSchemaFormat {
30    /// Name of the schema
31    pub name: String,
32    /// The JSON schema
33    pub schema: Value,
34    /// Whether to use strict mode (default: true)
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub strict: Option<bool>,
37}
38
39/// A completion request to an LLM provider.
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub struct CompletionRequest {
42    /// List of messages in the conversation
43    pub messages: Vec<Message>,
44    /// Model identifier (e.g., "gpt-4", "claude-3-opus")
45    pub model: String,
46    /// Maximum tokens to generate
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub max_tokens: Option<u32>,
49    /// Sampling temperature (0.0-2.0)
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub temperature: Option<f32>,
52    /// Nucleus sampling threshold (0.0-1.0)
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub top_p: Option<f32>,
55    /// Enable streaming responses
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub stream: Option<bool>,
58    /// Number of completions to generate
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub n: Option<u32>,
61    /// Stop sequences
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub stop: Option<Vec<String>>,
64    /// Presence penalty (-2.0 to 2.0)
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub presence_penalty: Option<f32>,
67    /// Frequency penalty (-2.0 to 2.0)
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub frequency_penalty: Option<f32>,
70    /// User identifier (for abuse detection)
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub user: Option<String>,
73    /// Response format for structured outputs (OpenAI only)
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub response_format: Option<ResponseFormat>,
76    /// Tool definitions for tool calling.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub tools: Option<Vec<ToolDefinition>>,
79    /// Tool choice configuration.
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub tool_choice: Option<ToolChoice>,
82    /// System instructions (Responses API)
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub instructions: Option<String>,
85    /// Previous response ID for multi-turn (Responses API)
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub previous_response_id: Option<String>,
88    /// Whether to store the response (Responses API)
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub store: Option<bool>,
91}
92
93impl CompletionRequest {
94    /// Create a new request with the given model and default fields.
95    pub fn new(model: impl Into<String>) -> Self {
96        Self {
97            messages: Vec::new(),
98            model: model.into(),
99            max_tokens: None,
100            temperature: None,
101            top_p: None,
102            stream: None,
103            n: None,
104            stop: None,
105            presence_penalty: None,
106            frequency_penalty: None,
107            user: None,
108            response_format: None,
109            tools: None,
110            tool_choice: None,
111            instructions: None,
112            previous_response_id: None,
113            store: None,
114        }
115    }
116
117    /// Set messages on the request (builder-style).
118    pub fn messages(mut self, messages: Vec<Message>) -> Self {
119        self.messages = messages;
120        self
121    }
122
123    /// Set system instructions (Responses API).
124    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
125        self.instructions = Some(instructions.into());
126        self
127    }
128
129    /// Set previous response ID for multi-turn (Responses API).
130    pub fn previous_response_id(mut self, id: impl Into<String>) -> Self {
131        self.previous_response_id = Some(id.into());
132        self
133    }
134
135    /// Set whether to store the response (Responses API).
136    pub fn store(mut self, store: bool) -> Self {
137        self.store = Some(store);
138        self
139    }
140
141    /// Create a new builder.
142    ///
143    /// # Example
144    /// ```
145    /// use simple_agent_type::request::CompletionRequest;
146    /// use simple_agent_type::message::Message;
147    ///
148    /// let request = CompletionRequest::builder()
149    ///     .model("gpt-4")
150    ///     .message(Message::user("Hello!"))
151    ///     .build()
152    ///     .unwrap();
153    /// ```
154    pub fn builder() -> CompletionRequestBuilder {
155        CompletionRequestBuilder::default()
156    }
157
158    /// Validate the request.
159    ///
160    /// # Validation Rules
161    /// - Messages: 1-1000 items, each < 1MB
162    /// - Model: alphanumeric + `-_./` only
163    /// - Temperature: 0.0-2.0
164    /// - Top_p: 0.0-1.0
165    /// - No null bytes (security)
166    pub fn validate(&self) -> Result<()> {
167        // Validate messages
168        if self.messages.is_empty() {
169            return Err(ValidationError::Empty {
170                field: "messages".to_string(),
171            }
172            .into());
173        }
174
175        if self.messages.len() > 1000 {
176            return Err(ValidationError::TooLong {
177                field: "messages".to_string(),
178                max: 1000,
179            }
180            .into());
181        }
182
183        // Validate each message content size (max 1MB)
184        const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
185        for (i, msg) in self.messages.iter().enumerate() {
186            if msg.content.text_len() > MAX_MESSAGE_SIZE {
187                return Err(ValidationError::TooLong {
188                    field: format!("messages[{}].content", i),
189                    max: MAX_MESSAGE_SIZE,
190                }
191                .into());
192            }
193
194            if msg.content.contains_null() {
195                return Err(ValidationError::InvalidFormat {
196                    field: format!("messages[{}].content", i),
197                    reason: "contains null bytes".to_string(),
198                }
199                .into());
200            }
201        }
202
203        // Validate total request size (max 10MB)
204        const MAX_TOTAL_REQUEST_SIZE: usize = 10 * 1024 * 1024;
205        let total_size: usize = self.messages.iter().map(|m| m.content.text_len()).sum();
206        if total_size > MAX_TOTAL_REQUEST_SIZE {
207            return Err(ValidationError::TooLong {
208                field: "total_request_size".to_string(),
209                max: MAX_TOTAL_REQUEST_SIZE,
210            }
211            .into());
212        }
213
214        // Validate model
215        if self.model.is_empty() {
216            return Err(ValidationError::Empty {
217                field: "model".to_string(),
218            }
219            .into());
220        }
221
222        // Model must be alphanumeric + `-_./`
223        if !self
224            .model
225            .chars()
226            .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '/')
227        {
228            return Err(ValidationError::InvalidFormat {
229                field: "model".to_string(),
230                reason: "must be alphanumeric with -_./ only".to_string(),
231            }
232            .into());
233        }
234
235        // Validate temperature
236        if let Some(temp) = self.temperature {
237            if !(0.0..=2.0).contains(&temp) {
238                return Err(ValidationError::OutOfRange {
239                    field: "temperature".to_string(),
240                    min: 0.0,
241                    max: 2.0,
242                }
243                .into());
244            }
245        }
246
247        // Validate top_p
248        if let Some(top_p) = self.top_p {
249            if !(0.0..=1.0).contains(&top_p) {
250                return Err(ValidationError::OutOfRange {
251                    field: "top_p".to_string(),
252                    min: 0.0,
253                    max: 1.0,
254                }
255                .into());
256            }
257        }
258
259        // Validate presence_penalty
260        if let Some(penalty) = self.presence_penalty {
261            if !(-2.0..=2.0).contains(&penalty) {
262                return Err(ValidationError::OutOfRange {
263                    field: "presence_penalty".to_string(),
264                    min: -2.0,
265                    max: 2.0,
266                }
267                .into());
268            }
269        }
270
271        // Validate frequency_penalty
272        if let Some(penalty) = self.frequency_penalty {
273            if !(-2.0..=2.0).contains(&penalty) {
274                return Err(ValidationError::OutOfRange {
275                    field: "frequency_penalty".to_string(),
276                    min: -2.0,
277                    max: 2.0,
278                }
279                .into());
280            }
281        }
282
283        Ok(())
284    }
285}
286
287/// Builder for CompletionRequest.
288#[derive(Debug, Default, Clone)]
289pub struct CompletionRequestBuilder {
290    messages: Vec<Message>,
291    model: Option<String>,
292    max_tokens: Option<u32>,
293    temperature: Option<f32>,
294    top_p: Option<f32>,
295    stream: Option<bool>,
296    n: Option<u32>,
297    stop: Option<Vec<String>>,
298    presence_penalty: Option<f32>,
299    frequency_penalty: Option<f32>,
300    user: Option<String>,
301    response_format: Option<ResponseFormat>,
302    tools: Option<Vec<ToolDefinition>>,
303    tool_choice: Option<ToolChoice>,
304    instructions: Option<String>,
305    previous_response_id: Option<String>,
306    store: Option<bool>,
307}
308
309impl CompletionRequestBuilder {
310    /// Set the model.
311    pub fn model(mut self, model: impl Into<String>) -> Self {
312        self.model = Some(model.into());
313        self
314    }
315
316    /// Add a message.
317    pub fn message(mut self, message: Message) -> Self {
318        self.messages.push(message);
319        self
320    }
321
322    /// Set all messages at once.
323    pub fn messages(mut self, messages: Vec<Message>) -> Self {
324        self.messages = messages;
325        self
326    }
327
328    /// Set max_tokens.
329    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
330        self.max_tokens = Some(max_tokens);
331        self
332    }
333
334    /// Set temperature.
335    pub fn temperature(mut self, temperature: f32) -> Self {
336        self.temperature = Some(temperature);
337        self
338    }
339
340    /// Set top_p.
341    pub fn top_p(mut self, top_p: f32) -> Self {
342        self.top_p = Some(top_p);
343        self
344    }
345
346    /// Enable streaming.
347    pub fn stream(mut self, stream: bool) -> Self {
348        self.stream = Some(stream);
349        self
350    }
351
352    /// Set number of completions.
353    pub fn n(mut self, n: u32) -> Self {
354        self.n = Some(n);
355        self
356    }
357
358    /// Set stop sequences.
359    pub fn stop(mut self, stop: Vec<String>) -> Self {
360        self.stop = Some(stop);
361        self
362    }
363
364    /// Set presence penalty.
365    pub fn presence_penalty(mut self, penalty: f32) -> Self {
366        self.presence_penalty = Some(penalty);
367        self
368    }
369
370    /// Set frequency penalty.
371    pub fn frequency_penalty(mut self, penalty: f32) -> Self {
372        self.frequency_penalty = Some(penalty);
373        self
374    }
375
376    /// Set user identifier.
377    pub fn user(mut self, user: impl Into<String>) -> Self {
378        self.user = Some(user.into());
379        self
380    }
381
382    /// Set response format.
383    pub fn response_format(mut self, format: ResponseFormat) -> Self {
384        self.response_format = Some(format);
385        self
386    }
387
388    /// Set tool definitions for tool calling.
389    pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
390        self.tools = Some(tools);
391        self
392    }
393
394    /// Set tool choice configuration.
395    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
396        self.tool_choice = Some(tool_choice);
397        self
398    }
399
400    /// Set system instructions (Responses API).
401    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
402        self.instructions = Some(instructions.into());
403        self
404    }
405
406    /// Set previous response ID for multi-turn (Responses API).
407    pub fn previous_response_id(mut self, id: impl Into<String>) -> Self {
408        self.previous_response_id = Some(id.into());
409        self
410    }
411
412    /// Set whether to store the response (Responses API).
413    pub fn store(mut self, store: bool) -> Self {
414        self.store = Some(store);
415        self
416    }
417
418    /// Enable JSON object mode (no schema validation).
419    pub fn json_mode(mut self) -> Self {
420        self.response_format = Some(ResponseFormat::JsonObject);
421        self
422    }
423
424    /// Enable structured output with JSON schema.
425    pub fn json_schema(mut self, name: impl Into<String>, schema: Value) -> Self {
426        self.response_format = Some(ResponseFormat::JsonSchema {
427            json_schema: JsonSchemaFormat {
428                name: name.into(),
429                schema,
430                strict: Some(true),
431            },
432        });
433        self
434    }
435
436    /// Build and validate the request.
437    pub fn build(self) -> Result<CompletionRequest> {
438        let model = self.model.ok_or_else(|| ValidationError::Empty {
439            field: "model".to_string(),
440        })?;
441
442        let request = CompletionRequest {
443            messages: self.messages,
444            model,
445            max_tokens: self.max_tokens,
446            temperature: self.temperature,
447            top_p: self.top_p,
448            stream: self.stream,
449            n: self.n,
450            stop: self.stop,
451            presence_penalty: self.presence_penalty,
452            frequency_penalty: self.frequency_penalty,
453            user: self.user,
454            response_format: self.response_format,
455            tools: self.tools,
456            tool_choice: self.tool_choice,
457            instructions: self.instructions,
458            previous_response_id: self.previous_response_id,
459            store: self.store,
460        };
461
462        request.validate()?;
463        Ok(request)
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    use crate::message::MessageContent;
471
472    #[test]
473    fn test_builder_basic() {
474        let request = CompletionRequest::builder()
475            .model("gpt-4")
476            .message(Message::user("Hello"))
477            .build()
478            .unwrap();
479
480        assert_eq!(request.model, "gpt-4");
481        assert_eq!(request.messages.len(), 1);
482        assert_eq!(
483            request.messages[0].content,
484            MessageContent::Text("Hello".to_string())
485        );
486    }
487
488    #[test]
489    fn test_builder_all_fields() {
490        let request = CompletionRequest::builder()
491            .model("gpt-4")
492            .message(Message::user("Hello"))
493            .max_tokens(100)
494            .temperature(0.7)
495            .top_p(0.9)
496            .stream(true)
497            .n(1)
498            .stop(vec!["END".to_string()])
499            .presence_penalty(0.5)
500            .frequency_penalty(0.5)
501            .user("test-user")
502            .build()
503            .unwrap();
504
505        assert_eq!(request.max_tokens, Some(100));
506        assert_eq!(request.temperature, Some(0.7));
507        assert_eq!(request.top_p, Some(0.9));
508        assert_eq!(request.stream, Some(true));
509        assert_eq!(request.n, Some(1));
510        assert_eq!(request.stop, Some(vec!["END".to_string()]));
511        assert_eq!(request.presence_penalty, Some(0.5));
512        assert_eq!(request.frequency_penalty, Some(0.5));
513        assert_eq!(request.user, Some("test-user".to_string()));
514    }
515
516    #[test]
517    fn test_builder_missing_model() {
518        let result = CompletionRequest::builder()
519            .message(Message::user("Hello"))
520            .build();
521        assert!(result.is_err());
522    }
523
524    #[test]
525    fn test_validation_empty_messages() {
526        let result = CompletionRequest::builder().model("gpt-4").build();
527        assert!(result.is_err());
528    }
529
530    #[test]
531    fn test_validation_invalid_temperature() {
532        let result = CompletionRequest::builder()
533            .model("gpt-4")
534            .message(Message::user("Hello"))
535            .temperature(3.0)
536            .build();
537        assert!(result.is_err());
538    }
539
540    #[test]
541    fn test_validation_invalid_top_p() {
542        let result = CompletionRequest::builder()
543            .model("gpt-4")
544            .message(Message::user("Hello"))
545            .top_p(1.5)
546            .build();
547        assert!(result.is_err());
548    }
549
550    #[test]
551    fn test_validation_invalid_model_chars() {
552        let result = CompletionRequest::builder()
553            .model("gpt-4!")
554            .message(Message::user("Hello"))
555            .build();
556        assert!(result.is_err());
557    }
558
559    #[test]
560    fn test_serialization() {
561        let request = CompletionRequest::builder()
562            .model("gpt-4")
563            .message(Message::user("Hello"))
564            .temperature(0.7)
565            .build()
566            .unwrap();
567
568        let json = serde_json::to_string(&request).unwrap();
569        let parsed: CompletionRequest = serde_json::from_str(&json).unwrap();
570        assert_eq!(request, parsed);
571    }
572
573    #[test]
574    fn test_optional_fields_not_serialized() {
575        let request = CompletionRequest::builder()
576            .model("gpt-4")
577            .message(Message::user("Hello"))
578            .build()
579            .unwrap();
580
581        let json = serde_json::to_value(&request).unwrap();
582        assert!(json.get("max_tokens").is_none());
583        assert!(json.get("temperature").is_none());
584    }
585
586    #[test]
587    fn test_validation_total_request_size_limit() {
588        // Create a request that exceeds 10MB total
589        let large_content = "x".repeat(2 * 1024 * 1024); // 2MB per message
590        let result = CompletionRequest::builder()
591            .model("gpt-4")
592            .message(Message::user(large_content.clone()))
593            .message(Message::user(large_content.clone()))
594            .message(Message::user(large_content.clone()))
595            .message(Message::user(large_content.clone()))
596            .message(Message::user(large_content.clone()))
597            .message(Message::user(large_content.clone())) // 6 * 2MB = 12MB > 10MB
598            .build();
599
600        assert!(result.is_err());
601        assert!(matches!(
602            result.unwrap_err(),
603            crate::error::SimpleAgentsError::Validation(ValidationError::TooLong { .. })
604        ));
605    }
606
607    #[test]
608    fn test_responses_api_fields() {
609        let req = CompletionRequest::new("gpt-4o")
610            .messages(vec![Message::user("hello")])
611            .instructions("You are helpful")
612            .store(true)
613            .previous_response_id("resp_abc");
614        assert_eq!(req.instructions.as_deref(), Some("You are helpful"));
615        assert_eq!(req.store, Some(true));
616        assert_eq!(req.previous_response_id.as_deref(), Some("resp_abc"));
617    }
618
619    #[test]
620    fn test_validation_total_request_size_within_limit() {
621        // Create a request that's under 10MB total
622        let content = "x".repeat(1024 * 1024); // 1MB per message
623        let result = CompletionRequest::builder()
624            .model("gpt-4")
625            .message(Message::user(content.clone()))
626            .message(Message::user(content.clone()))
627            .message(Message::user(content.clone())) // 3MB < 10MB
628            .build();
629
630        assert!(result.is_ok());
631    }
632}