Skip to main content

ruvector_attention_node/
async_ops.rs

1//! NAPI-RS bindings for async and batch operations
2//!
3//! Provides Node.js bindings for:
4//! - Async attention computation with tokio
5//! - Batch processing utilities
6//! - Parallel attention computation
7
8use napi::bindgen_prelude::*;
9use napi_derive::napi;
10use ruvector_attention::{
11    attention::ScaledDotProductAttention,
12    hyperbolic::{HyperbolicAttention, HyperbolicAttentionConfig},
13    sparse::{FlashAttention, LinearAttention, LocalGlobalAttention},
14    traits::Attention,
15};
16use std::sync::Arc;
17
18// ============================================================================
19// Batch Processing Configuration
20// ============================================================================
21
22/// Batch processing configuration
23#[napi(object)]
24pub struct BatchConfig {
25    pub batch_size: u32,
26    pub num_workers: Option<u32>,
27    pub prefetch: Option<bool>,
28}
29
30/// Batch processing result
31#[napi(object)]
32pub struct BatchResult {
33    pub outputs: Vec<Float32Array>,
34    pub elapsed_ms: f64,
35    pub throughput: f64,
36}
37
38// ============================================================================
39// Async Attention Operations
40// ============================================================================
41
42/// Async scaled dot-product attention computation
43#[napi]
44pub async fn compute_attention_async(
45    query: Float32Array,
46    keys: Vec<Float32Array>,
47    values: Vec<Float32Array>,
48    dim: u32,
49) -> Result<Float32Array> {
50    let query_vec = query.to_vec();
51    let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
52    let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
53
54    let result = tokio::task::spawn_blocking(move || {
55        let attention = ScaledDotProductAttention::new(dim as usize);
56        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
57        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
58
59        attention.compute(&query_vec, &keys_refs, &values_refs)
60    })
61    .await
62    .map_err(|e| Error::from_reason(e.to_string()))?
63    .map_err(|e| Error::from_reason(e.to_string()))?;
64
65    Ok(Float32Array::new(result))
66}
67
68/// Async flash attention computation
69#[napi]
70pub async fn compute_flash_attention_async(
71    query: Float32Array,
72    keys: Vec<Float32Array>,
73    values: Vec<Float32Array>,
74    dim: u32,
75    block_size: u32,
76) -> Result<Float32Array> {
77    let query_vec = query.to_vec();
78    let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
79    let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
80
81    let result = tokio::task::spawn_blocking(move || {
82        let attention = FlashAttention::new(dim as usize, block_size as usize);
83        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
84        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
85
86        attention.compute(&query_vec, &keys_refs, &values_refs)
87    })
88    .await
89    .map_err(|e| Error::from_reason(e.to_string()))?
90    .map_err(|e| Error::from_reason(e.to_string()))?;
91
92    Ok(Float32Array::new(result))
93}
94
95/// Async hyperbolic attention computation
96#[napi]
97pub async fn compute_hyperbolic_attention_async(
98    query: Float32Array,
99    keys: Vec<Float32Array>,
100    values: Vec<Float32Array>,
101    dim: u32,
102    curvature: f64,
103) -> Result<Float32Array> {
104    let query_vec = query.to_vec();
105    let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
106    let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
107
108    let result = tokio::task::spawn_blocking(move || {
109        let config = HyperbolicAttentionConfig {
110            dim: dim as usize,
111            curvature: curvature as f32,
112            ..Default::default()
113        };
114        let attention = HyperbolicAttention::new(config);
115        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
116        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
117
118        attention.compute(&query_vec, &keys_refs, &values_refs)
119    })
120    .await
121    .map_err(|e| Error::from_reason(e.to_string()))?
122    .map_err(|e| Error::from_reason(e.to_string()))?;
123
124    Ok(Float32Array::new(result))
125}
126
127// ============================================================================
128// Batch Processing
129// ============================================================================
130
131/// Process a batch of attention computations
132#[napi]
133pub async fn batch_attention_compute(
134    queries: Vec<Float32Array>,
135    keys: Vec<Vec<Float32Array>>,
136    values: Vec<Vec<Float32Array>>,
137    dim: u32,
138) -> Result<BatchResult> {
139    let start = std::time::Instant::now();
140    let batch_size = queries.len();
141
142    // Convert to owned vectors for thread safety
143    let queries_vec: Vec<Vec<f32>> = queries.into_iter().map(|q| q.to_vec()).collect();
144    let keys_vec: Vec<Vec<Vec<f32>>> = keys
145        .into_iter()
146        .map(|k| k.into_iter().map(|arr| arr.to_vec()).collect())
147        .collect();
148    let values_vec: Vec<Vec<Vec<f32>>> = values
149        .into_iter()
150        .map(|v| v.into_iter().map(|arr| arr.to_vec()).collect())
151        .collect();
152
153    let dim_usize = dim as usize;
154
155    let results = tokio::task::spawn_blocking(move || {
156        let attention = ScaledDotProductAttention::new(dim_usize);
157        let mut outputs = Vec::with_capacity(batch_size);
158
159        for i in 0..batch_size {
160            let keys_refs: Vec<&[f32]> = keys_vec[i].iter().map(|k| k.as_slice()).collect();
161            let values_refs: Vec<&[f32]> = values_vec[i].iter().map(|v| v.as_slice()).collect();
162
163            match attention.compute(&queries_vec[i], &keys_refs, &values_refs) {
164                Ok(output) => outputs.push(output),
165                Err(e) => return Err(e.to_string()),
166            }
167        }
168
169        Ok(outputs)
170    })
171    .await
172    .map_err(|e| Error::from_reason(e.to_string()))?
173    .map_err(|e| Error::from_reason(e))?;
174
175    let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
176    let throughput = batch_size as f64 / start.elapsed().as_secs_f64();
177
178    Ok(BatchResult {
179        outputs: results.into_iter().map(Float32Array::new).collect(),
180        elapsed_ms,
181        throughput,
182    })
183}
184
185/// Process a batch with flash attention
186#[napi]
187pub async fn batch_flash_attention_compute(
188    queries: Vec<Float32Array>,
189    keys: Vec<Vec<Float32Array>>,
190    values: Vec<Vec<Float32Array>>,
191    dim: u32,
192    block_size: u32,
193) -> Result<BatchResult> {
194    let start = std::time::Instant::now();
195    let batch_size = queries.len();
196
197    let queries_vec: Vec<Vec<f32>> = queries.into_iter().map(|q| q.to_vec()).collect();
198    let keys_vec: Vec<Vec<Vec<f32>>> = keys
199        .into_iter()
200        .map(|k| k.into_iter().map(|arr| arr.to_vec()).collect())
201        .collect();
202    let values_vec: Vec<Vec<Vec<f32>>> = values
203        .into_iter()
204        .map(|v| v.into_iter().map(|arr| arr.to_vec()).collect())
205        .collect();
206
207    let dim_usize = dim as usize;
208    let block_usize = block_size as usize;
209
210    let results = tokio::task::spawn_blocking(move || {
211        let attention = FlashAttention::new(dim_usize, block_usize);
212        let mut outputs = Vec::with_capacity(batch_size);
213
214        for i in 0..batch_size {
215            let keys_refs: Vec<&[f32]> = keys_vec[i].iter().map(|k| k.as_slice()).collect();
216            let values_refs: Vec<&[f32]> = values_vec[i].iter().map(|v| v.as_slice()).collect();
217
218            match attention.compute(&queries_vec[i], &keys_refs, &values_refs) {
219                Ok(output) => outputs.push(output),
220                Err(e) => return Err(e.to_string()),
221            }
222        }
223
224        Ok(outputs)
225    })
226    .await
227    .map_err(|e| Error::from_reason(e.to_string()))?
228    .map_err(|e| Error::from_reason(e))?;
229
230    let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
231    let throughput = batch_size as f64 / start.elapsed().as_secs_f64();
232
233    Ok(BatchResult {
234        outputs: results.into_iter().map(Float32Array::new).collect(),
235        elapsed_ms,
236        throughput,
237    })
238}
239
240// ============================================================================
241// Parallel Attention Computation
242// ============================================================================
243
244/// Attention type for parallel computation
245#[napi(string_enum)]
246pub enum AttentionType {
247    ScaledDotProduct,
248    Flash,
249    Linear,
250    LocalGlobal,
251    Hyperbolic,
252}
253
254/// Configuration for parallel attention
255#[napi(object)]
256pub struct ParallelConfig {
257    pub attention_type: AttentionType,
258    pub dim: u32,
259    pub block_size: Option<u32>,
260    pub num_features: Option<u32>,
261    pub local_window: Option<u32>,
262    pub global_tokens: Option<u32>,
263    pub curvature: Option<f64>,
264}
265
266/// Parallel attention computation across multiple queries
267#[napi]
268pub async fn parallel_attention_compute(
269    config: ParallelConfig,
270    queries: Vec<Float32Array>,
271    keys: Vec<Vec<Float32Array>>,
272    values: Vec<Vec<Float32Array>>,
273) -> Result<BatchResult> {
274    let start = std::time::Instant::now();
275    let batch_size = queries.len();
276
277    let queries_vec: Vec<Vec<f32>> = queries.into_iter().map(|q| q.to_vec()).collect();
278    let keys_vec: Vec<Vec<Vec<f32>>> = keys
279        .into_iter()
280        .map(|k| k.into_iter().map(|arr| arr.to_vec()).collect())
281        .collect();
282    let values_vec: Vec<Vec<Vec<f32>>> = values
283        .into_iter()
284        .map(|v| v.into_iter().map(|arr| arr.to_vec()).collect())
285        .collect();
286
287    let dim = config.dim as usize;
288    let attention_type = config.attention_type;
289    let block_size = config.block_size.unwrap_or(64) as usize;
290    let num_features = config.num_features.unwrap_or(64) as usize;
291    let local_window = config.local_window.unwrap_or(128) as usize;
292    let global_tokens = config.global_tokens.unwrap_or(8) as usize;
293    let curvature = config.curvature.unwrap_or(1.0) as f32;
294
295    let results = tokio::task::spawn_blocking(move || {
296        let mut outputs = Vec::with_capacity(batch_size);
297
298        for i in 0..batch_size {
299            let keys_refs: Vec<&[f32]> = keys_vec[i].iter().map(|k| k.as_slice()).collect();
300            let values_refs: Vec<&[f32]> = values_vec[i].iter().map(|v| v.as_slice()).collect();
301
302            let result = match attention_type {
303                AttentionType::ScaledDotProduct => {
304                    let attention = ScaledDotProductAttention::new(dim);
305                    attention.compute(&queries_vec[i], &keys_refs, &values_refs)
306                }
307                AttentionType::Flash => {
308                    let attention = FlashAttention::new(dim, block_size);
309                    attention.compute(&queries_vec[i], &keys_refs, &values_refs)
310                }
311                AttentionType::Linear => {
312                    let attention = LinearAttention::new(dim, num_features);
313                    attention.compute(&queries_vec[i], &keys_refs, &values_refs)
314                }
315                AttentionType::LocalGlobal => {
316                    let attention = LocalGlobalAttention::new(dim, local_window, global_tokens);
317                    attention.compute(&queries_vec[i], &keys_refs, &values_refs)
318                }
319                AttentionType::Hyperbolic => {
320                    let config = HyperbolicAttentionConfig {
321                        dim,
322                        curvature,
323                        ..Default::default()
324                    };
325                    let attention = HyperbolicAttention::new(config);
326                    attention.compute(&queries_vec[i], &keys_refs, &values_refs)
327                }
328            };
329
330            match result {
331                Ok(output) => outputs.push(output),
332                Err(e) => return Err(e.to_string()),
333            }
334        }
335
336        Ok(outputs)
337    })
338    .await
339    .map_err(|e| Error::from_reason(e.to_string()))?
340    .map_err(|e| Error::from_reason(e))?;
341
342    let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
343    let throughput = batch_size as f64 / start.elapsed().as_secs_f64();
344
345    Ok(BatchResult {
346        outputs: results.into_iter().map(Float32Array::new).collect(),
347        elapsed_ms,
348        throughput,
349    })
350}
351
352// ============================================================================
353// Streaming Processing
354// ============================================================================
355
356/// Stream processor for handling attention in chunks
357#[napi]
358pub struct StreamProcessor {
359    dim: usize,
360    buffer: Vec<Vec<f32>>,
361    max_buffer_size: usize,
362}
363
364#[napi]
365impl StreamProcessor {
366    /// Create a new stream processor
367    ///
368    /// # Arguments
369    /// * `dim` - Embedding dimension
370    /// * `max_buffer_size` - Maximum number of items to buffer
371    #[napi(constructor)]
372    pub fn new(dim: u32, max_buffer_size: u32) -> Self {
373        Self {
374            dim: dim as usize,
375            buffer: Vec::new(),
376            max_buffer_size: max_buffer_size as usize,
377        }
378    }
379
380    /// Add a vector to the buffer
381    #[napi]
382    pub fn push(&mut self, vector: Float32Array) -> bool {
383        if self.buffer.len() >= self.max_buffer_size {
384            return false;
385        }
386        self.buffer.push(vector.to_vec());
387        true
388    }
389
390    /// Process buffered vectors with attention against a query
391    #[napi]
392    pub fn process(&self, query: Float32Array) -> Result<Float32Array> {
393        if self.buffer.is_empty() {
394            return Err(Error::from_reason("Buffer is empty"));
395        }
396
397        let attention = ScaledDotProductAttention::new(self.dim);
398        let query_slice = query.as_ref();
399        let keys_refs: Vec<&[f32]> = self.buffer.iter().map(|k| k.as_slice()).collect();
400        let values_refs: Vec<&[f32]> = self.buffer.iter().map(|v| v.as_slice()).collect();
401
402        let result = attention
403            .compute(query_slice, &keys_refs, &values_refs)
404            .map_err(|e| Error::from_reason(e.to_string()))?;
405
406        Ok(Float32Array::new(result))
407    }
408
409    /// Clear the buffer
410    #[napi]
411    pub fn clear(&mut self) {
412        self.buffer.clear();
413    }
414
415    /// Get current buffer size
416    #[napi(getter)]
417    pub fn size(&self) -> u32 {
418        self.buffer.len() as u32
419    }
420
421    /// Check if buffer is full
422    #[napi(getter)]
423    pub fn is_full(&self) -> bool {
424        self.buffer.len() >= self.max_buffer_size
425    }
426}
427
428// ============================================================================
429// Benchmark Utilities
430// ============================================================================
431
432/// Benchmark result
433#[napi(object)]
434pub struct BenchmarkResult {
435    pub name: String,
436    pub iterations: u32,
437    pub total_ms: f64,
438    pub avg_ms: f64,
439    pub ops_per_sec: f64,
440    pub min_ms: f64,
441    pub max_ms: f64,
442}
443
444/// Run attention benchmark
445#[napi]
446pub async fn benchmark_attention(
447    attention_type: AttentionType,
448    dim: u32,
449    seq_length: u32,
450    iterations: u32,
451) -> Result<BenchmarkResult> {
452    let dim_usize = dim as usize;
453    let seq_usize = seq_length as usize;
454    let iter_usize = iterations as usize;
455
456    let result = tokio::task::spawn_blocking(move || {
457        // Generate test data
458        let query: Vec<f32> = (0..dim_usize).map(|i| (i as f32 * 0.01).sin()).collect();
459        let keys: Vec<Vec<f32>> = (0..seq_usize)
460            .map(|j| {
461                (0..dim_usize)
462                    .map(|i| ((i + j) as f32 * 0.01).cos())
463                    .collect()
464            })
465            .collect();
466        let values: Vec<Vec<f32>> = keys.clone();
467
468        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
469        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
470
471        let name = match attention_type {
472            AttentionType::ScaledDotProduct => "ScaledDotProduct",
473            AttentionType::Flash => "Flash",
474            AttentionType::Linear => "Linear",
475            AttentionType::LocalGlobal => "LocalGlobal",
476            AttentionType::Hyperbolic => "Hyperbolic",
477        }
478        .to_string();
479
480        let mut times: Vec<f64> = Vec::with_capacity(iter_usize);
481
482        for _ in 0..iter_usize {
483            let start = std::time::Instant::now();
484
485            match attention_type {
486                AttentionType::ScaledDotProduct => {
487                    let attention = ScaledDotProductAttention::new(dim_usize);
488                    let _ = attention.compute(&query, &keys_refs, &values_refs);
489                }
490                AttentionType::Flash => {
491                    let attention = FlashAttention::new(dim_usize, 64);
492                    let _ = attention.compute(&query, &keys_refs, &values_refs);
493                }
494                AttentionType::Linear => {
495                    let attention = LinearAttention::new(dim_usize, 64);
496                    let _ = attention.compute(&query, &keys_refs, &values_refs);
497                }
498                AttentionType::LocalGlobal => {
499                    let attention = LocalGlobalAttention::new(dim_usize, 128, 8);
500                    let _ = attention.compute(&query, &keys_refs, &values_refs);
501                }
502                AttentionType::Hyperbolic => {
503                    let config = HyperbolicAttentionConfig {
504                        dim: dim_usize,
505                        curvature: 1.0,
506                        ..Default::default()
507                    };
508                    let attention = HyperbolicAttention::new(config);
509                    let _ = attention.compute(&query, &keys_refs, &values_refs);
510                }
511            }
512
513            times.push(start.elapsed().as_secs_f64() * 1000.0);
514        }
515
516        let total_ms: f64 = times.iter().sum();
517        let avg_ms = total_ms / iter_usize as f64;
518        let min_ms = times.iter().copied().fold(f64::INFINITY, f64::min);
519        let max_ms = times.iter().copied().fold(f64::NEG_INFINITY, f64::max);
520        let ops_per_sec = 1000.0 / avg_ms;
521
522        BenchmarkResult {
523            name,
524            iterations: iterations,
525            total_ms,
526            avg_ms,
527            ops_per_sec,
528            min_ms,
529            max_ms,
530        }
531    })
532    .await
533    .map_err(|e| Error::from_reason(e.to_string()))?;
534
535    Ok(result)
536}