sevensense_embedding/application/
services.rs

1//! Application services for embedding generation.
2//!
3//! Provides high-level services for generating embeddings from audio
4//! spectrograms using the Perch 2.0 ONNX model.
5
6use std::sync::Arc;
7use std::time::Instant;
8
9use ndarray::Array3;
10use rayon::prelude::*;
11use tracing::{debug, info, instrument, warn};
12
13use crate::domain::entities::{
14    Embedding, EmbeddingBatch, EmbeddingMetadata, SegmentId, StorageTier,
15};
16use crate::infrastructure::model_manager::ModelManager;
17use crate::normalization;
18use crate::{EmbeddingError, EMBEDDING_DIM, MEL_BINS, MEL_FRAMES};
19
20/// Input spectrogram for embedding generation.
21///
22/// Represents a mel spectrogram with shape [1, MEL_FRAMES, MEL_BINS] = [1, 500, 128].
23#[derive(Debug, Clone)]
24pub struct Spectrogram {
25    /// The spectrogram data as a 3D array [batch, frames, bins]
26    pub data: Array3<f32>,
27
28    /// Associated segment ID
29    pub segment_id: SegmentId,
30
31    /// Additional metadata
32    pub metadata: SpectrogramMetadata,
33}
34
35/// Metadata about the spectrogram
36#[derive(Debug, Clone, Default)]
37pub struct SpectrogramMetadata {
38    /// Sample rate of the original audio
39    pub sample_rate: Option<u32>,
40
41    /// Duration of the audio segment in seconds
42    pub duration_secs: Option<f32>,
43
44    /// SNR of the audio segment
45    pub snr: Option<f32>,
46}
47
48impl Spectrogram {
49    /// Create a new spectrogram from raw data.
50    ///
51    /// # Arguments
52    ///
53    /// * `data` - 2D array of shape [MEL_FRAMES, MEL_BINS] (will be expanded to 3D)
54    /// * `segment_id` - ID of the source audio segment
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if the data dimensions are incorrect.
59    pub fn new(
60        data: ndarray::Array2<f32>,
61        segment_id: SegmentId,
62    ) -> Result<Self, EmbeddingError> {
63        let shape = data.shape();
64        if shape[0] != MEL_FRAMES || shape[1] != MEL_BINS {
65            return Err(EmbeddingError::InvalidDimensions {
66                expected: MEL_FRAMES * MEL_BINS,
67                actual: shape[0] * shape[1],
68            });
69        }
70
71        // Expand to 3D: [1, frames, bins]
72        let data = data.insert_axis(ndarray::Axis(0));
73
74        Ok(Self {
75            data,
76            segment_id,
77            metadata: SpectrogramMetadata::default(),
78        })
79    }
80
81    /// Create from a 3D array directly
82    pub fn from_array3(data: Array3<f32>, segment_id: SegmentId) -> Result<Self, EmbeddingError> {
83        let shape = data.shape();
84        if shape[1] != MEL_FRAMES || shape[2] != MEL_BINS {
85            return Err(EmbeddingError::InvalidDimensions {
86                expected: MEL_FRAMES * MEL_BINS,
87                actual: shape[1] * shape[2],
88            });
89        }
90
91        Ok(Self {
92            data,
93            segment_id,
94            metadata: SpectrogramMetadata::default(),
95        })
96    }
97
98    /// Set metadata for the spectrogram
99    pub fn with_metadata(mut self, metadata: SpectrogramMetadata) -> Self {
100        self.metadata = metadata;
101        self
102    }
103}
104
105/// Output from the embedding service
106#[derive(Debug, Clone)]
107pub struct EmbeddingOutput {
108    /// The generated embedding
109    pub embedding: Embedding,
110
111    /// Whether GPU was used for inference
112    pub gpu_used: bool,
113
114    /// Inference latency in milliseconds
115    pub latency_ms: f32,
116}
117
118/// Configuration for the embedding service
119#[derive(Debug, Clone)]
120pub struct EmbeddingServiceConfig {
121    /// Maximum batch size for inference
122    pub batch_size: usize,
123
124    /// Whether to L2 normalize embeddings
125    pub normalize: bool,
126
127    /// Default storage tier for new embeddings
128    pub default_tier: StorageTier,
129
130    /// Whether to validate embeddings after generation
131    pub validate_embeddings: bool,
132
133    /// Maximum allowed sparsity (fraction of near-zero values)
134    pub max_sparsity: f32,
135}
136
137impl Default for EmbeddingServiceConfig {
138    fn default() -> Self {
139        Self {
140            batch_size: 8,
141            normalize: true,
142            default_tier: StorageTier::Hot,
143            validate_embeddings: true,
144            max_sparsity: 0.9,
145        }
146    }
147}
148
149/// Service for generating embeddings from spectrograms.
150///
151/// This is the main application service for the embedding bounded context.
152/// It coordinates between the model manager, ONNX inference, and domain entities.
153pub struct EmbeddingService {
154    /// Model manager for loading and caching ONNX models
155    model_manager: Arc<ModelManager>,
156
157    /// Configuration for the service
158    config: EmbeddingServiceConfig,
159}
160
161impl EmbeddingService {
162    /// Create a new embedding service.
163    ///
164    /// # Arguments
165    ///
166    /// * `model_manager` - The model manager for ONNX model access
167    /// * `batch_size` - Maximum batch size for inference
168    #[must_use]
169    pub fn new(model_manager: Arc<ModelManager>, batch_size: usize) -> Self {
170        Self {
171            model_manager,
172            config: EmbeddingServiceConfig {
173                batch_size,
174                ..Default::default()
175            },
176        }
177    }
178
179    /// Create with custom configuration
180    #[must_use]
181    pub fn with_config(model_manager: Arc<ModelManager>, config: EmbeddingServiceConfig) -> Self {
182        Self {
183            model_manager,
184            config,
185        }
186    }
187
188    /// Generate an embedding from a single spectrogram.
189    ///
190    /// # Arguments
191    ///
192    /// * `spectrogram` - The input spectrogram
193    ///
194    /// # Errors
195    ///
196    /// Returns an error if inference fails or the embedding is invalid.
197    #[instrument(skip(self, spectrogram), fields(segment_id = %spectrogram.segment_id))]
198    pub async fn embed_segment(
199        &self,
200        spectrogram: &Spectrogram,
201    ) -> Result<EmbeddingOutput, EmbeddingError> {
202        let start = Instant::now();
203
204        // Get the inference session
205        let inference = self.model_manager.get_inference().await?;
206        let model_version = self.model_manager.current_version();
207
208        // Run inference
209        let raw_embedding = inference.run(&spectrogram.data)?;
210
211        // Convert to vector
212        let mut vector: Vec<f32> = raw_embedding.iter().copied().collect();
213
214        // Calculate original norm before normalization
215        let original_norm = normalization::compute_norm(&vector);
216
217        // L2 normalize if configured
218        if self.config.normalize {
219            normalization::l2_normalize(&mut vector);
220        }
221
222        // Validate embedding
223        if self.config.validate_embeddings {
224            self.validate_embedding(&vector)?;
225        }
226
227        // Calculate sparsity
228        let sparsity = normalization::compute_sparsity(&vector);
229
230        // Create embedding entity
231        let mut embedding = Embedding::new(
232            spectrogram.segment_id,
233            vector,
234            model_version.full_version(),
235        )?;
236
237        // Set metadata
238        let latency_ms = start.elapsed().as_secs_f32() * 1000.0;
239        embedding.metadata = EmbeddingMetadata {
240            inference_latency_ms: Some(latency_ms),
241            batch_id: None,
242            gpu_used: inference.is_gpu(),
243            original_norm: Some(original_norm),
244            sparsity: Some(sparsity),
245            quality_score: Some(self.compute_quality_score(&embedding)),
246        };
247
248        embedding.tier = self.config.default_tier;
249
250        debug!(
251            latency_ms = latency_ms,
252            norm = embedding.norm(),
253            sparsity = sparsity,
254            "Generated embedding"
255        );
256
257        Ok(EmbeddingOutput {
258            embedding,
259            gpu_used: inference.is_gpu(),
260            latency_ms,
261        })
262    }
263
264    /// Generate embeddings for multiple spectrograms in batches.
265    ///
266    /// This is more efficient than calling `embed_segment` multiple times
267    /// as it uses batched inference.
268    ///
269    /// # Arguments
270    ///
271    /// * `spectrograms` - Slice of input spectrograms
272    ///
273    /// # Errors
274    ///
275    /// Returns an error if any inference fails. Partial results are not returned.
276    #[instrument(skip(self, spectrograms), fields(count = spectrograms.len()))]
277    pub async fn embed_batch(
278        &self,
279        spectrograms: &[Spectrogram],
280    ) -> Result<Vec<EmbeddingOutput>, EmbeddingError> {
281        if spectrograms.is_empty() {
282            return Ok(Vec::new());
283        }
284
285        let total_start = Instant::now();
286        let batch_id = uuid::Uuid::new_v4().to_string();
287
288        info!(
289            batch_id = %batch_id,
290            total_segments = spectrograms.len(),
291            batch_size = self.config.batch_size,
292            "Starting batch embedding"
293        );
294
295        // Get the inference session
296        let inference = self.model_manager.get_inference().await?;
297        let model_version = self.model_manager.current_version();
298
299        // Process in batches
300        let mut all_outputs = Vec::with_capacity(spectrograms.len());
301
302        for (batch_idx, chunk) in spectrograms.chunks(self.config.batch_size).enumerate() {
303            let batch_start = Instant::now();
304
305            // Prepare batch input
306            let inputs: Vec<&Array3<f32>> = chunk.iter().map(|s| &s.data).collect();
307
308            // Run batched inference
309            let raw_embeddings = inference.run_batch(&inputs)?;
310
311            let batch_latency_ms = batch_start.elapsed().as_secs_f32() * 1000.0;
312            let per_item_latency = batch_latency_ms / chunk.len() as f32;
313
314            // Process each embedding in the batch (parallelize normalization)
315            let outputs: Vec<Result<EmbeddingOutput, EmbeddingError>> = chunk
316                .par_iter()
317                .zip(raw_embeddings.par_iter())
318                .map(|(spectrogram, raw_emb)| {
319                    let mut vector: Vec<f32> = raw_emb.iter().copied().collect();
320                    let original_norm = normalization::compute_norm(&vector);
321
322                    if self.config.normalize {
323                        normalization::l2_normalize(&mut vector);
324                    }
325
326                    if self.config.validate_embeddings {
327                        self.validate_embedding(&vector)?;
328                    }
329
330                    let sparsity = normalization::compute_sparsity(&vector);
331
332                    let mut embedding = Embedding::new(
333                        spectrogram.segment_id,
334                        vector,
335                        model_version.full_version(),
336                    )?;
337
338                    embedding.metadata = EmbeddingMetadata {
339                        inference_latency_ms: Some(per_item_latency),
340                        batch_id: Some(batch_id.clone()),
341                        gpu_used: inference.is_gpu(),
342                        original_norm: Some(original_norm),
343                        sparsity: Some(sparsity),
344                        quality_score: Some(self.compute_quality_score(&embedding)),
345                    };
346
347                    embedding.tier = self.config.default_tier;
348
349                    Ok(EmbeddingOutput {
350                        embedding,
351                        gpu_used: inference.is_gpu(),
352                        latency_ms: per_item_latency,
353                    })
354                })
355                .collect();
356
357            // Check for errors
358            let batch_outputs: Result<Vec<_>, _> = outputs.into_iter().collect();
359            all_outputs.extend(batch_outputs?);
360
361            debug!(
362                batch_idx = batch_idx,
363                batch_size = chunk.len(),
364                latency_ms = batch_latency_ms,
365                "Completed batch"
366            );
367        }
368
369        let total_latency_ms = total_start.elapsed().as_secs_f32() * 1000.0;
370        let throughput = spectrograms.len() as f32 / (total_latency_ms / 1000.0);
371
372        info!(
373            batch_id = %batch_id,
374            total_segments = spectrograms.len(),
375            total_latency_ms = total_latency_ms,
376            throughput_per_sec = throughput,
377            "Completed batch embedding"
378        );
379
380        Ok(all_outputs)
381    }
382
383    /// Create a batch tracking object for monitoring progress.
384    #[must_use]
385    pub fn create_batch(&self, segment_ids: Vec<SegmentId>) -> EmbeddingBatch {
386        EmbeddingBatch::new(segment_ids)
387    }
388
389    /// Validate an embedding vector.
390    fn validate_embedding(&self, vector: &[f32]) -> Result<(), EmbeddingError> {
391        // Check dimensions
392        if vector.len() != EMBEDDING_DIM {
393            return Err(EmbeddingError::InvalidDimensions {
394                expected: EMBEDDING_DIM,
395                actual: vector.len(),
396            });
397        }
398
399        // Check for NaN values
400        if vector.iter().any(|x| x.is_nan()) {
401            return Err(EmbeddingError::Validation(
402                "Embedding contains NaN values".to_string(),
403            ));
404        }
405
406        // Check for infinite values
407        if vector.iter().any(|x| x.is_infinite()) {
408            return Err(EmbeddingError::Validation(
409                "Embedding contains infinite values".to_string(),
410            ));
411        }
412
413        // Check sparsity
414        let sparsity = normalization::compute_sparsity(vector);
415        if sparsity > self.config.max_sparsity {
416            warn!(
417                sparsity = sparsity,
418                max_sparsity = self.config.max_sparsity,
419                "Embedding has high sparsity"
420            );
421        }
422
423        Ok(())
424    }
425
426    /// Compute a quality score for an embedding.
427    fn compute_quality_score(&self, embedding: &Embedding) -> f32 {
428        let mut score = 1.0_f32;
429
430        // Penalize deviation from unit norm
431        let norm = embedding.norm();
432        let norm_deviation = (norm - 1.0).abs();
433        score -= norm_deviation * 0.5;
434
435        // Penalize high sparsity
436        if let Some(sparsity) = embedding.metadata.sparsity {
437            score -= sparsity * 0.3;
438        }
439
440        score.clamp(0.0, 1.0)
441    }
442
443    /// Get the current model version being used.
444    #[must_use]
445    pub fn model_version(&self) -> String {
446        self.model_manager.current_version().full_version()
447    }
448
449    /// Check if the service is ready for inference.
450    pub async fn is_ready(&self) -> bool {
451        self.model_manager.is_ready().await
452    }
453}
454
455/// Builder for creating embedding service instances
456#[derive(Debug)]
457pub struct EmbeddingServiceBuilder {
458    model_manager: Option<Arc<ModelManager>>,
459    config: EmbeddingServiceConfig,
460}
461
462impl EmbeddingServiceBuilder {
463    /// Create a new builder
464    #[must_use]
465    pub fn new() -> Self {
466        Self {
467            model_manager: None,
468            config: EmbeddingServiceConfig::default(),
469        }
470    }
471
472    /// Set the model manager
473    #[must_use]
474    pub fn model_manager(mut self, manager: Arc<ModelManager>) -> Self {
475        self.model_manager = Some(manager);
476        self
477    }
478
479    /// Set the batch size
480    #[must_use]
481    pub fn batch_size(mut self, size: usize) -> Self {
482        self.config.batch_size = size;
483        self
484    }
485
486    /// Set whether to normalize embeddings
487    #[must_use]
488    pub fn normalize(mut self, normalize: bool) -> Self {
489        self.config.normalize = normalize;
490        self
491    }
492
493    /// Set the default storage tier
494    #[must_use]
495    pub fn default_tier(mut self, tier: StorageTier) -> Self {
496        self.config.default_tier = tier;
497        self
498    }
499
500    /// Set whether to validate embeddings
501    #[must_use]
502    pub fn validate_embeddings(mut self, validate: bool) -> Self {
503        self.config.validate_embeddings = validate;
504        self
505    }
506
507    /// Build the embedding service
508    ///
509    /// # Errors
510    ///
511    /// Returns an error if the model manager is not set.
512    pub fn build(self) -> Result<EmbeddingService, EmbeddingError> {
513        let model_manager = self.model_manager.ok_or_else(|| {
514            EmbeddingError::Validation("Model manager is required".to_string())
515        })?;
516
517        Ok(EmbeddingService::with_config(model_manager, self.config))
518    }
519}
520
521impl Default for EmbeddingServiceBuilder {
522    fn default() -> Self {
523        Self::new()
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use ndarray::Array2;
531
532    #[test]
533    fn test_spectrogram_creation() {
534        let data = Array2::zeros((MEL_FRAMES, MEL_BINS));
535        let segment_id = SegmentId::new();
536        let spec = Spectrogram::new(data, segment_id);
537        assert!(spec.is_ok());
538    }
539
540    #[test]
541    fn test_spectrogram_invalid_dimensions() {
542        let data = Array2::zeros((100, 100)); // Wrong dimensions
543        let segment_id = SegmentId::new();
544        let spec = Spectrogram::new(data, segment_id);
545        assert!(spec.is_err());
546    }
547
548    #[test]
549    fn test_service_config_default() {
550        let config = EmbeddingServiceConfig::default();
551        assert_eq!(config.batch_size, 8);
552        assert!(config.normalize);
553        assert!(config.validate_embeddings);
554    }
555
556    #[test]
557    fn test_service_builder() {
558        let builder = EmbeddingServiceBuilder::new()
559            .batch_size(16)
560            .normalize(false)
561            .default_tier(StorageTier::Warm);
562
563        assert_eq!(builder.config.batch_size, 16);
564        assert!(!builder.config.normalize);
565        assert_eq!(builder.config.default_tier, StorageTier::Warm);
566    }
567}