1use std::collections::BTreeSet;
21
22use serde::{Deserialize, Serialize};
23
24use crate::composite::{AgentSubmission, Run};
25
26#[derive(Clone, Debug, Serialize, Deserialize)]
29pub struct TaggedRun {
30 pub window_id: String,
31 pub run: Run,
32}
33
34#[derive(Clone, Debug, Serialize, Deserialize)]
37pub struct TaggedSubmission {
38 pub agent_id: String,
39 pub runs: Vec<TaggedRun>,
40 #[serde(default)]
42 pub in_sample_trials: u32,
43 #[serde(default)]
45 pub candidates: Vec<Vec<f64>>,
46}
47
48impl TaggedSubmission {
49 pub fn completed_windows(&self) -> BTreeSet<String> {
51 self.runs.iter().map(|r| r.window_id.clone()).collect()
52 }
53}
54
55#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
57pub struct ComparisonSet {
58 pub roster: Vec<String>,
60 pub shared_windows: Vec<String>,
62}
63
64pub fn comparison_set(roster: &[String], subs: &[TaggedSubmission]) -> ComparisonSet {
72 let mut shared: Option<BTreeSet<String>> = None;
73 for agent_id in roster {
74 let completed = subs
75 .iter()
76 .find(|s| &s.agent_id == agent_id)
77 .map(TaggedSubmission::completed_windows)
78 .unwrap_or_default();
79 shared = Some(match shared {
80 None => completed,
81 Some(acc) => acc.intersection(&completed).cloned().collect(),
82 });
83 }
84 ComparisonSet {
85 roster: roster.to_vec(),
86 shared_windows: shared.unwrap_or_default().into_iter().collect(),
87 }
88}
89
90pub fn qualifies(set: &ComparisonSet, sub: &TaggedSubmission, min_shared_windows: usize) -> bool {
96 let completed = sub.completed_windows();
97 let shared_completed = set
98 .shared_windows
99 .iter()
100 .filter(|w| completed.contains(*w))
101 .count();
102 shared_completed >= min_shared_windows
103}
104
105pub fn restrict_to_shared(set: &ComparisonSet, sub: &TaggedSubmission) -> AgentSubmission {
109 let shared: BTreeSet<&str> = set.shared_windows.iter().map(String::as_str).collect();
110 let runs = sub
111 .runs
112 .iter()
113 .filter(|r| shared.contains(r.window_id.as_str()))
114 .map(|r| r.run.clone())
115 .collect();
116 AgentSubmission {
117 agent_id: sub.agent_id.clone(),
118 runs,
119 in_sample_trials: sub.in_sample_trials,
120 candidates: sub.candidates.clone(),
121 }
122}
123
124pub fn restrict_field(roster: &[String], subs: &[TaggedSubmission]) -> Vec<AgentSubmission> {
130 let set = comparison_set(roster, subs);
131 roster
132 .iter()
133 .map(
134 |agent_id| match subs.iter().find(|s| &s.agent_id == agent_id) {
135 Some(sub) => restrict_to_shared(&set, sub),
136 None => AgentSubmission {
137 agent_id: agent_id.clone(),
138 runs: Vec::new(),
139 in_sample_trials: 0,
140 candidates: Vec::new(),
141 },
142 },
143 )
144 .collect()
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 fn run(mean_ret: f64, n: usize) -> Run {
152 Run {
153 returns: (0..n)
154 .map(|i| mean_ret + 0.0005 * (i as f64 * 0.7).sin())
155 .collect(),
156 trace: Default::default(),
157 confidences: Vec::new(),
158 outcomes: Vec::new(),
159 cost: 0.0,
160 }
161 }
162
163 fn tagged(agent_id: &str, windows: &[&str]) -> TaggedSubmission {
164 TaggedSubmission {
165 agent_id: agent_id.to_string(),
166 runs: windows
167 .iter()
168 .map(|w| TaggedRun {
169 window_id: (*w).to_string(),
170 run: run(0.002, 40),
171 })
172 .collect(),
173 in_sample_trials: 0,
174 candidates: Vec::new(),
175 }
176 }
177
178 #[test]
179 fn shared_is_intersection_sorted() {
180 let veteran = tagged("vet", &["w3", "w1", "w2", "w4"]);
181 let entrant = tagged("new", &["w2", "w3"]);
182 let roster = vec!["vet".to_string(), "new".to_string()];
183 let set = comparison_set(&roster, &[veteran, entrant]);
184 assert_eq!(set.shared_windows, vec!["w2".to_string(), "w3".to_string()]);
185 assert_eq!(set.roster, roster);
186 }
187
188 #[test]
189 fn missing_roster_member_empties_shared() {
190 let veteran = tagged("vet", &["w1", "w2"]);
191 let roster = vec!["vet".to_string(), "ghost".to_string()];
192 let set = comparison_set(&roster, &[veteran]);
193 assert!(set.shared_windows.is_empty());
194 }
195
196 #[test]
197 fn restrict_keeps_only_shared_runs() {
198 let veteran = tagged("vet", &["w1", "w2", "w3"]);
199 let entrant = tagged("new", &["w2", "w3", "w9"]);
200 let roster = vec!["vet".to_string(), "new".to_string()];
201 let field = restrict_field(&roster, &[veteran, entrant]);
202 assert_eq!(field.len(), 2);
204 assert_eq!(field[0].agent_id, "vet");
205 assert_eq!(field[0].runs.len(), 2);
206 assert_eq!(field[1].runs.len(), 2);
207 }
208
209 #[test]
210 fn qualifies_on_min_shared() {
211 let veteran = tagged("vet", &["w1", "w2", "w3"]);
212 let entrant = tagged("new", &["w2", "w3"]);
213 let roster = vec!["vet".to_string(), "new".to_string()];
214 let set = comparison_set(&roster, &[veteran, entrant.clone()]);
215 assert!(qualifies(&set, &entrant, 2));
217 assert!(!qualifies(&set, &entrant, 3));
218 let thin = tagged("thin", &["w2", "w7"]);
220 assert!(!qualifies(&set, &thin, 2));
221 assert!(qualifies(&set, &thin, 1));
222 }
223
224 #[test]
225 fn empty_roster_yields_empty_set() {
226 let set = comparison_set(&[], &[]);
227 assert!(set.shared_windows.is_empty());
228 assert!(set.roster.is_empty());
229 assert!(restrict_field(&[], &[]).is_empty());
230 }
231
232 #[test]
233 fn multiple_runs_per_window_all_survive() {
234 let mut sub = tagged("multi", &["w1", "w1", "w2"]);
235 let other = tagged("other", &["w1", "w2"]);
237 let roster = vec!["multi".to_string(), "other".to_string()];
238 let set = comparison_set(&roster, &[sub.clone(), other]);
239 assert_eq!(set.shared_windows, vec!["w1".to_string(), "w2".to_string()]);
240 let filtered = restrict_to_shared(&set, &sub);
241 assert_eq!(filtered.runs.len(), 3);
243 sub.runs.push(TaggedRun {
245 window_id: "w9".to_string(),
246 run: run(0.002, 40),
247 });
248 assert_eq!(restrict_to_shared(&set, &sub).runs.len(), 3);
249 }
250}