tf_idf_vectorizer/utils/datastruct/vector/
tf.rs1use std::{alloc::Layout, iter::FusedIterator, mem, ptr::NonNull};
2
3use num_traits::Num;
4
5#[allow(dead_code)]
6const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<u8>>();
7static_assertions::const_assert!(TF_VECTOR_SIZE == 32);
8
9pub trait TFVectorTrait<N>
10where N: Num + Copy
11{
12 fn len(&self) -> u32;
13 fn nnz(&self) -> u32;
14 fn term_sum(&self) -> u32;
15 fn new() -> Self;
16 fn new_with_capacity(capacity: u32) -> Self;
17 fn shrink_to_fit(&mut self);
18 fn raw_iter(&self) -> RawTFVectorIter<'_, N>;
19 unsafe fn from_vec(ind_vec: Vec<u32>, val_vec: Vec<N>, len: u32, term_sum: u32) -> Self;
20 unsafe fn ind_ptr(&self) -> *mut u32;
21 unsafe fn val_ptr(&self) -> *mut N;
22 #[inline(always)]
25 unsafe fn power_jump_search(&self, target: u32, start: usize) -> Option<(N, usize)>
26 where
27 N: Copy,
28 {
29 let nnz = self.nnz() as usize;
30 if start >= nnz {
31 return None;
32 }
33
34 let ind = unsafe { core::slice::from_raw_parts(self.ind_ptr(), nnz) };
35 let val = unsafe { core::slice::from_raw_parts(self.val_ptr(), nnz) };
36
37 let mut lo = start;
39 let mut hi = start;
40
41 let s = ind[hi];
42 if s == target {
43 return Some((val[hi], hi));
44 }
45 if s > target {
46 return None; }
48
49 let mut step = 1usize;
51 loop {
52 let next_hi = hi + step;
53 if next_hi >= nnz {
54 hi = nnz - 1;
55 break;
56 }
57 hi = next_hi;
58
59 if ind[hi] >= target {
60 break;
61 }
62
63 lo = hi;
64 step <<= 1;
65 }
66
67 let mut l = lo + 1;
69 let mut r = hi + 1; while l < r {
71 let m = (l + r) >> 1;
72 if ind[m] < target {
73 l = m + 1;
74 } else {
75 r = m;
76 }
77 }
78
79 if l < nnz && ind[l] == target {
80 Some((val[l], l))
81 } else {
82 None
83 }
84 }
85 #[inline(always)]
86 fn get_power_jump(&self, target: u32, cut_down: &mut usize) -> Option<N>
87 where
88 N: Copy,
89 {
90 unsafe {
91 if let Some((v, idx)) = self.power_jump_search(target, *cut_down) {
92 *cut_down = idx;
93 Some(v)
94 } else {
95 None
96 }
97 }
98 }
99 #[inline(always)]
100 fn as_val_slice(&self) -> &[N] {
101 unsafe { core::slice::from_raw_parts(self.val_ptr(), self.nnz() as usize) }
102 }
103 #[inline(always)]
104 fn as_ind_slice(&self) -> &[u32] {
105 unsafe { core::slice::from_raw_parts(self.ind_ptr(), self.nnz() as usize) }
106 }
107}
108
109impl<N> TFVectorTrait<N> for TFVector<N>
110where N: Num + Copy
111{
112 fn new() -> Self {
113 Self::low_new()
114 }
115
116 #[inline]
117 fn new_with_capacity(capacity: u32) -> Self {
118 let mut vec = Self::low_new();
119 if capacity != 0 {
120 vec.set_cap(capacity);
121 }
122 vec
123 }
124
125 #[inline]
126 fn shrink_to_fit(&mut self) {
127 if self.nnz < self.cap {
128 self.set_cap(self.nnz);
129 }
130 }
131
132 #[inline(always)]
133 fn raw_iter(&self) -> RawTFVectorIter<'_, N> {
134 RawTFVectorIter {
135 vec: self,
136 pos: 0,
137 end: self.nnz,
138 }
139 }
140
141 #[inline(always)]
142 fn nnz(&self) -> u32 {
143 self.nnz
144 }
145
146 #[inline(always)]
147 fn len(&self) -> u32 {
148 self.len
149 }
150
151 #[inline(always)]
152 fn term_sum(&self) -> u32 {
153 self.term_sum
154 }
155
156 #[inline(always)]
157 unsafe fn from_vec(mut ind_vec: Vec<u32>, mut val_vec: Vec<N>, len: u32, term_sum: u32) -> Self {
158 debug_assert_eq!(
159 ind_vec.len(),
160 val_vec.len(),
161 "ind_vec and val_vec must have the same length"
162 );
163
164 crate::utils::sort::radix_sort_u32_soa(&mut ind_vec, &mut val_vec);
166
167 let nnz = ind_vec.len() as u32;
168
169 if nnz == 0 {
170 let mut v = TFVector::low_new();
171 v.len = len;
172 v.term_sum = term_sum;
173 return v;
174 }
175
176 let inds_box: Box<[u32]> = ind_vec.into_boxed_slice();
180 let vals_box: Box<[N]> = val_vec.into_boxed_slice();
181
182 let inds_ptr = Box::into_raw(inds_box) as *mut u32;
183 let vals_ptr = Box::into_raw(vals_box) as *mut N;
184
185 TFVector {
186 inds: unsafe { NonNull::new_unchecked(inds_ptr) },
187 vals: unsafe { NonNull::new_unchecked(vals_ptr) },
188 cap: nnz,
189 nnz,
190 len,
191 term_sum,
192 }
193 }
194
195 #[inline(always)]
196 unsafe fn ind_ptr(&self) -> *mut u32 {
197 self.inds.as_ptr()
198 }
199
200 #[inline(always)]
201 unsafe fn val_ptr(&self) -> *mut N {
202 self.vals.as_ptr()
203 }
204}
205
206
207pub struct RawTFVectorIter<'a, N>
208where
209 N: Num + 'a,
210{
211 vec: &'a TFVector<N>,
212 pos: u32, end: u32, }
215
216impl<'a, N> RawTFVectorIter<'a, N>
217where
218 N: Num + 'a,
219{
220 #[inline]
221 pub fn new(vec: &'a TFVector<N>) -> Self {
222 Self { vec, pos: 0, end: vec.nnz }
223 }
224}
225
226impl<'a, N> Iterator for RawTFVectorIter<'a, N>
227where
228 N: Num + 'a + Copy,
229{
230 type Item = (u32, N);
231
232 #[inline]
233 fn next(&mut self) -> Option<Self::Item> {
234 if self.pos >= self.end {
235 return None;
236 }
237 unsafe {
238 let i = self.pos as usize;
239 self.pos += 1;
240 let ind = *self.vec.inds.as_ptr().add(i);
241 let val = *self.vec.vals.as_ptr().add(i);
242 Some((ind, val))
243 }
244 }
245
246 #[inline]
247 fn size_hint(&self) -> (usize, Option<usize>) {
248 let remaining = (self.end - self.pos) as usize;
249 (remaining, Some(remaining))
250 }
251}
252
253impl<'a, N> DoubleEndedIterator for RawTFVectorIter<'a, N>
254where
255 N: Num + 'a + Copy,
256{
257 #[inline]
258 fn next_back(&mut self) -> Option<Self::Item> {
259 if self.pos >= self.end {
260 return None;
261 }
262 self.end -= 1;
263 unsafe {
264 let i = self.end as usize;
265 let ind = *self.vec.inds.as_ptr().add(i);
266 let val = *self.vec.vals.as_ptr().add(i);
267 Some((ind, val))
268 }
269 }
270}
271
272impl<'a, N> ExactSizeIterator for RawTFVectorIter<'a, N>
273where
274 N: Num + 'a + Copy,
275{
276 #[inline]
277 fn len(&self) -> usize {
278 (self.end - self.pos) as usize
279 }
280}
281
282impl<'a, N> FusedIterator for RawTFVectorIter<'a, N>
283where
284 N: Num + 'a + Copy,
285{}
286
287#[derive(Debug)]
289#[repr(align(32))] pub struct TFVector<N>
291where N: Num
292{
293 inds: NonNull<u32>,
294 vals: NonNull<N>,
295 cap: u32,
296 nnz: u32,
297 len: u32,
298 term_sum: u32, }
303
304impl<N> TFVector<N>
306where N: Num
307{
308 const VAL_SIZE: usize = mem::size_of::<N>();
309
310 #[inline]
311 fn low_new() -> Self {
312 debug_assert!(Self::VAL_SIZE != 0, "Zero-sized type is not supported for TFVector");
314
315 TFVector {
316 inds: NonNull::dangling(),
318 vals: NonNull::dangling(),
319 cap: 0,
320 nnz: 0,
321 len: 0,
322 term_sum: 0,
323 }
324 }
325
326
327 #[inline]
328 #[allow(dead_code)]
329 fn grow(&mut self) {
330 let new_cap = if self.cap == 0 {
331 1
332 } else {
333 self.cap.checked_mul(2).expect("TFVector capacity overflowed")
334 };
335
336 self.set_cap(new_cap);
337 }
338
339 #[inline]
340 fn set_cap(&mut self, new_cap: u32) {
341 if new_cap == 0 {
342 self.free_alloc();
344 return;
345 }
346 let new_inds_layout = Layout::array::<u32>(new_cap as usize).expect("Failed to create inds memory layout");
347 let new_vals_layout = Layout::array::<N>(new_cap as usize).expect("Failed to create vals memory layout");
348
349 if self.cap == 0 {
350 let new_inds_ptr = unsafe { std::alloc::alloc(new_inds_layout) };
351 let new_vals_ptr = unsafe { std::alloc::alloc(new_vals_layout) };
352 if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
353 if new_inds_ptr.is_null() {
354 oom(new_inds_layout);
355 } else {
356 oom(new_vals_layout);
357 }
358 }
359
360 self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
361 self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
362 self.cap = new_cap;
363 } else {
364 let old_inds_layout = Layout::array::<u32>(self.cap as usize).expect("Failed to create old inds memory layout");
365 let old_vals_layout = Layout::array::<N>(self.cap as usize).expect("Failed to create old vals memory layout");
366
367 let new_inds_ptr = unsafe { std::alloc::realloc(
368 self.inds.as_ptr().cast::<u8>(),
369 old_inds_layout,
370 new_inds_layout.size(),
371 ) };
372 let new_vals_ptr = unsafe { std::alloc::realloc(
373 self.vals.as_ptr().cast::<u8>(),
374 old_vals_layout,
375 new_vals_layout.size(),
376 ) };
377 if new_inds_ptr.is_null() || new_vals_ptr.is_null() {
378 if new_inds_ptr.is_null() {
379 oom(new_inds_layout);
380 } else {
381 oom(new_vals_layout);
382 }
383 }
384
385 self.inds = unsafe { NonNull::new_unchecked(new_inds_ptr as *mut u32) };
386 self.vals = unsafe { NonNull::new_unchecked(new_vals_ptr as *mut N) };
387 self.cap = new_cap;
388 }
389 }
390
391 #[inline]
392 fn free_alloc(&mut self) {
393 if self.cap != 0 {
394 unsafe {
395 let inds_layout = Layout::array::<u32>(self.cap as usize).unwrap();
396 let vals_layout = Layout::array::<N>(self.cap as usize).unwrap();
397 std::alloc::dealloc(self.inds.as_ptr().cast::<u8>(), inds_layout);
398 std::alloc::dealloc(self.vals.as_ptr().cast::<u8>(), vals_layout);
399 }
400 }
401 self.inds = NonNull::dangling();
402 self.vals = NonNull::dangling();
403 self.cap = 0;
404 }
405}
406
407unsafe impl<N: Num + Send + Sync> Send for TFVector<N> {}
408unsafe impl<N: Num + Sync> Sync for TFVector<N> {}
409
410impl<N> Drop for TFVector<N>
411where N: Num
412{
413 #[inline]
414 fn drop(&mut self) {
415 self.free_alloc();
416 }
417}
418
419impl<N> Clone for TFVector<N>
420where
421 N: Num + Copy,
422{
423 #[inline]
424 fn clone(&self) -> Self {
425 let mut new_vec = TFVector::low_new();
426 if self.nnz > 0 {
427 new_vec.set_cap(self.nnz);
428 new_vec.len = self.len;
429 new_vec.nnz = self.nnz;
430 new_vec.term_sum = self.term_sum;
431
432 unsafe {
433 std::ptr::copy_nonoverlapping(
434 self.inds.as_ptr(),
435 new_vec.inds.as_ptr(),
436 self.nnz as usize,
437 );
438 std::ptr::copy_nonoverlapping(
439 self.vals.as_ptr(),
440 new_vec.vals.as_ptr(),
441 self.nnz as usize,
442 );
443 }
444 }
445 new_vec
446 }
447}
448
449
450
451#[cold]
458#[inline(never)]
459fn oom(layout: Layout) -> ! {
460 std::alloc::handle_alloc_error(layout)
461}