1use crate::random::core::{thread_rng, Random};
8use rand::seq::SliceRandom;
9use rand::Rng;
10use rand_distr::Uniform;
11
12pub use rand::seq::SliceRandom as SliceRandomExt;
14
15pub trait ScientificSliceRandom<T> {
17 fn scientific_shuffle<R: Rng>(&mut self, rng: &mut Random<R>);
19
20 fn scientific_choose<R: Rng>(&self, rng: &mut Random<R>) -> Option<&T>;
22
23 fn scientific_choose_multiple<R: Rng>(&self, rng: &mut Random<R>, amount: usize) -> Vec<&T>;
25
26 fn scientific_sample_with_replacement<R: Rng>(
28 &self,
29 rng: &mut Random<R>,
30 amount: usize,
31 ) -> Vec<&T>;
32
33 fn scientific_weighted_sample<R: Rng, W>(
35 &self,
36 rng: &mut Random<R>,
37 weights: &[W],
38 amount: usize,
39 ) -> Result<Vec<&T>, String>
40 where
41 W: Into<f64> + Copy;
42
43 fn scientific_reservoir_sample<R: Rng>(&self, rng: &mut Random<R>, k: usize) -> Vec<&T>;
45}
46
47impl<T> ScientificSliceRandom<T> for [T] {
48 fn scientific_shuffle<R: Rng>(&mut self, rng: &mut Random<R>) {
49 for i in (1..self.len()).rev() {
51 let j = rng.sample(Uniform::new(0, i + 1).expect("Operation failed"));
52 self.swap(i, j);
53 }
54 }
55
56 fn scientific_choose<R: Rng>(&self, rng: &mut Random<R>) -> Option<&T> {
57 if self.is_empty() {
58 None
59 } else {
60 let index = rng.random_range(0..self.len());
61 Some(&self[index])
62 }
63 }
64
65 fn scientific_choose_multiple<R: Rng>(&self, rng: &mut Random<R>, amount: usize) -> Vec<&T> {
66 if amount >= self.len() {
67 return self.iter().collect();
68 }
69
70 let mut selected = std::collections::HashSet::new();
72 let n = self.len();
73 let k = amount;
74
75 for i in (n - k)..n {
76 let mut j = rng.random_range(0..=i);
77 if selected.contains(&j) {
78 j = i;
79 }
80 selected.insert(j);
81 }
82
83 selected.into_iter().map(|i| &self[i]).collect()
84 }
85
86 fn scientific_sample_with_replacement<R: Rng>(
87 &self,
88 rng: &mut Random<R>,
89 amount: usize,
90 ) -> Vec<&T> {
91 (0..amount)
92 .map(|_| &self[rng.random_range(0..self.len())])
93 .collect()
94 }
95
96 fn scientific_weighted_sample<R: Rng, W>(
97 &self,
98 rng: &mut Random<R>,
99 weights: &[W],
100 amount: usize,
101 ) -> Result<Vec<&T>, String>
102 where
103 W: Into<f64> + Copy,
104 {
105 if self.len() != weights.len() {
106 return Err("Items and weights must have the same length".to_string());
107 }
108
109 if self.is_empty() {
110 return Ok(Vec::new());
111 }
112
113 let weights_f64: Vec<f64> = weights.iter().map(|&w| w.into()).collect();
115 let total_weight: f64 = weights_f64.iter().sum();
116
117 if total_weight <= 0.0 {
118 return Err("Total weight must be positive".to_string());
119 }
120
121 let mut cumulative = Vec::with_capacity(weights_f64.len());
122 let mut cum_sum = 0.0;
123 for &weight in &weights_f64 {
124 cum_sum += weight / total_weight;
125 cumulative.push(cum_sum);
126 }
127
128 let mut result = Vec::with_capacity(amount);
129 for _ in 0..amount {
130 let u = rng.random_range(0.0..1.0);
131 match cumulative.binary_search_by(|&x| x.partial_cmp(&u).expect("Operation failed")) {
132 Ok(idx) => result.push(&self[idx]),
133 Err(idx) => result.push(&self[idx.min(self.len() - 1)]),
134 }
135 }
136
137 Ok(result)
138 }
139
140 fn scientific_reservoir_sample<R: Rng>(&self, rng: &mut Random<R>, k: usize) -> Vec<&T> {
141 if k >= self.len() {
142 return self.iter().collect();
143 }
144
145 let mut reservoir: Vec<&T> = Vec::with_capacity(k);
146
147 for item in self.iter().take(k) {
149 reservoir.push(item);
150 }
151
152 for (i, item) in self.iter().enumerate().skip(k) {
154 let j = rng.sample(Uniform::new(0, i + 1).expect("Operation failed"));
155 if j < k {
156 reservoir[j] = item;
157 }
158 }
159
160 reservoir
161 }
162}
163
164impl<T> ScientificSliceRandom<T> for Vec<T> {
165 fn scientific_shuffle<R: Rng>(&mut self, rng: &mut Random<R>) {
166 self.as_mut_slice().scientific_shuffle(rng);
167 }
168
169 fn scientific_choose<R: Rng>(&self, rng: &mut Random<R>) -> Option<&T> {
170 self.as_slice().scientific_choose(rng)
171 }
172
173 fn scientific_choose_multiple<R: Rng>(&self, rng: &mut Random<R>, amount: usize) -> Vec<&T> {
174 self.as_slice().scientific_choose_multiple(rng, amount)
175 }
176
177 fn scientific_sample_with_replacement<R: Rng>(
178 &self,
179 rng: &mut Random<R>,
180 amount: usize,
181 ) -> Vec<&T> {
182 self.as_slice()
183 .scientific_sample_with_replacement(rng, amount)
184 }
185
186 fn scientific_weighted_sample<R: Rng, W>(
187 &self,
188 rng: &mut Random<R>,
189 weights: &[W],
190 amount: usize,
191 ) -> Result<Vec<&T>, String>
192 where
193 W: Into<f64> + Copy,
194 {
195 self.as_slice()
196 .scientific_weighted_sample(rng, weights, amount)
197 }
198
199 fn scientific_reservoir_sample<R: Rng>(&self, rng: &mut Random<R>, k: usize) -> Vec<&T> {
200 self.as_slice().scientific_reservoir_sample(rng, k)
201 }
202}
203
204pub mod convenience {
206 use super::*;
207
208 pub fn shuffle<T>(slice: &mut [T]) {
210 use rand::seq::SliceRandom as _;
211 let mut rng = thread_rng();
212 slice.shuffle(&mut rng.rng);
213 }
214
215 pub fn sample<T>(slice: &[T], n: usize) -> Vec<T>
217 where
218 T: Clone,
219 {
220 let mut rng = thread_rng();
222 let mut indices: Vec<usize> = (0..slice.len()).collect();
223 indices.shuffle(&mut rng.rng);
224 indices
225 .into_iter()
226 .take(n)
227 .map(|i| slice[i].clone())
228 .collect()
229 }
230
231 pub fn choose<T>(slice: &[T]) -> Option<&T> {
233 if slice.is_empty() {
235 None
236 } else {
237 let mut rng = thread_rng();
238 let index = rng.random_range(0..slice.len());
239 Some(&slice[index])
240 }
241 }
242
243 pub fn scientific_shuffle<T>(slice: &mut [T]) {
245 let mut rng = thread_rng();
246 slice.scientific_shuffle(&mut rng);
247 }
248
249 pub fn scientific_sample<T>(slice: &[T], n: usize) -> Vec<&T> {
251 let mut rng = thread_rng();
252 slice.scientific_choose_multiple(&mut rng, n)
253 }
254
255 pub fn scientific_weighted_sample<'a, T, W>(
257 slice: &'a [T],
258 weights: &[W],
259 n: usize,
260 ) -> Result<Vec<&'a T>, String>
261 where
262 W: Into<f64> + Copy,
263 {
264 let mut rng = thread_rng();
265 slice.scientific_weighted_sample(&mut rng, weights, n)
266 }
267
268 pub fn reservoir_sample<T>(slice: &[T], k: usize) -> Vec<&T> {
270 let mut rng = thread_rng();
271 slice.scientific_reservoir_sample(&mut rng, k)
272 }
273}
274
275pub mod algorithms {
277 use super::*;
278 use std::collections::HashMap;
279
280 pub fn stratified_sample<'a, T, K>(
282 data: &'a [(T, K)],
283 strata_sizes: &HashMap<K, usize>,
284 rng: &mut Random<impl Rng>,
285 ) -> Vec<&'a T>
286 where
287 K: Eq + std::hash::Hash + Clone,
288 {
289 let mut result = Vec::new();
290 let mut strata: HashMap<K, Vec<&T>> = HashMap::new();
291
292 for (item, key) in data {
294 strata.entry(key.clone()).or_default().push(item);
295 }
296
297 for (key, desired_size) in strata_sizes {
299 if let Some(stratum_data) = strata.get(key) {
300 let sample = stratum_data.scientific_choose_multiple(rng, *desired_size);
301 result.extend(sample);
302 }
303 }
304
305 result
306 }
307
308 pub fn systematic_sample<'a, T>(
310 data: &'a [T],
311 n: usize,
312 rng: &mut Random<impl Rng>,
313 ) -> Vec<&'a T> {
314 if n == 0 || data.is_empty() {
315 return Vec::new();
316 }
317
318 if n >= data.len() {
319 return data.iter().collect();
320 }
321
322 let interval = data.len() as f64 / n as f64;
323 let start = rng.random_range(0.0..interval);
324
325 (0..n)
326 .map(|i| {
327 let index = (start + i as f64 * interval) as usize;
328 &data[index.min(data.len() - 1)]
329 })
330 .collect()
331 }
332
333 pub fn cluster_sample<'a, T, C>(
335 clusters: &'a [(C, Vec<T>)],
336 n_clusters: usize,
337 rng: &mut Random<impl Rng>,
338 ) -> Vec<&'a T>
339 where
340 C: Clone,
341 {
342 if clusters.is_empty() || n_clusters == 0 {
343 return Vec::new();
344 }
345
346 let cluster_refs: Vec<&(C, Vec<T>)> = clusters.iter().collect();
347 let selected_clusters = cluster_refs.scientific_choose_multiple(rng, n_clusters);
348
349 let mut result = Vec::new();
350 for (_, cluster_data) in selected_clusters {
351 result.extend(cluster_data.iter());
352 }
353
354 result
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::random::core::seeded_rng;
362
363 #[test]
364 fn test_scientific_shuffle() {
365 let mut data = vec![1, 2, 3, 4, 5];
366 let mut rng = seeded_rng(42);
367
368 data.scientific_shuffle(&mut rng);
369 assert_eq!(data.len(), 5);
370 assert!(data.contains(&1));
371 assert!(data.contains(&5));
372 }
373
374 #[test]
375 fn test_scientific_choose() {
376 let data = [1, 2, 3, 4, 5];
377 let mut rng = seeded_rng(123);
378
379 let choice = data.scientific_choose(&mut rng);
380 assert!(choice.is_some());
381 assert!(data.contains(choice.expect("Operation failed")));
382 }
383
384 #[test]
385 fn test_scientific_choose_multiple() {
386 let data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
387 let mut rng = seeded_rng(456);
388
389 let choices = data.scientific_choose_multiple(&mut rng, 3);
390 assert_eq!(choices.len(), 3);
391
392 let mut unique_values = std::collections::HashSet::new();
394 for &choice in &choices {
395 unique_values.insert(*choice);
396 }
397 assert_eq!(unique_values.len(), 3);
398 }
399
400 #[test]
401 fn test_weighted_sampling() {
402 let items = ["A", "B", "C"];
403 let weights = [0.1, 0.3, 0.6];
404 let mut rng = seeded_rng(789);
405
406 let samples = items
407 .scientific_weighted_sample(&mut rng, &weights, 100)
408 .expect("Operation failed");
409 assert_eq!(samples.len(), 100);
410
411 for &sample in &samples {
413 assert!(items.contains(sample));
414 }
415 }
416
417 #[test]
418 fn test_reservoir_sampling() {
419 let data: Vec<i32> = (0..1000).collect();
420 let mut rng = seeded_rng(101112);
421
422 let sample = data.scientific_reservoir_sample(&mut rng, 10);
423 assert_eq!(sample.len(), 10);
424
425 let mut unique_values = std::collections::HashSet::new();
427 for &value in &sample {
428 unique_values.insert(*value);
429 }
430 assert_eq!(unique_values.len(), 10);
431 }
432
433 #[test]
434 fn test_stratified_sampling() {
435 let data = vec![(1, "A"), (2, "A"), (3, "B"), (4, "B"), (5, "C"), (6, "C")];
436 let mut strata_sizes = std::collections::HashMap::new();
437 strata_sizes.insert("A", 1);
438 strata_sizes.insert("B", 1);
439 strata_sizes.insert("C", 1);
440
441 let mut rng = seeded_rng(131415);
442 let sample = algorithms::stratified_sample(&data, &strata_sizes, &mut rng);
443
444 assert_eq!(sample.len(), 3);
445 }
446
447 #[test]
448 fn test_systematic_sampling() {
449 let data: Vec<i32> = (0..100).collect();
450 let mut rng = seeded_rng(161718);
451
452 let sample = algorithms::systematic_sample(&data, 10, &mut rng);
453 assert_eq!(sample.len(), 10);
454
455 let indices: Vec<usize> = sample.iter().map(|&&x| x as usize).collect();
457 for i in 1..indices.len() {
458 let gap = indices[i] - indices[i - 1];
459 assert!((8..=12).contains(&gap)); }
461 }
462
463 #[test]
464 fn test_convenience_functions() {
465 let mut data = vec![1, 2, 3, 4, 5];
466 convenience::shuffle(&mut data);
467 assert_eq!(data.len(), 5);
468
469 let original = vec![1, 2, 3, 4, 5];
470 let choice = convenience::choose(&original);
471 assert!(choice.is_some());
472
473 let sample = convenience::sample(&original, 3);
474 assert_eq!(sample.len(), 3);
475 }
476}