1use crate::umap::smooth_knn_dist::SmoothKnnDist;
2use crate::utils::parallel_vec::ParallelVec;
3use dashmap::DashSet;
4use ndarray::Array1;
5use ndarray::ArrayView1;
6use ndarray::ArrayView2;
7use rayon::prelude::*;
8use sprs::CsMatI;
9use std::sync::atomic::AtomicU32;
10use std::sync::atomic::Ordering;
11use std::time::Instant;
12use tracing::info;
13use typed_builder::TypedBuilder;
14
15pub type SparseMat = CsMatI<f32, u32, usize>;
19
20struct CscStructure {
23 indptr: Vec<usize>, indices: Vec<u32>, }
26
27impl CscStructure {
28 fn col_row_indices(&self, col: usize) -> &[u32] {
29 let start = self.indptr[col];
30 let end = self.indptr[col + 1];
31 &self.indices[start..end]
32 }
33}
34
35#[derive(TypedBuilder, Debug)]
95pub struct FuzzySimplicialSet<'a, 'd> {
96 n_samples: usize,
97 n_neighbors: usize,
98 knn_indices: ArrayView2<'a, u32>,
99 knn_dists: ArrayView2<'a, f32>,
100 knn_disconnections: &'d DashSet<(usize, usize)>,
101 #[builder(default = 1.0)]
102 set_op_mix_ratio: f32,
103 #[builder(default = 1.0)]
104 local_connectivity: f32,
105 #[builder(default = true)]
106 apply_set_operations: bool,
107}
108
109impl<'a, 'd> FuzzySimplicialSet<'a, 'd> {
110 pub fn exec(self) -> (SparseMat, Array1<f32>, Array1<f32>) {
111 assert!(
112 self.n_samples < u32::MAX as usize,
113 "n_samples must be < 2^32 for u32 indices"
114 );
115
116 let knn_dists = self.knn_dists;
118 let knn_indices = self.knn_indices;
119 let knn_disconnections = self.knn_disconnections;
120 let n_neighbors = self.n_neighbors;
121 let n_samples = self.n_samples;
122 let local_connectivity = self.local_connectivity;
123 let set_op_mix_ratio = self.set_op_mix_ratio;
124 let apply_set_operations = self.apply_set_operations;
125
126 let started = Instant::now();
127 let (sigmas, rhos) = SmoothKnnDist::builder()
128 .distances(knn_dists)
129 .k(n_neighbors)
130 .local_connectivity(local_connectivity)
131 .build()
132 .exec();
133 info!(
134 duration_ms = started.elapsed().as_millis(),
135 "smooth_knn_dist complete"
136 );
137
138 let started = Instant::now();
141 let mut result = build_membership_csr(
142 n_samples,
143 n_neighbors,
144 knn_indices,
145 knn_dists,
146 knn_disconnections,
147 &sigmas.view(),
148 &rhos.view(),
149 );
150 info!(
151 duration_ms = started.elapsed().as_millis(),
152 nnz = result.nnz(),
153 "build_membership_csr complete"
154 );
155
156 if apply_set_operations {
157 let started = Instant::now();
158 result = apply_set_operations_parallel(&result, set_op_mix_ratio);
159 info!(
160 duration_ms = started.elapsed().as_millis(),
161 "set_operations complete"
162 );
163 }
164
165 (result, sigmas, rhos)
166 }
167}
168
169fn build_membership_csr(
172 n_samples: usize,
173 n_neighbors: usize,
174 knn_indices: ArrayView2<u32>,
175 knn_dists: ArrayView2<f32>,
176 knn_disconnections: &DashSet<(usize, usize)>,
177 sigmas: &ArrayView1<f32>,
178 rhos: &ArrayView1<f32>,
179) -> SparseMat {
180 let started = Instant::now();
182 let row_counts: Vec<u32> = (0..n_samples)
183 .into_par_iter()
184 .map(|i| {
185 let mut count = 0u32;
186 for j in 0..n_neighbors {
187 if knn_disconnections.contains(&(i, j)) {
188 continue;
189 }
190 let knn_idx = knn_indices[(i, j)] as usize;
191 if knn_idx == i || knn_idx >= n_samples {
193 continue;
194 }
195 let val = compute_membership_strength(i, j, knn_dists, rhos, sigmas);
196 if val != 0.0 {
197 count += 1;
198 }
199 }
200 count
201 })
202 .collect();
203 info!(
204 duration_ms = started.elapsed().as_millis(),
205 "csr row_counts complete"
206 );
207
208 let started = Instant::now();
210 let mut indptr: Vec<usize> = Vec::with_capacity(n_samples + 1);
211 indptr.push(0);
212 let mut total = 0usize;
213 for &count in &row_counts {
214 total += count as usize;
215 indptr.push(total);
216 }
217 let nnz = total;
218 info!(
219 duration_ms = started.elapsed().as_millis(),
220 nnz, "csr indptr complete"
221 );
222
223 let indices_vec = ParallelVec::new(vec![0u32; nnz]);
226 let data_vec = ParallelVec::new(vec![0.0f32; nnz]);
227
228 let started = Instant::now();
232 (0..n_samples).into_par_iter().for_each(|i| {
233 let row_start = indptr[i];
234 let mut offset = 0;
235
236 for j in 0..n_neighbors {
237 if knn_disconnections.contains(&(i, j)) {
238 continue;
239 }
240 let knn_idx = knn_indices[(i, j)];
241 if knn_idx as usize == i || knn_idx as usize >= n_samples {
243 continue;
244 }
245 let val = compute_membership_strength(i, j, knn_dists, rhos, sigmas);
246 if val != 0.0 {
247 unsafe {
249 indices_vec.write(row_start + offset, knn_idx);
250 data_vec.write(row_start + offset, val);
251 }
252 offset += 1;
253 }
254 }
255 });
256 info!(
257 duration_ms = started.elapsed().as_millis(),
258 "csr fill complete"
259 );
260
261 let started = Instant::now();
264 (0..n_samples).into_par_iter().for_each(|i| {
265 let row_start = indptr[i];
266 let row_len = indptr[i + 1] - indptr[i];
267 if row_len > 0 {
268 let row_indices = unsafe { indices_vec.get_mut_slice(row_start, row_len) };
270 let row_data = unsafe { data_vec.get_mut_slice(row_start, row_len) };
271
272 for k in 1..row_len {
274 let mut m = k;
275 while m > 0 && row_indices[m - 1] > row_indices[m] {
276 row_indices.swap(m - 1, m);
277 row_data.swap(m - 1, m);
278 m -= 1;
279 }
280 }
281 }
282 });
283 info!(
284 duration_ms = started.elapsed().as_millis(),
285 "csr row_sort complete"
286 );
287
288 let indices = indices_vec.into_inner();
290 let data = data_vec.into_inner();
291 CsMatI::new((n_samples, n_samples), indptr, indices, data)
292}
293
294fn compute_membership_strength(
295 i: usize,
296 j: usize,
297 knn_dists: ArrayView2<f32>,
298 rhos: &ArrayView1<f32>,
299 sigmas: &ArrayView1<f32>,
300) -> f32 {
301 if knn_dists[(i, j)] - rhos[i] <= 0.0 || sigmas[i] == 0.0 {
302 1.0
303 } else {
304 f32::exp(-(knn_dists[(i, j)] - rhos[i]) / sigmas[i])
305 }
306}
307
308fn build_csc_structure(csr: &SparseMat) -> CscStructure {
311 let n_rows = csr.shape().0;
312 let n_cols = csr.shape().1;
313 let nnz = csr.nnz();
314
315 let started = Instant::now();
317 let col_counts: Vec<AtomicU32> = (0..n_cols).map(|_| AtomicU32::new(0)).collect();
318
319 (0..n_rows).into_par_iter().for_each(|row| {
320 let row_start = csr.indptr().index(row) as usize;
321 let row_end = csr.indptr().index(row + 1) as usize;
322 for &col in &csr.indices()[row_start..row_end] {
323 col_counts[col as usize].fetch_add(1, Ordering::Relaxed);
324 }
325 });
326 info!(
327 duration_ms = started.elapsed().as_millis(),
328 "csc col_counts complete"
329 );
330
331 let started = Instant::now();
333 let mut indptr: Vec<usize> = Vec::with_capacity(n_cols + 1);
334 indptr.push(0);
335 let mut total = 0usize;
336 for count in &col_counts {
337 total += count.load(Ordering::Relaxed) as usize;
338 indptr.push(total);
339 }
340 assert_eq!(total, nnz);
341 info!(
342 duration_ms = started.elapsed().as_millis(),
343 "csc indptr complete"
344 );
345
346 let started = Instant::now();
349 let mut indices: Vec<u32> = vec![0; nnz];
350 let mut col_offsets: Vec<usize> = vec![0; n_cols];
351
352 for row in 0..n_rows {
353 let row_start = csr.indptr().index(row);
354 let row_end = csr.indptr().index(row + 1);
355 let row_indices = &csr.indices()[row_start..row_end];
356
357 for &col in row_indices {
358 let write_pos = indptr[col as usize] + col_offsets[col as usize];
359 indices[write_pos] = row as u32;
360 col_offsets[col as usize] += 1;
361 }
362 }
363 info!(
364 duration_ms = started.elapsed().as_millis(),
365 "csc fill complete"
366 );
367 CscStructure { indptr, indices }
370}
371
372fn csr_get(csr: &SparseMat, row: usize, col: u32) -> f32 {
374 let row_start = csr.indptr().index(row);
375 let row_end = csr.indptr().index(row + 1);
376 let row_indices = &csr.indices()[row_start..row_end];
377 let row_data = &csr.data()[row_start..row_end];
378
379 match row_indices.binary_search(&col) {
380 Ok(idx) => row_data[idx],
381 Err(_) => 0.0,
382 }
383}
384
385fn apply_set_operations_parallel(input: &SparseMat, set_op_mix_ratio: f32) -> SparseMat {
393 let n_samples = input.shape().0;
394 let prod_coeff = 1.0 - 2.0 * set_op_mix_ratio;
395
396 let started = Instant::now();
399 let csc = build_csc_structure(input);
400 info!(
401 duration_ms = started.elapsed().as_millis(),
402 "set_operations csc_structure complete"
403 );
404
405 let started = Instant::now();
410 let row_counts: Vec<u32> = (0..n_samples)
411 .into_par_iter()
412 .map(|row| {
413 let row_start = input.indptr().index(row);
415 let row_end = input.indptr().index(row + 1);
416 let row_indices = &input.indices()[row_start..row_end];
417 let row_data = &input.data()[row_start..row_end];
418
419 let mut count = 0u32;
420 for (&col, &val_rc) in row_indices.iter().zip(row_data) {
421 let val_cr = csr_get(input, col as usize, row as u32);
422 let final_val =
423 set_op_mix_ratio * val_rc + set_op_mix_ratio * val_cr + prod_coeff * val_rc * val_cr;
424 if final_val != 0.0 {
425 count += 1;
426 }
427 }
428
429 for &c in csc.col_row_indices(row) {
432 if csr_get(input, row, c) != 0.0 {
434 continue;
435 }
436 let val_cr = csr_get(input, c as usize, row as u32);
438 let final_val = set_op_mix_ratio * val_cr; if final_val != 0.0 {
440 count += 1;
441 }
442 }
443
444 count
445 })
446 .collect();
447 info!(
448 duration_ms = started.elapsed().as_millis(),
449 "set_operations row_counts complete"
450 );
451
452 let started = Instant::now();
454 let mut indptr: Vec<usize> = Vec::with_capacity(n_samples + 1);
455 indptr.push(0);
456 let mut total = 0usize;
457 for &count in &row_counts {
458 total += count as usize;
459 indptr.push(total);
460 }
461 let nnz = total;
462 info!(
463 duration_ms = started.elapsed().as_millis(),
464 nnz, "set_operations indptr complete"
465 );
466
467 let indices_vec = ParallelVec::new(vec![0u32; nnz]);
472 let data_vec = ParallelVec::new(vec![0.0f32; nnz]);
473
474 let started = Instant::now();
475 (0..n_samples).into_par_iter().for_each(|row| {
476 let out_start = indptr[row];
477 let mut offset = 0;
478
479 let row_start = input.indptr().index(row);
481 let row_end = input.indptr().index(row + 1);
482 let row_indices = &input.indices()[row_start..row_end];
483 let row_data = &input.data()[row_start..row_end];
484
485 for (&col, &val_rc) in row_indices.iter().zip(row_data) {
486 let val_cr = csr_get(input, col as usize, row as u32);
487 let final_val =
488 set_op_mix_ratio * val_rc + set_op_mix_ratio * val_cr + prod_coeff * val_rc * val_cr;
489 if final_val != 0.0 {
490 unsafe {
492 indices_vec.write(out_start + offset, col);
493 data_vec.write(out_start + offset, final_val);
494 }
495 offset += 1;
496 }
497 }
498
499 for &c in csc.col_row_indices(row) {
501 if csr_get(input, row, c) != 0.0 {
503 continue;
504 }
505 let val_cr = csr_get(input, c as usize, row as u32);
506 let final_val = set_op_mix_ratio * val_cr;
507 if final_val != 0.0 {
508 unsafe {
510 indices_vec.write(out_start + offset, c);
511 data_vec.write(out_start + offset, final_val);
512 }
513 offset += 1;
514 }
515 }
516 });
517 info!(
518 duration_ms = started.elapsed().as_millis(),
519 "set_operations fill complete"
520 );
521
522 let started = Instant::now();
524 (0..n_samples).into_par_iter().for_each(|row| {
525 let row_start = indptr[row];
526 let row_len = indptr[row + 1] - indptr[row];
527 if row_len > 1 {
528 let row_indices = unsafe { indices_vec.get_mut_slice(row_start, row_len) };
530 let row_data = unsafe { data_vec.get_mut_slice(row_start, row_len) };
531
532 for k in 1..row_len {
534 let mut m = k;
535 while m > 0 && row_indices[m - 1] > row_indices[m] {
536 row_indices.swap(m - 1, m);
537 row_data.swap(m - 1, m);
538 m -= 1;
539 }
540 }
541 }
542 });
543 info!(
544 duration_ms = started.elapsed().as_millis(),
545 "set_operations row_sort complete"
546 );
547
548 let indices = indices_vec.into_inner();
550 let data = data_vec.into_inner();
551 CsMatI::new((n_samples, n_samples), indptr, indices, data)
552}