umap_rs/
optimizer.rs

1use crate::config::UmapConfig;
2use crate::embedding::FittedUmap;
3use crate::layout::optimize_layout_euclidean::optimize_layout_euclidean_single_epoch_stateful;
4use crate::layout::optimize_layout_generic::optimize_layout_generic_single_epoch_stateful;
5use crate::manifold::LearnedManifold;
6use crate::metric::Metric;
7use crate::metric::MetricType;
8use crate::umap::make_epochs_per_sample::make_epochs_per_sample;
9use crate::utils::parallel_vec::ParallelVec;
10use ndarray::Array1;
11use ndarray::Array2;
12use ndarray::ArrayView2;
13use rayon::prelude::*;
14use serde::Deserialize;
15use serde::Serialize;
16use std::time::Instant;
17use tracing::info;
18
19/// Active optimization state for UMAP embedding.
20///
21/// This contains all the state needed to run and resume stochastic gradient
22/// descent optimization. It's large and mutable, meant to be used during
23/// training and then converted to a lightweight `FittedUmap` when done.
24///
25/// The optimizer can be serialized mid-training to enable fault-tolerant
26/// training with checkpoints.
27#[derive(Debug, Serialize, Deserialize)]
28pub struct Optimizer {
29  // Reference to learned manifold
30  manifold: LearnedManifold,
31
32  // Preprocessed graph structures for optimization
33  head: Array1<u32>,
34  tail: Array1<u32>,
35  epochs_per_sample: Array1<f64>,
36
37  // Current embedding state
38  embedding: Array2<f32>,
39
40  // SGD scheduling state
41  epoch_of_next_sample: Array1<f64>,
42  epoch_of_next_negative_sample: Array1<f64>,
43  epochs_per_negative_sample: Array1<f64>,
44
45  // Progress tracking
46  current_epoch: usize,
47  total_epochs: usize,
48
49  // Optimization parameters
50  gamma: f32,
51  initial_alpha: f32,
52  negative_sample_rate: f64,
53
54  // Metric type (determined once at creation)
55  metric_type: MetricType,
56}
57
58impl Optimizer {
59  /// Create a new optimizer from a learned manifold.
60  ///
61  /// This performs preprocessing:
62  /// - Filters weak edges from the graph
63  /// - Extracts head/tail edge lists
64  /// - Computes epoch sampling schedules
65  /// - Normalizes the initial embedding to [0, 10]
66  ///
67  /// # Arguments
68  ///
69  /// * `manifold` - The learned manifold structure
70  /// * `init` - Initial embedding (will be normalized)
71  /// * `total_epochs` - Total number of epochs to run
72  /// * `opt_params` - Optimization parameters (learning rate, negative sampling, etc.)
73  /// * `metric_type` - Type of distance metric being used
74  pub fn new(
75    manifold: LearnedManifold,
76    init: Array2<f32>,
77    total_epochs: usize,
78    opt_params: &UmapConfig,
79    metric_type: MetricType,
80  ) -> Self {
81    let gamma = opt_params.optimization.repulsion_strength;
82    let initial_alpha = opt_params.optimization.learning_rate;
83    let negative_sample_rate = opt_params.optimization.negative_sample_rate;
84
85    let graph = &manifold.graph;
86    let n_samples = graph.shape().0;
87
88    // Determine epoch threshold for filtering weak edges
89    let started = Instant::now();
90    let max_val = graph
91      .data()
92      .par_iter()
93      .copied()
94      .reduce(|| 0.0f32, |a, b| a.max(b));
95
96    let default_epochs = if n_samples <= 10000 { 500 } else { 200 };
97    let threshold_epochs = if total_epochs > 10 {
98      total_epochs
99    } else {
100      default_epochs
101    };
102    let threshold = max_val / threshold_epochs as f32;
103    info!(
104      duration_ms = started.elapsed().as_millis(),
105      max_val, threshold, "optimizer threshold computed"
106    );
107
108    // Count edges per row that pass threshold (parallel)
109    let started = Instant::now();
110    let row_counts: Vec<usize> = (0..n_samples)
111      .into_par_iter()
112      .map(|row| {
113        let row_start = graph.indptr().index(row);
114        let row_end = graph.indptr().index(row + 1);
115        let row_data = &graph.data()[row_start..row_end];
116        row_data.iter().filter(|&&v| v >= threshold).count()
117      })
118      .collect();
119
120    // Prefix sum for edge offsets
121    let mut edge_offsets: Vec<usize> = Vec::with_capacity(n_samples + 1);
122    edge_offsets.push(0);
123    let mut total_edges = 0usize;
124    for &count in &row_counts {
125      total_edges += count;
126      edge_offsets.push(total_edges);
127    }
128    info!(
129      duration_ms = started.elapsed().as_millis(),
130      total_edges, "optimizer edge filtering complete"
131    );
132
133    // Extract head, tail, and weights in parallel
134    let started = Instant::now();
135    let head_vec = ParallelVec::new(vec![0u32; total_edges]);
136    let tail_vec = ParallelVec::new(vec![0u32; total_edges]);
137    let weights_vec = ParallelVec::new(vec![0.0f32; total_edges]);
138
139    (0..n_samples).into_par_iter().for_each(|row| {
140      let row_start = graph.indptr().index(row);
141      let row_end = graph.indptr().index(row + 1);
142      let row_indices = &graph.indices()[row_start..row_end];
143      let row_data = &graph.data()[row_start..row_end];
144
145      let out_start = edge_offsets[row];
146      let mut offset = 0;
147
148      for (&col, &val) in row_indices.iter().zip(row_data) {
149        if val >= threshold {
150          // SAFETY: Each row writes to disjoint section [edge_offsets[row]..edge_offsets[row+1]]
151          unsafe {
152            head_vec.write(out_start + offset, row as u32);
153            tail_vec.write(out_start + offset, col);
154            weights_vec.write(out_start + offset, val);
155          }
156          offset += 1;
157        }
158      }
159    });
160
161    let head = head_vec.into_inner();
162    let tail = tail_vec.into_inner();
163    let weights = weights_vec.into_inner();
164    info!(
165      duration_ms = started.elapsed().as_millis(),
166      "optimizer edge extraction complete"
167    );
168
169    // Compute epochs per sample from edge weights
170    let started = Instant::now();
171    let weights_array = Array1::from(weights);
172    let epochs_per_sample = make_epochs_per_sample(&weights_array.view(), total_epochs);
173
174    let head = Array1::from(head);
175    let tail = Array1::from(tail);
176    info!(
177      duration_ms = started.elapsed().as_millis(),
178      "optimizer epochs_per_sample complete"
179    );
180
181    // Normalize embedding to [0, 10] range.
182    // Uses row indices + flat slice because columns aren't contiguous in row-major arrays.
183    let started = Instant::now();
184    let mut embedding = init;
185    let n_rows = embedding.shape()[0];
186    let n_dims = embedding.shape()[1];
187
188    // Compute min/max per dimension using fold/reduce (no per-row allocations)
189    let (mins, maxs) = (0..n_rows)
190      .into_par_iter()
191      .fold(
192        || (vec![f32::INFINITY; n_dims], vec![f32::NEG_INFINITY; n_dims]),
193        |(mut mins, mut maxs), i| {
194          let row = embedding.row(i);
195          for (d, &v) in row.iter().enumerate() {
196            mins[d] = mins[d].min(v);
197            maxs[d] = maxs[d].max(v);
198          }
199          (mins, maxs)
200        },
201      )
202      .reduce(
203        || (vec![f32::INFINITY; n_dims], vec![f32::NEG_INFINITY; n_dims]),
204        |(mut mins1, mut maxs1), (mins2, maxs2)| {
205          for d in 0..mins1.len() {
206            mins1[d] = mins1[d].min(mins2[d]);
207            maxs1[d] = maxs1[d].max(maxs2[d]);
208          }
209          (mins1, maxs1)
210        },
211      );
212
213    // Compute scales
214    let scales: Vec<f32> = mins
215      .iter()
216      .zip(&maxs)
217      .map(|(&min, &max)| {
218        let range = max - min;
219        if range > 0.0 { 10.0 / range } else { 0.0 }
220      })
221      .collect();
222
223    // Apply normalization (parallel over flat array since it's contiguous)
224    let flat = embedding.as_slice_mut().unwrap();
225    flat.par_iter_mut().enumerate().for_each(|(idx, v)| {
226      let d = idx % n_dims;
227      if scales[d] > 0.0 {
228        *v = (*v - mins[d]) * scales[d];
229      }
230    });
231    info!(
232      duration_ms = started.elapsed().as_millis(),
233      "optimizer embedding normalization complete"
234    );
235
236    // Initialize epoch scheduling (one at a time to avoid memory spike).
237    // No perf loss: each array is still computed in parallel, just not allocated simultaneously.
238    let started = Instant::now();
239    let neg_rate = negative_sample_rate as f64;
240    let eps_slice = epochs_per_sample.as_slice().unwrap();
241
242    let epoch_of_next_sample = Array1::from(eps_slice.par_iter().copied().collect::<Vec<_>>());
243    info!(
244      duration_ms = started.elapsed().as_millis(),
245      "optimizer epoch_of_next_sample complete"
246    );
247
248    let started = Instant::now();
249    let epochs_per_negative_sample = Array1::from(
250      eps_slice
251        .par_iter()
252        .map(|&eps| eps / neg_rate)
253        .collect::<Vec<_>>(),
254    );
255    info!(
256      duration_ms = started.elapsed().as_millis(),
257      "optimizer epochs_per_negative_sample complete"
258    );
259
260    let started = Instant::now();
261    let epoch_of_next_negative_sample = Array1::from(
262      epochs_per_negative_sample
263        .as_slice()
264        .unwrap()
265        .par_iter()
266        .copied()
267        .collect::<Vec<_>>(),
268    );
269    info!(
270      duration_ms = started.elapsed().as_millis(),
271      "optimizer epoch_of_next_negative_sample complete"
272    );
273
274    Self {
275      manifold,
276      head,
277      tail,
278      epochs_per_sample,
279      embedding,
280      epoch_of_next_sample,
281      epoch_of_next_negative_sample,
282      epochs_per_negative_sample,
283      current_epoch: 0,
284      total_epochs,
285      gamma,
286      initial_alpha,
287      negative_sample_rate: negative_sample_rate as f64,
288      metric_type,
289    }
290  }
291
292  /// Run n more epochs of stochastic gradient descent.
293  ///
294  /// # Panics
295  ///
296  /// Panics if this would exceed total_epochs. Check remaining_epochs() first.
297  pub fn step_epochs(&mut self, n: usize, output_metric: &dyn Metric) {
298    assert!(
299      self.current_epoch + n <= self.total_epochs,
300      "Cannot step {} epochs: would exceed total_epochs {} (current: {})",
301      n,
302      self.total_epochs,
303      self.current_epoch
304    );
305
306    let start_epoch = self.current_epoch;
307    let end_epoch = self.current_epoch + n;
308
309    let n_vertices = self.manifold.n_vertices;
310    let a = self.manifold.a;
311    let b = self.manifold.b;
312
313    // Run the optimization epochs
314    let mut embedding_copy = self.embedding.clone();
315
316    for epoch in start_epoch..end_epoch {
317      let alpha = self.initial_alpha * (1.0 - (epoch as f32 / self.total_epochs as f32));
318
319      match self.metric_type {
320        MetricType::Euclidean => {
321          // Euclidean specialization with parallelization
322          optimize_layout_euclidean_single_epoch_stateful(
323            &mut self.embedding.view_mut(),
324            &mut embedding_copy.view_mut(),
325            &self.head.view(),
326            &self.tail.view(),
327            n_vertices,
328            &self.epochs_per_sample.view(),
329            a,
330            b,
331            self.gamma,
332            alpha,
333            &mut self.epochs_per_negative_sample,
334            &mut self.epoch_of_next_sample,
335            &mut self.epoch_of_next_negative_sample,
336            epoch,
337            true, // parallel
338            true, // move_other
339          );
340        }
341        MetricType::Generic => {
342          // Generic metric path
343          optimize_layout_generic_single_epoch_stateful(
344            &mut self.embedding.view_mut(),
345            &mut embedding_copy.view_mut(),
346            &self.head.view(),
347            &self.tail.view(),
348            n_vertices,
349            &self.epochs_per_sample.view(),
350            a,
351            b,
352            self.gamma,
353            alpha,
354            &mut self.epochs_per_negative_sample,
355            &mut self.epoch_of_next_sample,
356            &mut self.epoch_of_next_negative_sample,
357            epoch,
358            true, // move_other
359            output_metric,
360          );
361        }
362      }
363    }
364
365    self.current_epoch = end_epoch;
366  }
367
368  /// Get the current epoch number.
369  pub fn current_epoch(&self) -> usize {
370    self.current_epoch
371  }
372
373  /// Get the total epochs this optimizer is configured for.
374  pub fn total_epochs(&self) -> usize {
375    self.total_epochs
376  }
377
378  /// Get the number of remaining epochs.
379  pub fn remaining_epochs(&self) -> usize {
380    self.total_epochs - self.current_epoch
381  }
382
383  /// Get a view of the current embedding.
384  pub fn embedding(&self) -> ArrayView2<'_, f32> {
385    self.embedding.view()
386  }
387
388  /// Get a reference to the learned manifold.
389  pub fn manifold(&self) -> &LearnedManifold {
390    &self.manifold
391  }
392
393  /// Consume the optimizer and return a lightweight fitted model.
394  ///
395  /// This drops all the optimization state (epoch counters, preprocessed
396  /// arrays) and keeps only the manifold and final embedding.
397  pub fn into_fitted(self, config: UmapConfig) -> FittedUmap {
398    FittedUmap {
399      manifold: self.manifold,
400      embedding: self.embedding,
401      config,
402    }
403  }
404}