serdes_ai_core/
settings.rs1use serde::{Deserialize, Serialize};
7use std::time::Duration;
8
9#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
11pub struct ModelSettings {
12 #[serde(skip_serializing_if = "Option::is_none")]
14 pub max_tokens: Option<u64>,
15
16 #[serde(skip_serializing_if = "Option::is_none")]
18 pub temperature: Option<f64>,
19
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub top_p: Option<f64>,
23
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub top_k: Option<u64>,
27
28 #[serde(skip_serializing_if = "Option::is_none")]
30 pub frequency_penalty: Option<f64>,
31
32 #[serde(skip_serializing_if = "Option::is_none")]
34 pub presence_penalty: Option<f64>,
35
36 #[serde(skip_serializing_if = "Option::is_none")]
38 pub stop: Option<Vec<String>>,
39
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub seed: Option<u64>,
43
44 #[serde(
46 skip_serializing_if = "Option::is_none",
47 with = "option_duration_serde"
48 )]
49 pub timeout: Option<Duration>,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub parallel_tool_calls: Option<bool>,
54
55 #[serde(skip_serializing_if = "Option::is_none")]
57 pub extra: Option<serde_json::Value>,
58}
59
60impl ModelSettings {
61 #[must_use]
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 #[must_use]
69 pub fn max_tokens(mut self, tokens: u64) -> Self {
70 self.max_tokens = Some(tokens);
71 self
72 }
73
74 #[must_use]
76 pub fn temperature(mut self, temp: f64) -> Self {
77 self.temperature = Some(temp);
78 self
79 }
80
81 #[must_use]
83 pub fn top_p(mut self, p: f64) -> Self {
84 self.top_p = Some(p);
85 self
86 }
87
88 #[must_use]
90 pub fn top_k(mut self, k: u64) -> Self {
91 self.top_k = Some(k);
92 self
93 }
94
95 #[must_use]
97 pub fn frequency_penalty(mut self, penalty: f64) -> Self {
98 self.frequency_penalty = Some(penalty);
99 self
100 }
101
102 #[must_use]
104 pub fn presence_penalty(mut self, penalty: f64) -> Self {
105 self.presence_penalty = Some(penalty);
106 self
107 }
108
109 #[must_use]
111 pub fn stop(mut self, sequences: Vec<String>) -> Self {
112 self.stop = Some(sequences);
113 self
114 }
115
116 #[must_use]
118 pub fn add_stop(mut self, sequence: impl Into<String>) -> Self {
119 self.stop.get_or_insert_with(Vec::new).push(sequence.into());
120 self
121 }
122
123 #[must_use]
125 pub fn seed(mut self, seed: u64) -> Self {
126 self.seed = Some(seed);
127 self
128 }
129
130 #[must_use]
132 pub fn timeout(mut self, timeout: Duration) -> Self {
133 self.timeout = Some(timeout);
134 self
135 }
136
137 #[must_use]
139 pub fn timeout_secs(self, secs: u64) -> Self {
140 self.timeout(Duration::from_secs(secs))
141 }
142
143 #[must_use]
145 pub fn parallel_tool_calls(mut self, parallel: bool) -> Self {
146 self.parallel_tool_calls = Some(parallel);
147 self
148 }
149
150 #[must_use]
152 pub fn extra(mut self, extra: serde_json::Value) -> Self {
153 self.extra = Some(extra);
154 self
155 }
156
157 #[must_use]
161 pub fn merge(&self, other: &ModelSettings) -> ModelSettings {
162 ModelSettings {
163 max_tokens: other.max_tokens.or(self.max_tokens),
164 temperature: other.temperature.or(self.temperature),
165 top_p: other.top_p.or(self.top_p),
166 top_k: other.top_k.or(self.top_k),
167 frequency_penalty: other.frequency_penalty.or(self.frequency_penalty),
168 presence_penalty: other.presence_penalty.or(self.presence_penalty),
169 stop: other.stop.clone().or_else(|| self.stop.clone()),
170 seed: other.seed.or(self.seed),
171 timeout: other.timeout.or(self.timeout),
172 parallel_tool_calls: other.parallel_tool_calls.or(self.parallel_tool_calls),
173 extra: match (&self.extra, &other.extra) {
174 (Some(a), Some(b)) => Some(merge_json(a, b)),
175 (_, Some(b)) => Some(b.clone()),
176 (Some(a), None) => Some(a.clone()),
177 (None, None) => None,
178 },
179 }
180 }
181
182 #[must_use]
184 pub fn is_empty(&self) -> bool {
185 self.max_tokens.is_none()
186 && self.temperature.is_none()
187 && self.top_p.is_none()
188 && self.top_k.is_none()
189 && self.frequency_penalty.is_none()
190 && self.presence_penalty.is_none()
191 && self.stop.is_none()
192 && self.seed.is_none()
193 && self.timeout.is_none()
194 && self.parallel_tool_calls.is_none()
195 && self.extra.is_none()
196 }
197}
198
199fn merge_json(a: &serde_json::Value, b: &serde_json::Value) -> serde_json::Value {
201 use serde_json::Value;
202 match (a, b) {
203 (Value::Object(a_obj), Value::Object(b_obj)) => {
204 let mut result = a_obj.clone();
205 for (k, v) in b_obj {
206 result.insert(
207 k.clone(),
208 if let Some(existing) = a_obj.get(k) {
209 merge_json(existing, v)
210 } else {
211 v.clone()
212 },
213 );
214 }
215 Value::Object(result)
216 }
217 (_, b) => b.clone(),
218 }
219}
220
221mod option_duration_serde {
223 use serde::{Deserialize, Deserializer, Serialize, Serializer};
224 use std::time::Duration;
225
226 pub fn serialize<S>(duration: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
227 where
228 S: Serializer,
229 {
230 match duration {
231 Some(d) => d.as_secs_f64().serialize(serializer),
232 None => serializer.serialize_none(),
233 }
234 }
235
236 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
237 where
238 D: Deserializer<'de>,
239 {
240 let opt: Option<f64> = Option::deserialize(deserializer)?;
241 Ok(opt.map(Duration::from_secs_f64))
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_model_settings_new() {
251 let settings = ModelSettings::new();
252 assert!(settings.is_empty());
253 }
254
255 #[test]
256 fn test_model_settings_builder() {
257 let settings = ModelSettings::new()
258 .max_tokens(1000)
259 .temperature(0.7)
260 .top_p(0.9)
261 .seed(42);
262
263 assert_eq!(settings.max_tokens, Some(1000));
264 assert_eq!(settings.temperature, Some(0.7));
265 assert_eq!(settings.top_p, Some(0.9));
266 assert_eq!(settings.seed, Some(42));
267 }
268
269 #[test]
270 fn test_model_settings_stop() {
271 let settings = ModelSettings::new().add_stop("\n\n").add_stop("END");
272
273 assert_eq!(
274 settings.stop,
275 Some(vec!["\n\n".to_string(), "END".to_string()])
276 );
277 }
278
279 #[test]
280 fn test_model_settings_merge() {
281 let base = ModelSettings::new().max_tokens(1000).temperature(0.5);
282
283 let override_settings = ModelSettings::new().temperature(0.8).top_p(0.9);
284
285 let merged = base.merge(&override_settings);
286
287 assert_eq!(merged.max_tokens, Some(1000)); assert_eq!(merged.temperature, Some(0.8)); assert_eq!(merged.top_p, Some(0.9)); }
291
292 #[test]
293 fn test_model_settings_timeout() {
294 let settings = ModelSettings::new().timeout_secs(30);
295 assert_eq!(settings.timeout, Some(Duration::from_secs(30)));
296 }
297
298 #[test]
299 fn test_serde_roundtrip() {
300 let settings = ModelSettings::new()
301 .max_tokens(1000)
302 .temperature(0.7)
303 .timeout_secs(30);
304
305 let json = serde_json::to_string(&settings).unwrap();
306 let parsed: ModelSettings = serde_json::from_str(&json).unwrap();
307
308 assert_eq!(settings.max_tokens, parsed.max_tokens);
309 assert_eq!(settings.temperature, parsed.temperature);
310 assert!(parsed.timeout.is_some());
312 }
313}