tf_idf_vectorizer/utils/datastruct/vector/
tf.rs1use std::{alloc::Layout, iter::FusedIterator, mem, ptr::NonNull};
2
3use num_traits::Num;
4
5#[allow(dead_code)]
6const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<u8>>();
7static_assertions::const_assert!(TF_VECTOR_SIZE == 32);
8
9pub trait TFVectorTrait<N>
10where N: Num + Copy
11{
12 fn len(&self) -> u32;
13 fn nnz(&self) -> u32;
14 fn cap(&self) -> u32;
15 fn term_sum(&self) -> u32;
16 fn new() -> Self;
17 fn new_with_capacity(capacity: u32) -> Self;
18 fn shrink_to_fit(&mut self);
19 fn raw_iter(&self) -> RawTFVectorIter<'_, N>;
20 unsafe fn from_vec(ind_vec: Vec<u32>, val_vec: Vec<N>, len: u32, term_sum: u32) -> Self;
21 unsafe fn ind_ptr(&self) -> *mut u32;
22 unsafe fn val_ptr(&self) -> *mut N;
23 #[inline(always)]
26 unsafe fn power_jump_search(&self, target: u32, start: usize) -> Option<(N, usize)>
27 where
28 N: Copy,
29 {
30 let nnz = self.nnz() as usize;
31 if start >= nnz {
32 return None;
33 }
34
35 let ind = unsafe { core::slice::from_raw_parts(self.ind_ptr(), nnz) };
36 let val = unsafe { core::slice::from_raw_parts(self.val_ptr(), nnz) };
37
38 let mut lo = start;
40 let mut hi = start;
41
42 let s = ind[hi];
43 if s == target {
44 return Some((val[hi], hi));
45 }
46 if s > target {
47 return None; }
49
50 let mut step = 1usize;
52 loop {
53 let next_hi = hi + step;
54 if next_hi >= nnz {
55 hi = nnz - 1;
56 break;
57 }
58 hi = next_hi;
59
60 if ind[hi] >= target {
61 break;
62 }
63
64 lo = hi;
65 step <<= 1;
66 }
67
68 let mut l = lo + 1;
70 let mut r = hi + 1; while l < r {
72 let m = (l + r) >> 1;
73 if ind[m] < target {
74 l = m + 1;
75 } else {
76 r = m;
77 }
78 }
79
80 if l < nnz && ind[l] == target {
81 Some((val[l], l))
82 } else {
83 None
84 }
85 }
86 #[inline(always)]
87 fn get_power_jump(&self, target: u32, cut_down: &mut usize) -> Option<N>
88 where
89 N: Copy,
90 {
91 unsafe {
92 if let Some((v, idx)) = self.power_jump_search(target, *cut_down) {
93 *cut_down = idx;
94 Some(v)
95 } else {
96 None
97 }
98 }
99 }
100 #[inline(always)]
101 fn as_val_slice(&self) -> &[N] {
102 unsafe { core::slice::from_raw_parts(self.val_ptr(), self.nnz() as usize) }
103 }
104 #[inline(always)]
105 fn as_ind_slice(&self) -> &[u32] {
106 unsafe { core::slice::from_raw_parts(self.ind_ptr(), self.nnz() as usize) }
107 }
108 #[inline(always)]
109 fn perm(&mut self, perm_idxs: &[u32]) {
110 unsafe {
111 let mut_ind_slice = core::slice::from_raw_parts_mut(self.ind_ptr(), self.nnz() as usize);
112 let mut_val_slice = core::slice::from_raw_parts_mut(self.val_ptr(), self.nnz() as usize);
113 mut_ind_slice.iter_mut().for_each(|x| {
114 *x = perm_idxs[*x as usize];
115 });
116 crate::utils::sort::radix_sort_u32_soa(mut_ind_slice, mut_val_slice);
117 }
118 }
119}
120
121impl<N> TFVectorTrait<N> for TFVector<N>
122where N: Num + Copy
123{
124 fn new() -> Self {
125 Self::low_new()
126 }
127
128 #[inline]
129 fn new_with_capacity(capacity: u32) -> Self {
130 let mut vec = Self::low_new();
131 if capacity != 0 {
132 vec.set_cap(capacity);
133 }
134 vec
135 }
136
137 #[inline]
138 fn shrink_to_fit(&mut self) {
139 if self.nnz < self.cap {
140 self.set_cap(self.nnz);
141 }
142 }
143
144 #[inline(always)]
145 fn raw_iter(&self) -> RawTFVectorIter<'_, N> {
146 RawTFVectorIter {
147 vec: self,
148 pos: 0,
149 end: self.nnz,
150 }
151 }
152
153 #[inline(always)]
154 fn nnz(&self) -> u32 {
155 self.nnz
156 }
157
158 #[inline(always)]
159 fn len(&self) -> u32 {
160 self.len
161 }
162
163 #[inline(always)]
164 fn cap(&self) -> u32 {
165 self.cap
166 }
167
168 #[inline(always)]
169 fn term_sum(&self) -> u32 {
170 self.term_sum
171 }
172
173 #[inline(always)]
174 unsafe fn from_vec(mut ind_vec: Vec<u32>, mut val_vec: Vec<N>, len: u32, term_sum: u32) -> Self {
175 debug_assert_eq!(
176 ind_vec.len(),
177 val_vec.len(),
178 "ind_vec and val_vec must have the same length"
179 );
180
181 crate::utils::sort::radix_sort_u32_soa(&mut ind_vec, &mut val_vec);
183
184 let nnz = ind_vec.len() as u32;
185
186 if nnz == 0 {
187 let mut v = TFVector::low_new();
188 v.len = len;
189 v.term_sum = term_sum;
190 return v;
191 }
192
193 let inds_box: Box<[u32]> = ind_vec.into_boxed_slice();
197 let vals_box: Box<[N]> = val_vec.into_boxed_slice();
198
199 let inds_ptr = Box::into_raw(inds_box) as *mut u32;
200 let vals_ptr = Box::into_raw(vals_box) as *mut N;
201
202 TFVector {
203 inds: unsafe { NonNull::new_unchecked(inds_ptr) },
204 vals: unsafe { NonNull::new_unchecked(vals_ptr) },
205 cap: nnz,
206 nnz,
207 len,
208 term_sum,
209 }
210 }
211
212 #[inline(always)]
213 unsafe fn ind_ptr(&self) -> *mut u32 {
214 self.inds.as_ptr()
215 }
216
217 #[inline(always)]
218 unsafe fn val_ptr(&self) -> *mut N {
219 self.vals.as_ptr()
220 }
221}
222
223
224pub struct RawTFVectorIter<'a, N>
225where
226 N: Num + 'a,
227{
228 vec: &'a TFVector<N>,
229 pos: u32, end: u32, }
232
233impl<'a, N> RawTFVectorIter<'a, N>
234where
235 N: Num + 'a,
236{
237 #[inline]
238 pub fn new(vec: &'a TFVector<N>) -> Self {
239 Self { vec, pos: 0, end: vec.nnz }
240 }
241}
242
243impl<'a, N> Iterator for RawTFVectorIter<'a, N>
244where
245 N: Num + 'a + Copy,
246{
247 type Item = (u32, N);
248
249 #[inline]
250 fn next(&mut self) -> Option<Self::Item> {
251 if self.pos >= self.end {
252 return None;
253 }
254 unsafe {
255 let i = self.pos as usize;
256 self.pos += 1;
257 let ind = *self.vec.inds.as_ptr().add(i);
258 let val = *self.vec.vals.as_ptr().add(i);
259 Some((ind, val))
260 }
261 }
262
263 #[inline]
264 fn size_hint(&self) -> (usize, Option<usize>) {
265 let remaining = (self.end - self.pos) as usize;
266 (remaining, Some(remaining))
267 }
268}
269
270impl<'a, N> DoubleEndedIterator for RawTFVectorIter<'a, N>
271where
272 N: Num + 'a + Copy,
273{
274 #[inline]
275 fn next_back(&mut self) -> Option<Self::Item> {
276 if self.pos >= self.end {
277 return None;
278 }
279 self.end -= 1;
280 unsafe {
281 let i = self.end as usize;
282 let ind = *self.vec.inds.as_ptr().add(i);
283 let val = *self.vec.vals.as_ptr().add(i);
284 Some((ind, val))
285 }
286 }
287}
288
289impl<'a, N> ExactSizeIterator for RawTFVectorIter<'a, N>
290where
291 N: Num + 'a + Copy,
292{
293 #[inline]
294 fn len(&self) -> usize {
295 (self.end - self.pos) as usize
296 }
297}
298
299impl<'a, N> FusedIterator for RawTFVectorIter<'a, N>
300where
301 N: Num + 'a + Copy,
302{}
303
304#[derive(Debug)]
306#[repr(align(32))] pub struct TFVector<N>
308where N: Num
309{
310 inds: NonNull<u32>,
311 vals: NonNull<N>,
312 cap: u32,
313 nnz: u32,
314 len: u32,
315 term_sum: u32, }
320
321impl<N> TFVector<N>
323where N: Num
324{
325 const VAL_SIZE: usize = mem::size_of::<N>();
326
327 #[inline]
328 fn low_new() -> Self {
329 debug_assert!(Self::VAL_SIZE != 0, "Zero-sized type is not supported for TFVector");
331
332 TFVector {
333 inds: NonNull::dangling(),
335 vals: NonNull::dangling(),
336 cap: 0,
337 nnz: 0,
338 len: 0,
339 term_sum: 0,
340 }
341 }
342
343
344 #[inline]
345 #[allow(dead_code)]
346 fn grow(&mut self) {
347 let new_cap = if self.cap == 0 {
348 1
349 } else {
350 self.cap.checked_mul(2).expect("TFVector capacity overflowed")
351 };
352
353 self.set_cap(new_cap);
354 }
355
356 #[inline]
357 fn set_cap(&mut self, new_cap: u32) {
358 if new_cap == 0 {
359 self.free_alloc();
361 return;
362 }
363 let new_inds_layout = Layout::array::<u32>(new_cap as usize).expect("Failed to create inds memory layout");
364 let new_vals_layout = Layout::array::<N>(new_cap as usize).expect("Failed to create vals memory layout");
365
366 if self.cap == 0 {
367 let new_inds_ptr = unsafe { std::alloc::alloc(new_inds_layout) };
368 let new_vals_ptr = unsafe { std::alloc::alloc(new_vals_layout) };
369 if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
370 if new_inds_ptr.is_null() {
371 oom(new_inds_layout);
372 } else {
373 oom(new_vals_layout);
374 }
375 }
376
377 self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
378 self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
379 self.cap = new_cap;
380 } else {
381 let old_inds_layout = Layout::array::<u32>(self.cap as usize).expect("Failed to create old inds memory layout");
382 let old_vals_layout = Layout::array::<N>(self.cap as usize).expect("Failed to create old vals memory layout");
383
384 let new_inds_ptr = unsafe { std::alloc::realloc(
385 self.inds.as_ptr().cast::<u8>(),
386 old_inds_layout,
387 new_inds_layout.size(),
388 ) };
389 let new_vals_ptr = unsafe { std::alloc::realloc(
390 self.vals.as_ptr().cast::<u8>(),
391 old_vals_layout,
392 new_vals_layout.size(),
393 ) };
394 if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
395 if new_inds_ptr.is_null() {
396 oom(new_inds_layout);
397 } else {
398 oom(new_vals_layout);
399 }
400 }
401
402 self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
403 self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
404 self.cap = new_cap;
405 }
406 }
407
408 #[inline]
409 fn free_alloc(&mut self) {
410 if self.cap != 0 {
411 unsafe {
412 let inds_layout = Layout::array::<u32>(self.cap as usize).unwrap();
413 let vals_layout = Layout::array::<N>(self.cap as usize).unwrap();
414 std::alloc::dealloc(self.inds.as_ptr().cast::<u8>(), inds_layout);
415 std::alloc::dealloc(self.vals.as_ptr().cast::<u8>(), vals_layout);
416 }
417 }
418 self.inds = NonNull::dangling();
419 self.vals = NonNull::dangling();
420 self.cap = 0;
421 }
422}
423
424unsafe impl<N: Num + Send + Sync> Send for TFVector<N> {}
425unsafe impl<N: Num + Sync> Sync for TFVector<N> {}
426
427impl<N> Drop for TFVector<N>
428where N: Num
429{
430 #[inline]
431 fn drop(&mut self) {
432 self.free_alloc();
433 }
434}
435
436impl<N> Clone for TFVector<N>
437where
438 N: Num + Copy,
439{
440 #[inline]
441 fn clone(&self) -> Self {
442 let mut new_vec = TFVector::low_new();
443 if self.nnz > 0 {
444 new_vec.set_cap(self.nnz);
445 new_vec.len = self.len;
446 new_vec.nnz = self.nnz;
447 new_vec.term_sum = self.term_sum;
448
449 unsafe {
450 std::ptr::copy_nonoverlapping(
451 self.inds.as_ptr(),
452 new_vec.inds.as_ptr(),
453 self.nnz as usize,
454 );
455 std::ptr::copy_nonoverlapping(
456 self.vals.as_ptr(),
457 new_vec.vals.as_ptr(),
458 self.nnz as usize,
459 );
460 }
461 }
462 new_vec
463 }
464}
465
466
467
468#[cold]
475#[inline(never)]
476fn oom(layout: Layout) -> ! {
477 std::alloc::handle_alloc_error(layout)
478}