1use napi::bindgen_prelude::*;
9use napi_derive::napi;
10use ruvector_attention::{
11 attention::ScaledDotProductAttention,
12 hyperbolic::{HyperbolicAttention, HyperbolicAttentionConfig},
13 sparse::{FlashAttention, LinearAttention, LocalGlobalAttention},
14 traits::Attention,
15};
16use std::sync::Arc;
17
18#[napi(object)]
24pub struct BatchConfig {
25 pub batch_size: u32,
26 pub num_workers: Option<u32>,
27 pub prefetch: Option<bool>,
28}
29
30#[napi(object)]
32pub struct BatchResult {
33 pub outputs: Vec<Float32Array>,
34 pub elapsed_ms: f64,
35 pub throughput: f64,
36}
37
38#[napi]
44pub async fn compute_attention_async(
45 query: Float32Array,
46 keys: Vec<Float32Array>,
47 values: Vec<Float32Array>,
48 dim: u32,
49) -> Result<Float32Array> {
50 let query_vec = query.to_vec();
51 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
52 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
53
54 let result = tokio::task::spawn_blocking(move || {
55 let attention = ScaledDotProductAttention::new(dim as usize);
56 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
57 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
58
59 attention.compute(&query_vec, &keys_refs, &values_refs)
60 })
61 .await
62 .map_err(|e| Error::from_reason(e.to_string()))?
63 .map_err(|e| Error::from_reason(e.to_string()))?;
64
65 Ok(Float32Array::new(result))
66}
67
68#[napi]
70pub async fn compute_flash_attention_async(
71 query: Float32Array,
72 keys: Vec<Float32Array>,
73 values: Vec<Float32Array>,
74 dim: u32,
75 block_size: u32,
76) -> Result<Float32Array> {
77 let query_vec = query.to_vec();
78 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
79 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
80
81 let result = tokio::task::spawn_blocking(move || {
82 let attention = FlashAttention::new(dim as usize, block_size as usize);
83 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
84 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
85
86 attention.compute(&query_vec, &keys_refs, &values_refs)
87 })
88 .await
89 .map_err(|e| Error::from_reason(e.to_string()))?
90 .map_err(|e| Error::from_reason(e.to_string()))?;
91
92 Ok(Float32Array::new(result))
93}
94
95#[napi]
97pub async fn compute_hyperbolic_attention_async(
98 query: Float32Array,
99 keys: Vec<Float32Array>,
100 values: Vec<Float32Array>,
101 dim: u32,
102 curvature: f64,
103) -> Result<Float32Array> {
104 let query_vec = query.to_vec();
105 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
106 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
107
108 let result = tokio::task::spawn_blocking(move || {
109 let config = HyperbolicAttentionConfig {
110 dim: dim as usize,
111 curvature: curvature as f32,
112 ..Default::default()
113 };
114 let attention = HyperbolicAttention::new(config);
115 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
116 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
117
118 attention.compute(&query_vec, &keys_refs, &values_refs)
119 })
120 .await
121 .map_err(|e| Error::from_reason(e.to_string()))?
122 .map_err(|e| Error::from_reason(e.to_string()))?;
123
124 Ok(Float32Array::new(result))
125}
126
127#[napi]
133pub async fn batch_attention_compute(
134 queries: Vec<Float32Array>,
135 keys: Vec<Vec<Float32Array>>,
136 values: Vec<Vec<Float32Array>>,
137 dim: u32,
138) -> Result<BatchResult> {
139 let start = std::time::Instant::now();
140 let batch_size = queries.len();
141
142 let queries_vec: Vec<Vec<f32>> = queries.into_iter().map(|q| q.to_vec()).collect();
144 let keys_vec: Vec<Vec<Vec<f32>>> = keys
145 .into_iter()
146 .map(|k| k.into_iter().map(|arr| arr.to_vec()).collect())
147 .collect();
148 let values_vec: Vec<Vec<Vec<f32>>> = values
149 .into_iter()
150 .map(|v| v.into_iter().map(|arr| arr.to_vec()).collect())
151 .collect();
152
153 let dim_usize = dim as usize;
154
155 let results = tokio::task::spawn_blocking(move || {
156 let attention = ScaledDotProductAttention::new(dim_usize);
157 let mut outputs = Vec::with_capacity(batch_size);
158
159 for i in 0..batch_size {
160 let keys_refs: Vec<&[f32]> = keys_vec[i].iter().map(|k| k.as_slice()).collect();
161 let values_refs: Vec<&[f32]> = values_vec[i].iter().map(|v| v.as_slice()).collect();
162
163 match attention.compute(&queries_vec[i], &keys_refs, &values_refs) {
164 Ok(output) => outputs.push(output),
165 Err(e) => return Err(e.to_string()),
166 }
167 }
168
169 Ok(outputs)
170 })
171 .await
172 .map_err(|e| Error::from_reason(e.to_string()))?
173 .map_err(|e| Error::from_reason(e))?;
174
175 let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
176 let throughput = batch_size as f64 / start.elapsed().as_secs_f64();
177
178 Ok(BatchResult {
179 outputs: results.into_iter().map(Float32Array::new).collect(),
180 elapsed_ms,
181 throughput,
182 })
183}
184
185#[napi]
187pub async fn batch_flash_attention_compute(
188 queries: Vec<Float32Array>,
189 keys: Vec<Vec<Float32Array>>,
190 values: Vec<Vec<Float32Array>>,
191 dim: u32,
192 block_size: u32,
193) -> Result<BatchResult> {
194 let start = std::time::Instant::now();
195 let batch_size = queries.len();
196
197 let queries_vec: Vec<Vec<f32>> = queries.into_iter().map(|q| q.to_vec()).collect();
198 let keys_vec: Vec<Vec<Vec<f32>>> = keys
199 .into_iter()
200 .map(|k| k.into_iter().map(|arr| arr.to_vec()).collect())
201 .collect();
202 let values_vec: Vec<Vec<Vec<f32>>> = values
203 .into_iter()
204 .map(|v| v.into_iter().map(|arr| arr.to_vec()).collect())
205 .collect();
206
207 let dim_usize = dim as usize;
208 let block_usize = block_size as usize;
209
210 let results = tokio::task::spawn_blocking(move || {
211 let attention = FlashAttention::new(dim_usize, block_usize);
212 let mut outputs = Vec::with_capacity(batch_size);
213
214 for i in 0..batch_size {
215 let keys_refs: Vec<&[f32]> = keys_vec[i].iter().map(|k| k.as_slice()).collect();
216 let values_refs: Vec<&[f32]> = values_vec[i].iter().map(|v| v.as_slice()).collect();
217
218 match attention.compute(&queries_vec[i], &keys_refs, &values_refs) {
219 Ok(output) => outputs.push(output),
220 Err(e) => return Err(e.to_string()),
221 }
222 }
223
224 Ok(outputs)
225 })
226 .await
227 .map_err(|e| Error::from_reason(e.to_string()))?
228 .map_err(|e| Error::from_reason(e))?;
229
230 let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
231 let throughput = batch_size as f64 / start.elapsed().as_secs_f64();
232
233 Ok(BatchResult {
234 outputs: results.into_iter().map(Float32Array::new).collect(),
235 elapsed_ms,
236 throughput,
237 })
238}
239
240#[napi(string_enum)]
246pub enum AttentionType {
247 ScaledDotProduct,
248 Flash,
249 Linear,
250 LocalGlobal,
251 Hyperbolic,
252}
253
254#[napi(object)]
256pub struct ParallelConfig {
257 pub attention_type: AttentionType,
258 pub dim: u32,
259 pub block_size: Option<u32>,
260 pub num_features: Option<u32>,
261 pub local_window: Option<u32>,
262 pub global_tokens: Option<u32>,
263 pub curvature: Option<f64>,
264}
265
266#[napi]
268pub async fn parallel_attention_compute(
269 config: ParallelConfig,
270 queries: Vec<Float32Array>,
271 keys: Vec<Vec<Float32Array>>,
272 values: Vec<Vec<Float32Array>>,
273) -> Result<BatchResult> {
274 let start = std::time::Instant::now();
275 let batch_size = queries.len();
276
277 let queries_vec: Vec<Vec<f32>> = queries.into_iter().map(|q| q.to_vec()).collect();
278 let keys_vec: Vec<Vec<Vec<f32>>> = keys
279 .into_iter()
280 .map(|k| k.into_iter().map(|arr| arr.to_vec()).collect())
281 .collect();
282 let values_vec: Vec<Vec<Vec<f32>>> = values
283 .into_iter()
284 .map(|v| v.into_iter().map(|arr| arr.to_vec()).collect())
285 .collect();
286
287 let dim = config.dim as usize;
288 let attention_type = config.attention_type;
289 let block_size = config.block_size.unwrap_or(64) as usize;
290 let num_features = config.num_features.unwrap_or(64) as usize;
291 let local_window = config.local_window.unwrap_or(128) as usize;
292 let global_tokens = config.global_tokens.unwrap_or(8) as usize;
293 let curvature = config.curvature.unwrap_or(1.0) as f32;
294
295 let results = tokio::task::spawn_blocking(move || {
296 let mut outputs = Vec::with_capacity(batch_size);
297
298 for i in 0..batch_size {
299 let keys_refs: Vec<&[f32]> = keys_vec[i].iter().map(|k| k.as_slice()).collect();
300 let values_refs: Vec<&[f32]> = values_vec[i].iter().map(|v| v.as_slice()).collect();
301
302 let result = match attention_type {
303 AttentionType::ScaledDotProduct => {
304 let attention = ScaledDotProductAttention::new(dim);
305 attention.compute(&queries_vec[i], &keys_refs, &values_refs)
306 }
307 AttentionType::Flash => {
308 let attention = FlashAttention::new(dim, block_size);
309 attention.compute(&queries_vec[i], &keys_refs, &values_refs)
310 }
311 AttentionType::Linear => {
312 let attention = LinearAttention::new(dim, num_features);
313 attention.compute(&queries_vec[i], &keys_refs, &values_refs)
314 }
315 AttentionType::LocalGlobal => {
316 let attention = LocalGlobalAttention::new(dim, local_window, global_tokens);
317 attention.compute(&queries_vec[i], &keys_refs, &values_refs)
318 }
319 AttentionType::Hyperbolic => {
320 let config = HyperbolicAttentionConfig {
321 dim,
322 curvature,
323 ..Default::default()
324 };
325 let attention = HyperbolicAttention::new(config);
326 attention.compute(&queries_vec[i], &keys_refs, &values_refs)
327 }
328 };
329
330 match result {
331 Ok(output) => outputs.push(output),
332 Err(e) => return Err(e.to_string()),
333 }
334 }
335
336 Ok(outputs)
337 })
338 .await
339 .map_err(|e| Error::from_reason(e.to_string()))?
340 .map_err(|e| Error::from_reason(e))?;
341
342 let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
343 let throughput = batch_size as f64 / start.elapsed().as_secs_f64();
344
345 Ok(BatchResult {
346 outputs: results.into_iter().map(Float32Array::new).collect(),
347 elapsed_ms,
348 throughput,
349 })
350}
351
352#[napi]
358pub struct StreamProcessor {
359 dim: usize,
360 buffer: Vec<Vec<f32>>,
361 max_buffer_size: usize,
362}
363
364#[napi]
365impl StreamProcessor {
366 #[napi(constructor)]
372 pub fn new(dim: u32, max_buffer_size: u32) -> Self {
373 Self {
374 dim: dim as usize,
375 buffer: Vec::new(),
376 max_buffer_size: max_buffer_size as usize,
377 }
378 }
379
380 #[napi]
382 pub fn push(&mut self, vector: Float32Array) -> bool {
383 if self.buffer.len() >= self.max_buffer_size {
384 return false;
385 }
386 self.buffer.push(vector.to_vec());
387 true
388 }
389
390 #[napi]
392 pub fn process(&self, query: Float32Array) -> Result<Float32Array> {
393 if self.buffer.is_empty() {
394 return Err(Error::from_reason("Buffer is empty"));
395 }
396
397 let attention = ScaledDotProductAttention::new(self.dim);
398 let query_slice = query.as_ref();
399 let keys_refs: Vec<&[f32]> = self.buffer.iter().map(|k| k.as_slice()).collect();
400 let values_refs: Vec<&[f32]> = self.buffer.iter().map(|v| v.as_slice()).collect();
401
402 let result = attention
403 .compute(query_slice, &keys_refs, &values_refs)
404 .map_err(|e| Error::from_reason(e.to_string()))?;
405
406 Ok(Float32Array::new(result))
407 }
408
409 #[napi]
411 pub fn clear(&mut self) {
412 self.buffer.clear();
413 }
414
415 #[napi(getter)]
417 pub fn size(&self) -> u32 {
418 self.buffer.len() as u32
419 }
420
421 #[napi(getter)]
423 pub fn is_full(&self) -> bool {
424 self.buffer.len() >= self.max_buffer_size
425 }
426}
427
428#[napi(object)]
434pub struct BenchmarkResult {
435 pub name: String,
436 pub iterations: u32,
437 pub total_ms: f64,
438 pub avg_ms: f64,
439 pub ops_per_sec: f64,
440 pub min_ms: f64,
441 pub max_ms: f64,
442}
443
444#[napi]
446pub async fn benchmark_attention(
447 attention_type: AttentionType,
448 dim: u32,
449 seq_length: u32,
450 iterations: u32,
451) -> Result<BenchmarkResult> {
452 let dim_usize = dim as usize;
453 let seq_usize = seq_length as usize;
454 let iter_usize = iterations as usize;
455
456 let result = tokio::task::spawn_blocking(move || {
457 let query: Vec<f32> = (0..dim_usize).map(|i| (i as f32 * 0.01).sin()).collect();
459 let keys: Vec<Vec<f32>> = (0..seq_usize)
460 .map(|j| {
461 (0..dim_usize)
462 .map(|i| ((i + j) as f32 * 0.01).cos())
463 .collect()
464 })
465 .collect();
466 let values: Vec<Vec<f32>> = keys.clone();
467
468 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
469 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
470
471 let name = match attention_type {
472 AttentionType::ScaledDotProduct => "ScaledDotProduct",
473 AttentionType::Flash => "Flash",
474 AttentionType::Linear => "Linear",
475 AttentionType::LocalGlobal => "LocalGlobal",
476 AttentionType::Hyperbolic => "Hyperbolic",
477 }
478 .to_string();
479
480 let mut times: Vec<f64> = Vec::with_capacity(iter_usize);
481
482 for _ in 0..iter_usize {
483 let start = std::time::Instant::now();
484
485 match attention_type {
486 AttentionType::ScaledDotProduct => {
487 let attention = ScaledDotProductAttention::new(dim_usize);
488 let _ = attention.compute(&query, &keys_refs, &values_refs);
489 }
490 AttentionType::Flash => {
491 let attention = FlashAttention::new(dim_usize, 64);
492 let _ = attention.compute(&query, &keys_refs, &values_refs);
493 }
494 AttentionType::Linear => {
495 let attention = LinearAttention::new(dim_usize, 64);
496 let _ = attention.compute(&query, &keys_refs, &values_refs);
497 }
498 AttentionType::LocalGlobal => {
499 let attention = LocalGlobalAttention::new(dim_usize, 128, 8);
500 let _ = attention.compute(&query, &keys_refs, &values_refs);
501 }
502 AttentionType::Hyperbolic => {
503 let config = HyperbolicAttentionConfig {
504 dim: dim_usize,
505 curvature: 1.0,
506 ..Default::default()
507 };
508 let attention = HyperbolicAttention::new(config);
509 let _ = attention.compute(&query, &keys_refs, &values_refs);
510 }
511 }
512
513 times.push(start.elapsed().as_secs_f64() * 1000.0);
514 }
515
516 let total_ms: f64 = times.iter().sum();
517 let avg_ms = total_ms / iter_usize as f64;
518 let min_ms = times.iter().copied().fold(f64::INFINITY, f64::min);
519 let max_ms = times.iter().copied().fold(f64::NEG_INFINITY, f64::max);
520 let ops_per_sec = 1000.0 / avg_ms;
521
522 BenchmarkResult {
523 name,
524 iterations: iterations,
525 total_ms,
526 avg_ms,
527 ops_per_sec,
528 min_ms,
529 max_ms,
530 }
531 })
532 .await
533 .map_err(|e| Error::from_reason(e.to_string()))?;
534
535 Ok(result)
536}