1use wasm_bindgen::prelude::*;
11use serde::{Serialize, Deserialize};
12
13#[cfg(feature = "wee_alloc")]
15#[global_allocator]
16static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
17
18#[derive(Serialize, Deserialize, Clone, Copy)]
20struct WasmNode {
21 id: u32,
22 op: u8,
23 cost: f32,
24}
25
26#[wasm_bindgen]
29pub struct WasmDag {
30 nodes: Vec<WasmNode>,
31 edges: Vec<(u32, u32)>,
32}
33
34#[wasm_bindgen]
35impl WasmDag {
36 #[wasm_bindgen(constructor)]
38 pub fn new() -> Self {
39 Self {
40 nodes: Vec::new(),
41 edges: Vec::new(),
42 }
43 }
44
45 #[inline]
48 pub fn add_node(&mut self, op: u8, cost: f32) -> u32 {
49 let id = self.nodes.len() as u32;
50 self.nodes.push(WasmNode { id, op, cost });
51 id
52 }
53
54 #[inline]
57 pub fn add_edge(&mut self, from: u32, to: u32) -> bool {
58 if from >= self.nodes.len() as u32 || to >= self.nodes.len() as u32 {
60 return false;
61 }
62
63 if self.has_path(to, from) {
65 return false;
66 }
67
68 self.edges.push((from, to));
69 true
70 }
71
72 #[inline]
74 pub fn node_count(&self) -> u32 {
75 self.nodes.len() as u32
76 }
77
78 #[inline]
80 pub fn edge_count(&self) -> u32 {
81 self.edges.len() as u32
82 }
83
84 pub fn topo_sort(&self) -> Vec<u32> {
87 let n = self.nodes.len();
88 let mut in_degree = vec![0u32; n];
89
90 for &(_, to) in &self.edges {
92 in_degree[to as usize] += 1;
93 }
94
95 let mut queue: Vec<u32> = (0..n as u32)
97 .filter(|&i| in_degree[i as usize] == 0)
98 .collect();
99
100 let mut result = Vec::with_capacity(n);
101
102 while let Some(node) = queue.pop() {
103 result.push(node);
104
105 for &(from, to) in &self.edges {
107 if from == node {
108 in_degree[to as usize] -= 1;
109 if in_degree[to as usize] == 0 {
110 queue.push(to);
111 }
112 }
113 }
114 }
115
116 result
117 }
118
119 pub fn critical_path(&self) -> JsValue {
122 let topo = self.topo_sort();
123 let n = self.nodes.len();
124
125 let mut dist = vec![(0.0f32, u32::MAX); n];
127
128 for &node in &topo {
130 if !self.has_incoming(node) {
131 dist[node as usize] = (self.nodes[node as usize].cost, u32::MAX);
132 }
133 }
134
135 for &from in &topo {
137 let from_cost = dist[from as usize].0;
138
139 for &(f, to) in &self.edges {
140 if f == from {
141 let new_cost = from_cost + self.nodes[to as usize].cost;
142 if new_cost > dist[to as usize].0 {
143 dist[to as usize] = (new_cost, from);
144 }
145 }
146 }
147 }
148
149 let (max_idx, (max_cost, _)) = dist.iter()
151 .enumerate()
152 .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).unwrap())
153 .unwrap();
154
155 let mut path = Vec::new();
157 let mut current = max_idx as u32;
158
159 while current != u32::MAX {
160 path.push(current);
161 current = dist[current as usize].1;
162 }
163
164 path.reverse();
165
166 let path_str = path.iter()
168 .map(|id| id.to_string())
169 .collect::<Vec<_>>()
170 .join(",");
171
172 let json = format!("{{\"path\":[{}],\"cost\":{}}}", path_str, max_cost);
173 JsValue::from_str(&json)
174 }
175
176 pub fn attention(&self, mechanism: u8) -> Vec<f32> {
179 compute_attention(self, mechanism)
180 }
181
182 pub fn to_bytes(&self) -> Vec<u8> {
184 #[derive(Serialize)]
185 struct SerDag<'a> {
186 nodes: &'a [WasmNode],
187 edges: &'a [(u32, u32)],
188 }
189
190 let data = SerDag {
191 nodes: &self.nodes,
192 edges: &self.edges,
193 };
194
195 bincode::serialize(&data).unwrap_or_default()
196 }
197
198 pub fn from_bytes(data: &[u8]) -> Result<WasmDag, JsValue> {
200 #[derive(Deserialize)]
201 struct SerDag {
202 nodes: Vec<WasmNode>,
203 edges: Vec<(u32, u32)>,
204 }
205
206 bincode::deserialize::<SerDag>(data)
207 .map(|d| WasmDag {
208 nodes: d.nodes,
209 edges: d.edges,
210 })
211 .map_err(|e| JsValue::from_str(&format!("Deserialize error: {}", e)))
212 }
213
214 pub fn to_json(&self) -> String {
216 #[derive(Serialize)]
217 struct SerDag<'a> {
218 nodes: &'a [WasmNode],
219 edges: &'a [(u32, u32)],
220 }
221
222 let data = SerDag {
223 nodes: &self.nodes,
224 edges: &self.edges,
225 };
226
227 serde_json::to_string(&data).unwrap_or_else(|_| String::from("{}"))
228 }
229
230 pub fn from_json(json: &str) -> Result<WasmDag, JsValue> {
232 #[derive(Deserialize)]
233 struct SerDag {
234 nodes: Vec<WasmNode>,
235 edges: Vec<(u32, u32)>,
236 }
237
238 serde_json::from_str::<SerDag>(json)
239 .map(|d| WasmDag {
240 nodes: d.nodes,
241 edges: d.edges,
242 })
243 .map_err(|e| JsValue::from_str(&format!("JSON error: {}", e)))
244 }
245}
246
247impl WasmDag {
249 #[inline(always)]
251 fn has_path(&self, from: u32, to: u32) -> bool {
252 if from == to {
253 return true;
254 }
255
256 let mut visited = vec![false; self.nodes.len()];
257 let mut stack = Vec::with_capacity(8);
258 stack.push(from);
259
260 while let Some(node) = stack.pop() {
261 if visited[node as usize] {
262 continue;
263 }
264 visited[node as usize] = true;
265
266 for &(f, t) in &self.edges {
267 if f == node {
268 if t == to {
269 return true;
270 }
271 stack.push(t);
272 }
273 }
274 }
275
276 false
277 }
278
279 #[inline(always)]
281 fn has_incoming(&self, node: u32) -> bool {
282 self.edges.iter().any(|&(_, to)| to == node)
283 }
284}
285
286#[inline]
293fn compute_attention(dag: &WasmDag, mechanism: u8) -> Vec<f32> {
294 let n = dag.nodes.len();
295
296 match mechanism {
297 0 => {
298 let topo = dag.topo_sort();
300 let mut scores = vec![0.0f32; n];
301
302 for (i, &node_id) in topo.iter().enumerate() {
303 scores[node_id as usize] = 1.0 - (i as f32 / n as f32);
304 }
305
306 scores
307 }
308
309 1 => {
310 let topo = dag.topo_sort();
312 let mut dist = vec![0.0f32; n];
313
314 for &from in &topo {
316 for &(f, to) in &dag.edges {
317 if f == from {
318 let new_dist = dist[from as usize] + dag.nodes[to as usize].cost;
319 if new_dist > dist[to as usize] {
320 dist[to as usize] = new_dist;
321 }
322 }
323 }
324 }
325
326 let max_dist = dist.iter().fold(0.0f32, |a, &b| a.max(b));
328 if max_dist > 0.0 {
329 dist.iter_mut().for_each(|d| *d /= max_dist);
330 }
331
332 dist
333 }
334
335 _ => {
336 vec![1.0f32 / n as f32; n]
338 }
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_basic_dag() {
348 let mut dag = WasmDag::new();
349
350 let n0 = dag.add_node(1, 1.0);
351 let n1 = dag.add_node(2, 2.0);
352 let n2 = dag.add_node(3, 3.0);
353
354 assert_eq!(dag.node_count(), 3);
355
356 assert!(dag.add_edge(n0, n1));
357 assert!(dag.add_edge(n1, n2));
358 assert_eq!(dag.edge_count(), 2);
359
360 assert!(!dag.add_edge(n2, n0));
362 }
363
364 #[test]
365 fn test_topo_sort() {
366 let mut dag = WasmDag::new();
367
368 let n0 = dag.add_node(0, 1.0);
369 let n1 = dag.add_node(1, 1.0);
370 let n2 = dag.add_node(2, 1.0);
371
372 dag.add_edge(n0, n1);
373 dag.add_edge(n1, n2);
374
375 let topo = dag.topo_sort();
376 assert_eq!(topo, vec![0, 1, 2]);
377 }
378
379 #[test]
380 fn test_attention() {
381 let mut dag = WasmDag::new();
382
383 dag.add_node(0, 1.0);
384 dag.add_node(1, 2.0);
385 dag.add_node(2, 3.0);
386
387 let uniform = dag.attention(2);
389 assert_eq!(uniform.len(), 3);
390 assert!((uniform[0] - 0.333).abs() < 0.01);
391
392 let topo = dag.attention(0);
394 assert_eq!(topo.len(), 3);
395 }
396
397 #[test]
398 fn test_serialization() {
399 let mut dag = WasmDag::new();
400
401 dag.add_node(1, 1.5);
402 dag.add_node(2, 2.5);
403 dag.add_edge(0, 1);
404
405 let bytes = dag.to_bytes();
407 let restored = WasmDag::from_bytes(&bytes).unwrap();
408 assert_eq!(restored.node_count(), 2);
409 assert_eq!(restored.edge_count(), 1);
410
411 let json = dag.to_json();
413 let from_json = WasmDag::from_json(&json).unwrap();
414 assert_eq!(from_json.node_count(), 2);
415 }
416}