1use crate::error::{NeuralError, Result};
13use oxicode::{config as oxicode_config, serde as oxicode_serde};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct WasmTensor {
32 data: Vec<f32>,
33 shape: Vec<usize>,
34}
35
36impl WasmTensor {
37 pub fn from_vec(data: Vec<f32>, shape: Vec<usize>) -> Self {
39 Self { data, shape }
40 }
41
42 pub fn zeros(shape: Vec<usize>) -> Self {
44 let n: usize = shape.iter().product();
45 Self {
46 data: vec![0.0_f32; n],
47 shape,
48 }
49 }
50
51 pub fn shape(&self) -> &[usize] {
53 &self.shape
54 }
55
56 pub fn numel(&self) -> usize {
58 self.data.len()
59 }
60
61 pub fn data(&self) -> &[f32] {
63 &self.data
64 }
65
66 pub fn data_mut(&mut self) -> &mut Vec<f32> {
68 &mut self.data
69 }
70
71 pub fn into_data(self) -> Vec<f32> {
73 self.data
74 }
75
76 pub fn batch_size(&self) -> usize {
78 self.shape.first().copied().unwrap_or(1)
79 }
80
81 pub fn reshape(mut self, new_shape: Vec<usize>) -> Result<Self> {
83 let n: usize = new_shape.iter().product();
84 if n != self.data.len() {
85 return Err(NeuralError::ShapeMismatch(format!(
86 "WasmTensor::reshape: old numel={} new numel={n}",
87 self.data.len()
88 )));
89 }
90 self.shape = new_shape;
91 Ok(self)
92 }
93
94 pub fn relu_inplace(&mut self) {
96 for v in self.data.iter_mut() {
97 if *v < 0.0 {
98 *v = 0.0;
99 }
100 }
101 }
102
103 pub fn sigmoid_inplace(&mut self) {
105 for v in self.data.iter_mut() {
106 *v = 1.0 / (1.0 + (-*v).exp());
107 }
108 }
109
110 pub fn tanh_inplace(&mut self) {
112 for v in self.data.iter_mut() {
113 *v = v.tanh();
114 }
115 }
116
117 pub fn softmax_inplace(&mut self) {
119 if self.shape.is_empty() || self.data.is_empty() {
120 return;
121 }
122 let last_dim = *self.shape.last().unwrap_or(&1);
123 if last_dim == 0 {
124 return;
125 }
126 let batch = self.data.len() / last_dim;
127 for b in 0..batch {
128 let slice = &mut self.data[b * last_dim..(b + 1) * last_dim];
129 let max = slice.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
130 let mut sum = 0.0_f32;
131 for v in slice.iter_mut() {
132 *v = (*v - max).exp();
133 sum += *v;
134 }
135 if sum > 0.0 {
136 for v in slice.iter_mut() {
137 *v /= sum;
138 }
139 }
140 }
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
150pub enum WasmLayer {
151 Dense {
153 in_features: usize,
154 out_features: usize,
155 weights: Vec<f32>,
157 bias: Vec<f32>,
158 },
159 ReLU,
161 Sigmoid,
163 Tanh,
165 Softmax,
167 Dropout { rate: f32 },
169 LayerNorm {
171 normalized_shape: usize,
172 weight: Vec<f32>,
173 bias: Vec<f32>,
174 eps: f32,
175 },
176 Flatten,
178}
179
180impl WasmLayer {
181 pub fn type_name(&self) -> &str {
183 match self {
184 WasmLayer::Dense { .. } => "Dense",
185 WasmLayer::ReLU => "ReLU",
186 WasmLayer::Sigmoid => "Sigmoid",
187 WasmLayer::Tanh => "Tanh",
188 WasmLayer::Softmax => "Softmax",
189 WasmLayer::Dropout { .. } => "Dropout",
190 WasmLayer::LayerNorm { .. } => "LayerNorm",
191 WasmLayer::Flatten => "Flatten",
192 }
193 }
194
195 pub fn parameter_count(&self) -> usize {
197 match self {
198 WasmLayer::Dense { weights, bias, .. } => weights.len() + bias.len(),
199 WasmLayer::LayerNorm { weight, bias, .. } => weight.len() + bias.len(),
200 _ => 0,
201 }
202 }
203
204 pub fn forward(&self, input: WasmTensor) -> Result<WasmTensor> {
206 match self {
207 WasmLayer::Dense {
208 in_features,
209 out_features,
210 weights,
211 bias,
212 } => dense_forward(input, *in_features, *out_features, weights, bias),
213 WasmLayer::ReLU => {
214 let mut t = input;
215 t.relu_inplace();
216 Ok(t)
217 }
218 WasmLayer::Sigmoid => {
219 let mut t = input;
220 t.sigmoid_inplace();
221 Ok(t)
222 }
223 WasmLayer::Tanh => {
224 let mut t = input;
225 t.tanh_inplace();
226 Ok(t)
227 }
228 WasmLayer::Softmax => {
229 let mut t = input;
230 t.softmax_inplace();
231 Ok(t)
232 }
233 WasmLayer::Dropout { .. } => Ok(input),
234 WasmLayer::LayerNorm {
235 normalized_shape,
236 weight,
237 bias,
238 eps,
239 } => layer_norm_forward(input, *normalized_shape, weight, bias, *eps),
240 WasmLayer::Flatten => {
241 let batch = input.batch_size();
242 let rest = input.numel() / batch.max(1);
243 input.reshape(vec![batch, rest])
244 }
245 }
246 }
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct WasmNeuralNet {
276 name: String,
277 layers: Vec<WasmLayer>,
278 input_shape: Vec<usize>,
279 metadata: HashMap<String, String>,
280}
281
282impl WasmNeuralNet {
283 pub fn new(name: impl Into<String>) -> Self {
285 Self {
286 name: name.into(),
287 layers: Vec::new(),
288 input_shape: Vec::new(),
289 metadata: HashMap::new(),
290 }
291 }
292
293 pub fn name(&self) -> &str {
295 &self.name
296 }
297
298 pub fn num_layers(&self) -> usize {
300 self.layers.len()
301 }
302
303 pub fn layers(&self) -> &[WasmLayer] {
305 &self.layers
306 }
307
308 pub fn input_shape(&self) -> &[usize] {
310 &self.input_shape
311 }
312
313 pub fn set_input_shape(&mut self, shape: Vec<usize>) {
315 self.input_shape = shape;
316 }
317
318 pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
320 self.metadata.insert(key.into(), value.into());
321 }
322
323 pub fn get_metadata(&self, key: &str) -> Option<&str> {
325 self.metadata.get(key).map(|s| s.as_str())
326 }
327
328 pub fn add_layer(&mut self, layer: WasmLayer) {
330 self.layers.push(layer);
331 }
332
333 pub fn total_parameters(&self) -> usize {
335 self.layers.iter().map(|l| l.parameter_count()).sum()
336 }
337
338 pub fn forward(&self, input: WasmTensor) -> Result<WasmTensor> {
340 let mut x = input;
341 for layer in &self.layers {
342 x = layer.forward(x)?;
343 }
344 Ok(x)
345 }
346
347 pub fn to_bytes(&self) -> Result<Vec<u8>> {
349 let cfg = oxicode_config::standard();
350 oxicode_serde::encode_to_vec(self, cfg)
351 .map_err(|e| NeuralError::SerializationError(format!("oxicode encode: {e}")))
352 }
353
354 pub fn from_bytes(data: &[u8]) -> Result<Self> {
356 let cfg = oxicode_config::standard();
357 let (net, _) = oxicode_serde::decode_from_slice::<Self, _>(data, cfg)
358 .map_err(|e| NeuralError::DeserializationError(format!("oxicode decode: {e}")))?;
359 Ok(net)
360 }
361
362 pub fn to_json(&self) -> Result<String> {
364 serde_json::to_string(self)
365 .map_err(|e| NeuralError::SerializationError(format!("json encode: {e}")))
366 }
367
368 pub fn from_json(json: &str) -> Result<Self> {
370 serde_json::from_str(json)
371 .map_err(|e| NeuralError::DeserializationError(format!("json decode: {e}")))
372 }
373
374 pub fn summary(&self) -> String {
376 let mut s = format!("WasmNeuralNet '{}'\n", self.name);
377 for (i, layer) in self.layers.iter().enumerate() {
378 s.push_str(&format!(" [{i}] {}\n", layer.type_name()));
379 }
380 s.push_str(&format!("Total parameters: {}\n", self.total_parameters()));
381 s
382 }
383}
384
385fn dense_forward(
390 input: WasmTensor,
391 in_features: usize,
392 out_features: usize,
393 weights: &[f32],
394 bias: &[f32],
395) -> Result<WasmTensor> {
396 let shape = input.shape().to_vec();
397 if shape.len() < 2 {
398 return Err(NeuralError::ShapeMismatch(
399 "Dense: input must be at least 2-D [batch, features]".to_string(),
400 ));
401 }
402 let feat_dim = *shape.last().unwrap_or(&0);
403 if feat_dim != in_features {
404 return Err(NeuralError::ShapeMismatch(format!(
405 "Dense: expected in_features={in_features}, got {feat_dim}"
406 )));
407 }
408 if weights.len() != out_features * in_features {
409 return Err(NeuralError::ShapeMismatch(format!(
410 "Dense: weights len {} != {out_features}×{in_features}",
411 weights.len()
412 )));
413 }
414 if bias.len() != out_features {
415 return Err(NeuralError::ShapeMismatch(format!(
416 "Dense: bias len {} != {out_features}",
417 bias.len()
418 )));
419 }
420 let batch: usize = shape[..shape.len() - 1].iter().product::<usize>().max(1);
421 let input_data = input.data();
422 let mut output = vec![0.0_f32; batch * out_features];
423 for b in 0..batch {
424 for o in 0..out_features {
425 let mut acc = bias[o];
426 for i in 0..in_features {
427 acc += input_data[b * in_features + i] * weights[o * in_features + i];
428 }
429 output[b * out_features + o] = acc;
430 }
431 }
432 let mut out_shape = shape[..shape.len() - 1].to_vec();
433 out_shape.push(out_features);
434 Ok(WasmTensor::from_vec(output, out_shape))
435}
436
437fn layer_norm_forward(
438 input: WasmTensor,
439 normalized_shape: usize,
440 weight: &[f32],
441 bias: &[f32],
442 eps: f32,
443) -> Result<WasmTensor> {
444 let shape = input.shape().to_vec();
445 let feat_dim = *shape.last().unwrap_or(&0);
446 if feat_dim != normalized_shape {
447 return Err(NeuralError::ShapeMismatch(format!(
448 "LayerNorm: expected {normalized_shape}, got {feat_dim}"
449 )));
450 }
451 let batch: usize = (input.numel() / feat_dim.max(1)).max(1);
452 let data = input.data().to_vec();
453 let mut out_data = vec![0.0_f32; data.len()];
454 for b in 0..batch {
455 let slice = &data[b * feat_dim..(b + 1) * feat_dim];
456 let mean: f32 = slice.iter().sum::<f32>() / feat_dim as f32;
457 let var: f32 = slice.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / feat_dim as f32;
458 let std_inv = 1.0 / (var + eps).sqrt();
459 for (j, &v) in slice.iter().enumerate() {
460 out_data[b * feat_dim + j] = (v - mean) * std_inv * weight[j] + bias[j];
461 }
462 }
463 Ok(WasmTensor::from_vec(out_data, shape))
464}
465
466#[cfg(test)]
471mod tests {
472 use super::*;
473
474 fn make_tiny_net() -> WasmNeuralNet {
475 let mut net = WasmNeuralNet::new("tiny");
476 net.add_layer(WasmLayer::Dense {
477 in_features: 2,
478 out_features: 2,
479 weights: vec![1.0_f32, 0.0, 0.0, 1.0], bias: vec![0.0, 0.0],
481 });
482 net.add_layer(WasmLayer::ReLU);
483 net.add_layer(WasmLayer::Dense {
484 in_features: 2,
485 out_features: 2,
486 weights: vec![0.5_f32, 0.5, 0.5, 0.5],
487 bias: vec![0.0, 0.0],
488 });
489 net
490 }
491
492 #[test]
493 fn test_wasm_tensor_creation() {
494 let t = WasmTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
495 assert_eq!(t.shape(), &[2, 2]);
496 assert_eq!(t.numel(), 4);
497 }
498
499 #[test]
500 fn test_wasm_tensor_reshape_ok() {
501 let t = WasmTensor::from_vec(vec![1.0_f32; 6], vec![2, 3]);
502 let t2 = t.reshape(vec![3, 2]).expect("ok");
503 assert_eq!(t2.shape(), &[3, 2]);
504 }
505
506 #[test]
507 fn test_wasm_tensor_reshape_err() {
508 let t = WasmTensor::from_vec(vec![1.0_f32; 6], vec![2, 3]);
509 assert!(t.reshape(vec![4, 2]).is_err());
510 }
511
512 #[test]
513 fn test_relu_inplace() {
514 let mut t = WasmTensor::from_vec(vec![-1.0_f32, 2.0, -3.0, 4.0], vec![1, 4]);
515 t.relu_inplace();
516 assert_eq!(t.data(), &[0.0, 2.0, 0.0, 4.0]);
517 }
518
519 #[test]
520 fn test_sigmoid_range() {
521 let mut t = WasmTensor::from_vec(vec![-100.0_f32, 0.0, 100.0], vec![1, 3]);
522 t.sigmoid_inplace();
523 let d = t.data();
524 assert!(d[0] >= 0.0 && d[0] < 0.01);
525 assert!((d[1] - 0.5).abs() < 1e-4);
526 assert!(d[2] > 0.99 && d[2] <= 1.0);
527 }
528
529 #[test]
530 fn test_softmax_sums_to_one() {
531 let mut t = WasmTensor::from_vec(vec![1.0_f32, 2.0, 3.0], vec![1, 3]);
532 t.softmax_inplace();
533 let sum: f32 = t.data().iter().sum();
534 assert!((sum - 1.0).abs() < 1e-5, "sum={sum}");
535 }
536
537 #[test]
538 fn test_dense_identity() {
539 let layer = WasmLayer::Dense {
540 in_features: 2,
541 out_features: 2,
542 weights: vec![1.0_f32, 0.0, 0.0, 1.0],
543 bias: vec![0.0, 0.0],
544 };
545 let input = WasmTensor::from_vec(vec![3.0_f32, 4.0], vec![1, 2]);
546 let out = layer.forward(input).expect("ok");
547 assert!((out.data()[0] - 3.0).abs() < 1e-5);
548 assert!((out.data()[1] - 4.0).abs() < 1e-5);
549 }
550
551 #[test]
552 fn test_dense_shape_mismatch_err() {
553 let layer = WasmLayer::Dense {
554 in_features: 3,
555 out_features: 2,
556 weights: vec![1.0_f32; 6],
557 bias: vec![0.0; 2],
558 };
559 let input = WasmTensor::from_vec(vec![1.0_f32; 4], vec![1, 4]);
560 assert!(layer.forward(input).is_err());
561 }
562
563 #[test]
564 fn test_layer_norm_zero_mean() {
565 let feat = 4;
566 let layer = WasmLayer::LayerNorm {
567 normalized_shape: feat,
568 weight: vec![1.0_f32; feat],
569 bias: vec![0.0_f32; feat],
570 eps: 1e-5,
571 };
572 let input = WasmTensor::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0], vec![1, feat]);
573 let out = layer.forward(input).expect("ok");
574 let mean: f32 = out.data().iter().sum::<f32>() / feat as f32;
575 assert!(mean.abs() < 1e-4, "mean={mean}");
576 }
577
578 #[test]
579 fn test_dropout_is_identity() {
580 let layer = WasmLayer::Dropout { rate: 0.5 };
581 let data = vec![1.0_f32, 2.0, 3.0];
582 let input = WasmTensor::from_vec(data.clone(), vec![1, 3]);
583 let out = layer.forward(input).expect("ok");
584 assert_eq!(out.data(), data.as_slice());
585 }
586
587 #[test]
588 fn test_flatten_layer() {
589 let layer = WasmLayer::Flatten;
590 let input = WasmTensor::from_vec(vec![1.0_f32; 24], vec![2, 3, 4]);
591 let out = layer.forward(input).expect("ok");
592 assert_eq!(out.shape(), &[2, 12]);
593 }
594
595 #[test]
596 fn test_net_forward() {
597 let net = make_tiny_net();
598 let input = WasmTensor::from_vec(vec![1.0_f32, -1.0], vec![1, 2]);
599 let out = net.forward(input).expect("ok");
600 assert_eq!(out.shape(), &[1, 2]);
601 }
602
603 #[test]
604 fn test_net_total_params() {
605 let net = make_tiny_net();
606 assert_eq!(net.total_parameters(), 12); }
608
609 #[test]
610 fn test_net_binary_roundtrip() {
611 let net = make_tiny_net();
612 let bytes = net.to_bytes().expect("serialize ok");
613 let net2 = WasmNeuralNet::from_bytes(&bytes).expect("deserialize ok");
614 assert_eq!(net2.name(), "tiny");
615 assert_eq!(net2.num_layers(), 3);
616 assert_eq!(net2.total_parameters(), net.total_parameters());
617 }
618
619 #[test]
620 fn test_net_json_roundtrip() {
621 let net = make_tiny_net();
622 let json = net.to_json().expect("json ok");
623 let net2 = WasmNeuralNet::from_json(&json).expect("from json ok");
624 assert_eq!(net2.name(), "tiny");
625 assert_eq!(net2.num_layers(), 3);
626 }
627
628 #[test]
629 fn test_net_summary() {
630 let net = make_tiny_net();
631 let s = net.summary();
632 assert!(s.contains("tiny"));
633 assert!(s.contains("Dense"));
634 assert!(s.contains("ReLU"));
635 }
636
637 #[test]
638 fn test_net_metadata() {
639 let mut net = WasmNeuralNet::new("m");
640 net.add_metadata("version", "1.0");
641 assert_eq!(net.get_metadata("version"), Some("1.0"));
642 assert_eq!(net.get_metadata("missing"), None);
643 }
644
645 #[test]
646 fn test_from_bytes_invalid_err() {
647 assert!(WasmNeuralNet::from_bytes(b"not valid data").is_err());
648 }
649
650 #[test]
651 fn test_net_deterministic() {
652 let net = make_tiny_net();
653 let input = WasmTensor::from_vec(vec![2.0_f32, 3.0], vec![1, 2]);
654 let out1 = net.forward(input.clone()).expect("ok");
655 let out2 = net.forward(input).expect("ok");
656 for (a, b) in out1.data().iter().zip(out2.data().iter()) {
657 assert!((a - b).abs() < 1e-7);
658 }
659 }
660
661 #[test]
662 fn test_wasm_layer_type_names() {
663 assert_eq!(WasmLayer::ReLU.type_name(), "ReLU");
664 assert_eq!(WasmLayer::Sigmoid.type_name(), "Sigmoid");
665 assert_eq!(WasmLayer::Flatten.type_name(), "Flatten");
666 assert_eq!(WasmLayer::Softmax.type_name(), "Softmax");
667 assert_eq!(WasmLayer::Tanh.type_name(), "Tanh");
668 }
669}