tract_linalg/frame/
element_wise_helper.rs1use crate::LADatum;
2use std::alloc::*;
3use tract_data::TractResult;
4
5pub(crate) fn map_slice_with_alignment<T>(
6 vec: &mut [T],
7 f: impl Fn(&mut [T]),
8 nr: usize,
9 alignment_bytes: usize,
10) -> TractResult<()>
11where
12 T: LADatum,
13{
14 if vec.is_empty() {
15 return Ok(());
16 }
17 unsafe {
18 TMP.with(|buffer| {
19 let mut buffer = buffer.borrow_mut();
20 buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes);
21 let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr);
22 let mut compute_via_temp_buffer = |slice: &mut [T]| {
23 tmp[..slice.len()].copy_from_slice(slice);
24 f(tmp);
25 slice.copy_from_slice(&tmp[..slice.len()])
26 };
27 let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len());
28 if prefix_len > 0 {
29 compute_via_temp_buffer(&mut vec[..prefix_len]);
30 }
31 let aligned_len = (vec.len() - prefix_len) / nr * nr;
32 if aligned_len > 0 {
33 f(&mut vec[prefix_len..][..aligned_len]);
34 }
35 if prefix_len + aligned_len < vec.len() {
36 compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..]);
37 }
38 })
39 }
40 Ok(())
41}
42
43pub(crate) fn reduce_slice_with_alignment<T>(
44 vec: &[T],
45 f: impl Fn(&[T]) -> T,
46 nr: usize,
47 alignment_bytes: usize,
48 neutral: T,
49 reduce: impl Fn(T, T) -> T,
50) -> TractResult<T>
51where
52 T: LADatum,
53{
54 if vec.is_empty() {
55 return Ok(neutral);
56 }
57 let mut red = neutral;
58 unsafe {
59 TMP.with(|buffer| {
60 let mut buffer = buffer.borrow_mut();
61 buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes);
62 let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr);
63 let mut compute_via_temp_buffer = |slice: &[T], red: &mut T| {
64 tmp[..slice.len()].copy_from_slice(slice);
65 tmp[slice.len()..].fill(neutral);
66 *red = reduce(*red, f(tmp));
67 };
68 let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len());
69 if prefix_len > 0 {
70 compute_via_temp_buffer(&vec[..prefix_len], &mut red);
71 }
72 let aligned_len = (vec.len() - prefix_len) / nr * nr;
73 if aligned_len > 0 {
74 let t = f(&vec[prefix_len..][..aligned_len]);
75 red = reduce(red, t);
76 }
77 if prefix_len + aligned_len < vec.len() {
78 compute_via_temp_buffer(&vec[prefix_len + aligned_len..], &mut red);
79 }
80 })
81 }
82 Ok(red)
83}
84
85pub(crate) fn map_reduce_slice_with_alignment<T>(
86 vec: &mut [T],
87 f: impl Fn(&mut [T]) -> T,
88 nr: usize,
89 alignment_bytes: usize,
90 map_neutral: T,
91 neutral: T,
92 reduce: impl Fn(T, T) -> T,
93) -> TractResult<T>
94where
95 T: LADatum,
96{
97 if vec.is_empty() {
98 return Ok(neutral);
99 }
100 let mut red = neutral;
101 unsafe {
102 TMP.with(|buffer| {
103 let mut buffer = buffer.borrow_mut();
104 buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes);
105 let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr);
106 let mut compute_via_temp_buffer = |slice: &mut [T], red: &mut T| {
107 tmp[..slice.len()].copy_from_slice(slice);
108 tmp[slice.len()..].fill(map_neutral);
109 *red = reduce(*red, f(tmp));
110 slice.copy_from_slice(&tmp[..slice.len()]);
111 };
112 let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len());
113 if prefix_len > 0 {
114 compute_via_temp_buffer(&mut vec[..prefix_len], &mut red);
115 }
116 let aligned_len = (vec.len() - prefix_len) / nr * nr;
117 if aligned_len > 0 {
118 let t = f(&mut vec[prefix_len..][..aligned_len]);
119 red = reduce(red, t);
120 }
121 if prefix_len + aligned_len < vec.len() {
122 compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..], &mut red);
123 }
124 })
125 }
126 Ok(red)
127}
128
129std::thread_local! {
130 static TMP: std::cell::RefCell<TempBuffer> = std::cell::RefCell::new(TempBuffer::default());
131}
132
133pub struct TempBuffer {
134 pub layout: Layout,
135 pub buffer: *mut u8,
136}
137
138impl Default for TempBuffer {
139 fn default() -> Self {
140 TempBuffer { layout: Layout::new::<()>(), buffer: std::ptr::null_mut() }
141 }
142}
143
144impl TempBuffer {
145 pub fn ensure(&mut self, size: usize, alignment: usize) {
146 unsafe {
147 if size > self.layout.size() || alignment > self.layout.align() {
148 let size = size.max(self.layout.size());
149 let alignment = alignment.max(self.layout.align());
150 if !self.buffer.is_null() {
151 std::alloc::dealloc(self.buffer, self.layout);
152 }
153 self.layout = Layout::from_size_align_unchecked(size, alignment);
154 self.buffer = std::alloc::alloc(self.layout);
155 assert!(!self.buffer.is_null());
156 }
157 }
158 }
159}
160
161impl Drop for TempBuffer {
162 fn drop(&mut self) {
163 unsafe {
164 if !self.buffer.is_null() {
165 std::alloc::dealloc(self.buffer, self.layout);
166 }
167 }
168 }
169}