umap_rs/umap/
fuzzy_simplicial_set.rs

1use crate::umap::smooth_knn_dist::SmoothKnnDist;
2use crate::utils::parallel_vec::ParallelVec;
3use dashmap::DashSet;
4use ndarray::Array1;
5use ndarray::ArrayView1;
6use ndarray::ArrayView2;
7use rayon::prelude::*;
8use sprs::CsMatI;
9use std::sync::atomic::AtomicU32;
10use std::sync::atomic::Ordering;
11use std::time::Instant;
12use tracing::info;
13use typed_builder::TypedBuilder;
14
15/// Sparse matrix with u32 column indices to save memory (4 bytes vs 8 bytes per index).
16/// Uses usize for indptr since nnz can exceed u32::MAX for very large datasets.
17/// Valid for n_samples < 2^32 (~4 billion).
18pub type SparseMat = CsMatI<f32, u32, usize>;
19
20/// CSC structure without data - only stores indptr and indices for transpose traversal.
21/// Values are looked up in the original CSR via binary search O(log k).
22struct CscStructure {
23  indptr: Vec<usize>, // Column pointers (usize since nnz can exceed u32::MAX)
24  indices: Vec<u32>,  // Row indices (which rows have entries in each column)
25}
26
27impl CscStructure {
28  fn col_row_indices(&self, col: usize) -> &[u32] {
29    let start = self.indptr[col];
30    let end = self.indptr[col + 1];
31    &self.indices[start..end]
32  }
33}
34
35/*
36  Given a set of data X, a neighborhood size, and a measure of distance
37  compute the fuzzy simplicial set (here represented as a fuzzy graph in
38  the form of a sparse matrix) associated to the data. This is done by
39  locally approximating geodesic distance at each point, creating a fuzzy
40  simplicial set for each such point, and then combining all the local
41  fuzzy simplicial sets into a global one via a fuzzy union.
42
43  Parameters
44  ----------
45  X: array of shape (n_samples, n_features)
46      The data to be modelled as a fuzzy simplicial set.
47
48  n_neighbors: int
49      The number of neighbors to use to approximate geodesic distance.
50      Larger numbers induce more global estimates of the manifold that can
51      miss finer detail, while smaller values will focus on fine manifold
52      structure to the detriment of the larger picture.
53
54  knn_indices: array of shape (n_samples, n_neighbors) (optional)
55      If the k-nearest neighbors of each point has already been calculated
56      you can pass them in here to save computation time. This should be
57      an array with the indices of the k-nearest neighbors as a row for
58      each data point.
59
60  knn_dists: array of shape (n_samples, n_neighbors) (optional)
61      If the k-nearest neighbors of each point has already been calculated
62      you can pass them in here to save computation time. This should be
63      an array with the distances of the k-nearest neighbors as a row for
64      each data point.
65
66  set_op_mix_ratio: float (optional, default 1.0)
67      Interpolate between (fuzzy) union and intersection as the set operation
68      used to combine local fuzzy simplicial sets to obtain a global fuzzy
69      simplicial sets. Both fuzzy set operations use the product t-norm.
70      The value of this parameter should be between 0.0 and 1.0; a value of
71      1.0 will use a pure fuzzy union, while 0.0 will use a pure fuzzy
72      intersection.
73
74  local_connectivity: int (optional, default 1)
75      The local connectivity required -- i.e. the number of nearest
76      neighbors that should be assumed to be connected at a local level.
77      The higher this value the more connected the manifold becomes
78      locally. In practice this should be not more than the local intrinsic
79      dimension of the manifold.
80
81  verbose: bool (optional, default False)
82      Whether to report information on the current progress of the algorithm.
83
84  return_dists: bool or None (optional, default None)
85      Whether to return the pairwise distance associated with each edge.
86
87  Returns
88  -------
89  fuzzy_simplicial_set: coo_matrix
90      A fuzzy simplicial set represented as a sparse matrix. The (i,
91      j) entry of the matrix represents the membership strength of the
92      1-simplex between the ith and jth sample points.
93*/
94#[derive(TypedBuilder, Debug)]
95pub struct FuzzySimplicialSet<'a, 'd> {
96  n_samples: usize,
97  n_neighbors: usize,
98  knn_indices: ArrayView2<'a, u32>,
99  knn_dists: ArrayView2<'a, f32>,
100  knn_disconnections: &'d DashSet<(usize, usize)>,
101  #[builder(default = 1.0)]
102  set_op_mix_ratio: f32,
103  #[builder(default = 1.0)]
104  local_connectivity: f32,
105  #[builder(default = true)]
106  apply_set_operations: bool,
107}
108
109impl<'a, 'd> FuzzySimplicialSet<'a, 'd> {
110  pub fn exec(self) -> (SparseMat, Array1<f32>, Array1<f32>) {
111    assert!(
112      self.n_samples < u32::MAX as usize,
113      "n_samples must be < 2^32 for u32 indices"
114    );
115
116    // Extract the fields we need
117    let knn_dists = self.knn_dists;
118    let knn_indices = self.knn_indices;
119    let knn_disconnections = self.knn_disconnections;
120    let n_neighbors = self.n_neighbors;
121    let n_samples = self.n_samples;
122    let local_connectivity = self.local_connectivity;
123    let set_op_mix_ratio = self.set_op_mix_ratio;
124    let apply_set_operations = self.apply_set_operations;
125
126    let started = Instant::now();
127    let (sigmas, rhos) = SmoothKnnDist::builder()
128      .distances(knn_dists)
129      .k(n_neighbors)
130      .local_connectivity(local_connectivity)
131      .build()
132      .exec();
133    info!(
134      duration_ms = started.elapsed().as_millis(),
135      "smooth_knn_dist complete"
136    );
137
138    // Build CSR directly - no intermediate allocations
139    // Uses u32 indices to halve index memory
140    let started = Instant::now();
141    let mut result = build_membership_csr(
142      n_samples,
143      n_neighbors,
144      knn_indices,
145      knn_dists,
146      knn_disconnections,
147      &sigmas.view(),
148      &rhos.view(),
149    );
150    info!(
151      duration_ms = started.elapsed().as_millis(),
152      nnz = result.nnz(),
153      "build_membership_csr complete"
154    );
155
156    if apply_set_operations {
157      let started = Instant::now();
158      result = apply_set_operations_parallel(&result, set_op_mix_ratio);
159      info!(
160        duration_ms = started.elapsed().as_millis(),
161        "set_operations complete"
162      );
163    }
164
165    (result, sigmas, rhos)
166  }
167}
168
169/// Build CSR matrix directly from KNN data without intermediate allocations.
170/// Uses u32 indices to halve index memory (4 bytes vs 8 bytes per index).
171fn build_membership_csr(
172  n_samples: usize,
173  n_neighbors: usize,
174  knn_indices: ArrayView2<u32>,
175  knn_dists: ArrayView2<f32>,
176  knn_disconnections: &DashSet<(usize, usize)>,
177  sigmas: &ArrayView1<f32>,
178  rhos: &ArrayView1<f32>,
179) -> SparseMat {
180  // Step 1: Count valid (non-zero) entries per row in parallel
181  let started = Instant::now();
182  let row_counts: Vec<u32> = (0..n_samples)
183    .into_par_iter()
184    .map(|i| {
185      let mut count = 0u32;
186      for j in 0..n_neighbors {
187        if knn_disconnections.contains(&(i, j)) {
188          continue;
189        }
190        let knn_idx = knn_indices[(i, j)] as usize;
191        // Skip self-loops and sentinel values (e.g. u32::MAX used when KNN couldn't find k neighbors)
192        if knn_idx == i || knn_idx >= n_samples {
193          continue;
194        }
195        let val = compute_membership_strength(i, j, knn_dists, rhos, sigmas);
196        if val != 0.0 {
197          count += 1;
198        }
199      }
200      count
201    })
202    .collect();
203  info!(
204    duration_ms = started.elapsed().as_millis(),
205    "csr row_counts complete"
206  );
207
208  // Step 2: Build indptr from prefix sum (usize since nnz can exceed u32::MAX)
209  let started = Instant::now();
210  let mut indptr: Vec<usize> = Vec::with_capacity(n_samples + 1);
211  indptr.push(0);
212  let mut total = 0usize;
213  for &count in &row_counts {
214    total += count as usize;
215    indptr.push(total);
216  }
217  let nnz = total;
218  info!(
219    duration_ms = started.elapsed().as_millis(),
220    nnz, "csr indptr complete"
221  );
222
223  // Step 3: Pre-allocate indices and data, wrap in UnsafeCell for parallel access
224  // SAFETY: Each row i writes only to [indptr[i]..indptr[i+1]], which are disjoint
225  let indices_vec = ParallelVec::new(vec![0u32; nnz]);
226  let data_vec = ParallelVec::new(vec![0.0f32; nnz]);
227
228  // Step 4: Fill indices and data in parallel (each row writes to its own section)
229  // No false sharing: each row is ~256 elements (~1KB with u32), writes are sequential within row.
230  // Threads work on different rows, not adjacent elements.
231  let started = Instant::now();
232  (0..n_samples).into_par_iter().for_each(|i| {
233    let row_start = indptr[i];
234    let mut offset = 0;
235
236    for j in 0..n_neighbors {
237      if knn_disconnections.contains(&(i, j)) {
238        continue;
239      }
240      let knn_idx = knn_indices[(i, j)];
241      // Skip self-loops and sentinel values (must match count phase exactly)
242      if knn_idx as usize == i || knn_idx as usize >= n_samples {
243        continue;
244      }
245      let val = compute_membership_strength(i, j, knn_dists, rhos, sigmas);
246      if val != 0.0 {
247        // SAFETY: Each row writes to disjoint section [indptr[i]..indptr[i+1]]
248        unsafe {
249          indices_vec.write(row_start + offset, knn_idx);
250          data_vec.write(row_start + offset, val);
251        }
252        offset += 1;
253      }
254    }
255  });
256  info!(
257    duration_ms = started.elapsed().as_millis(),
258    "csr fill complete"
259  );
260
261  // Step 5: Sort column indices within each row (required for valid CSR)
262  // Each row can be sorted independently in parallel
263  let started = Instant::now();
264  (0..n_samples).into_par_iter().for_each(|i| {
265    let row_start = indptr[i];
266    let row_len = indptr[i + 1] - indptr[i];
267    if row_len > 0 {
268      // SAFETY: Each row accesses disjoint section [indptr[i]..indptr[i+1]]
269      let row_indices = unsafe { indices_vec.get_mut_slice(row_start, row_len) };
270      let row_data = unsafe { data_vec.get_mut_slice(row_start, row_len) };
271
272      // Sort by column index (insertion sort is fast for small rows, ~256 elements)
273      for k in 1..row_len {
274        let mut m = k;
275        while m > 0 && row_indices[m - 1] > row_indices[m] {
276          row_indices.swap(m - 1, m);
277          row_data.swap(m - 1, m);
278          m -= 1;
279        }
280      }
281    }
282  });
283  info!(
284    duration_ms = started.elapsed().as_millis(),
285    "csr row_sort complete"
286  );
287
288  // Extract Vecs from UnsafeCell wrappers and build CSR
289  let indices = indices_vec.into_inner();
290  let data = data_vec.into_inner();
291  CsMatI::new((n_samples, n_samples), indptr, indices, data)
292}
293
294fn compute_membership_strength(
295  i: usize,
296  j: usize,
297  knn_dists: ArrayView2<f32>,
298  rhos: &ArrayView1<f32>,
299  sigmas: &ArrayView1<f32>,
300) -> f32 {
301  if knn_dists[(i, j)] - rhos[i] <= 0.0 || sigmas[i] == 0.0 {
302    1.0
303  } else {
304    f32::exp(-(knn_dists[(i, j)] - rhos[i]) / sigmas[i])
305  }
306}
307
308/// Build CSC structure (indptr + indices only, no data) for transpose traversal.
309/// Values are looked up in original CSR when needed via binary search O(log k).
310fn build_csc_structure(csr: &SparseMat) -> CscStructure {
311  let n_rows = csr.shape().0;
312  let n_cols = csr.shape().1;
313  let nnz = csr.nnz();
314
315  // Step 1: Count entries per column (parallel with atomics)
316  let started = Instant::now();
317  let col_counts: Vec<AtomicU32> = (0..n_cols).map(|_| AtomicU32::new(0)).collect();
318
319  (0..n_rows).into_par_iter().for_each(|row| {
320    let row_start = csr.indptr().index(row) as usize;
321    let row_end = csr.indptr().index(row + 1) as usize;
322    for &col in &csr.indices()[row_start..row_end] {
323      col_counts[col as usize].fetch_add(1, Ordering::Relaxed);
324    }
325  });
326  info!(
327    duration_ms = started.elapsed().as_millis(),
328    "csc col_counts complete"
329  );
330
331  // Step 2: Build column pointers (prefix sum, usize since nnz can exceed u32::MAX)
332  let started = Instant::now();
333  let mut indptr: Vec<usize> = Vec::with_capacity(n_cols + 1);
334  indptr.push(0);
335  let mut total = 0usize;
336  for count in &col_counts {
337    total += count.load(Ordering::Relaxed) as usize;
338    indptr.push(total);
339  }
340  assert_eq!(total, nnz);
341  info!(
342    duration_ms = started.elapsed().as_millis(),
343    "csc indptr complete"
344  );
345
346  // Step 3: Fill indices only (sequential to avoid atomic contention)
347  // No data array - values will be looked up in CSR when needed.
348  let started = Instant::now();
349  let mut indices: Vec<u32> = vec![0; nnz];
350  let mut col_offsets: Vec<usize> = vec![0; n_cols];
351
352  for row in 0..n_rows {
353    let row_start = csr.indptr().index(row);
354    let row_end = csr.indptr().index(row + 1);
355    let row_indices = &csr.indices()[row_start..row_end];
356
357    for &col in row_indices {
358      let write_pos = indptr[col as usize] + col_offsets[col as usize];
359      indices[write_pos] = row as u32;
360      col_offsets[col as usize] += 1;
361    }
362  }
363  info!(
364    duration_ms = started.elapsed().as_millis(),
365    "csc fill complete"
366  );
367  // No sorting needed: iterating rows in order guarantees sorted row indices per column
368
369  CscStructure { indptr, indices }
370}
371
372/// Binary search for value A[row, col] in CSR matrix. Returns 0.0 if not found.
373fn csr_get(csr: &SparseMat, row: usize, col: u32) -> f32 {
374  let row_start = csr.indptr().index(row);
375  let row_end = csr.indptr().index(row + 1);
376  let row_indices = &csr.indices()[row_start..row_end];
377  let row_data = &csr.data()[row_start..row_end];
378
379  match row_indices.binary_search(&col) {
380    Ok(idx) => row_data[idx],
381    Err(_) => 0.0,
382  }
383}
384
385/// Apply fuzzy set union/intersection operations, building CSR directly.
386///
387/// Computes: set_op_mix_ratio * (A + A^T) + (1 - 2*set_op_mix_ratio) * (A ⊙ A^T)
388/// where ⊙ is the Hadamard (elementwise) product.
389///
390/// The result is symmetric: for each pair (i,j) where A[i,j] OR A[j,i] is non-zero,
391/// both output[i,j] and output[j,i] are set to the same computed value.
392fn apply_set_operations_parallel(input: &SparseMat, set_op_mix_ratio: f32) -> SparseMat {
393  let n_samples = input.shape().0;
394  let prod_coeff = 1.0 - 2.0 * set_op_mix_ratio;
395
396  // Build CSC structure (no data) for efficient transpose traversal
397  // Values are looked up in original CSR via binary search (avoids duplicating data array)
398  let started = Instant::now();
399  let csc = build_csc_structure(input);
400  info!(
401    duration_ms = started.elapsed().as_millis(),
402    "set_operations csc_structure complete"
403  );
404
405  // Step 1: Count output entries per row
406  // For row r, entries come from:
407  //   - A's row r (direct entries)
408  //   - A's column r where A[r,c] doesn't exist (transpose entries without direct counterpart)
409  let started = Instant::now();
410  let row_counts: Vec<u32> = (0..n_samples)
411    .into_par_iter()
412    .map(|row| {
413      // Count from A's row (direct entries)
414      let row_start = input.indptr().index(row);
415      let row_end = input.indptr().index(row + 1);
416      let row_indices = &input.indices()[row_start..row_end];
417      let row_data = &input.data()[row_start..row_end];
418
419      let mut count = 0u32;
420      for (&col, &val_rc) in row_indices.iter().zip(row_data) {
421        let val_cr = csr_get(input, col as usize, row as u32);
422        let final_val =
423          set_op_mix_ratio * val_rc + set_op_mix_ratio * val_cr + prod_coeff * val_rc * val_cr;
424        if final_val != 0.0 {
425          count += 1;
426        }
427      }
428
429      // Count from A's column (transpose entries without direct counterpart)
430      // CSC tells us which rows c have entries in column `row` (i.e., A[c, row] exists)
431      for &c in csc.col_row_indices(row) {
432        // Skip if direct entry A[row, c] exists (already counted above)
433        if csr_get(input, row, c) != 0.0 {
434          continue;
435        }
436        // val_rc = 0 since no direct entry, val_cr = A[c, row]
437        let val_cr = csr_get(input, c as usize, row as u32);
438        let final_val = set_op_mix_ratio * val_cr; // Simplified: 0 + mix*val_cr + 0
439        if final_val != 0.0 {
440          count += 1;
441        }
442      }
443
444      count
445    })
446    .collect();
447  info!(
448    duration_ms = started.elapsed().as_millis(),
449    "set_operations row_counts complete"
450  );
451
452  // Step 2: Build indptr (usize since output nnz can exceed u32::MAX)
453  let started = Instant::now();
454  let mut indptr: Vec<usize> = Vec::with_capacity(n_samples + 1);
455  indptr.push(0);
456  let mut total = 0usize;
457  for &count in &row_counts {
458    total += count as usize;
459    indptr.push(total);
460  }
461  let nnz = total;
462  info!(
463    duration_ms = started.elapsed().as_millis(),
464    nnz, "set_operations indptr complete"
465  );
466
467  // Step 3: Pre-allocate and wrap in UnsafeCell for parallel access
468  // SAFETY: Each row writes only to [indptr[row]..indptr[row+1]], which are disjoint
469  // No false sharing: each row section is ~512 elements (~2KB after symmetrization with u32),
470  // writes are sequential within row. Threads work on different rows.
471  let indices_vec = ParallelVec::new(vec![0u32; nnz]);
472  let data_vec = ParallelVec::new(vec![0.0f32; nnz]);
473
474  let started = Instant::now();
475  (0..n_samples).into_par_iter().for_each(|row| {
476    let out_start = indptr[row];
477    let mut offset = 0;
478
479    // Fill from A's row (direct entries)
480    let row_start = input.indptr().index(row);
481    let row_end = input.indptr().index(row + 1);
482    let row_indices = &input.indices()[row_start..row_end];
483    let row_data = &input.data()[row_start..row_end];
484
485    for (&col, &val_rc) in row_indices.iter().zip(row_data) {
486      let val_cr = csr_get(input, col as usize, row as u32);
487      let final_val =
488        set_op_mix_ratio * val_rc + set_op_mix_ratio * val_cr + prod_coeff * val_rc * val_cr;
489      if final_val != 0.0 {
490        // SAFETY: Each row writes to disjoint section [indptr[row]..indptr[row+1]]
491        unsafe {
492          indices_vec.write(out_start + offset, col);
493          data_vec.write(out_start + offset, final_val);
494        }
495        offset += 1;
496      }
497    }
498
499    // Fill from A's column (transpose entries without direct counterpart)
500    for &c in csc.col_row_indices(row) {
501      // Skip if direct entry exists (already filled above)
502      if csr_get(input, row, c) != 0.0 {
503        continue;
504      }
505      let val_cr = csr_get(input, c as usize, row as u32);
506      let final_val = set_op_mix_ratio * val_cr;
507      if final_val != 0.0 {
508        // SAFETY: Each row writes to disjoint section
509        unsafe {
510          indices_vec.write(out_start + offset, c);
511          data_vec.write(out_start + offset, final_val);
512        }
513        offset += 1;
514      }
515    }
516  });
517  info!(
518    duration_ms = started.elapsed().as_millis(),
519    "set_operations fill complete"
520  );
521
522  // Step 4: Sort columns within each row (entries may be unsorted after combining)
523  let started = Instant::now();
524  (0..n_samples).into_par_iter().for_each(|row| {
525    let row_start = indptr[row];
526    let row_len = indptr[row + 1] - indptr[row];
527    if row_len > 1 {
528      // SAFETY: Each row accesses disjoint section
529      let row_indices = unsafe { indices_vec.get_mut_slice(row_start, row_len) };
530      let row_data = unsafe { data_vec.get_mut_slice(row_start, row_len) };
531
532      // Insertion sort (rows are small after set operations)
533      for k in 1..row_len {
534        let mut m = k;
535        while m > 0 && row_indices[m - 1] > row_indices[m] {
536          row_indices.swap(m - 1, m);
537          row_data.swap(m - 1, m);
538          m -= 1;
539        }
540      }
541    }
542  });
543  info!(
544    duration_ms = started.elapsed().as_millis(),
545    "set_operations row_sort complete"
546  );
547
548  // Extract Vecs from UnsafeCell wrappers and build CSR
549  let indices = indices_vec.into_inner();
550  let data = data_vec.into_inner();
551  CsMatI::new((n_samples, n_samples), indptr, indices, data)
552}