tract_linalg/frame/
element_wise_helper.rs

1use crate::LADatum;
2use std::alloc::*;
3use tract_data::TractResult;
4
5pub(crate) fn map_slice_with_alignment<T>(
6    vec: &mut [T],
7    f: impl Fn(&mut [T]),
8    nr: usize,
9    alignment_bytes: usize,
10) -> TractResult<()>
11where
12    T: LADatum,
13{
14    if vec.is_empty() {
15        return Ok(());
16    }
17    unsafe {
18        TMP.with(|buffer| {
19            let mut buffer = buffer.borrow_mut();
20            buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes);
21            let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr);
22            let mut compute_via_temp_buffer = |slice: &mut [T]| {
23                tmp[..slice.len()].copy_from_slice(slice);
24                f(tmp);
25                slice.copy_from_slice(&tmp[..slice.len()])
26            };
27            let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len());
28            if prefix_len > 0 {
29                compute_via_temp_buffer(&mut vec[..prefix_len]);
30            }
31            let aligned_len = (vec.len() - prefix_len) / nr * nr;
32            if aligned_len > 0 {
33                f(&mut vec[prefix_len..][..aligned_len]);
34            }
35            if prefix_len + aligned_len < vec.len() {
36                compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..]);
37            }
38        })
39    }
40    Ok(())
41}
42
43pub(crate) fn reduce_slice_with_alignment<T>(
44    vec: &[T],
45    f: impl Fn(&[T]) -> T,
46    nr: usize,
47    alignment_bytes: usize,
48    neutral: T,
49    reduce: impl Fn(T, T) -> T,
50) -> TractResult<T>
51where
52    T: LADatum,
53{
54    if vec.is_empty() {
55        return Ok(neutral);
56    }
57    let mut red = neutral;
58    unsafe {
59        TMP.with(|buffer| {
60            let mut buffer = buffer.borrow_mut();
61            buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes);
62            let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr);
63            let mut compute_via_temp_buffer = |slice: &[T], red: &mut T| {
64                tmp[..slice.len()].copy_from_slice(slice);
65                tmp[slice.len()..].fill(neutral);
66                *red = reduce(*red, f(tmp));
67            };
68            let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len());
69            if prefix_len > 0 {
70                compute_via_temp_buffer(&vec[..prefix_len], &mut red);
71            }
72            let aligned_len = (vec.len() - prefix_len) / nr * nr;
73            if aligned_len > 0 {
74                let t = f(&vec[prefix_len..][..aligned_len]);
75                red = reduce(red, t);
76            }
77            if prefix_len + aligned_len < vec.len() {
78                compute_via_temp_buffer(&vec[prefix_len + aligned_len..], &mut red);
79            }
80        })
81    }
82    Ok(red)
83}
84
85pub(crate) fn map_reduce_slice_with_alignment<T>(
86    vec: &mut [T],
87    f: impl Fn(&mut [T]) -> T,
88    nr: usize,
89    alignment_bytes: usize,
90    map_neutral: T,
91    neutral: T,
92    reduce: impl Fn(T, T) -> T,
93) -> TractResult<T>
94where
95    T: LADatum,
96{
97    if vec.is_empty() {
98        return Ok(neutral);
99    }
100    let mut red = neutral;
101    unsafe {
102        TMP.with(|buffer| {
103            let mut buffer = buffer.borrow_mut();
104            buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes);
105            let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr);
106            let mut compute_via_temp_buffer = |slice: &mut [T], red: &mut T| {
107                tmp[..slice.len()].copy_from_slice(slice);
108                tmp[slice.len()..].fill(map_neutral);
109                *red = reduce(*red, f(tmp));
110                slice.copy_from_slice(&tmp[..slice.len()]);
111            };
112            let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len());
113            if prefix_len > 0 {
114                compute_via_temp_buffer(&mut vec[..prefix_len], &mut red);
115            }
116            let aligned_len = (vec.len() - prefix_len) / nr * nr;
117            if aligned_len > 0 {
118                let t = f(&mut vec[prefix_len..][..aligned_len]);
119                red = reduce(red, t);
120            }
121            if prefix_len + aligned_len < vec.len() {
122                compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..], &mut red);
123            }
124        })
125    }
126    Ok(red)
127}
128
129std::thread_local! {
130    static TMP: std::cell::RefCell<TempBuffer> = std::cell::RefCell::new(TempBuffer::default());
131}
132
133pub struct TempBuffer {
134    pub layout: Layout,
135    pub buffer: *mut u8,
136}
137
138impl Default for TempBuffer {
139    fn default() -> Self {
140        TempBuffer { layout: Layout::new::<()>(), buffer: std::ptr::null_mut() }
141    }
142}
143
144impl TempBuffer {
145    pub fn ensure(&mut self, size: usize, alignment: usize) {
146        unsafe {
147            if size > self.layout.size() || alignment > self.layout.align() {
148                let size = size.max(self.layout.size());
149                let alignment = alignment.max(self.layout.align());
150                if !self.buffer.is_null() {
151                    std::alloc::dealloc(self.buffer, self.layout);
152                }
153                self.layout = Layout::from_size_align_unchecked(size, alignment);
154                self.buffer = std::alloc::alloc(self.layout);
155                assert!(!self.buffer.is_null());
156            }
157        }
158    }
159}
160
161impl Drop for TempBuffer {
162    fn drop(&mut self) {
163        unsafe {
164            if !self.buffer.is_null() {
165                std::alloc::dealloc(self.buffer, self.layout);
166            }
167        }
168    }
169}