1use ordered_float::OrderedFloat;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct Variation {
9 pub parameter: ParameterKind,
10 pub value: VariationValue,
11}
12
13#[non_exhaustive]
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum ParameterKind {
17 Temperature,
18 TopP,
19 TopK,
20 FrequencyPenalty,
21 PresencePenalty,
22 RetrievalTopK,
23 SimilarityThreshold,
24 TemporalDecay,
25}
26
27impl ParameterKind {
28 #[must_use]
29 pub fn as_str(&self) -> &'static str {
30 #[allow(unreachable_patterns)]
31 match self {
32 Self::Temperature => "temperature",
33 Self::TopP => "top_p",
34 Self::TopK => "top_k",
35 Self::FrequencyPenalty => "frequency_penalty",
36 Self::PresencePenalty => "presence_penalty",
37 Self::RetrievalTopK => "retrieval_top_k",
38 Self::SimilarityThreshold => "similarity_threshold",
39 Self::TemporalDecay => "temporal_decay",
40 _ => "unknown",
41 }
42 }
43
44 #[must_use]
46 pub fn is_integer(&self) -> bool {
47 matches!(self, Self::TopK | Self::RetrievalTopK)
48 }
49}
50
51impl std::fmt::Display for ParameterKind {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.write_str(self.as_str())
54 }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
58#[serde(tag = "type", content = "value")]
59pub enum VariationValue {
60 Float(OrderedFloat<f64>),
61 Int(i64),
62}
63
64impl VariationValue {
65 #[must_use]
67 pub fn as_f64(&self) -> f64 {
68 match self {
69 Self::Float(f) => f.into_inner(),
70 #[allow(clippy::cast_precision_loss)]
71 Self::Int(i) => *i as f64,
72 }
73 }
74}
75
76impl From<f64> for VariationValue {
77 fn from(v: f64) -> Self {
78 Self::Float(OrderedFloat(v))
79 }
80}
81
82impl From<i64> for VariationValue {
83 fn from(v: i64) -> Self {
84 Self::Int(v)
85 }
86}
87
88impl std::fmt::Display for VariationValue {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 match self {
91 Self::Float(v) => write!(f, "{v}"),
92 Self::Int(v) => write!(f, "{v}"),
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct ExperimentResult {
99 pub id: i64,
100 pub session_id: String,
101 pub variation: Variation,
102 pub baseline_score: f64,
103 pub candidate_score: f64,
104 pub delta: f64,
105 pub latency_ms: u64,
106 pub tokens_used: u64,
107 pub accepted: bool,
108 pub source: ExperimentSource,
109 pub created_at: String,
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
113#[serde(rename_all = "snake_case")]
114pub enum ExperimentSource {
115 Manual,
116 Scheduled,
117}
118
119impl ExperimentSource {
120 #[must_use]
121 pub fn as_str(&self) -> &'static str {
122 match self {
123 Self::Manual => "manual",
124 Self::Scheduled => "scheduled",
125 }
126 }
127}
128
129impl std::fmt::Display for ExperimentSource {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 f.write_str(self.as_str())
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 #![allow(clippy::approx_constant)]
138
139 use super::*;
140
141 #[test]
142 fn parameter_kind_as_str_all_variants() {
143 let cases = [
144 (ParameterKind::Temperature, "temperature"),
145 (ParameterKind::TopP, "top_p"),
146 (ParameterKind::TopK, "top_k"),
147 (ParameterKind::FrequencyPenalty, "frequency_penalty"),
148 (ParameterKind::PresencePenalty, "presence_penalty"),
149 (ParameterKind::RetrievalTopK, "retrieval_top_k"),
150 (ParameterKind::SimilarityThreshold, "similarity_threshold"),
151 (ParameterKind::TemporalDecay, "temporal_decay"),
152 ];
153 for (kind, expected) in cases {
154 assert_eq!(kind.as_str(), expected);
155 assert_eq!(kind.to_string(), expected);
156 }
157 }
158
159 #[test]
160 fn parameter_kind_is_integer() {
161 assert!(ParameterKind::TopK.is_integer());
162 assert!(ParameterKind::RetrievalTopK.is_integer());
163 assert!(!ParameterKind::Temperature.is_integer());
164 assert!(!ParameterKind::TopP.is_integer());
165 assert!(!ParameterKind::FrequencyPenalty.is_integer());
166 assert!(!ParameterKind::PresencePenalty.is_integer());
167 assert!(!ParameterKind::SimilarityThreshold.is_integer());
168 assert!(!ParameterKind::TemporalDecay.is_integer());
169 }
170
171 #[test]
172 fn variation_value_as_f64_float() {
173 let v = VariationValue::Float(OrderedFloat(3.14));
174 assert!((v.as_f64() - 3.14).abs() < f64::EPSILON);
175 }
176
177 #[test]
178 fn variation_value_as_f64_int() {
179 let v = VariationValue::Int(42);
180 assert!((v.as_f64() - 42.0).abs() < f64::EPSILON);
181 }
182
183 #[test]
184 fn variation_value_from_f64() {
185 let v = VariationValue::from(0.7_f64);
186 assert!(matches!(v, VariationValue::Float(_)));
187 assert!((v.as_f64() - 0.7).abs() < f64::EPSILON);
188 }
189
190 #[test]
191 fn variation_value_from_i64() {
192 let v = VariationValue::from(40_i64);
193 assert!(matches!(v, VariationValue::Int(40)));
194 assert!((v.as_f64() - 40.0).abs() < f64::EPSILON);
195 }
196
197 #[test]
198 fn variation_value_float_hash_eq() {
199 use std::collections::HashSet;
200 let a = VariationValue::Float(OrderedFloat(0.7));
201 let b = VariationValue::Float(OrderedFloat(0.7));
202 let c = VariationValue::Float(OrderedFloat(0.8));
203 let mut set = HashSet::new();
204 set.insert(a.clone());
205 assert!(set.contains(&b));
206 assert!(!set.contains(&c));
207 }
208
209 #[test]
210 fn variation_serde_roundtrip() {
211 let v = Variation {
212 parameter: ParameterKind::Temperature,
213 value: VariationValue::Float(OrderedFloat(0.7)),
214 };
215 let json = serde_json::to_string(&v).expect("serialize");
216 let v2: Variation = serde_json::from_str(&json).expect("deserialize");
217 assert_eq!(v, v2);
218 }
219
220 #[test]
221 fn experiment_source_as_str() {
222 assert_eq!(ExperimentSource::Manual.as_str(), "manual");
223 assert_eq!(ExperimentSource::Scheduled.as_str(), "scheduled");
224 assert_eq!(ExperimentSource::Manual.to_string(), "manual");
225 assert_eq!(ExperimentSource::Scheduled.to_string(), "scheduled");
226 }
227
228 #[test]
229 fn variation_value_int_display() {
230 let v = VariationValue::Int(42);
231 assert_eq!(v.to_string(), "42");
232 }
233
234 #[test]
235 fn experiment_result_serde_roundtrip() {
236 let result = ExperimentResult {
237 id: 1,
238 session_id: "sess-abc".to_string(),
239 variation: Variation {
240 parameter: ParameterKind::Temperature,
241 value: VariationValue::Float(OrderedFloat(0.7)),
242 },
243 baseline_score: 7.0,
244 candidate_score: 8.0,
245 delta: 1.0,
246 latency_ms: 500,
247 tokens_used: 1_000,
248 accepted: true,
249 source: ExperimentSource::Manual,
250 created_at: "2026-03-07 22:00:00".to_string(),
251 };
252 let json = serde_json::to_string(&result).expect("serialize");
253 let parsed: serde_json::Value = serde_json::from_str(&json).expect("parse");
254 assert_eq!(parsed["id"], 1);
255 assert_eq!(parsed["session_id"], "sess-abc");
256 assert_eq!(parsed["accepted"], true);
257 assert_eq!(parsed["source"], "manual");
258 assert_eq!(parsed["variation"]["parameter"], "temperature");
259
260 let result2: ExperimentResult = serde_json::from_str(&json).expect("deserialize");
261 assert_eq!(result2.id, result.id);
262 assert_eq!(result2.session_id, result.session_id);
263 assert_eq!(result2.variation, result.variation);
264 assert!(result2.accepted);
265 assert_eq!(result2.source, ExperimentSource::Manual);
266 }
267}