Expand description
§zyx-nn
Neural network modules for the zyx machine learning library.
This crate provides a collection of common neural network building blocks implemented as reusable Module traits, designed to work seamlessly with zyx’s kernel fusion and autograd system.
§Features
§Linear & Normalization
Linear— Dense fully-connected layerLayerNorm— Layer normalizationBatchNorm— Batch normalizationGroupNorm— Group normalizationRMSNorm— Root mean square normalization
§Recurrent Layers
RNNCell— Simple recurrent cellGRUCell— Gated recurrent unitLSTMCell— Long short-term memory
§Attention Mechanisms
CausalSelfAttention— Causal self-attention for transformersMultiheadAttention— Multi-head attention with configurable heads
§Embeddings & Convolution
Embedding— Learnable embedding lookupConv2d— 2D convolution
§Transformers
TransformerEncoderLayer— Single transformer encoder blockTransformerDecoderLayer— Single transformer decoder blockPositionalEncoding— Sinusoidal positional embeddings
§Python Bindings
pyfeature enables Python interoperability viapyo3
§Usage
use zyx::{Tensor, DType};
use zyx_nn::{Linear, LayerNorm};
let linear = Linear::new(128, 64, true, DType::F32).unwrap();
let x = Tensor::randn([32, 128], DType::F32).unwrap();
let y = linear.forward(&x).unwrap();
let norm = LayerNorm::new([64], 1e-5, true, true, DType::F32).unwrap();
let z = norm.forward(&y).unwrap();§API
All modules implement the Module trait from zyx-derive, which provides:
forward(x: impl Into<Tensor>) -> Result<Tensor, ZyxError>— Forward pass
§Autograd
zyx uses GradientTape for automatic differentiation. The tape must be created before the forward pass to capture the computation graph.
use zyx::{Tensor, DType, GradientTape};
use zyx_nn::Linear;
let mut linear = Linear::new(128, 64, true, DType::F32).unwrap();
let x = Tensor::randn([32, 128], DType::F32).unwrap();
let target = Tensor::randn([32, 64], DType::F32).unwrap();
// Create gradient tape BEFORE forward pass
let tape = GradientTape::new();
let y = linear.forward(&x).unwrap();
let loss = y.mse_loss(&target)?;
// Compute gradients w.r.t. model parameters
let grads = tape.gradient(&loss, &[linear.weight, linear.bias.unwrap()]);§Features
py— Enable Python bindings via pyo3
§Documentation
§License
LGPL-3.0-only
§Badges
Structs§
- Batch
Norm - Batch norm
- Causal
Self Attention - Causal self attention
- Conv2d
- Applies a 2D convolution over an input signal composed of several input planes.
- Embedding
- Embedding layer
- GRUCell
- GRU cell (PyTorch-style)
- Group
Norm - Group normalization
- LSTM
Cell - A single LSTM (Long Short-Term Memory) cell.
- Layer
Norm - A Layer Normalization layer.
- Linear
- Linear layer
- Multihead
Attention - Implements multi-head attention as described in “Attention Is All You Need”.
- Positional
Encoding - Sinusoidal positional encoding module for transformers.
- RMSNorm
- RMS norm layer
- RNNCell
- An Elman RNN cell with optional nonlinearity.
- Transformer
Decoder Layer - A single layer of a Transformer decoder.
- Transformer
Encoder Layer - A single Transformer Encoder layer, analogous to
torch.nn.TransformerEncoderLayer.
Derive Macros§
- Module
- Procedural macro Module