Skip to main content

scirs2_neural/wasm/
mod.rs

1//! WebAssembly inference wrapper for neural networks
2//!
3//! This module provides a pure-Rust inference engine for WASM deployment.
4//! No wasm-bindgen bindings are required at the struct level.
5//!
6//! # Key types
7//!
8//! - [`WasmTensor`] – heap-allocated f32 tensor
9//! - [`WasmLayer`] – inference-only layer enum
10//! - [`WasmNeuralNet`] – sequential stack with oxicode serialization
11
12use crate::error::{NeuralError, Result};
13use oxicode::{config as oxicode_config, serde as oxicode_serde};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17// ─────────────────────────────────────────────────────────────────────────────
18// WasmTensor
19// ─────────────────────────────────────────────────────────────────────────────
20
21/// A heap-allocated f32 tensor for WebAssembly inference.
22///
23/// # Examples
24/// ```
25/// use scirs2_neural::wasm::WasmTensor;
26/// let t = WasmTensor::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0], vec![2, 2]);
27/// assert_eq!(t.shape(), &[2, 2]);
28/// assert_eq!(t.numel(), 4);
29/// ```
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct WasmTensor {
32    data: Vec<f32>,
33    shape: Vec<usize>,
34}
35
36impl WasmTensor {
37    /// Create a tensor from a data vector and a shape.
38    pub fn from_vec(data: Vec<f32>, shape: Vec<usize>) -> Self {
39        Self { data, shape }
40    }
41
42    /// Create an all-zeros tensor.
43    pub fn zeros(shape: Vec<usize>) -> Self {
44        let n: usize = shape.iter().product();
45        Self {
46            data: vec![0.0_f32; n],
47            shape,
48        }
49    }
50
51    /// Returns a reference to the shape.
52    pub fn shape(&self) -> &[usize] {
53        &self.shape
54    }
55
56    /// Total number of elements.
57    pub fn numel(&self) -> usize {
58        self.data.len()
59    }
60
61    /// Raw data slice.
62    pub fn data(&self) -> &[f32] {
63        &self.data
64    }
65
66    /// Mutable raw data.
67    pub fn data_mut(&mut self) -> &mut Vec<f32> {
68        &mut self.data
69    }
70
71    /// Consume `self` and return the raw data vector.
72    pub fn into_data(self) -> Vec<f32> {
73        self.data
74    }
75
76    /// Batch size (first dimension).
77    pub fn batch_size(&self) -> usize {
78        self.shape.first().copied().unwrap_or(1)
79    }
80
81    /// Reshape without copying. Returns an error if element counts differ.
82    pub fn reshape(mut self, new_shape: Vec<usize>) -> Result<Self> {
83        let n: usize = new_shape.iter().product();
84        if n != self.data.len() {
85            return Err(NeuralError::ShapeMismatch(format!(
86                "WasmTensor::reshape: old numel={} new numel={n}",
87                self.data.len()
88            )));
89        }
90        self.shape = new_shape;
91        Ok(self)
92    }
93
94    /// Apply ReLU element-wise in-place.
95    pub fn relu_inplace(&mut self) {
96        for v in self.data.iter_mut() {
97            if *v < 0.0 {
98                *v = 0.0;
99            }
100        }
101    }
102
103    /// Apply sigmoid element-wise in-place.
104    pub fn sigmoid_inplace(&mut self) {
105        for v in self.data.iter_mut() {
106            *v = 1.0 / (1.0 + (-*v).exp());
107        }
108    }
109
110    /// Apply tanh element-wise in-place.
111    pub fn tanh_inplace(&mut self) {
112        for v in self.data.iter_mut() {
113            *v = v.tanh();
114        }
115    }
116
117    /// Apply row-wise softmax (last dimension) in-place.
118    pub fn softmax_inplace(&mut self) {
119        if self.shape.is_empty() || self.data.is_empty() {
120            return;
121        }
122        let last_dim = *self.shape.last().unwrap_or(&1);
123        if last_dim == 0 {
124            return;
125        }
126        let batch = self.data.len() / last_dim;
127        for b in 0..batch {
128            let slice = &mut self.data[b * last_dim..(b + 1) * last_dim];
129            let max = slice.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
130            let mut sum = 0.0_f32;
131            for v in slice.iter_mut() {
132                *v = (*v - max).exp();
133                sum += *v;
134            }
135            if sum > 0.0 {
136                for v in slice.iter_mut() {
137                    *v /= sum;
138                }
139            }
140        }
141    }
142}
143
144// ─────────────────────────────────────────────────────────────────────────────
145// WasmLayer
146// ─────────────────────────────────────────────────────────────────────────────
147
148/// Inference-only layer variants for WASM deployment.
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub enum WasmLayer {
151    /// Fully-connected layer: `y = xW^T + b`
152    Dense {
153        in_features: usize,
154        out_features: usize,
155        /// Row-major weights `[out_features × in_features]`
156        weights: Vec<f32>,
157        bias: Vec<f32>,
158    },
159    /// ReLU activation
160    ReLU,
161    /// Sigmoid activation
162    Sigmoid,
163    /// Tanh activation
164    Tanh,
165    /// Softmax activation (last dimension)
166    Softmax,
167    /// Dropout (identity at inference)
168    Dropout { rate: f32 },
169    /// Layer normalisation
170    LayerNorm {
171        normalized_shape: usize,
172        weight: Vec<f32>,
173        bias: Vec<f32>,
174        eps: f32,
175    },
176    /// Flatten: `[batch, ...rest]` → `[batch, rest.product()]`
177    Flatten,
178}
179
180impl WasmLayer {
181    /// Human-readable layer type name.
182    pub fn type_name(&self) -> &str {
183        match self {
184            WasmLayer::Dense { .. } => "Dense",
185            WasmLayer::ReLU => "ReLU",
186            WasmLayer::Sigmoid => "Sigmoid",
187            WasmLayer::Tanh => "Tanh",
188            WasmLayer::Softmax => "Softmax",
189            WasmLayer::Dropout { .. } => "Dropout",
190            WasmLayer::LayerNorm { .. } => "LayerNorm",
191            WasmLayer::Flatten => "Flatten",
192        }
193    }
194
195    /// Number of trainable parameters.
196    pub fn parameter_count(&self) -> usize {
197        match self {
198            WasmLayer::Dense { weights, bias, .. } => weights.len() + bias.len(),
199            WasmLayer::LayerNorm { weight, bias, .. } => weight.len() + bias.len(),
200            _ => 0,
201        }
202    }
203
204    /// Forward pass.
205    pub fn forward(&self, input: WasmTensor) -> Result<WasmTensor> {
206        match self {
207            WasmLayer::Dense {
208                in_features,
209                out_features,
210                weights,
211                bias,
212            } => dense_forward(input, *in_features, *out_features, weights, bias),
213            WasmLayer::ReLU => {
214                let mut t = input;
215                t.relu_inplace();
216                Ok(t)
217            }
218            WasmLayer::Sigmoid => {
219                let mut t = input;
220                t.sigmoid_inplace();
221                Ok(t)
222            }
223            WasmLayer::Tanh => {
224                let mut t = input;
225                t.tanh_inplace();
226                Ok(t)
227            }
228            WasmLayer::Softmax => {
229                let mut t = input;
230                t.softmax_inplace();
231                Ok(t)
232            }
233            WasmLayer::Dropout { .. } => Ok(input),
234            WasmLayer::LayerNorm {
235                normalized_shape,
236                weight,
237                bias,
238                eps,
239            } => layer_norm_forward(input, *normalized_shape, weight, bias, *eps),
240            WasmLayer::Flatten => {
241                let batch = input.batch_size();
242                let rest = input.numel() / batch.max(1);
243                input.reshape(vec![batch, rest])
244            }
245        }
246    }
247}
248
249// ─────────────────────────────────────────────────────────────────────────────
250// WasmNeuralNet
251// ─────────────────────────────────────────────────────────────────────────────
252
253/// Serializable, inference-only sequential neural network for WASM.
254///
255/// # Examples
256/// ```
257/// use scirs2_neural::wasm::{WasmNeuralNet, WasmLayer, WasmTensor};
258///
259/// let mut net = WasmNeuralNet::new("my_model");
260/// net.add_layer(WasmLayer::Dense {
261///     in_features: 4, out_features: 2,
262///     weights: vec![0.1; 4 * 2], bias: vec![0.0; 2],
263/// });
264/// net.add_layer(WasmLayer::ReLU);
265///
266/// let input = WasmTensor::from_vec(vec![1.0, 0.0, -1.0, 0.5], vec![1, 4]);
267/// let output = net.forward(input).expect("ok");
268/// assert_eq!(output.shape(), &[1, 2]);
269///
270/// let bytes = net.to_bytes().expect("serialize ok");
271/// let net2 = WasmNeuralNet::from_bytes(&bytes).expect("deserialize ok");
272/// assert_eq!(net2.name(), "my_model");
273/// ```
274#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct WasmNeuralNet {
276    name: String,
277    layers: Vec<WasmLayer>,
278    input_shape: Vec<usize>,
279    metadata: HashMap<String, String>,
280}
281
282impl WasmNeuralNet {
283    /// Create an empty network.
284    pub fn new(name: impl Into<String>) -> Self {
285        Self {
286            name: name.into(),
287            layers: Vec::new(),
288            input_shape: Vec::new(),
289            metadata: HashMap::new(),
290        }
291    }
292
293    /// Network name.
294    pub fn name(&self) -> &str {
295        &self.name
296    }
297
298    /// Number of layers.
299    pub fn num_layers(&self) -> usize {
300        self.layers.len()
301    }
302
303    /// All layers.
304    pub fn layers(&self) -> &[WasmLayer] {
305        &self.layers
306    }
307
308    /// Input shape (may be empty if not set).
309    pub fn input_shape(&self) -> &[usize] {
310        &self.input_shape
311    }
312
313    /// Set expected input shape (excluding batch).
314    pub fn set_input_shape(&mut self, shape: Vec<usize>) {
315        self.input_shape = shape;
316    }
317
318    /// Add a metadata entry.
319    pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
320        self.metadata.insert(key.into(), value.into());
321    }
322
323    /// Get a metadata value.
324    pub fn get_metadata(&self, key: &str) -> Option<&str> {
325        self.metadata.get(key).map(|s| s.as_str())
326    }
327
328    /// Append a layer.
329    pub fn add_layer(&mut self, layer: WasmLayer) {
330        self.layers.push(layer);
331    }
332
333    /// Total parameter count.
334    pub fn total_parameters(&self) -> usize {
335        self.layers.iter().map(|l| l.parameter_count()).sum()
336    }
337
338    /// Run the full forward pass.
339    pub fn forward(&self, input: WasmTensor) -> Result<WasmTensor> {
340        let mut x = input;
341        for layer in &self.layers {
342            x = layer.forward(x)?;
343        }
344        Ok(x)
345    }
346
347    /// Serialise to compact binary (oxicode).
348    pub fn to_bytes(&self) -> Result<Vec<u8>> {
349        let cfg = oxicode_config::standard();
350        oxicode_serde::encode_to_vec(self, cfg)
351            .map_err(|e| NeuralError::SerializationError(format!("oxicode encode: {e}")))
352    }
353
354    /// Deserialise from bytes produced by [`to_bytes`].
355    pub fn from_bytes(data: &[u8]) -> Result<Self> {
356        let cfg = oxicode_config::standard();
357        let (net, _) = oxicode_serde::decode_from_slice::<Self, _>(data, cfg)
358            .map_err(|e| NeuralError::DeserializationError(format!("oxicode decode: {e}")))?;
359        Ok(net)
360    }
361
362    /// Serialise to JSON.
363    pub fn to_json(&self) -> Result<String> {
364        serde_json::to_string(self)
365            .map_err(|e| NeuralError::SerializationError(format!("json encode: {e}")))
366    }
367
368    /// Deserialise from JSON.
369    pub fn from_json(json: &str) -> Result<Self> {
370        serde_json::from_str(json)
371            .map_err(|e| NeuralError::DeserializationError(format!("json decode: {e}")))
372    }
373
374    /// Print a brief summary of the network.
375    pub fn summary(&self) -> String {
376        let mut s = format!("WasmNeuralNet '{}'\n", self.name);
377        for (i, layer) in self.layers.iter().enumerate() {
378            s.push_str(&format!("  [{i}] {}\n", layer.type_name()));
379        }
380        s.push_str(&format!("Total parameters: {}\n", self.total_parameters()));
381        s
382    }
383}
384
385// ─────────────────────────────────────────────────────────────────────────────
386// Layer implementations
387// ─────────────────────────────────────────────────────────────────────────────
388
389fn dense_forward(
390    input: WasmTensor,
391    in_features: usize,
392    out_features: usize,
393    weights: &[f32],
394    bias: &[f32],
395) -> Result<WasmTensor> {
396    let shape = input.shape().to_vec();
397    if shape.len() < 2 {
398        return Err(NeuralError::ShapeMismatch(
399            "Dense: input must be at least 2-D [batch, features]".to_string(),
400        ));
401    }
402    let feat_dim = *shape.last().unwrap_or(&0);
403    if feat_dim != in_features {
404        return Err(NeuralError::ShapeMismatch(format!(
405            "Dense: expected in_features={in_features}, got {feat_dim}"
406        )));
407    }
408    if weights.len() != out_features * in_features {
409        return Err(NeuralError::ShapeMismatch(format!(
410            "Dense: weights len {} != {out_features}×{in_features}",
411            weights.len()
412        )));
413    }
414    if bias.len() != out_features {
415        return Err(NeuralError::ShapeMismatch(format!(
416            "Dense: bias len {} != {out_features}",
417            bias.len()
418        )));
419    }
420    let batch: usize = shape[..shape.len() - 1].iter().product::<usize>().max(1);
421    let input_data = input.data();
422    let mut output = vec![0.0_f32; batch * out_features];
423    for b in 0..batch {
424        for o in 0..out_features {
425            let mut acc = bias[o];
426            for i in 0..in_features {
427                acc += input_data[b * in_features + i] * weights[o * in_features + i];
428            }
429            output[b * out_features + o] = acc;
430        }
431    }
432    let mut out_shape = shape[..shape.len() - 1].to_vec();
433    out_shape.push(out_features);
434    Ok(WasmTensor::from_vec(output, out_shape))
435}
436
437fn layer_norm_forward(
438    input: WasmTensor,
439    normalized_shape: usize,
440    weight: &[f32],
441    bias: &[f32],
442    eps: f32,
443) -> Result<WasmTensor> {
444    let shape = input.shape().to_vec();
445    let feat_dim = *shape.last().unwrap_or(&0);
446    if feat_dim != normalized_shape {
447        return Err(NeuralError::ShapeMismatch(format!(
448            "LayerNorm: expected {normalized_shape}, got {feat_dim}"
449        )));
450    }
451    let batch: usize = (input.numel() / feat_dim.max(1)).max(1);
452    let data = input.data().to_vec();
453    let mut out_data = vec![0.0_f32; data.len()];
454    for b in 0..batch {
455        let slice = &data[b * feat_dim..(b + 1) * feat_dim];
456        let mean: f32 = slice.iter().sum::<f32>() / feat_dim as f32;
457        let var: f32 = slice.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / feat_dim as f32;
458        let std_inv = 1.0 / (var + eps).sqrt();
459        for (j, &v) in slice.iter().enumerate() {
460            out_data[b * feat_dim + j] = (v - mean) * std_inv * weight[j] + bias[j];
461        }
462    }
463    Ok(WasmTensor::from_vec(out_data, shape))
464}
465
466// ─────────────────────────────────────────────────────────────────────────────
467// Tests
468// ─────────────────────────────────────────────────────────────────────────────
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    fn make_tiny_net() -> WasmNeuralNet {
475        let mut net = WasmNeuralNet::new("tiny");
476        net.add_layer(WasmLayer::Dense {
477            in_features: 2,
478            out_features: 2,
479            weights: vec![1.0_f32, 0.0, 0.0, 1.0], // identity
480            bias: vec![0.0, 0.0],
481        });
482        net.add_layer(WasmLayer::ReLU);
483        net.add_layer(WasmLayer::Dense {
484            in_features: 2,
485            out_features: 2,
486            weights: vec![0.5_f32, 0.5, 0.5, 0.5],
487            bias: vec![0.0, 0.0],
488        });
489        net
490    }
491
492    #[test]
493    fn test_wasm_tensor_creation() {
494        let t = WasmTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
495        assert_eq!(t.shape(), &[2, 2]);
496        assert_eq!(t.numel(), 4);
497    }
498
499    #[test]
500    fn test_wasm_tensor_reshape_ok() {
501        let t = WasmTensor::from_vec(vec![1.0_f32; 6], vec![2, 3]);
502        let t2 = t.reshape(vec![3, 2]).expect("ok");
503        assert_eq!(t2.shape(), &[3, 2]);
504    }
505
506    #[test]
507    fn test_wasm_tensor_reshape_err() {
508        let t = WasmTensor::from_vec(vec![1.0_f32; 6], vec![2, 3]);
509        assert!(t.reshape(vec![4, 2]).is_err());
510    }
511
512    #[test]
513    fn test_relu_inplace() {
514        let mut t = WasmTensor::from_vec(vec![-1.0_f32, 2.0, -3.0, 4.0], vec![1, 4]);
515        t.relu_inplace();
516        assert_eq!(t.data(), &[0.0, 2.0, 0.0, 4.0]);
517    }
518
519    #[test]
520    fn test_sigmoid_range() {
521        let mut t = WasmTensor::from_vec(vec![-100.0_f32, 0.0, 100.0], vec![1, 3]);
522        t.sigmoid_inplace();
523        let d = t.data();
524        assert!(d[0] >= 0.0 && d[0] < 0.01);
525        assert!((d[1] - 0.5).abs() < 1e-4);
526        assert!(d[2] > 0.99 && d[2] <= 1.0);
527    }
528
529    #[test]
530    fn test_softmax_sums_to_one() {
531        let mut t = WasmTensor::from_vec(vec![1.0_f32, 2.0, 3.0], vec![1, 3]);
532        t.softmax_inplace();
533        let sum: f32 = t.data().iter().sum();
534        assert!((sum - 1.0).abs() < 1e-5, "sum={sum}");
535    }
536
537    #[test]
538    fn test_dense_identity() {
539        let layer = WasmLayer::Dense {
540            in_features: 2,
541            out_features: 2,
542            weights: vec![1.0_f32, 0.0, 0.0, 1.0],
543            bias: vec![0.0, 0.0],
544        };
545        let input = WasmTensor::from_vec(vec![3.0_f32, 4.0], vec![1, 2]);
546        let out = layer.forward(input).expect("ok");
547        assert!((out.data()[0] - 3.0).abs() < 1e-5);
548        assert!((out.data()[1] - 4.0).abs() < 1e-5);
549    }
550
551    #[test]
552    fn test_dense_shape_mismatch_err() {
553        let layer = WasmLayer::Dense {
554            in_features: 3,
555            out_features: 2,
556            weights: vec![1.0_f32; 6],
557            bias: vec![0.0; 2],
558        };
559        let input = WasmTensor::from_vec(vec![1.0_f32; 4], vec![1, 4]);
560        assert!(layer.forward(input).is_err());
561    }
562
563    #[test]
564    fn test_layer_norm_zero_mean() {
565        let feat = 4;
566        let layer = WasmLayer::LayerNorm {
567            normalized_shape: feat,
568            weight: vec![1.0_f32; feat],
569            bias: vec![0.0_f32; feat],
570            eps: 1e-5,
571        };
572        let input = WasmTensor::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0], vec![1, feat]);
573        let out = layer.forward(input).expect("ok");
574        let mean: f32 = out.data().iter().sum::<f32>() / feat as f32;
575        assert!(mean.abs() < 1e-4, "mean={mean}");
576    }
577
578    #[test]
579    fn test_dropout_is_identity() {
580        let layer = WasmLayer::Dropout { rate: 0.5 };
581        let data = vec![1.0_f32, 2.0, 3.0];
582        let input = WasmTensor::from_vec(data.clone(), vec![1, 3]);
583        let out = layer.forward(input).expect("ok");
584        assert_eq!(out.data(), data.as_slice());
585    }
586
587    #[test]
588    fn test_flatten_layer() {
589        let layer = WasmLayer::Flatten;
590        let input = WasmTensor::from_vec(vec![1.0_f32; 24], vec![2, 3, 4]);
591        let out = layer.forward(input).expect("ok");
592        assert_eq!(out.shape(), &[2, 12]);
593    }
594
595    #[test]
596    fn test_net_forward() {
597        let net = make_tiny_net();
598        let input = WasmTensor::from_vec(vec![1.0_f32, -1.0], vec![1, 2]);
599        let out = net.forward(input).expect("ok");
600        assert_eq!(out.shape(), &[1, 2]);
601    }
602
603    #[test]
604    fn test_net_total_params() {
605        let net = make_tiny_net();
606        assert_eq!(net.total_parameters(), 12); // (4+2) + 0 + (4+2)
607    }
608
609    #[test]
610    fn test_net_binary_roundtrip() {
611        let net = make_tiny_net();
612        let bytes = net.to_bytes().expect("serialize ok");
613        let net2 = WasmNeuralNet::from_bytes(&bytes).expect("deserialize ok");
614        assert_eq!(net2.name(), "tiny");
615        assert_eq!(net2.num_layers(), 3);
616        assert_eq!(net2.total_parameters(), net.total_parameters());
617    }
618
619    #[test]
620    fn test_net_json_roundtrip() {
621        let net = make_tiny_net();
622        let json = net.to_json().expect("json ok");
623        let net2 = WasmNeuralNet::from_json(&json).expect("from json ok");
624        assert_eq!(net2.name(), "tiny");
625        assert_eq!(net2.num_layers(), 3);
626    }
627
628    #[test]
629    fn test_net_summary() {
630        let net = make_tiny_net();
631        let s = net.summary();
632        assert!(s.contains("tiny"));
633        assert!(s.contains("Dense"));
634        assert!(s.contains("ReLU"));
635    }
636
637    #[test]
638    fn test_net_metadata() {
639        let mut net = WasmNeuralNet::new("m");
640        net.add_metadata("version", "1.0");
641        assert_eq!(net.get_metadata("version"), Some("1.0"));
642        assert_eq!(net.get_metadata("missing"), None);
643    }
644
645    #[test]
646    fn test_from_bytes_invalid_err() {
647        assert!(WasmNeuralNet::from_bytes(b"not valid data").is_err());
648    }
649
650    #[test]
651    fn test_net_deterministic() {
652        let net = make_tiny_net();
653        let input = WasmTensor::from_vec(vec![2.0_f32, 3.0], vec![1, 2]);
654        let out1 = net.forward(input.clone()).expect("ok");
655        let out2 = net.forward(input).expect("ok");
656        for (a, b) in out1.data().iter().zip(out2.data().iter()) {
657            assert!((a - b).abs() < 1e-7);
658        }
659    }
660
661    #[test]
662    fn test_wasm_layer_type_names() {
663        assert_eq!(WasmLayer::ReLU.type_name(), "ReLU");
664        assert_eq!(WasmLayer::Sigmoid.type_name(), "Sigmoid");
665        assert_eq!(WasmLayer::Flatten.type_name(), "Flatten");
666        assert_eq!(WasmLayer::Softmax.type_name(), "Softmax");
667        assert_eq!(WasmLayer::Tanh.type_name(), "Tanh");
668    }
669}