ruvector_dag_wasm/
lib.rs

1//! Minimal WASM DAG library optimized for browser and embedded systems
2//!
3//! Size optimizations:
4//! - u8/u32/f32 instead of larger types
5//! - Inline hot paths
6//! - Minimal error handling
7//! - No string operations in critical paths
8//! - Optional wee_alloc for smaller binary
9
10use wasm_bindgen::prelude::*;
11use serde::{Serialize, Deserialize};
12
13// Use wee_alloc for smaller WASM binary (~10KB reduction)
14#[cfg(feature = "wee_alloc")]
15#[global_allocator]
16static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
17
18/// Minimal DAG node - 9 bytes (u32 + u8 + f32)
19#[derive(Serialize, Deserialize, Clone, Copy)]
20struct WasmNode {
21    id: u32,
22    op: u8,
23    cost: f32,
24}
25
26/// Minimal DAG structure for WASM
27/// Self-contained with no external dependencies beyond wasm-bindgen
28#[wasm_bindgen]
29pub struct WasmDag {
30    nodes: Vec<WasmNode>,
31    edges: Vec<(u32, u32)>,
32}
33
34#[wasm_bindgen]
35impl WasmDag {
36    /// Create new empty DAG
37    #[wasm_bindgen(constructor)]
38    pub fn new() -> Self {
39        Self {
40            nodes: Vec::new(),
41            edges: Vec::new(),
42        }
43    }
44
45    /// Add a node with operator type and cost
46    /// Returns node ID
47    #[inline]
48    pub fn add_node(&mut self, op: u8, cost: f32) -> u32 {
49        let id = self.nodes.len() as u32;
50        self.nodes.push(WasmNode { id, op, cost });
51        id
52    }
53
54    /// Add edge from -> to
55    /// Returns false if creates cycle (simple check)
56    #[inline]
57    pub fn add_edge(&mut self, from: u32, to: u32) -> bool {
58        // Basic validation - nodes must exist
59        if from >= self.nodes.len() as u32 || to >= self.nodes.len() as u32 {
60            return false;
61        }
62
63        // Simple cycle check: to must not reach from
64        if self.has_path(to, from) {
65            return false;
66        }
67
68        self.edges.push((from, to));
69        true
70    }
71
72    /// Get number of nodes
73    #[inline]
74    pub fn node_count(&self) -> u32 {
75        self.nodes.len() as u32
76    }
77
78    /// Get number of edges
79    #[inline]
80    pub fn edge_count(&self) -> u32 {
81        self.edges.len() as u32
82    }
83
84    /// Topological sort using Kahn's algorithm
85    /// Returns node IDs in topological order
86    pub fn topo_sort(&self) -> Vec<u32> {
87        let n = self.nodes.len();
88        let mut in_degree = vec![0u32; n];
89
90        // Calculate in-degrees
91        for &(_, to) in &self.edges {
92            in_degree[to as usize] += 1;
93        }
94
95        // Find nodes with no incoming edges
96        let mut queue: Vec<u32> = (0..n as u32)
97            .filter(|&i| in_degree[i as usize] == 0)
98            .collect();
99
100        let mut result = Vec::with_capacity(n);
101
102        while let Some(node) = queue.pop() {
103            result.push(node);
104
105            // Reduce in-degree for neighbors
106            for &(from, to) in &self.edges {
107                if from == node {
108                    in_degree[to as usize] -= 1;
109                    if in_degree[to as usize] == 0 {
110                        queue.push(to);
111                    }
112                }
113            }
114        }
115
116        result
117    }
118
119    /// Find critical path (longest path by cost)
120    /// Returns JSON: {"path": [node_ids], "cost": total}
121    pub fn critical_path(&self) -> JsValue {
122        let topo = self.topo_sort();
123        let n = self.nodes.len();
124
125        // dist[i] = (max_cost_to_i, predecessor)
126        let mut dist = vec![(0.0f32, u32::MAX); n];
127
128        // Initialize starting nodes
129        for &node in &topo {
130            if !self.has_incoming(node) {
131                dist[node as usize] = (self.nodes[node as usize].cost, u32::MAX);
132            }
133        }
134
135        // Relax edges in topological order
136        for &from in &topo {
137            let from_cost = dist[from as usize].0;
138
139            for &(f, to) in &self.edges {
140                if f == from {
141                    let new_cost = from_cost + self.nodes[to as usize].cost;
142                    if new_cost > dist[to as usize].0 {
143                        dist[to as usize] = (new_cost, from);
144                    }
145                }
146            }
147        }
148
149        // Find node with maximum cost
150        let (max_idx, (max_cost, _)) = dist.iter()
151            .enumerate()
152            .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).unwrap())
153            .unwrap();
154
155        // Backtrack to build path
156        let mut path = Vec::new();
157        let mut current = max_idx as u32;
158
159        while current != u32::MAX {
160            path.push(current);
161            current = dist[current as usize].1;
162        }
163
164        path.reverse();
165
166        // Convert to JSON manually to avoid serde_json dependency
167        let path_str = path.iter()
168            .map(|id| id.to_string())
169            .collect::<Vec<_>>()
170            .join(",");
171
172        let json = format!("{{\"path\":[{}],\"cost\":{}}}", path_str, max_cost);
173        JsValue::from_str(&json)
174    }
175
176    /// Compute attention scores for nodes
177    /// mechanism: 0=topological, 1=critical_path, 2=uniform
178    pub fn attention(&self, mechanism: u8) -> Vec<f32> {
179        compute_attention(self, mechanism)
180    }
181
182    /// Serialize to bytes (bincode format)
183    pub fn to_bytes(&self) -> Vec<u8> {
184        #[derive(Serialize)]
185        struct SerDag<'a> {
186            nodes: &'a [WasmNode],
187            edges: &'a [(u32, u32)],
188        }
189
190        let data = SerDag {
191            nodes: &self.nodes,
192            edges: &self.edges,
193        };
194
195        bincode::serialize(&data).unwrap_or_default()
196    }
197
198    /// Deserialize from bytes
199    pub fn from_bytes(data: &[u8]) -> Result<WasmDag, JsValue> {
200        #[derive(Deserialize)]
201        struct SerDag {
202            nodes: Vec<WasmNode>,
203            edges: Vec<(u32, u32)>,
204        }
205
206        bincode::deserialize::<SerDag>(data)
207            .map(|d| WasmDag {
208                nodes: d.nodes,
209                edges: d.edges,
210            })
211            .map_err(|e| JsValue::from_str(&format!("Deserialize error: {}", e)))
212    }
213
214    /// Serialize to JSON
215    pub fn to_json(&self) -> String {
216        #[derive(Serialize)]
217        struct SerDag<'a> {
218            nodes: &'a [WasmNode],
219            edges: &'a [(u32, u32)],
220        }
221
222        let data = SerDag {
223            nodes: &self.nodes,
224            edges: &self.edges,
225        };
226
227        serde_json::to_string(&data).unwrap_or_else(|_| String::from("{}"))
228    }
229
230    /// Deserialize from JSON
231    pub fn from_json(json: &str) -> Result<WasmDag, JsValue> {
232        #[derive(Deserialize)]
233        struct SerDag {
234            nodes: Vec<WasmNode>,
235            edges: Vec<(u32, u32)>,
236        }
237
238        serde_json::from_str::<SerDag>(json)
239            .map(|d| WasmDag {
240                nodes: d.nodes,
241                edges: d.edges,
242            })
243            .map_err(|e| JsValue::from_str(&format!("JSON error: {}", e)))
244    }
245}
246
247// Internal helper methods (not exported to WASM)
248impl WasmDag {
249    /// Check if there's a path from 'from' to 'to' (for cycle detection)
250    #[inline(always)]
251    fn has_path(&self, from: u32, to: u32) -> bool {
252        if from == to {
253            return true;
254        }
255
256        let mut visited = vec![false; self.nodes.len()];
257        let mut stack = Vec::with_capacity(8);
258        stack.push(from);
259
260        while let Some(node) = stack.pop() {
261            if visited[node as usize] {
262                continue;
263            }
264            visited[node as usize] = true;
265
266            for &(f, t) in &self.edges {
267                if f == node {
268                    if t == to {
269                        return true;
270                    }
271                    stack.push(t);
272                }
273            }
274        }
275
276        false
277    }
278
279    /// Check if node has incoming edges
280    #[inline(always)]
281    fn has_incoming(&self, node: u32) -> bool {
282        self.edges.iter().any(|&(_, to)| to == node)
283    }
284}
285
286/// Compute attention scores based on mechanism
287///
288/// Mechanisms:
289/// - 0: Topological (position in topo sort)
290/// - 1: Critical path (distance from critical path)
291/// - 2: Uniform (all equal)
292#[inline]
293fn compute_attention(dag: &WasmDag, mechanism: u8) -> Vec<f32> {
294    let n = dag.nodes.len();
295
296    match mechanism {
297        0 => {
298            // Topological attention - earlier nodes get higher scores
299            let topo = dag.topo_sort();
300            let mut scores = vec![0.0f32; n];
301
302            for (i, &node_id) in topo.iter().enumerate() {
303                scores[node_id as usize] = 1.0 - (i as f32 / n as f32);
304            }
305
306            scores
307        }
308
309        1 => {
310            // Critical path attention - nodes on/near critical path get higher scores
311            let topo = dag.topo_sort();
312            let mut dist = vec![0.0f32; n];
313
314            // Forward pass - compute longest path to each node
315            for &from in &topo {
316                for &(f, to) in &dag.edges {
317                    if f == from {
318                        let new_dist = dist[from as usize] + dag.nodes[to as usize].cost;
319                        if new_dist > dist[to as usize] {
320                            dist[to as usize] = new_dist;
321                        }
322                    }
323                }
324            }
325
326            // Normalize to [0, 1]
327            let max_dist = dist.iter().fold(0.0f32, |a, &b| a.max(b));
328            if max_dist > 0.0 {
329                dist.iter_mut().for_each(|d| *d /= max_dist);
330            }
331
332            dist
333        }
334
335        _ => {
336            // Uniform attention
337            vec![1.0f32 / n as f32; n]
338        }
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_basic_dag() {
348        let mut dag = WasmDag::new();
349
350        let n0 = dag.add_node(1, 1.0);
351        let n1 = dag.add_node(2, 2.0);
352        let n2 = dag.add_node(3, 3.0);
353
354        assert_eq!(dag.node_count(), 3);
355
356        assert!(dag.add_edge(n0, n1));
357        assert!(dag.add_edge(n1, n2));
358        assert_eq!(dag.edge_count(), 2);
359
360        // Should detect cycle
361        assert!(!dag.add_edge(n2, n0));
362    }
363
364    #[test]
365    fn test_topo_sort() {
366        let mut dag = WasmDag::new();
367
368        let n0 = dag.add_node(0, 1.0);
369        let n1 = dag.add_node(1, 1.0);
370        let n2 = dag.add_node(2, 1.0);
371
372        dag.add_edge(n0, n1);
373        dag.add_edge(n1, n2);
374
375        let topo = dag.topo_sort();
376        assert_eq!(topo, vec![0, 1, 2]);
377    }
378
379    #[test]
380    fn test_attention() {
381        let mut dag = WasmDag::new();
382
383        dag.add_node(0, 1.0);
384        dag.add_node(1, 2.0);
385        dag.add_node(2, 3.0);
386
387        // Uniform
388        let uniform = dag.attention(2);
389        assert_eq!(uniform.len(), 3);
390        assert!((uniform[0] - 0.333).abs() < 0.01);
391
392        // Topological
393        let topo = dag.attention(0);
394        assert_eq!(topo.len(), 3);
395    }
396
397    #[test]
398    fn test_serialization() {
399        let mut dag = WasmDag::new();
400
401        dag.add_node(1, 1.5);
402        dag.add_node(2, 2.5);
403        dag.add_edge(0, 1);
404
405        // Binary
406        let bytes = dag.to_bytes();
407        let restored = WasmDag::from_bytes(&bytes).unwrap();
408        assert_eq!(restored.node_count(), 2);
409        assert_eq!(restored.edge_count(), 1);
410
411        // JSON
412        let json = dag.to_json();
413        let from_json = WasmDag::from_json(&json).unwrap();
414        assert_eq!(from_json.node_count(), 2);
415    }
416}