1use scirs2_core::ndarray::{Array2, ArrayView2};
8use sklears_core::{
9 error::Result as SklResult,
10 prelude::{SklearsError, Transform},
11 traits::{Estimator, Fit},
12 types::Float,
13};
14use std::collections::HashMap;
15use std::marker::PhantomData;
16use std::sync::{Arc, Mutex, RwLock, Weak};
17use std::time::{Duration, Instant};
18
19pub struct ZeroCostStep<T, S> {
21 inner: T,
22 _state: PhantomData<S>,
23}
24
25impl<T, S> ZeroCostStep<T, S> {
26 #[inline(always)]
28 pub fn new(inner: T) -> Self {
29 Self {
30 inner,
31 _state: PhantomData,
32 }
33 }
34
35 #[inline(always)]
37 pub fn inner(&self) -> &T {
38 &self.inner
39 }
40
41 #[inline(always)]
43 pub fn into_inner(self) -> T {
44 self.inner
45 }
46}
47
48pub struct ZeroCostPipeline<const N: usize, T> {
50 steps: [T; N],
51}
52
53impl<const N: usize, T> ZeroCostPipeline<N, T> {
54 #[inline(always)]
56 pub const fn new(steps: [T; N]) -> Self {
57 Self { steps }
58 }
59
60 #[inline(always)]
62 pub fn steps(&self) -> &[T; N] {
63 &self.steps
64 }
65
66 #[inline(always)]
68 pub fn execute<I>(&self, input: I) -> I
69 where
70 T: Fn(I) -> I,
71 {
72 let mut result = input;
73 for step in &self.steps {
74 result = step(result);
75 }
76 result
77 }
78}
79
80pub struct ZeroCostFeatureUnion<const N: usize, T> {
82 transformers: [T; N],
83}
84
85impl<const N: usize, T> ZeroCostFeatureUnion<N, T> {
86 #[inline(always)]
88 pub const fn new(transformers: [T; N]) -> Self {
89 Self { transformers }
90 }
91
92 pub fn transform(&self, input: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>>
94 where
95 T: for<'a> Transform<ArrayView2<'a, Float>, Array2<f64>>,
96 {
97 if N == 0 {
98 return Ok(input.mapv(|v| v));
99 }
100
101 let mut results = Vec::with_capacity(N);
102
103 for transformer in &self.transformers {
105 results.push(transformer.transform(input)?);
106 }
107
108 let total_features: usize = results
110 .iter()
111 .map(scirs2_core::ndarray::ArrayBase::ncols)
112 .sum();
113 let n_samples = results[0].nrows();
114
115 let mut concatenated = Array2::zeros((n_samples, total_features));
116 let mut col_idx = 0;
117
118 for result in results {
119 let end_idx = col_idx + result.ncols();
120 concatenated
121 .slice_mut(s![.., col_idx..end_idx])
122 .assign(&result);
123 col_idx = end_idx;
124 }
125
126 Ok(concatenated)
127 }
128}
129
130pub struct ZeroCostEstimator<E> {
132 estimator: E,
133}
134
135impl<E> ZeroCostEstimator<E> {
136 #[inline(always)]
138 pub const fn new(estimator: E) -> Self {
139 Self { estimator }
140 }
141
142 #[inline(always)]
144 pub fn estimator(&self) -> &E {
145 &self.estimator
146 }
147}
148
149impl<E> Estimator for ZeroCostEstimator<E>
150where
151 E: Estimator,
152{
153 type Config = E::Config;
154 type Error = E::Error;
155 type Float = E::Float;
156
157 #[inline(always)]
158 fn config(&self) -> &Self::Config {
159 self.estimator.config()
160 }
161}
162
163impl<E, X, Y> Fit<X, Y> for ZeroCostEstimator<E>
164where
165 E: Estimator + Fit<X, Y>,
166 E::Error: Into<SklearsError>,
167{
168 type Fitted = ZeroCostEstimator<E::Fitted>;
169
170 #[inline(always)]
171 fn fit(self, x: &X, y: &Y) -> SklResult<Self::Fitted> {
172 self.estimator.fit(x, y).map(ZeroCostEstimator::new)
173 }
174}
175
176pub struct ZeroCostBuilder<T> {
178 _phantom: PhantomData<T>,
179}
180
181impl Default for ZeroCostBuilder<()> {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187impl ZeroCostBuilder<()> {
188 #[inline(always)]
190 #[must_use]
191 pub const fn new() -> Self {
192 Self {
193 _phantom: PhantomData,
194 }
195 }
196}
197
198impl<T> ZeroCostBuilder<T> {
199 #[inline(always)]
201 pub fn step<S>(self, _step: S) -> ZeroCostBuilder<(T, S)> {
202 ZeroCostBuilder {
203 _phantom: PhantomData,
204 }
205 }
206
207 #[inline(always)]
209 #[must_use]
210 pub fn build(self) -> ZeroCostBuilder<T> {
211 self
212 }
213}
214
215pub struct ZeroCostConditional<const CONDITION: bool, T, F> {
217 true_branch: T,
218 false_branch: F,
219}
220
221impl<const CONDITION: bool, T, F> ZeroCostConditional<CONDITION, T, F> {
222 #[inline(always)]
224 pub const fn new(true_branch: T, false_branch: F) -> Self {
225 Self {
226 true_branch,
227 false_branch,
228 }
229 }
230
231 #[inline(always)]
233 pub fn execute<I>(&self, input: I) -> I
234 where
235 T: Fn(I) -> I,
236 F: Fn(I) -> I,
237 {
238 if CONDITION {
239 (self.true_branch)(input)
240 } else {
241 (self.false_branch)(input)
242 }
243 }
244}
245
246pub struct ZeroCostFeatureSelector<const FEATURES: u64> {
248 _phantom: PhantomData<u64>,
249}
250
251impl<const FEATURES: u64> ZeroCostFeatureSelector<FEATURES> {
252 #[inline(always)]
254 #[must_use]
255 pub const fn new() -> Self {
256 Self {
257 _phantom: PhantomData,
258 }
259 }
260
261 pub fn select(&self, input: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
263 let n_features = FEATURES as usize;
264 if input.ncols() < n_features {
265 return Err(SklearsError::InvalidInput(format!(
266 "Input has {} features, but {} are required",
267 input.ncols(),
268 n_features
269 )));
270 }
271
272 Ok(input.slice(s![.., ..n_features]).mapv(|v| v))
274 }
275}
276
277impl<const FEATURES: u64> Default for ZeroCostFeatureSelector<FEATURES> {
278 fn default() -> Self {
279 Self::new()
280 }
281}
282
283pub struct ZeroCostComposition<F, G> {
285 first: F,
286 second: G,
287}
288
289impl<F, G> ZeroCostComposition<F, G> {
290 #[inline(always)]
292 pub const fn new(first: F, second: G) -> Self {
293 Self { first, second }
294 }
295
296 #[inline(always)]
298 pub fn apply<I>(&self, input: I) -> I
299 where
300 F: Fn(I) -> I,
301 G: Fn(I) -> I,
302 {
303 (self.second)((self.first)(input))
304 }
305}
306
307pub trait ZeroCostCompose<Other> {
309 type Output;
310
311 fn compose(self, other: Other) -> Self::Output;
313}
314
315impl<F, G> ZeroCostCompose<G> for F
316where
317 F: Fn(f64) -> f64,
318 G: Fn(f64) -> f64,
319{
320 type Output = ZeroCostComposition<F, G>;
321
322 #[inline(always)]
323 fn compose(self, other: G) -> Self::Output {
324 ZeroCostComposition::new(self, other)
325 }
326}
327
328pub struct ZeroCostParallel<const N: usize> {
330 _phantom: PhantomData<[(); N]>,
331}
332
333impl<const N: usize> ZeroCostParallel<N> {
334 #[inline(always)]
336 #[must_use]
337 pub const fn new() -> Self {
338 Self {
339 _phantom: PhantomData,
340 }
341 }
342
343 pub fn execute<T, F, R>(&self, tasks: [F; N]) -> [R; N]
345 where
346 F: Fn() -> R + Send,
347 R: Send,
348 T: Send,
349 {
350 tasks.map(|task| task())
353 }
354}
355
356impl<const N: usize> Default for ZeroCostParallel<N> {
357 fn default() -> Self {
358 Self::new()
359 }
360}
361
362pub struct ZeroCostLayout<T> {
364 data: T,
365}
366
367impl<T> ZeroCostLayout<T> {
368 #[inline(always)]
370 pub const fn new(data: T) -> Self {
371 Self { data }
372 }
373
374 #[inline(always)]
376 pub fn data(&self) -> &T {
377 &self.data
378 }
379
380 #[inline(always)]
382 pub fn into_data(self) -> T {
383 self.data
384 }
385}
386
387#[repr(C)]
389pub struct ZeroCostBuffer<T, const SIZE: usize> {
390 data: Vec<T>,
391 _phantom: PhantomData<[T; SIZE]>,
392}
393
394impl<T, const SIZE: usize> Default for ZeroCostBuffer<T, SIZE> {
395 fn default() -> Self {
396 Self::new()
397 }
398}
399
400impl<T, const SIZE: usize> ZeroCostBuffer<T, SIZE> {
401 #[inline(always)]
403 #[must_use]
404 pub fn new() -> Self {
405 Self {
406 data: Vec::with_capacity(SIZE),
407 _phantom: PhantomData,
408 }
409 }
410
411 #[inline(always)]
413 pub fn push(&mut self, item: T) -> Result<(), &'static str> {
414 if self.data.len() >= SIZE {
415 return Err("Buffer full");
416 }
417 self.data.push(item);
418 Ok(())
419 }
420
421 #[inline(always)]
423 #[must_use]
424 pub fn as_slice(&self) -> &[T] {
425 &self.data
426 }
427
428 #[inline(always)]
430 pub fn clear(&mut self) {
431 self.data.clear();
432 }
433}
434
435use scirs2_core::ndarray::s;
437use std::rc::Rc;
438
439pub struct ZeroCopyView<'a, T> {
441 data: &'a [T],
442 shape: (usize, usize),
443 strides: (usize, usize),
444}
445
446impl<'a, T> ZeroCopyView<'a, T> {
447 #[inline(always)]
449 pub fn new(data: &'a [T], shape: (usize, usize), strides: (usize, usize)) -> Self {
450 Self {
451 data,
452 shape,
453 strides,
454 }
455 }
456
457 #[inline(always)]
459 #[must_use]
460 pub fn shape(&self) -> (usize, usize) {
461 self.shape
462 }
463
464 #[inline(always)]
466 #[must_use]
467 pub fn strides(&self) -> (usize, usize) {
468 self.strides
469 }
470
471 #[inline(always)]
473 #[must_use]
474 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
475 if row >= self.shape.0 || col >= self.shape.1 {
476 return None;
477 }
478 let index = row * self.strides.0 + col * self.strides.1;
479 self.data.get(index)
480 }
481
482 #[inline(always)]
484 #[must_use]
485 pub fn slice(
486 &self,
487 row_range: std::ops::Range<usize>,
488 col_range: std::ops::Range<usize>,
489 ) -> Option<ZeroCopyView<'a, T>> {
490 if row_range.end > self.shape.0 || col_range.end > self.shape.1 {
491 return None;
492 }
493
494 let start_index = row_range.start * self.strides.0 + col_range.start * self.strides.1;
495 let new_shape = (row_range.len(), col_range.len());
496 let new_data = &self.data[start_index..];
497
498 Some(ZeroCopyView::new(new_data, new_shape, self.strides))
499 }
500}
501
502#[derive(Debug)]
504pub struct SharedData<T> {
505 data: Rc<T>,
506}
507
508impl<T> SharedData<T> {
509 #[inline(always)]
511 pub fn new(data: T) -> Self {
512 Self {
513 data: Rc::new(data),
514 }
515 }
516
517 #[inline(always)]
519 #[must_use]
520 pub fn ref_count(&self) -> usize {
521 Rc::strong_count(&self.data)
522 }
523
524 #[inline(always)]
526 pub fn try_unwrap(self) -> Result<T, Self> {
527 Rc::try_unwrap(self.data).map_err(|data| Self { data })
528 }
529}
530
531impl<T> Clone for SharedData<T> {
532 #[inline(always)]
533 fn clone(&self) -> Self {
534 Self {
535 data: Rc::clone(&self.data),
536 }
537 }
538}
539
540impl<T> std::ops::Deref for SharedData<T> {
541 type Target = T;
542
543 #[inline(always)]
544 fn deref(&self) -> &Self::Target {
545 &self.data
546 }
547}
548
549pub struct CowData<T: Clone + 'static> {
551 data: std::borrow::Cow<'static, T>,
552}
553
554impl<T: Clone + 'static> CowData<T> {
555 #[inline(always)]
557 pub fn borrowed(data: &'static T) -> Self {
558 Self {
559 data: std::borrow::Cow::Borrowed(data),
560 }
561 }
562
563 #[inline(always)]
565 pub fn owned(data: T) -> Self {
566 Self {
567 data: std::borrow::Cow::Owned(data),
568 }
569 }
570
571 #[inline(always)]
573 pub fn to_mut(&mut self) -> &mut T {
574 self.data.to_mut()
575 }
576
577 #[inline(always)]
579 pub fn into_owned(self) -> T {
580 self.data.into_owned()
581 }
582}
583
584impl<T: Clone + 'static> std::ops::Deref for CowData<T> {
585 type Target = T;
586
587 #[inline(always)]
588 fn deref(&self) -> &Self::Target {
589 &self.data
590 }
591}
592
593pub struct Arena<T> {
595 chunks: Vec<Vec<T>>,
596 current_chunk: usize,
597 current_offset: usize,
598 chunk_size: usize,
599}
600
601impl<T> Arena<T> {
602 #[must_use]
604 pub fn new(chunk_size: usize) -> Self {
605 Self {
606 chunks: vec![Vec::with_capacity(chunk_size)],
607 current_chunk: 0,
608 current_offset: 0,
609 chunk_size,
610 }
611 }
612
613 pub fn alloc(&mut self, item: T) -> &mut T {
615 if self.current_offset >= self.chunk_size {
617 self.chunks.push(Vec::with_capacity(self.chunk_size));
619 self.current_chunk += 1;
620 self.current_offset = 0;
621 }
622
623 let chunk = &mut self.chunks[self.current_chunk];
624 chunk.push(item);
625 self.current_offset += 1;
626
627 chunk.last_mut().unwrap()
628 }
629
630 pub fn alloc_slice(&mut self, items: &[T]) -> &mut [T]
632 where
633 T: Clone,
634 {
635 let start_len = self.chunks[self.current_chunk].len();
636
637 if self.current_offset + items.len() > self.chunk_size {
639 self.chunks
641 .push(Vec::with_capacity(self.chunk_size.max(items.len())));
642 self.current_chunk += 1;
643 self.current_offset = 0;
644 }
645
646 let chunk = &mut self.chunks[self.current_chunk];
647 chunk.extend_from_slice(items);
648 self.current_offset += items.len();
649
650 &mut chunk[start_len..]
651 }
652
653 pub fn clear(&mut self) {
655 for chunk in &mut self.chunks {
656 chunk.clear();
657 }
658 self.current_chunk = 0;
659 self.current_offset = 0;
660 }
661
662 #[must_use]
664 pub fn len(&self) -> usize {
665 self.chunks.iter().map(std::vec::Vec::len).sum()
666 }
667
668 #[must_use]
670 pub fn is_empty(&self) -> bool {
671 self.len() == 0
672 }
673}
674
675impl<T> Default for Arena<T> {
676 fn default() -> Self {
677 Self::new(1024) }
679}
680
681pub struct MemoryPool<T> {
683 free_buffers: Vec<Vec<T>>,
684 min_capacity: usize,
685 max_capacity: usize,
686}
687
688impl<T> MemoryPool<T> {
689 #[must_use]
691 pub fn new(min_capacity: usize, max_capacity: usize) -> Self {
692 Self {
693 free_buffers: Vec::new(),
694 min_capacity,
695 max_capacity,
696 }
697 }
698
699 pub fn get_buffer(&mut self, capacity: usize) -> Vec<T> {
701 for i in 0..self.free_buffers.len() {
703 if self.free_buffers[i].capacity() >= capacity {
704 let mut buffer = self.free_buffers.swap_remove(i);
705 buffer.clear();
706 return buffer;
707 }
708 }
709
710 Vec::with_capacity(capacity.clamp(self.min_capacity, self.max_capacity))
712 }
713
714 pub fn return_buffer(&mut self, mut buffer: Vec<T>) {
716 buffer.clear();
717
718 if buffer.capacity() >= self.min_capacity && buffer.capacity() <= self.max_capacity {
720 if self.free_buffers.len() < 16 {
722 self.free_buffers.push(buffer);
723 }
724 }
725 }
727
728 pub fn clear(&mut self) {
730 self.free_buffers.clear();
731 }
732
733 #[must_use]
735 pub fn pool_size(&self) -> usize {
736 self.free_buffers.len()
737 }
738}
739
740impl<T> Default for MemoryPool<T> {
741 fn default() -> Self {
742 Self::new(64, 4096)
743 }
744}
745
746pub struct PooledBuffer<T> {
748 buffer: Option<Vec<T>>,
749 pool: *mut MemoryPool<T>,
750}
751
752impl<T> PooledBuffer<T> {
753 pub(crate) fn new(buffer: Vec<T>, pool: &mut MemoryPool<T>) -> Self {
755 Self {
756 buffer: Some(buffer),
757 pool: pool as *mut MemoryPool<T>,
758 }
759 }
760
761 pub fn buffer_mut(&mut self) -> &mut Vec<T> {
763 self.buffer.as_mut().unwrap()
764 }
765
766 #[must_use]
768 pub fn buffer_ref(&self) -> &Vec<T> {
769 self.buffer.as_ref().unwrap()
770 }
771}
772
773impl<T> std::ops::Deref for PooledBuffer<T> {
774 type Target = Vec<T>;
775
776 fn deref(&self) -> &Self::Target {
777 self.buffer.as_ref().unwrap()
778 }
779}
780
781impl<T> std::ops::DerefMut for PooledBuffer<T> {
782 fn deref_mut(&mut self) -> &mut Self::Target {
783 self.buffer.as_mut().unwrap()
784 }
785}
786
787impl<T> Drop for PooledBuffer<T> {
788 fn drop(&mut self) {
789 if let Some(buffer) = self.buffer.take() {
790 unsafe {
792 (*self.pool).return_buffer(buffer);
793 }
794 }
795 }
796}
797
798pub trait ZeroCopySlice<T> {
800 fn zero_copy_slice(&self, start: usize, end: usize) -> Option<&[T]>;
802
803 fn zero_copy_iter(&self) -> std::slice::Iter<'_, T>;
805
806 fn zero_copy_chunks(&self, chunk_size: usize) -> std::slice::Chunks<'_, T>;
808}
809
810impl<T> ZeroCopySlice<T> for [T] {
811 #[inline(always)]
812 fn zero_copy_slice(&self, start: usize, end: usize) -> Option<&[T]> {
813 self.get(start..end)
814 }
815
816 #[inline(always)]
817 fn zero_copy_iter(&self) -> std::slice::Iter<'_, T> {
818 self.iter()
819 }
820
821 #[inline(always)]
822 fn zero_copy_chunks(&self, chunk_size: usize) -> std::slice::Chunks<'_, T> {
823 self.chunks(chunk_size)
824 }
825}
826
827impl<T> ZeroCopySlice<T> for Vec<T> {
828 #[inline(always)]
829 fn zero_copy_slice(&self, start: usize, end: usize) -> Option<&[T]> {
830 self.as_slice().zero_copy_slice(start, end)
831 }
832
833 #[inline(always)]
834 fn zero_copy_iter(&self) -> std::slice::Iter<'_, T> {
835 self.iter()
836 }
837
838 #[inline(always)]
839 fn zero_copy_chunks(&self, chunk_size: usize) -> std::slice::Chunks<'_, T> {
840 self.chunks(chunk_size)
841 }
842}
843
844impl<T> MemoryPool<T> {
846 pub fn get_pooled_buffer(&mut self, capacity: usize) -> PooledBuffer<T> {
848 let buffer = self.get_buffer(capacity);
849 PooledBuffer::new(buffer, self)
850 }
851}
852
853#[derive(Debug)]
855pub struct MemoryLeakDetector {
856 allocations: Mutex<HashMap<u64, AllocationInfo>>,
858 next_id: std::sync::atomic::AtomicU64,
860 config: MemoryLeakConfig,
862}
863
864#[derive(Debug, Clone)]
866pub struct AllocationInfo {
867 pub id: u64,
869 pub size: usize,
871 pub timestamp: Instant,
873 pub stack_trace: Option<String>,
875 pub type_name: &'static str,
877 pub location: &'static str,
879}
880
881#[derive(Debug, Clone)]
883pub struct MemoryLeakConfig {
884 pub collect_stack_traces: bool,
886 pub max_age: Duration,
888 pub panic_on_leak: bool,
890 pub max_tracked_allocations: usize,
892}
893
894impl Default for MemoryLeakConfig {
895 fn default() -> Self {
896 Self {
897 collect_stack_traces: false,
898 max_age: Duration::from_secs(300), panic_on_leak: false,
900 max_tracked_allocations: 10000,
901 }
902 }
903}
904
905impl MemoryLeakDetector {
906 #[must_use]
908 pub fn new(config: MemoryLeakConfig) -> Self {
909 Self {
910 allocations: Mutex::new(HashMap::new()),
911 next_id: std::sync::atomic::AtomicU64::new(1),
912 config,
913 }
914 }
915
916 pub fn track_allocation<T>(
918 &self,
919 size: usize,
920 location: &'static str,
921 ) -> TrackedAllocation<'_> {
922 let id = self
923 .next_id
924 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
925
926 let info = AllocationInfo {
927 id,
928 size,
929 timestamp: Instant::now(),
930 stack_trace: if self.config.collect_stack_traces {
931 Some(self.collect_stack_trace())
932 } else {
933 None
934 },
935 type_name: std::any::type_name::<T>(),
936 location,
937 };
938
939 if let Ok(mut allocations) = self.allocations.lock() {
940 if allocations.len() < self.config.max_tracked_allocations {
942 allocations.insert(id, info);
943 }
944 }
945
946 TrackedAllocation { id, detector: self }
947 }
948
949 fn untrack_allocation(&self, id: u64) {
951 if let Ok(mut allocations) = self.allocations.lock() {
952 allocations.remove(&id);
953 }
954 }
955
956 pub fn check_leaks(&self) -> Vec<AllocationInfo> {
958 let mut leaks = Vec::new();
959 let now = Instant::now();
960
961 if let Ok(allocations) = self.allocations.lock() {
962 for info in allocations.values() {
963 if now.duration_since(info.timestamp) > self.config.max_age {
964 leaks.push(info.clone());
965 }
966 }
967 }
968
969 assert!(
970 !(!leaks.is_empty() && self.config.panic_on_leak),
971 "Memory leaks detected: {} allocations",
972 leaks.len()
973 );
974
975 leaks
976 }
977
978 pub fn get_stats(&self) -> MemoryStats {
980 if let Ok(allocations) = self.allocations.lock() {
981 let total_allocations = allocations.len();
982 let total_size = allocations.values().map(|info| info.size).sum();
983 let oldest_age = allocations
984 .values()
985 .map(|info| info.timestamp.elapsed())
986 .max()
987 .unwrap_or_default();
988
989 MemoryStats {
990 total_allocations,
991 total_size,
992 oldest_age,
993 }
994 } else {
995 MemoryStats::default()
996 }
997 }
998
999 fn collect_stack_trace(&self) -> String {
1001 "Stack trace collection not implemented".to_string()
1003 }
1004}
1005
1006pub struct TrackedAllocation<'a> {
1008 id: u64,
1009 detector: &'a MemoryLeakDetector,
1010}
1011
1012impl Drop for TrackedAllocation<'_> {
1013 fn drop(&mut self) {
1014 self.detector.untrack_allocation(self.id);
1015 }
1016}
1017
1018#[derive(Debug, Default, Clone)]
1020pub struct MemoryStats {
1021 pub total_allocations: usize,
1023 pub total_size: usize,
1025 pub oldest_age: Duration,
1027}
1028
1029#[derive(Debug)]
1031pub struct SafeConcurrentData<T> {
1032 data: RwLock<T>,
1033 stats: Arc<Mutex<ConcurrencyStats>>,
1035}
1036
1037#[derive(Debug, Default, Clone)]
1039pub struct ConcurrencyStats {
1040 pub read_locks: u64,
1042 pub write_locks: u64,
1044 pub total_wait_time: Duration,
1046 pub contentions: u64,
1048}
1049
1050impl<T> SafeConcurrentData<T> {
1051 pub fn new(data: T) -> Self {
1053 Self {
1054 data: RwLock::new(data),
1055 stats: Arc::new(Mutex::new(ConcurrencyStats::default())),
1056 }
1057 }
1058
1059 pub fn read<F, R>(&self, f: F) -> SklResult<R>
1061 where
1062 F: FnOnce(&T) -> R,
1063 {
1064 let start = Instant::now();
1065
1066 if let Ok(guard) = self.data.read() {
1067 self.update_stats(true, start.elapsed(), false);
1068 Ok(f(&*guard))
1069 } else {
1070 self.update_stats(true, start.elapsed(), true);
1071 Err(SklearsError::InvalidOperation("Lock poisoned".to_string()))
1072 }
1073 }
1074
1075 pub fn write<F, R>(&self, f: F) -> SklResult<R>
1077 where
1078 F: FnOnce(&mut T) -> R,
1079 {
1080 let start = Instant::now();
1081
1082 if let Ok(mut guard) = self.data.write() {
1083 self.update_stats(false, start.elapsed(), false);
1084 Ok(f(&mut *guard))
1085 } else {
1086 self.update_stats(false, start.elapsed(), true);
1087 Err(SklearsError::InvalidOperation("Lock poisoned".to_string()))
1088 }
1089 }
1090
1091 pub fn try_read<F, R>(&self, f: F) -> SklResult<Option<R>>
1093 where
1094 F: FnOnce(&T) -> R,
1095 {
1096 let start = Instant::now();
1097
1098 match self.data.try_read() {
1099 Ok(guard) => {
1100 self.update_stats(true, start.elapsed(), false);
1101 Ok(Some(f(&*guard)))
1102 }
1103 Err(std::sync::TryLockError::WouldBlock) => {
1104 self.update_stats(true, start.elapsed(), true);
1105 Ok(None)
1106 }
1107 Err(std::sync::TryLockError::Poisoned(_)) => {
1108 self.update_stats(true, start.elapsed(), true);
1109 Err(SklearsError::InvalidOperation("Lock poisoned".to_string()))
1110 }
1111 }
1112 }
1113
1114 pub fn try_write<F, R>(&self, f: F) -> SklResult<Option<R>>
1116 where
1117 F: FnOnce(&mut T) -> R,
1118 {
1119 let start = Instant::now();
1120
1121 match self.data.try_write() {
1122 Ok(mut guard) => {
1123 self.update_stats(false, start.elapsed(), false);
1124 Ok(Some(f(&mut *guard)))
1125 }
1126 Err(std::sync::TryLockError::WouldBlock) => {
1127 self.update_stats(false, start.elapsed(), true);
1128 Ok(None)
1129 }
1130 Err(std::sync::TryLockError::Poisoned(_)) => {
1131 self.update_stats(false, start.elapsed(), true);
1132 Err(SklearsError::InvalidOperation("Lock poisoned".to_string()))
1133 }
1134 }
1135 }
1136
1137 pub fn get_stats(&self) -> ConcurrencyStats {
1139 self.stats
1140 .lock()
1141 .unwrap_or_else(std::sync::PoisonError::into_inner)
1142 .clone()
1143 }
1144
1145 fn update_stats(&self, is_read: bool, wait_time: Duration, contention: bool) {
1147 if let Ok(mut stats) = self.stats.lock() {
1148 if is_read {
1149 stats.read_locks += 1;
1150 } else {
1151 stats.write_locks += 1;
1152 }
1153 stats.total_wait_time += wait_time;
1154 if contention {
1155 stats.contentions += 1;
1156 }
1157 }
1158 }
1159}
1160
1161#[derive(Debug)]
1163pub struct AtomicRcData<T> {
1164 data: Arc<T>,
1165}
1166
1167impl<T> AtomicRcData<T> {
1168 pub fn new(data: T) -> Self {
1170 Self {
1171 data: Arc::new(data),
1172 }
1173 }
1174
1175 #[must_use]
1177 pub fn strong_count(&self) -> usize {
1178 Arc::strong_count(&self.data)
1179 }
1180
1181 #[must_use]
1183 pub fn weak_count(&self) -> usize {
1184 Arc::weak_count(&self.data)
1185 }
1186
1187 #[must_use]
1189 pub fn downgrade(&self) -> WeakRcData<T> {
1190 WeakRcData {
1191 weak: Arc::downgrade(&self.data),
1192 }
1193 }
1194
1195 pub fn try_unwrap(self) -> Result<T, Self> {
1197 Arc::try_unwrap(self.data).map_err(|data| Self { data })
1198 }
1199}
1200
1201impl<T> Clone for AtomicRcData<T> {
1202 fn clone(&self) -> Self {
1203 Self {
1204 data: Arc::clone(&self.data),
1205 }
1206 }
1207}
1208
1209impl<T> std::ops::Deref for AtomicRcData<T> {
1210 type Target = T;
1211
1212 fn deref(&self) -> &Self::Target {
1213 &self.data
1214 }
1215}
1216
1217#[derive(Debug)]
1219pub struct WeakRcData<T> {
1220 weak: Weak<T>,
1221}
1222
1223impl<T> WeakRcData<T> {
1224 #[must_use]
1226 pub fn upgrade(&self) -> Option<AtomicRcData<T>> {
1227 self.weak.upgrade().map(|data| AtomicRcData { data })
1228 }
1229
1230 #[must_use]
1232 pub fn weak_count(&self) -> usize {
1233 self.weak.weak_count()
1234 }
1235
1236 #[must_use]
1238 pub fn strong_count(&self) -> usize {
1239 self.weak.strong_count()
1240 }
1241}
1242
1243impl<T> Clone for WeakRcData<T> {
1244 fn clone(&self) -> Self {
1245 Self {
1246 weak: Weak::clone(&self.weak),
1247 }
1248 }
1249}
1250
1251#[derive(Debug)]
1253pub struct LockFreeQueue<T> {
1254 inner: Mutex<std::collections::VecDeque<T>>,
1256 stats: Arc<Mutex<QueueStats>>,
1257}
1258
1259#[derive(Debug, Default, Clone)]
1261pub struct QueueStats {
1262 pub enqueues: u64,
1263 pub dequeues: u64,
1264 pub current_size: usize,
1265 pub max_size: usize,
1266 pub contentions: u64,
1267}
1268
1269impl<T> LockFreeQueue<T> {
1270 #[must_use]
1272 pub fn new() -> Self {
1273 Self {
1274 inner: Mutex::new(std::collections::VecDeque::new()),
1275 stats: Arc::new(Mutex::new(QueueStats::default())),
1276 }
1277 }
1278
1279 pub fn enqueue(&self, item: T) -> SklResult<()> {
1281 match self.inner.lock() {
1282 Ok(mut queue) => {
1283 queue.push_back(item);
1284
1285 if let Ok(mut stats) = self.stats.lock() {
1287 stats.enqueues += 1;
1288 stats.current_size = queue.len();
1289 stats.max_size = stats.max_size.max(queue.len());
1290 }
1291
1292 Ok(())
1293 }
1294 Err(_) => Err(SklearsError::InvalidOperation(
1295 "Queue lock poisoned".to_string(),
1296 )),
1297 }
1298 }
1299
1300 pub fn dequeue(&self) -> SklResult<Option<T>> {
1302 match self.inner.lock() {
1303 Ok(mut queue) => {
1304 let item = queue.pop_front();
1305
1306 if let Ok(mut stats) = self.stats.lock() {
1308 if item.is_some() {
1309 stats.dequeues += 1;
1310 }
1311 stats.current_size = queue.len();
1312 }
1313
1314 Ok(item)
1315 }
1316 Err(_) => Err(SklearsError::InvalidOperation(
1317 "Queue lock poisoned".to_string(),
1318 )),
1319 }
1320 }
1321
1322 pub fn try_dequeue(&self) -> SklResult<Option<T>> {
1324 match self.inner.try_lock() {
1325 Ok(mut queue) => {
1326 let item = queue.pop_front();
1327
1328 if let Ok(mut stats) = self.stats.lock() {
1330 if item.is_some() {
1331 stats.dequeues += 1;
1332 }
1333 stats.current_size = queue.len();
1334 }
1335
1336 Ok(item)
1337 }
1338 Err(std::sync::TryLockError::WouldBlock) => {
1339 if let Ok(mut stats) = self.stats.lock() {
1341 stats.contentions += 1;
1342 }
1343 Ok(None)
1344 }
1345 Err(std::sync::TryLockError::Poisoned(_)) => Err(SklearsError::InvalidOperation(
1346 "Queue lock poisoned".to_string(),
1347 )),
1348 }
1349 }
1350
1351 pub fn len(&self) -> usize {
1353 self.inner.lock().map(|queue| queue.len()).unwrap_or(0)
1354 }
1355
1356 pub fn is_empty(&self) -> bool {
1358 self.len() == 0
1359 }
1360
1361 pub fn get_stats(&self) -> QueueStats {
1363 self.stats
1364 .lock()
1365 .unwrap_or_else(std::sync::PoisonError::into_inner)
1366 .clone()
1367 }
1368}
1369
1370impl<T> Default for LockFreeQueue<T> {
1371 fn default() -> Self {
1372 Self::new()
1373 }
1374}
1375
1376#[derive(Debug)]
1378pub struct WorkStealingDeque<T> {
1379 items: Mutex<std::collections::VecDeque<T>>,
1381 stats: Arc<Mutex<WorkStealingStats>>,
1383}
1384
1385#[derive(Debug, Default, Clone)]
1387pub struct WorkStealingStats {
1388 pub pushes: u64,
1390 pub pops: u64,
1392 pub steal_attempts: u64,
1394 pub successful_steals: u64,
1396 pub current_size: usize,
1398}
1399
1400impl<T> WorkStealingDeque<T> {
1401 #[must_use]
1403 pub fn new() -> Self {
1404 Self {
1405 items: Mutex::new(std::collections::VecDeque::new()),
1406 stats: Arc::new(Mutex::new(WorkStealingStats::default())),
1407 }
1408 }
1409
1410 pub fn push(&self, item: T) -> SklResult<()> {
1412 match self.items.lock() {
1413 Ok(mut items) => {
1414 items.push_back(item);
1415
1416 if let Ok(mut stats) = self.stats.lock() {
1417 stats.pushes += 1;
1418 stats.current_size = items.len();
1419 }
1420
1421 Ok(())
1422 }
1423 Err(_) => Err(SklearsError::InvalidOperation(
1424 "Deque lock poisoned".to_string(),
1425 )),
1426 }
1427 }
1428
1429 pub fn pop(&self) -> SklResult<Option<T>> {
1431 match self.items.lock() {
1432 Ok(mut items) => {
1433 let item = items.pop_back();
1434
1435 if let Ok(mut stats) = self.stats.lock() {
1436 if item.is_some() {
1437 stats.pops += 1;
1438 }
1439 stats.current_size = items.len();
1440 }
1441
1442 Ok(item)
1443 }
1444 Err(_) => Err(SklearsError::InvalidOperation(
1445 "Deque lock poisoned".to_string(),
1446 )),
1447 }
1448 }
1449
1450 pub fn steal(&self) -> SklResult<Option<T>> {
1452 if let Ok(mut stats) = self.stats.lock() {
1454 stats.steal_attempts += 1;
1455 }
1456
1457 match self.items.try_lock() {
1458 Ok(mut items) => {
1459 let item = items.pop_front();
1460
1461 if let Ok(mut stats) = self.stats.lock() {
1462 if item.is_some() {
1463 stats.successful_steals += 1;
1464 }
1465 stats.current_size = items.len();
1466 }
1467
1468 Ok(item)
1469 }
1470 Err(_) => Ok(None), }
1472 }
1473
1474 pub fn len(&self) -> usize {
1476 self.items.lock().map(|items| items.len()).unwrap_or(0)
1477 }
1478
1479 pub fn is_empty(&self) -> bool {
1481 self.len() == 0
1482 }
1483
1484 pub fn get_stats(&self) -> WorkStealingStats {
1486 self.stats
1487 .lock()
1488 .unwrap_or_else(std::sync::PoisonError::into_inner)
1489 .clone()
1490 }
1491}
1492
1493impl<T> Default for WorkStealingDeque<T> {
1494 fn default() -> Self {
1495 Self::new()
1496 }
1497}
1498
1499#[allow(non_snake_case)]
1500#[cfg(test)]
1501mod tests {
1502 use super::*;
1503
1504 #[test]
1505 fn test_zero_cost_step() {
1506 let step: ZeroCostStep<i32, ()> = ZeroCostStep::new(42);
1507 assert_eq!(*step.inner(), 42);
1508 assert_eq!(step.into_inner(), 42);
1509 }
1510
1511 #[test]
1512 fn test_zero_cost_pipeline() {
1513 fn add_one(x: i32) -> i32 {
1514 x + 1
1515 }
1516 fn mul_two(x: i32) -> i32 {
1517 x * 2
1518 }
1519
1520 let pipeline =
1521 ZeroCostPipeline::new([add_one as fn(i32) -> i32, mul_two as fn(i32) -> i32]);
1522 let result = pipeline.execute(5);
1523 assert_eq!(result, 12); }
1525
1526 #[test]
1527 fn test_zero_cost_builder() {
1528 let builder = ZeroCostBuilder::new();
1529 let _pipeline = builder.step("transform").step("estimate").build();
1530 }
1531
1532 #[test]
1533 fn test_zero_cost_conditional() {
1534 let add_one = |x: i32| x + 1;
1535 let mul_two = |x: i32| x * 2;
1536
1537 let conditional_true = ZeroCostConditional::<true, _, _>::new(add_one, mul_two);
1538 assert_eq!(conditional_true.execute(5), 6);
1539
1540 let conditional_false = ZeroCostConditional::<false, _, _>::new(add_one, mul_two);
1541 assert_eq!(conditional_false.execute(5), 10);
1542 }
1543
1544 #[test]
1545 fn test_zero_cost_composition() {
1546 let add_one = |x: f64| x + 1.0;
1547 let mul_two = |x: f64| x * 2.0;
1548
1549 let composition = add_one.compose(mul_two);
1550 assert_eq!(composition.apply(5.0), 12.0); }
1552
1553 #[test]
1554 fn test_zero_cost_buffer() {
1555 let mut buffer: ZeroCostBuffer<i32, 4> = ZeroCostBuffer::new();
1556
1557 assert!(buffer.push(1).is_ok());
1558 assert!(buffer.push(2).is_ok());
1559 assert_eq!(buffer.as_slice(), &[1, 2]);
1560
1561 buffer.clear();
1562 let empty: &[i32] = &[];
1563 assert_eq!(buffer.as_slice(), empty);
1564 }
1565
1566 #[test]
1567 fn test_zero_cost_parallel() {
1568 let parallel: ZeroCostParallel<2> = ZeroCostParallel::new();
1569 let tasks = [|| 1 + 1, || 2 * 2];
1570 let results = parallel.execute::<(), _, _>(tasks);
1571 assert_eq!(results, [2, 4]);
1572 }
1573
1574 #[test]
1575 fn test_zero_cost_layout() {
1576 let layout = ZeroCostLayout::new(vec![1, 2, 3]);
1577 assert_eq!(layout.data(), &vec![1, 2, 3]);
1578 assert_eq!(layout.into_data(), vec![1, 2, 3]);
1579 }
1580
1581 #[test]
1582 fn test_zero_copy_view() {
1583 let data = vec![1, 2, 3, 4, 5, 6];
1584 let view = ZeroCopyView::new(&data, (2, 3), (3, 1));
1585
1586 assert_eq!(view.shape(), (2, 3));
1587 assert_eq!(view.strides(), (3, 1));
1588 assert_eq!(view.get(0, 0), Some(&1));
1589 assert_eq!(view.get(1, 2), Some(&6));
1590 assert_eq!(view.get(2, 0), None); }
1592
1593 #[test]
1594 fn test_zero_copy_view_slice() {
1595 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
1596 let view = ZeroCopyView::new(&data, (3, 3), (3, 1));
1597
1598 let sub_view = view.slice(1..3, 1..3).unwrap();
1599 assert_eq!(sub_view.shape(), (2, 2));
1600 assert_eq!(sub_view.get(0, 0), Some(&5)); }
1602
1603 #[test]
1604 fn test_shared_data() {
1605 let data = SharedData::new(vec![1, 2, 3]);
1606 assert_eq!(data.ref_count(), 1);
1607
1608 let cloned = data.clone();
1609 assert_eq!(data.ref_count(), 2);
1610 assert_eq!(cloned.ref_count(), 2);
1611
1612 drop(cloned);
1613 assert_eq!(data.ref_count(), 1);
1614
1615 let recovered = data.try_unwrap().unwrap();
1616 assert_eq!(recovered, vec![1, 2, 3]);
1617 }
1618
1619 #[test]
1620 fn test_cow_data() {
1621 let mut cow = CowData::owned(vec![1, 2, 3]);
1623
1624 assert_eq!(cow.len(), 3);
1626
1627 cow.to_mut().push(42);
1629 assert_eq!(cow.len(), 4);
1630 assert_eq!(cow[3], 42);
1631
1632 let owned = cow.into_owned();
1633 assert_eq!(owned, vec![1, 2, 3, 42]);
1634 }
1635
1636 #[test]
1637 fn test_arena() {
1638 let mut arena = Arena::new(4);
1639
1640 arena.alloc(10);
1642 arena.alloc(20);
1643 assert_eq!(arena.len(), 2);
1644
1645 arena.alloc_slice(&[30, 40, 50]);
1647 assert_eq!(arena.len(), 5);
1648
1649 arena.alloc(60);
1651 assert_eq!(arena.len(), 6);
1652
1653 arena.clear();
1654 assert_eq!(arena.len(), 0);
1655 assert!(arena.is_empty());
1656 }
1657
1658 #[test]
1659 fn test_memory_pool() {
1660 let mut pool = MemoryPool::new(4, 16);
1661
1662 let mut buffer1 = pool.get_buffer(8);
1664 buffer1.extend_from_slice(&[1, 2, 3]);
1665 assert_eq!(buffer1, vec![1, 2, 3]);
1666 assert!(buffer1.capacity() >= 8);
1667
1668 pool.return_buffer(buffer1);
1670 assert_eq!(pool.pool_size(), 1);
1671
1672 let buffer2 = pool.get_buffer(6);
1674 assert!(buffer2.is_empty());
1675 assert!(buffer2.capacity() >= 6);
1676 }
1677
1678 #[test]
1679 fn test_pooled_buffer() {
1680 let mut pool = MemoryPool::new(4, 16);
1681
1682 {
1683 let mut pooled = pool.get_pooled_buffer(8);
1684 pooled.push(42);
1685 assert_eq!(pooled[0], 42);
1686 assert_eq!(pool.pool_size(), 0); } assert_eq!(pool.pool_size(), 1);
1690 }
1691
1692 #[test]
1693 fn test_zero_copy_slice_trait() {
1694 let data = vec![1, 2, 3, 4, 5];
1695
1696 let slice = data.zero_copy_slice(1, 4).unwrap();
1698 assert_eq!(slice, &[2, 3, 4]);
1699
1700 let chunks: Vec<_> = data.zero_copy_chunks(2).collect();
1702 assert_eq!(chunks, vec![&[1, 2][..], &[3, 4][..], &[5][..]]);
1703
1704 let values: Vec<_> = data.zero_copy_iter().copied().collect();
1706 assert_eq!(values, vec![1, 2, 3, 4, 5]);
1707 }
1708
1709 #[test]
1710 fn test_memory_leak_detector() {
1711 let config = MemoryLeakConfig {
1712 collect_stack_traces: false,
1713 max_age: Duration::from_millis(100),
1714 panic_on_leak: false,
1715 max_tracked_allocations: 100,
1716 };
1717
1718 let detector = MemoryLeakDetector::new(config);
1719
1720 let _tracked = detector.track_allocation::<Vec<i32>>(100, "test_location");
1722
1723 let stats = detector.get_stats();
1724 assert_eq!(stats.total_allocations, 1);
1725 assert_eq!(stats.total_size, 100);
1726
1727 let leaks = detector.check_leaks();
1729 assert!(leaks.is_empty());
1730
1731 std::thread::sleep(Duration::from_millis(150));
1733
1734 let leaks = detector.check_leaks();
1736 assert_eq!(leaks.len(), 1);
1737 assert_eq!(leaks[0].size, 100);
1738
1739 drop(_tracked);
1741
1742 let stats = detector.get_stats();
1744 assert_eq!(stats.total_allocations, 0);
1745 }
1746
1747 #[test]
1748 fn test_safe_concurrent_data() {
1749 let data = SafeConcurrentData::new(vec![1, 2, 3]);
1750
1751 let result = data.read(|v| v.len()).unwrap();
1753 assert_eq!(result, 3);
1754
1755 let result = data
1757 .write(|v| {
1758 v.push(4);
1759 v.len()
1760 })
1761 .unwrap();
1762 assert_eq!(result, 4);
1763
1764 let stats = data.get_stats();
1766 assert_eq!(stats.read_locks, 1);
1767 assert_eq!(stats.write_locks, 1);
1768 assert_eq!(stats.contentions, 0);
1769
1770 let result = data.try_read(|v| v.len()).unwrap();
1772 assert_eq!(result, Some(4));
1773
1774 let final_result = data.read(|v| v.clone()).unwrap();
1776 assert_eq!(final_result, vec![1, 2, 3, 4]);
1777 }
1778
1779 #[test]
1780 fn test_atomic_rc_data() {
1781 let data = AtomicRcData::new(vec![1, 2, 3]);
1782 assert_eq!(data.strong_count(), 1);
1783 assert_eq!(data.weak_count(), 0);
1784
1785 let weak = data.downgrade();
1787 assert_eq!(data.weak_count(), 1);
1788 assert_eq!(weak.strong_count(), 1);
1789
1790 let cloned = data.clone();
1792 assert_eq!(data.strong_count(), 2);
1793 assert_eq!(cloned.strong_count(), 2);
1794
1795 let upgraded = weak.upgrade().unwrap();
1797 assert_eq!(upgraded.strong_count(), 3);
1798
1799 drop(cloned);
1801 drop(upgraded);
1802 assert_eq!(data.strong_count(), 1);
1803
1804 let recovered = data.try_unwrap().unwrap();
1806 assert_eq!(recovered, vec![1, 2, 3]);
1807 }
1808
1809 #[test]
1810 fn test_lock_free_queue() {
1811 let queue = LockFreeQueue::new();
1812 assert!(queue.is_empty());
1813 assert_eq!(queue.len(), 0);
1814
1815 queue.enqueue(1).unwrap();
1817 queue.enqueue(2).unwrap();
1818 queue.enqueue(3).unwrap();
1819
1820 assert_eq!(queue.len(), 3);
1821 assert!(!queue.is_empty());
1822
1823 assert_eq!(queue.dequeue().unwrap(), Some(1));
1825 assert_eq!(queue.dequeue().unwrap(), Some(2));
1826 assert_eq!(queue.len(), 1);
1827
1828 assert_eq!(queue.try_dequeue().unwrap(), Some(3));
1830 assert_eq!(queue.try_dequeue().unwrap(), None);
1831
1832 assert!(queue.is_empty());
1833
1834 let stats = queue.get_stats();
1836 assert_eq!(stats.enqueues, 3);
1837 assert_eq!(stats.dequeues, 3);
1838 assert_eq!(stats.current_size, 0);
1839 assert_eq!(stats.max_size, 3);
1840 }
1841
1842 #[test]
1843 fn test_work_stealing_deque() {
1844 let deque = WorkStealingDeque::new();
1845 assert!(deque.is_empty());
1846 assert_eq!(deque.len(), 0);
1847
1848 deque.push(1).unwrap();
1850 deque.push(2).unwrap();
1851 deque.push(3).unwrap();
1852
1853 assert_eq!(deque.len(), 3);
1854 assert!(!deque.is_empty());
1855
1856 assert_eq!(deque.pop().unwrap(), Some(3));
1858 assert_eq!(deque.len(), 2);
1859
1860 assert_eq!(deque.steal().unwrap(), Some(1));
1862 assert_eq!(deque.len(), 1);
1863
1864 assert_eq!(deque.pop().unwrap(), Some(2));
1866 assert!(deque.is_empty());
1867
1868 assert_eq!(deque.steal().unwrap(), None);
1870
1871 let stats = deque.get_stats();
1873 assert_eq!(stats.pushes, 3);
1874 assert_eq!(stats.pops, 2);
1875 assert_eq!(stats.steal_attempts, 2);
1876 assert_eq!(stats.successful_steals, 1);
1877 assert_eq!(stats.current_size, 0);
1878 }
1879}