Skip to main content

serdes_ai_core/
settings.rs

1//! Model settings and configuration.
2//!
3//! This module provides the `ModelSettings` type for configuring model behavior,
4//! including temperature, token limits, and other generation parameters.
5
6use serde::{Deserialize, Serialize};
7use std::time::Duration;
8
9/// Settings for model generation.
10#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
11pub struct ModelSettings {
12    /// Maximum tokens to generate.
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub max_tokens: Option<u64>,
15
16    /// Sampling temperature (0.0 to 2.0 typically).
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub temperature: Option<f64>,
19
20    /// Top-p (nucleus) sampling.
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub top_p: Option<f64>,
23
24    /// Top-k sampling.
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub top_k: Option<u64>,
27
28    /// Frequency penalty (-2.0 to 2.0).
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub frequency_penalty: Option<f64>,
31
32    /// Presence penalty (-2.0 to 2.0).
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub presence_penalty: Option<f64>,
35
36    /// Stop sequences.
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub stop: Option<Vec<String>>,
39
40    /// Random seed for reproducibility.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub seed: Option<u64>,
43
44    /// Request timeout.
45    #[serde(
46        skip_serializing_if = "Option::is_none",
47        with = "option_duration_serde"
48    )]
49    pub timeout: Option<Duration>,
50
51    /// Whether to allow parallel tool calls.
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub parallel_tool_calls: Option<bool>,
54
55    /// Extra provider-specific settings.
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub extra: Option<serde_json::Value>,
58}
59
60impl ModelSettings {
61    /// Create new empty settings.
62    #[must_use]
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    /// Set max tokens.
68    #[must_use]
69    pub fn max_tokens(mut self, tokens: u64) -> Self {
70        self.max_tokens = Some(tokens);
71        self
72    }
73
74    /// Set temperature.
75    #[must_use]
76    pub fn temperature(mut self, temp: f64) -> Self {
77        self.temperature = Some(temp);
78        self
79    }
80
81    /// Set top-p.
82    #[must_use]
83    pub fn top_p(mut self, p: f64) -> Self {
84        self.top_p = Some(p);
85        self
86    }
87
88    /// Set top-k.
89    #[must_use]
90    pub fn top_k(mut self, k: u64) -> Self {
91        self.top_k = Some(k);
92        self
93    }
94
95    /// Set frequency penalty.
96    #[must_use]
97    pub fn frequency_penalty(mut self, penalty: f64) -> Self {
98        self.frequency_penalty = Some(penalty);
99        self
100    }
101
102    /// Set presence penalty.
103    #[must_use]
104    pub fn presence_penalty(mut self, penalty: f64) -> Self {
105        self.presence_penalty = Some(penalty);
106        self
107    }
108
109    /// Set stop sequences.
110    #[must_use]
111    pub fn stop(mut self, sequences: Vec<String>) -> Self {
112        self.stop = Some(sequences);
113        self
114    }
115
116    /// Add a stop sequence.
117    #[must_use]
118    pub fn add_stop(mut self, sequence: impl Into<String>) -> Self {
119        self.stop.get_or_insert_with(Vec::new).push(sequence.into());
120        self
121    }
122
123    /// Set seed.
124    #[must_use]
125    pub fn seed(mut self, seed: u64) -> Self {
126        self.seed = Some(seed);
127        self
128    }
129
130    /// Set timeout.
131    #[must_use]
132    pub fn timeout(mut self, timeout: Duration) -> Self {
133        self.timeout = Some(timeout);
134        self
135    }
136
137    /// Set timeout in seconds.
138    #[must_use]
139    pub fn timeout_secs(self, secs: u64) -> Self {
140        self.timeout(Duration::from_secs(secs))
141    }
142
143    /// Set parallel tool calls.
144    #[must_use]
145    pub fn parallel_tool_calls(mut self, parallel: bool) -> Self {
146        self.parallel_tool_calls = Some(parallel);
147        self
148    }
149
150    /// Set extra settings.
151    #[must_use]
152    pub fn extra(mut self, extra: serde_json::Value) -> Self {
153        self.extra = Some(extra);
154        self
155    }
156
157    /// Merge with another settings, preferring values from `other`.
158    ///
159    /// Values in `other` override values in `self` when both are present.
160    #[must_use]
161    pub fn merge(&self, other: &ModelSettings) -> ModelSettings {
162        ModelSettings {
163            max_tokens: other.max_tokens.or(self.max_tokens),
164            temperature: other.temperature.or(self.temperature),
165            top_p: other.top_p.or(self.top_p),
166            top_k: other.top_k.or(self.top_k),
167            frequency_penalty: other.frequency_penalty.or(self.frequency_penalty),
168            presence_penalty: other.presence_penalty.or(self.presence_penalty),
169            stop: other.stop.clone().or_else(|| self.stop.clone()),
170            seed: other.seed.or(self.seed),
171            timeout: other.timeout.or(self.timeout),
172            parallel_tool_calls: other.parallel_tool_calls.or(self.parallel_tool_calls),
173            extra: match (&self.extra, &other.extra) {
174                (Some(a), Some(b)) => Some(merge_json(a, b)),
175                (_, Some(b)) => Some(b.clone()),
176                (Some(a), None) => Some(a.clone()),
177                (None, None) => None,
178            },
179        }
180    }
181
182    /// Check if all settings are None.
183    #[must_use]
184    pub fn is_empty(&self) -> bool {
185        self.max_tokens.is_none()
186            && self.temperature.is_none()
187            && self.top_p.is_none()
188            && self.top_k.is_none()
189            && self.frequency_penalty.is_none()
190            && self.presence_penalty.is_none()
191            && self.stop.is_none()
192            && self.seed.is_none()
193            && self.timeout.is_none()
194            && self.parallel_tool_calls.is_none()
195            && self.extra.is_none()
196    }
197}
198
199/// Merge two JSON values, with `b` taking precedence.
200fn merge_json(a: &serde_json::Value, b: &serde_json::Value) -> serde_json::Value {
201    use serde_json::Value;
202    match (a, b) {
203        (Value::Object(a_obj), Value::Object(b_obj)) => {
204            let mut result = a_obj.clone();
205            for (k, v) in b_obj {
206                result.insert(
207                    k.clone(),
208                    if let Some(existing) = a_obj.get(k) {
209                        merge_json(existing, v)
210                    } else {
211                        v.clone()
212                    },
213                );
214            }
215            Value::Object(result)
216        }
217        (_, b) => b.clone(),
218    }
219}
220
221/// Serde helper for optional Duration.
222mod option_duration_serde {
223    use serde::{Deserialize, Deserializer, Serialize, Serializer};
224    use std::time::Duration;
225
226    pub fn serialize<S>(duration: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
227    where
228        S: Serializer,
229    {
230        match duration {
231            Some(d) => d.as_secs_f64().serialize(serializer),
232            None => serializer.serialize_none(),
233        }
234    }
235
236    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
237    where
238        D: Deserializer<'de>,
239    {
240        let opt: Option<f64> = Option::deserialize(deserializer)?;
241        Ok(opt.map(Duration::from_secs_f64))
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_model_settings_new() {
251        let settings = ModelSettings::new();
252        assert!(settings.is_empty());
253    }
254
255    #[test]
256    fn test_model_settings_builder() {
257        let settings = ModelSettings::new()
258            .max_tokens(1000)
259            .temperature(0.7)
260            .top_p(0.9)
261            .seed(42);
262
263        assert_eq!(settings.max_tokens, Some(1000));
264        assert_eq!(settings.temperature, Some(0.7));
265        assert_eq!(settings.top_p, Some(0.9));
266        assert_eq!(settings.seed, Some(42));
267    }
268
269    #[test]
270    fn test_model_settings_stop() {
271        let settings = ModelSettings::new().add_stop("\n\n").add_stop("END");
272
273        assert_eq!(
274            settings.stop,
275            Some(vec!["\n\n".to_string(), "END".to_string()])
276        );
277    }
278
279    #[test]
280    fn test_model_settings_merge() {
281        let base = ModelSettings::new().max_tokens(1000).temperature(0.5);
282
283        let override_settings = ModelSettings::new().temperature(0.8).top_p(0.9);
284
285        let merged = base.merge(&override_settings);
286
287        assert_eq!(merged.max_tokens, Some(1000)); // from base
288        assert_eq!(merged.temperature, Some(0.8)); // overridden
289        assert_eq!(merged.top_p, Some(0.9)); // from override
290    }
291
292    #[test]
293    fn test_model_settings_timeout() {
294        let settings = ModelSettings::new().timeout_secs(30);
295        assert_eq!(settings.timeout, Some(Duration::from_secs(30)));
296    }
297
298    #[test]
299    fn test_serde_roundtrip() {
300        let settings = ModelSettings::new()
301            .max_tokens(1000)
302            .temperature(0.7)
303            .timeout_secs(30);
304
305        let json = serde_json::to_string(&settings).unwrap();
306        let parsed: ModelSettings = serde_json::from_str(&json).unwrap();
307
308        assert_eq!(settings.max_tokens, parsed.max_tokens);
309        assert_eq!(settings.temperature, parsed.temperature);
310        // Duration comparison (might have slight floating point differences)
311        assert!(parsed.timeout.is_some());
312    }
313}