Skip to main content

rvf_federation/
types.rs

1//! Federation segment payload types.
2//!
3//! Four new RVF segment types (0x33-0x36) defined in ADR-057.
4
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7
8// ── Segment type constants ──────────────────────────────────────────
9
10/// Segment type discriminator for FederatedManifest.
11pub const SEG_FEDERATED_MANIFEST: u8 = 0x33;
12/// Segment type discriminator for DiffPrivacyProof.
13pub const SEG_DIFF_PRIVACY_PROOF: u8 = 0x34;
14/// Segment type discriminator for RedactionLog.
15pub const SEG_REDACTION_LOG: u8 = 0x35;
16/// Segment type discriminator for AggregateWeights.
17pub const SEG_AGGREGATE_WEIGHTS: u8 = 0x36;
18
19// ── FederatedManifest (0x33) ────────────────────────────────────────
20
21/// Describes a federated learning export.
22///
23/// Attached as the first segment in every federation RVF file.
24#[derive(Clone, Debug, PartialEq)]
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26pub struct FederatedManifest {
27    /// Format version (currently 1).
28    pub format_version: u32,
29    /// Pseudonym of the contributor (never the real identity).
30    pub contributor_pseudonym: String,
31    /// UNIX timestamp (seconds) when the export was created.
32    pub export_timestamp_s: u64,
33    /// Segment IDs included in this export.
34    pub included_segment_ids: Vec<u64>,
35    /// Cumulative differential privacy budget spent (epsilon).
36    pub privacy_budget_spent: f64,
37    /// Domain identifier this export applies to.
38    pub domain_id: String,
39    /// RVF format version compatibility tag.
40    pub rvf_version_tag: String,
41    /// Number of trajectories summarized in the exported learning.
42    pub trajectory_count: u64,
43    /// Average quality score of exported trajectories.
44    pub avg_quality_score: f64,
45}
46
47impl FederatedManifest {
48    /// Create a new manifest with required fields.
49    pub fn new(contributor_pseudonym: String, domain_id: String) -> Self {
50        Self {
51            format_version: 1,
52            contributor_pseudonym,
53            export_timestamp_s: 0,
54            included_segment_ids: Vec::new(),
55            privacy_budget_spent: 0.0,
56            domain_id,
57            rvf_version_tag: String::from("rvf-v1"),
58            trajectory_count: 0,
59            avg_quality_score: 0.0,
60        }
61    }
62}
63
64// ── DiffPrivacyProof (0x34) ─────────────────────────────────────────
65
66/// Noise mechanism used for differential privacy.
67#[derive(Clone, Copy, Debug, PartialEq, Eq)]
68#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69pub enum NoiseMechanism {
70    /// Gaussian noise for (epsilon, delta)-DP.
71    Gaussian,
72    /// Laplace noise for epsilon-DP.
73    Laplace,
74}
75
76/// Differential privacy attestation.
77///
78/// Records the privacy parameters and noise applied during export.
79#[derive(Clone, Debug, PartialEq)]
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81pub struct DiffPrivacyProof {
82    /// Privacy loss parameter.
83    pub epsilon: f64,
84    /// Probability of privacy failure.
85    pub delta: f64,
86    /// Noise mechanism applied.
87    pub mechanism: NoiseMechanism,
88    /// L2 sensitivity bound used for noise calibration.
89    pub sensitivity: f64,
90    /// Gradient clipping norm.
91    pub clipping_norm: f64,
92    /// Noise scale (sigma for Gaussian, b for Laplace).
93    pub noise_scale: f64,
94    /// Number of parameters that had noise added.
95    pub noised_parameter_count: u64,
96}
97
98impl DiffPrivacyProof {
99    /// Create a new proof for Gaussian mechanism.
100    pub fn gaussian(epsilon: f64, delta: f64, sensitivity: f64, clipping_norm: f64) -> Self {
101        let sigma = sensitivity * (2.0_f64 * (1.25_f64 / delta).ln()).sqrt() / epsilon;
102        Self {
103            epsilon,
104            delta,
105            mechanism: NoiseMechanism::Gaussian,
106            sensitivity,
107            clipping_norm,
108            noise_scale: sigma,
109            noised_parameter_count: 0,
110        }
111    }
112
113    /// Create a new proof for Laplace mechanism.
114    pub fn laplace(epsilon: f64, sensitivity: f64, clipping_norm: f64) -> Self {
115        let b = sensitivity / epsilon;
116        Self {
117            epsilon,
118            delta: 0.0,
119            mechanism: NoiseMechanism::Laplace,
120            sensitivity,
121            clipping_norm,
122            noise_scale: b,
123            noised_parameter_count: 0,
124        }
125    }
126}
127
128// ── RedactionLog (0x35) ─────────────────────────────────────────────
129
130/// A single redaction event.
131#[derive(Clone, Debug, PartialEq)]
132#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
133pub struct RedactionEntry {
134    /// Category of PII detected (e.g. "path", "ip", "email", "api_key").
135    pub category: String,
136    /// Number of occurrences redacted.
137    pub count: u32,
138    /// Rule identifier that triggered the redaction.
139    pub rule_id: String,
140}
141
142/// PII stripping attestation.
143///
144/// Proves that PII scanning was performed without revealing the original content.
145#[derive(Clone, Debug, PartialEq)]
146#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
147pub struct RedactionLog {
148    /// Individual redaction entries by category.
149    pub entries: Vec<RedactionEntry>,
150    /// SHAKE-256 hash of the pre-redaction content (32 bytes).
151    pub pre_redaction_hash: [u8; 32],
152    /// Total number of fields scanned.
153    pub fields_scanned: u64,
154    /// Total number of redactions applied.
155    pub total_redactions: u64,
156    /// UNIX timestamp (seconds) when redaction was performed.
157    pub timestamp_s: u64,
158}
159
160impl RedactionLog {
161    /// Create an empty redaction log.
162    pub fn new() -> Self {
163        Self {
164            entries: Vec::new(),
165            pre_redaction_hash: [0u8; 32],
166            fields_scanned: 0,
167            total_redactions: 0,
168            timestamp_s: 0,
169        }
170    }
171
172    /// Add a redaction entry.
173    pub fn add_entry(&mut self, category: &str, count: u32, rule_id: &str) {
174        self.total_redactions += count as u64;
175        self.entries.push(RedactionEntry {
176            category: category.to_string(),
177            count,
178            rule_id: rule_id.to_string(),
179        });
180    }
181}
182
183impl Default for RedactionLog {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189// ── AggregateWeights (0x36) ─────────────────────────────────────────
190
191/// Federated-averaged weight vector with metadata.
192#[derive(Clone, Debug, PartialEq)]
193#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
194pub struct AggregateWeights {
195    /// Federated averaging round number.
196    pub round: u64,
197    /// Number of participants in this round.
198    pub participation_count: u32,
199    /// Aggregated LoRA delta weights (flattened).
200    pub lora_deltas: Vec<f64>,
201    /// Per-weight confidence scores.
202    pub confidences: Vec<f64>,
203    /// Mean loss across participants.
204    pub mean_loss: f64,
205    /// Loss variance across participants.
206    pub loss_variance: f64,
207    /// Domain identifier.
208    pub domain_id: String,
209    /// Whether Byzantine outlier removal was applied.
210    pub byzantine_filtered: bool,
211    /// Number of contributions removed as outliers.
212    pub outliers_removed: u32,
213}
214
215impl AggregateWeights {
216    /// Create empty aggregate weights for a domain.
217    pub fn new(domain_id: String, round: u64) -> Self {
218        Self {
219            round,
220            participation_count: 0,
221            lora_deltas: Vec::new(),
222            confidences: Vec::new(),
223            mean_loss: 0.0,
224            loss_variance: 0.0,
225            domain_id,
226            byzantine_filtered: false,
227            outliers_removed: 0,
228        }
229    }
230}
231
232// ── BetaParams (local copy for federation) ──────────────────────────
233
234/// Beta distribution parameters for Thompson Sampling priors.
235///
236/// Mirrors the type in `ruvector-domain-expansion` to avoid cross-crate dependency.
237#[derive(Clone, Copy, Debug, PartialEq)]
238#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
239pub struct BetaParams {
240    /// Alpha (success count + 1).
241    pub alpha: f64,
242    /// Beta (failure count + 1).
243    pub beta: f64,
244}
245
246impl BetaParams {
247    /// Create new Beta parameters.
248    pub fn new(alpha: f64, beta: f64) -> Self {
249        Self { alpha, beta }
250    }
251
252    /// Uniform (uninformative) prior.
253    pub fn uniform() -> Self {
254        Self { alpha: 1.0, beta: 1.0 }
255    }
256
257    /// Mean of the Beta distribution.
258    pub fn mean(&self) -> f64 {
259        self.alpha / (self.alpha + self.beta)
260    }
261
262    /// Total observations (alpha + beta - 2 for a Beta(1,1) prior).
263    pub fn observations(&self) -> f64 {
264        self.alpha + self.beta - 2.0
265    }
266
267    /// Merge two Beta posteriors by summing parameters and subtracting the uniform prior.
268    pub fn merge(&self, other: &BetaParams) -> BetaParams {
269        BetaParams {
270            alpha: self.alpha + other.alpha - 1.0,
271            beta: self.beta + other.beta - 1.0,
272        }
273    }
274
275    /// Dampen this prior by mixing with a uniform prior using sqrt-scaling.
276    pub fn dampen(&self, factor: f64) -> BetaParams {
277        let f = factor.clamp(0.0, 1.0);
278        BetaParams {
279            alpha: 1.0 + (self.alpha - 1.0) * f,
280            beta: 1.0 + (self.beta - 1.0) * f,
281        }
282    }
283}
284
285impl Default for BetaParams {
286    fn default() -> Self {
287        Self::uniform()
288    }
289}
290
291// ── TransferPrior (local copy for federation) ───────────────────────
292
293/// Compact summary of learned priors for a single context bucket.
294#[derive(Clone, Debug, PartialEq)]
295#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
296pub struct TransferPriorEntry {
297    /// Context bucket identifier.
298    pub bucket_id: String,
299    /// Arm identifier.
300    pub arm_id: String,
301    /// Beta posterior parameters.
302    pub params: BetaParams,
303    /// Number of observations backing this prior.
304    pub observation_count: u64,
305}
306
307/// Collection of transfer priors from a trained domain.
308#[derive(Clone, Debug, PartialEq)]
309#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
310pub struct TransferPriorSet {
311    /// Source domain identifier.
312    pub source_domain: String,
313    /// Individual prior entries.
314    pub entries: Vec<TransferPriorEntry>,
315    /// EMA cost at time of extraction.
316    pub cost_ema: f64,
317}
318
319// ── PolicyKernelSnapshot ────────────────────────────────────────────
320
321/// Snapshot of a policy kernel configuration for federation export.
322#[derive(Clone, Debug, PartialEq)]
323#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
324pub struct PolicyKernelSnapshot {
325    /// Kernel identifier.
326    pub kernel_id: String,
327    /// Tunable knob values.
328    pub knobs: Vec<f64>,
329    /// Fitness score.
330    pub fitness: f64,
331    /// Generation number.
332    pub generation: u64,
333}
334
335// ── CostCurveSnapshot ───────────────────────────────────────────────
336
337/// Snapshot of cost curve data for federation export.
338#[derive(Clone, Debug, PartialEq)]
339#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
340pub struct CostCurveSnapshot {
341    /// Domain identifier.
342    pub domain_id: String,
343    /// Ordered (step, cost) points.
344    pub points: Vec<(u64, f64)>,
345    /// Acceleration factor (> 1.0 means transfer helped).
346    pub acceleration: f64,
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn segment_type_constants() {
355        assert_eq!(SEG_FEDERATED_MANIFEST, 0x33);
356        assert_eq!(SEG_DIFF_PRIVACY_PROOF, 0x34);
357        assert_eq!(SEG_REDACTION_LOG, 0x35);
358        assert_eq!(SEG_AGGREGATE_WEIGHTS, 0x36);
359    }
360
361    #[test]
362    fn federated_manifest_new() {
363        let m = FederatedManifest::new("alice".into(), "genomics".into());
364        assert_eq!(m.format_version, 1);
365        assert_eq!(m.contributor_pseudonym, "alice");
366        assert_eq!(m.domain_id, "genomics");
367        assert_eq!(m.trajectory_count, 0);
368    }
369
370    #[test]
371    fn diff_privacy_proof_gaussian() {
372        let p = DiffPrivacyProof::gaussian(1.0, 1e-5, 1.0, 1.0);
373        assert_eq!(p.mechanism, NoiseMechanism::Gaussian);
374        assert!(p.noise_scale > 0.0);
375        assert_eq!(p.epsilon, 1.0);
376    }
377
378    #[test]
379    fn diff_privacy_proof_laplace() {
380        let p = DiffPrivacyProof::laplace(1.0, 1.0, 1.0);
381        assert_eq!(p.mechanism, NoiseMechanism::Laplace);
382        assert!((p.noise_scale - 1.0).abs() < 1e-10);
383    }
384
385    #[test]
386    fn redaction_log_add_entry() {
387        let mut log = RedactionLog::new();
388        log.add_entry("path", 3, "rule_path_unix");
389        log.add_entry("ip", 2, "rule_ipv4");
390        assert_eq!(log.entries.len(), 2);
391        assert_eq!(log.total_redactions, 5);
392    }
393
394    #[test]
395    fn aggregate_weights_new() {
396        let w = AggregateWeights::new("code_review".into(), 1);
397        assert_eq!(w.round, 1);
398        assert_eq!(w.participation_count, 0);
399        assert!(!w.byzantine_filtered);
400    }
401
402    #[test]
403    fn beta_params_merge() {
404        let a = BetaParams::new(10.0, 5.0);
405        let b = BetaParams::new(8.0, 3.0);
406        let merged = a.merge(&b);
407        assert!((merged.alpha - 17.0).abs() < 1e-10);
408        assert!((merged.beta - 7.0).abs() < 1e-10);
409    }
410
411    #[test]
412    fn beta_params_dampen() {
413        let p = BetaParams::new(10.0, 5.0);
414        let dampened = p.dampen(0.25);
415        // alpha = 1 + (10-1)*0.25 = 1 + 2.25 = 3.25
416        assert!((dampened.alpha - 3.25).abs() < 1e-10);
417        // beta = 1 + (5-1)*0.25 = 1 + 1.0 = 2.0
418        assert!((dampened.beta - 2.0).abs() < 1e-10);
419    }
420
421    #[test]
422    fn beta_params_mean() {
423        let p = BetaParams::new(10.0, 10.0);
424        assert!((p.mean() - 0.5).abs() < 1e-10);
425    }
426}