1use std::marker::PhantomData;
14use std::mem;
15use std::ops::{Deref, DerefMut};
16use std::slice;
17use std::sync::Arc;
18
19use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub enum Alignment {
24 None,
26 Align16,
28 Align32,
30 Align64,
32 Custom(usize),
34}
35
36impl Alignment {
37 #[must_use]
39 pub const fn bytes(&self) -> usize {
40 match self {
41 Alignment::None => 1,
42 Alignment::Align16 => 16,
43 Alignment::Align32 => 32,
44 Alignment::Align64 => 64,
45 Alignment::Custom(n) => *n,
46 }
47 }
48
49 #[must_use]
51 pub fn is_aligned<T>(&self, ptr: *const T) -> bool {
52 (ptr as usize) % self.bytes() == 0
53 }
54
55 #[must_use]
57 pub const fn optimal_for_simd<T>() -> Self {
58 let size = mem::size_of::<T>();
59 if size >= 8 {
60 Alignment::Align64
61 } else if size >= 4 {
62 Alignment::Align32
63 } else {
64 Alignment::Align16
65 }
66 }
67}
68
69impl Default for Alignment {
70 fn default() -> Self {
71 Alignment::None
72 }
73}
74
75#[derive(Debug, Clone, PartialEq)]
77pub struct MemoryLayout {
78 pub shape: Vec<usize>,
80 pub strides: Vec<isize>,
82 pub len: usize,
84 pub element_size: usize,
86 pub alignment: Alignment,
88 pub is_contiguous: bool,
90 pub is_c_contiguous: bool,
92 pub is_f_contiguous: bool,
94}
95
96impl MemoryLayout {
97 #[must_use]
99 pub fn contiguous<T>(shape: &[usize]) -> Self {
100 let element_size = mem::size_of::<T>();
101 let len: usize = shape.iter().product();
102 let strides = Self::compute_c_strides(shape, element_size);
103
104 Self {
105 shape: shape.to_vec(),
106 strides,
107 len,
108 element_size,
109 alignment: Alignment::optimal_for_simd::<T>(),
110 is_contiguous: true,
111 is_c_contiguous: true,
112 is_f_contiguous: shape.len() <= 1,
113 }
114 }
115
116 #[must_use]
118 pub fn fortran_contiguous<T>(shape: &[usize]) -> Self {
119 let element_size = mem::size_of::<T>();
120 let len: usize = shape.iter().product();
121 let strides = Self::compute_f_strides(shape, element_size);
122
123 Self {
124 shape: shape.to_vec(),
125 strides,
126 len,
127 element_size,
128 alignment: Alignment::optimal_for_simd::<T>(),
129 is_contiguous: true,
130 is_c_contiguous: shape.len() <= 1,
131 is_f_contiguous: true,
132 }
133 }
134
135 fn compute_c_strides(shape: &[usize], element_size: usize) -> Vec<isize> {
137 let ndim = shape.len();
138 if ndim == 0 {
139 return vec![];
140 }
141
142 let mut strides = vec![0isize; ndim];
143 strides[ndim - 1] = element_size as isize;
144
145 for i in (0..ndim - 1).rev() {
146 strides[i] = strides[i + 1] * (shape[i + 1] as isize);
147 }
148
149 strides
150 }
151
152 fn compute_f_strides(shape: &[usize], element_size: usize) -> Vec<isize> {
154 let ndim = shape.len();
155 if ndim == 0 {
156 return vec![];
157 }
158
159 let mut strides = vec![0isize; ndim];
160 strides[0] = element_size as isize;
161
162 for i in 1..ndim {
163 strides[i] = strides[i - 1] * (shape[i - 1] as isize);
164 }
165
166 strides
167 }
168
169 #[must_use]
171 pub fn ndim(&self) -> usize {
172 self.shape.len()
173 }
174
175 #[must_use]
177 pub fn is_compatible(&self, other: &Self) -> bool {
178 self.shape == other.shape
179 && self.element_size == other.element_size
180 && self.is_contiguous
181 && other.is_contiguous
182 }
183
184 #[must_use]
186 pub fn size_bytes(&self) -> usize {
187 self.len * self.element_size
188 }
189}
190
191pub trait ContiguousMemory {
193 fn as_ptr(&self) -> *const u8;
195
196 fn layout(&self) -> &MemoryLayout;
198
199 fn is_contiguous(&self) -> bool {
201 self.layout().is_contiguous
202 }
203
204 fn size_bytes(&self) -> usize {
206 self.layout().size_bytes()
207 }
208}
209
210pub trait ContiguousMemoryMut: ContiguousMemory {
212 fn as_mut_ptr(&mut self) -> *mut u8;
214}
215
216#[derive(Debug)]
218pub struct SharedArrayView<'a, T> {
219 ptr: *const T,
221 len: usize,
223 layout: MemoryLayout,
225 _marker: PhantomData<&'a T>,
227}
228
229impl<'a, T> SharedArrayView<'a, T> {
230 #[must_use]
232 pub fn from_slice(data: &'a [T]) -> Self {
233 let layout = MemoryLayout::contiguous::<T>(&[data.len()]);
234 Self {
235 ptr: data.as_ptr(),
236 len: data.len(),
237 layout,
238 _marker: PhantomData,
239 }
240 }
241
242 pub unsafe fn from_raw_parts(ptr: *const T, layout: MemoryLayout) -> Self {
251 Self {
252 ptr,
253 len: layout.len,
254 layout,
255 _marker: PhantomData,
256 }
257 }
258
259 #[must_use]
261 pub const fn len(&self) -> usize {
262 self.len
263 }
264
265 #[must_use]
267 pub const fn is_empty(&self) -> bool {
268 self.len == 0
269 }
270
271 #[must_use]
273 pub const fn layout(&self) -> &MemoryLayout {
274 &self.layout
275 }
276
277 #[must_use]
279 pub fn shape(&self) -> &[usize] {
280 &self.layout.shape
281 }
282
283 #[must_use]
289 pub unsafe fn get_unchecked(&self, index: usize) -> &T {
290 &*self.ptr.add(index)
291 }
292
293 pub fn get(&self, index: usize) -> Option<&T> {
295 if index < self.len {
296 Some(unsafe { self.get_unchecked(index) })
298 } else {
299 None
300 }
301 }
302
303 pub fn as_slice(&self) -> Option<&'a [T]> {
305 if self.layout.is_contiguous {
306 Some(unsafe { slice::from_raw_parts(self.ptr, self.len) })
308 } else {
309 None
310 }
311 }
312
313 pub fn slice(&self, start: usize, end: usize) -> CoreResult<SharedArrayView<'a, T>> {
315 if start > end || end > self.len {
316 return Err(CoreError::ValidationError(
317 ErrorContext::new(format!(
318 "Invalid slice range [{start}, {end}) for length {len}",
319 len = self.len
320 ))
321 .with_location(ErrorLocation::new(file!(), line!())),
322 ));
323 }
324
325 let new_len = end - start;
326 let new_layout = MemoryLayout::contiguous::<T>(&[new_len]);
327
328 Ok(SharedArrayView {
330 ptr: unsafe { self.ptr.add(start) },
331 len: new_len,
332 layout: new_layout,
333 _marker: PhantomData,
334 })
335 }
336
337 #[must_use]
339 pub fn is_simd_aligned(&self) -> bool {
340 self.layout.alignment.is_aligned(self.ptr)
341 }
342}
343
344impl<'a, T> ContiguousMemory for SharedArrayView<'a, T> {
345 fn as_ptr(&self) -> *const u8 {
346 self.ptr as *const u8
347 }
348
349 fn layout(&self) -> &MemoryLayout {
350 &self.layout
351 }
352}
353
354unsafe impl<T: Send + Sync> Send for SharedArrayView<'_, T> {}
356
357unsafe impl<T: Sync> Sync for SharedArrayView<'_, T> {}
359
360#[derive(Debug)]
362pub struct SharedArrayViewMut<'a, T> {
363 ptr: *mut T,
365 len: usize,
367 layout: MemoryLayout,
369 _marker: PhantomData<&'a mut T>,
371}
372
373impl<'a, T> SharedArrayViewMut<'a, T> {
374 #[must_use]
376 pub fn from_slice(data: &'a mut [T]) -> Self {
377 let layout = MemoryLayout::contiguous::<T>(&[data.len()]);
378 Self {
379 ptr: data.as_mut_ptr(),
380 len: data.len(),
381 layout,
382 _marker: PhantomData,
383 }
384 }
385
386 pub unsafe fn from_raw_parts(ptr: *mut T, layout: MemoryLayout) -> Self {
396 Self {
397 ptr,
398 len: layout.len,
399 layout,
400 _marker: PhantomData,
401 }
402 }
403
404 #[must_use]
406 pub const fn len(&self) -> usize {
407 self.len
408 }
409
410 #[must_use]
412 pub const fn is_empty(&self) -> bool {
413 self.len == 0
414 }
415
416 #[must_use]
418 pub const fn layout(&self) -> &MemoryLayout {
419 &self.layout
420 }
421
422 pub fn get(&self, index: usize) -> Option<&T> {
424 if index < self.len {
425 Some(unsafe { &*self.ptr.add(index) })
427 } else {
428 None
429 }
430 }
431
432 pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
434 if index < self.len {
435 Some(unsafe { &mut *self.ptr.add(index) })
437 } else {
438 None
439 }
440 }
441
442 #[must_use]
444 pub fn as_view(&self) -> SharedArrayView<'_, T> {
445 SharedArrayView {
446 ptr: self.ptr,
447 len: self.len,
448 layout: self.layout.clone(),
449 _marker: PhantomData,
450 }
451 }
452
453 pub fn as_slice(&self) -> Option<&[T]> {
455 if self.layout.is_contiguous {
456 Some(unsafe { slice::from_raw_parts(self.ptr, self.len) })
458 } else {
459 None
460 }
461 }
462
463 pub fn as_mut_slice(&mut self) -> Option<&mut [T]> {
465 if self.layout.is_contiguous {
466 Some(unsafe { slice::from_raw_parts_mut(self.ptr, self.len) })
468 } else {
469 None
470 }
471 }
472}
473
474impl<'a, T> ContiguousMemory for SharedArrayViewMut<'a, T> {
475 fn as_ptr(&self) -> *const u8 {
476 self.ptr as *const u8
477 }
478
479 fn layout(&self) -> &MemoryLayout {
480 &self.layout
481 }
482}
483
484impl<'a, T> ContiguousMemoryMut for SharedArrayViewMut<'a, T> {
485 fn as_mut_ptr(&mut self) -> *mut u8 {
486 self.ptr as *mut u8
487 }
488}
489
490unsafe impl<T: Send> Send for SharedArrayViewMut<'_, T> {}
492
493unsafe impl<T: Send + Sync> Sync for SharedArrayViewMut<'_, T> {}
495
496#[derive(Debug)]
498pub struct ZeroCopyBuffer {
499 data: Arc<[u8]>,
501 layout: MemoryLayout,
503 type_id: std::any::TypeId,
505 type_name: &'static str,
507}
508
509impl ZeroCopyBuffer {
510 pub fn from_vec<T: 'static + Clone>(data: Vec<T>) -> Self {
512 let layout = MemoryLayout::contiguous::<T>(&[data.len()]);
513 let type_id = std::any::TypeId::of::<T>();
514 let type_name = std::any::type_name::<T>();
515
516 let byte_len = data.len() * mem::size_of::<T>();
518 let ptr = data.as_ptr() as *const u8;
519
520 let bytes = unsafe { slice::from_raw_parts(ptr, byte_len) };
522 let arc_bytes: Arc<[u8]> = bytes.into();
523
524 mem::forget(data);
526
527 Self {
528 data: arc_bytes,
529 layout,
530 type_id,
531 type_name,
532 }
533 }
534
535 pub fn as_typed<T: 'static>(&self) -> Option<&[T]> {
537 if std::any::TypeId::of::<T>() != self.type_id {
538 return None;
539 }
540
541 if !self.layout.is_contiguous {
542 return None;
543 }
544
545 Some(unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.layout.len) })
547 }
548
549 #[must_use]
551 pub const fn layout(&self) -> &MemoryLayout {
552 &self.layout
553 }
554
555 #[must_use]
557 pub const fn type_name(&self) -> &'static str {
558 self.type_name
559 }
560
561 #[must_use]
563 pub fn as_bytes(&self) -> &[u8] {
564 &self.data
565 }
566
567 #[must_use]
569 pub const fn len(&self) -> usize {
570 self.layout.len
571 }
572
573 #[must_use]
575 pub const fn is_empty(&self) -> bool {
576 self.layout.len == 0
577 }
578}
579
580impl Clone for ZeroCopyBuffer {
581 fn clone(&self) -> Self {
582 Self {
583 data: Arc::clone(&self.data),
584 layout: self.layout.clone(),
585 type_id: self.type_id,
586 type_name: self.type_name,
587 }
588 }
589}
590
591#[derive(Debug)]
593pub struct ZeroCopySlice<'a, T> {
594 buffer: &'a ZeroCopyBuffer,
596 start: usize,
598 end: usize,
600 _marker: PhantomData<T>,
602}
603
604impl<'a, T: 'static> ZeroCopySlice<'a, T> {
605 pub fn new(buffer: &'a ZeroCopyBuffer, start: usize, end: usize) -> CoreResult<Self> {
607 if std::any::TypeId::of::<T>() != buffer.type_id {
608 return Err(CoreError::ValidationError(
609 ErrorContext::new(format!(
610 "Type mismatch: buffer is {buf_type}, requested {req_type}",
611 buf_type = buffer.type_name,
612 req_type = std::any::type_name::<T>()
613 ))
614 .with_location(ErrorLocation::new(file!(), line!())),
615 ));
616 }
617
618 if start > end || end > buffer.layout.len {
619 return Err(CoreError::ValidationError(
620 ErrorContext::new(format!(
621 "Invalid slice range [{start}, {end}) for length {len}",
622 len = buffer.layout.len
623 ))
624 .with_location(ErrorLocation::new(file!(), line!())),
625 ));
626 }
627
628 Ok(Self {
629 buffer,
630 start,
631 end,
632 _marker: PhantomData,
633 })
634 }
635
636 #[must_use]
638 pub fn as_slice(&self) -> &[T] {
639 let full_slice: &[T] = self.buffer.as_typed().expect("Type already validated");
640 &full_slice[self.start..self.end]
641 }
642
643 #[must_use]
645 pub const fn len(&self) -> usize {
646 self.end - self.start
647 }
648
649 #[must_use]
651 pub const fn is_empty(&self) -> bool {
652 self.start == self.end
653 }
654}
655
656impl<'a, T: 'static> Deref for ZeroCopySlice<'a, T> {
657 type Target = [T];
658
659 fn deref(&self) -> &Self::Target {
660 self.as_slice()
661 }
662}
663
664#[derive(Debug)]
666pub struct ArrayBridge<T> {
667 data: Vec<T>,
669 layout: MemoryLayout,
671}
672
673impl<T: Clone> ArrayBridge<T> {
674 #[must_use]
676 pub fn from_vec(data: Vec<T>) -> Self {
677 let layout = MemoryLayout::contiguous::<T>(&[data.len()]);
678 Self { data, layout }
679 }
680
681 #[must_use]
683 pub fn from_slice(data: &[T]) -> Self {
684 Self::from_vec(data.to_vec())
685 }
686
687 pub fn with_shape(data: Vec<T>, shape: &[usize]) -> CoreResult<Self> {
689 let expected_len: usize = shape.iter().product();
690 if data.len() != expected_len {
691 return Err(CoreError::ValidationError(
692 ErrorContext::new(format!(
693 "Data length {actual} does not match shape {shape:?} (expected {expected})",
694 actual = data.len(),
695 expected = expected_len
696 ))
697 .with_location(ErrorLocation::new(file!(), line!())),
698 ));
699 }
700
701 let layout = MemoryLayout::contiguous::<T>(shape);
702 Ok(Self { data, layout })
703 }
704
705 #[must_use]
707 pub fn view(&self) -> SharedArrayView<'_, T> {
708 SharedArrayView::from_slice(&self.data)
709 }
710
711 #[must_use]
713 pub fn view_mut(&mut self) -> SharedArrayViewMut<'_, T> {
714 SharedArrayViewMut::from_slice(&mut self.data)
715 }
716
717 #[must_use]
719 pub fn as_slice(&self) -> &[T] {
720 &self.data
721 }
722
723 #[must_use]
725 pub fn as_mut_slice(&mut self) -> &mut [T] {
726 &mut self.data
727 }
728
729 #[must_use]
731 pub const fn layout(&self) -> &MemoryLayout {
732 &self.layout
733 }
734
735 #[must_use]
737 pub fn shape(&self) -> &[usize] {
738 &self.layout.shape
739 }
740
741 #[must_use]
743 pub fn len(&self) -> usize {
744 self.data.len()
745 }
746
747 #[must_use]
749 pub fn is_empty(&self) -> bool {
750 self.data.is_empty()
751 }
752
753 #[must_use]
755 pub fn into_vec(self) -> Vec<T> {
756 self.data
757 }
758
759 pub fn reshape(&mut self, new_shape: &[usize]) -> CoreResult<()> {
761 let expected_len: usize = new_shape.iter().product();
762 if self.data.len() != expected_len {
763 return Err(CoreError::ValidationError(
764 ErrorContext::new(format!(
765 "Cannot reshape array of length {} to shape {new_shape:?}",
766 self.data.len()
767 ))
768 .with_location(ErrorLocation::new(file!(), line!())),
769 ));
770 }
771
772 self.layout = MemoryLayout::contiguous::<T>(new_shape);
773 Ok(())
774 }
775}
776
777impl<T: Clone> Clone for ArrayBridge<T> {
778 fn clone(&self) -> Self {
779 Self {
780 data: self.data.clone(),
781 layout: self.layout.clone(),
782 }
783 }
784}
785
786impl<T> Deref for ArrayBridge<T> {
787 type Target = [T];
788
789 fn deref(&self) -> &Self::Target {
790 &self.data
791 }
792}
793
794impl<T> DerefMut for ArrayBridge<T> {
795 fn deref_mut(&mut self) -> &mut Self::Target {
796 &mut self.data
797 }
798}
799
800pub type TypedBuffer<T> = ArrayBridge<T>;
802
803pub type BufferRef<'a, T> = SharedArrayView<'a, T>;
805
806pub type BufferMut<'a, T> = SharedArrayViewMut<'a, T>;
808
809pub type BorrowedArray<'a, T> = SharedArrayView<'a, T>;
811
812pub type OwnedArray<T> = ArrayBridge<T>;
814
815#[cfg(test)]
816mod tests {
817 use super::*;
818
819 #[test]
820 fn test_shared_array_view() {
821 let data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
822 let view = SharedArrayView::from_slice(&data);
823
824 assert_eq!(view.len(), 5);
825 assert!(!view.is_empty());
826 assert_eq!(view.get(0), Some(&1.0));
827 assert_eq!(view.get(4), Some(&5.0));
828 assert_eq!(view.get(5), None);
829 }
830
831 #[test]
832 fn test_shared_array_view_slice() {
833 let data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
834 let view = SharedArrayView::from_slice(&data);
835
836 let subview = view.slice(1, 4).expect("Slice should succeed");
837 assert_eq!(subview.len(), 3);
838 assert_eq!(subview.get(0), Some(&2.0));
839 assert_eq!(subview.get(2), Some(&4.0));
840 }
841
842 #[test]
843 fn test_shared_array_view_mut() {
844 let mut data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
845 let mut view = SharedArrayViewMut::from_slice(&mut data);
846
847 if let Some(elem) = view.get_mut(2) {
848 *elem = 10.0;
849 }
850
851 assert_eq!(view.get(2), Some(&10.0));
852 assert_eq!(data[2], 10.0);
853 }
854
855 #[test]
856 fn test_memory_layout() {
857 let layout = MemoryLayout::contiguous::<f64>(&[3, 4]);
858
859 assert_eq!(layout.ndim(), 2);
860 assert_eq!(layout.len, 12);
861 assert_eq!(layout.element_size, 8);
862 assert!(layout.is_contiguous);
863 assert!(layout.is_c_contiguous);
864 }
865
866 #[test]
867 fn test_array_bridge() {
868 let data = vec![1, 2, 3, 4, 5, 6];
869 let mut bridge = ArrayBridge::with_shape(data, &[2, 3]).expect("Shape should be valid");
870
871 assert_eq!(bridge.shape(), &[2, 3]);
872 assert_eq!(bridge.len(), 6);
873
874 bridge.reshape(&[3, 2]).expect("Reshape should succeed");
875 assert_eq!(bridge.shape(), &[3, 2]);
876 }
877
878 #[test]
879 fn test_zero_copy_buffer() {
880 let data = vec![1.0f32, 2.0, 3.0, 4.0];
881 let buffer = ZeroCopyBuffer::from_vec(data);
882
883 assert_eq!(buffer.len(), 4);
884 assert_eq!(buffer.type_name(), "f32");
885
886 let typed: &[f32] = buffer.as_typed().expect("Type should match");
887 assert_eq!(typed, &[1.0f32, 2.0, 3.0, 4.0]);
888
889 let wrong: Option<&[f64]> = buffer.as_typed();
891 assert!(wrong.is_none());
892 }
893
894 #[test]
895 fn test_zero_copy_slice() {
896 let data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
897 let buffer = ZeroCopyBuffer::from_vec(data);
898
899 let slice: ZeroCopySlice<'_, f64> =
900 ZeroCopySlice::new(&buffer, 1, 4).expect("Slice should be valid");
901
902 assert_eq!(slice.len(), 3);
903 assert_eq!(slice.as_slice(), &[2.0, 3.0, 4.0]);
904 }
905
906 #[test]
907 fn test_alignment() {
908 assert_eq!(Alignment::None.bytes(), 1);
909 assert_eq!(Alignment::Align16.bytes(), 16);
910 assert_eq!(Alignment::Align32.bytes(), 32);
911 assert_eq!(Alignment::Align64.bytes(), 64);
912 assert_eq!(Alignment::Custom(128).bytes(), 128);
913 }
914
915 #[test]
916 fn test_contiguous_memory_trait() {
917 let data = vec![1.0f64, 2.0, 3.0];
918 let view = SharedArrayView::from_slice(&data);
919
920 assert!(view.is_contiguous());
921 assert_eq!(view.size_bytes(), 24); }
923}