Skip to main content

zeph_experiments/
types.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use ordered_float::OrderedFloat;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct Variation {
9    pub parameter: ParameterKind,
10    pub value: VariationValue,
11}
12
13#[non_exhaustive]
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum ParameterKind {
17    Temperature,
18    TopP,
19    TopK,
20    FrequencyPenalty,
21    PresencePenalty,
22    RetrievalTopK,
23    SimilarityThreshold,
24    TemporalDecay,
25}
26
27impl ParameterKind {
28    #[must_use]
29    pub fn as_str(&self) -> &'static str {
30        #[allow(unreachable_patterns)]
31        match self {
32            Self::Temperature => "temperature",
33            Self::TopP => "top_p",
34            Self::TopK => "top_k",
35            Self::FrequencyPenalty => "frequency_penalty",
36            Self::PresencePenalty => "presence_penalty",
37            Self::RetrievalTopK => "retrieval_top_k",
38            Self::SimilarityThreshold => "similarity_threshold",
39            Self::TemporalDecay => "temporal_decay",
40            _ => "unknown",
41        }
42    }
43
44    /// Returns `true` if this parameter has integer semantics (e.g. `TopK`, `RetrievalTopK`).
45    #[must_use]
46    pub fn is_integer(&self) -> bool {
47        matches!(self, Self::TopK | Self::RetrievalTopK)
48    }
49}
50
51impl std::fmt::Display for ParameterKind {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.write_str(self.as_str())
54    }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
58#[serde(tag = "type", content = "value")]
59pub enum VariationValue {
60    Float(OrderedFloat<f64>),
61    Int(i64),
62}
63
64impl VariationValue {
65    /// Return the value as `f64`. `Int` variants are cast to `f64`.
66    #[must_use]
67    pub fn as_f64(&self) -> f64 {
68        match self {
69            Self::Float(f) => f.into_inner(),
70            #[allow(clippy::cast_precision_loss)]
71            Self::Int(i) => *i as f64,
72        }
73    }
74}
75
76impl From<f64> for VariationValue {
77    fn from(v: f64) -> Self {
78        Self::Float(OrderedFloat(v))
79    }
80}
81
82impl From<i64> for VariationValue {
83    fn from(v: i64) -> Self {
84        Self::Int(v)
85    }
86}
87
88impl std::fmt::Display for VariationValue {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        match self {
91            Self::Float(v) => write!(f, "{v}"),
92            Self::Int(v) => write!(f, "{v}"),
93        }
94    }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct ExperimentResult {
99    pub id: i64,
100    pub session_id: String,
101    pub variation: Variation,
102    pub baseline_score: f64,
103    pub candidate_score: f64,
104    pub delta: f64,
105    pub latency_ms: u64,
106    pub tokens_used: u64,
107    pub accepted: bool,
108    pub source: ExperimentSource,
109    pub created_at: String,
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
113#[serde(rename_all = "snake_case")]
114pub enum ExperimentSource {
115    Manual,
116    Scheduled,
117}
118
119impl ExperimentSource {
120    #[must_use]
121    pub fn as_str(&self) -> &'static str {
122        match self {
123            Self::Manual => "manual",
124            Self::Scheduled => "scheduled",
125        }
126    }
127}
128
129impl std::fmt::Display for ExperimentSource {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.write_str(self.as_str())
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    #![allow(clippy::approx_constant)]
138
139    use super::*;
140
141    #[test]
142    fn parameter_kind_as_str_all_variants() {
143        let cases = [
144            (ParameterKind::Temperature, "temperature"),
145            (ParameterKind::TopP, "top_p"),
146            (ParameterKind::TopK, "top_k"),
147            (ParameterKind::FrequencyPenalty, "frequency_penalty"),
148            (ParameterKind::PresencePenalty, "presence_penalty"),
149            (ParameterKind::RetrievalTopK, "retrieval_top_k"),
150            (ParameterKind::SimilarityThreshold, "similarity_threshold"),
151            (ParameterKind::TemporalDecay, "temporal_decay"),
152        ];
153        for (kind, expected) in cases {
154            assert_eq!(kind.as_str(), expected);
155            assert_eq!(kind.to_string(), expected);
156        }
157    }
158
159    #[test]
160    fn parameter_kind_is_integer() {
161        assert!(ParameterKind::TopK.is_integer());
162        assert!(ParameterKind::RetrievalTopK.is_integer());
163        assert!(!ParameterKind::Temperature.is_integer());
164        assert!(!ParameterKind::TopP.is_integer());
165        assert!(!ParameterKind::FrequencyPenalty.is_integer());
166        assert!(!ParameterKind::PresencePenalty.is_integer());
167        assert!(!ParameterKind::SimilarityThreshold.is_integer());
168        assert!(!ParameterKind::TemporalDecay.is_integer());
169    }
170
171    #[test]
172    fn variation_value_as_f64_float() {
173        let v = VariationValue::Float(OrderedFloat(3.14));
174        assert!((v.as_f64() - 3.14).abs() < f64::EPSILON);
175    }
176
177    #[test]
178    fn variation_value_as_f64_int() {
179        let v = VariationValue::Int(42);
180        assert!((v.as_f64() - 42.0).abs() < f64::EPSILON);
181    }
182
183    #[test]
184    fn variation_value_from_f64() {
185        let v = VariationValue::from(0.7_f64);
186        assert!(matches!(v, VariationValue::Float(_)));
187        assert!((v.as_f64() - 0.7).abs() < f64::EPSILON);
188    }
189
190    #[test]
191    fn variation_value_from_i64() {
192        let v = VariationValue::from(40_i64);
193        assert!(matches!(v, VariationValue::Int(40)));
194        assert!((v.as_f64() - 40.0).abs() < f64::EPSILON);
195    }
196
197    #[test]
198    fn variation_value_float_hash_eq() {
199        use std::collections::HashSet;
200        let a = VariationValue::Float(OrderedFloat(0.7));
201        let b = VariationValue::Float(OrderedFloat(0.7));
202        let c = VariationValue::Float(OrderedFloat(0.8));
203        let mut set = HashSet::new();
204        set.insert(a.clone());
205        assert!(set.contains(&b));
206        assert!(!set.contains(&c));
207    }
208
209    #[test]
210    fn variation_serde_roundtrip() {
211        let v = Variation {
212            parameter: ParameterKind::Temperature,
213            value: VariationValue::Float(OrderedFloat(0.7)),
214        };
215        let json = serde_json::to_string(&v).expect("serialize");
216        let v2: Variation = serde_json::from_str(&json).expect("deserialize");
217        assert_eq!(v, v2);
218    }
219
220    #[test]
221    fn experiment_source_as_str() {
222        assert_eq!(ExperimentSource::Manual.as_str(), "manual");
223        assert_eq!(ExperimentSource::Scheduled.as_str(), "scheduled");
224        assert_eq!(ExperimentSource::Manual.to_string(), "manual");
225        assert_eq!(ExperimentSource::Scheduled.to_string(), "scheduled");
226    }
227
228    #[test]
229    fn variation_value_int_display() {
230        let v = VariationValue::Int(42);
231        assert_eq!(v.to_string(), "42");
232    }
233
234    #[test]
235    fn experiment_result_serde_roundtrip() {
236        let result = ExperimentResult {
237            id: 1,
238            session_id: "sess-abc".to_string(),
239            variation: Variation {
240                parameter: ParameterKind::Temperature,
241                value: VariationValue::Float(OrderedFloat(0.7)),
242            },
243            baseline_score: 7.0,
244            candidate_score: 8.0,
245            delta: 1.0,
246            latency_ms: 500,
247            tokens_used: 1_000,
248            accepted: true,
249            source: ExperimentSource::Manual,
250            created_at: "2026-03-07 22:00:00".to_string(),
251        };
252        let json = serde_json::to_string(&result).expect("serialize");
253        let parsed: serde_json::Value = serde_json::from_str(&json).expect("parse");
254        assert_eq!(parsed["id"], 1);
255        assert_eq!(parsed["session_id"], "sess-abc");
256        assert_eq!(parsed["accepted"], true);
257        assert_eq!(parsed["source"], "manual");
258        assert_eq!(parsed["variation"]["parameter"], "temperature");
259
260        let result2: ExperimentResult = serde_json::from_str(&json).expect("deserialize");
261        assert_eq!(result2.id, result.id);
262        assert_eq!(result2.session_id, result.session_id);
263        assert_eq!(result2.variation, result.variation);
264        assert!(result2.accepted);
265        assert_eq!(result2.source, ExperimentSource::Manual);
266    }
267}