Skip to main content

Graph

Struct Graph 

Source
pub struct Graph { /* private fields */ }
Expand description

Eager autograd graph with explicit backward pass.

Implementations§

Source§

impl Graph

Source

pub fn backward(&mut self, target: NodeId) -> Result<(), AutogradError>

Backpropagates gradients from a scalar target.

Source

pub fn backward_with_checkpoints( &mut self, target: NodeId, config: &CheckpointConfig, ) -> Result<(), AutogradError>

Backward pass with activation checkpointing.

Source§

impl Graph

Source

pub fn new() -> Self

Creates an empty graph with automatic parallel backend.

Uses all available CPU threads for parallel matmul, conv2d, softmax, etc.

Source

pub fn new_single_threaded() -> Self

Creates an empty graph without a parallel backend (single-threaded).

Source

pub fn set_backend(&mut self, backend: Box<dyn BackwardOps>)

Set a compute backend for GPU-accelerated operations. When set, supported ops will dispatch through this backend. When None (default), ops use direct CPU kernel calls.

Source

pub fn clear_backend(&mut self)

Remove the backend, reverting to CPU kernel calls.

Source

pub fn variable(&mut self, value: Tensor) -> NodeId

Adds a trainable leaf node.

Source

pub fn constant(&mut self, value: Tensor) -> NodeId

Adds a non-trainable leaf node.

Source

pub fn value(&self, node: NodeId) -> Result<&Tensor, AutogradError>

Returns immutable node value.

Source

pub fn value_mut(&mut self, node: NodeId) -> Result<&mut Tensor, AutogradError>

Returns mutable node value.

Source

pub fn requires_grad(&self, node: NodeId) -> Result<bool, AutogradError>

Returns whether node is trainable.

Source

pub fn grad(&self, node: NodeId) -> Result<Option<&Tensor>, AutogradError>

Returns immutable gradient if already computed.

Source

pub fn grad_mut( &mut self, node: NodeId, ) -> Result<Option<&mut Tensor>, AutogradError>

Returns mutable gradient if already computed.

Source

pub fn set_grad( &mut self, node: NodeId, grad: Tensor, ) -> Result<(), AutogradError>

Sets the gradient for a node, replacing any existing gradient.

Source

pub fn node_count(&self) -> usize

Returns current node count in the graph.

Source

pub fn zero_grads(&mut self)

Clears gradients for all nodes.

Source

pub fn truncate(&mut self, keep_nodes: usize) -> Result<(), AutogradError>

Truncates graph to a given node count.

Source

pub fn add( &mut self, left: NodeId, right: NodeId, ) -> Result<NodeId, AutogradError>

Adds two nodes with broadcasting support.

Source

pub fn sub( &mut self, left: NodeId, right: NodeId, ) -> Result<NodeId, AutogradError>

Subtracts two nodes with broadcasting support.

Source

pub fn mul( &mut self, left: NodeId, right: NodeId, ) -> Result<NodeId, AutogradError>

Multiplies two nodes elementwise with broadcasting support.

Source

pub fn relu(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies ReLU activation to one node.

Source

pub fn matmul_2d( &mut self, left: NodeId, right: NodeId, ) -> Result<NodeId, AutogradError>

Performs rank-2 matrix multiplication.

Source

pub fn div( &mut self, left: NodeId, right: NodeId, ) -> Result<NodeId, AutogradError>

Divides two nodes elementwise with broadcasting support.

Source

pub fn neg(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies element-wise negation.

Source

pub fn exp(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies element-wise natural exponential.

Source

pub fn log(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies element-wise natural logarithm.

Source

pub fn sqrt(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies element-wise square root.

Source

pub fn sigmoid(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies element-wise sigmoid activation: 1 / (1 + exp(-x)).

Source

pub fn gelu(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies element-wise GELU activation (fast approximation): x * sigmoid(1.702 * x).

Source

pub fn silu(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies element-wise SiLU (Swish) activation: x * sigmoid(x).

Source

pub fn mish(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies element-wise Mish activation: x * tanh(softplus(x)).

Source

pub fn tanh(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies element-wise hyperbolic tangent.

Source

pub fn abs(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies element-wise absolute value.

Source

pub fn pow( &mut self, base: NodeId, exponent: NodeId, ) -> Result<NodeId, AutogradError>

Applies element-wise power: base ^ exponent.

Source

pub fn clamp( &mut self, input: NodeId, min_val: f32, max_val: f32, ) -> Result<NodeId, AutogradError>

Applies element-wise clamping to [min_val, max_val].

Source

pub fn leaky_relu( &mut self, input: NodeId, negative_slope: f32, ) -> Result<NodeId, AutogradError>

Applies element-wise leaky ReLU: max(0, x) + negative_slope * min(0, x).

Source

pub fn softmax(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies softmax along the last dimension.

Source

pub fn log_softmax(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Applies log-softmax along the last dimension.

Source

pub fn transpose_2d(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

2D matrix transpose in graph (for backward).

Source

pub fn reshape( &mut self, input: NodeId, new_shape: Vec<usize>, ) -> Result<NodeId, AutogradError>

Reshape in graph (preserves backward path).

Source

pub fn unsqueeze( &mut self, input: NodeId, axis: usize, ) -> Result<NodeId, AutogradError>

Unsqueeze in graph (preserves backward path).

Source

pub fn squeeze( &mut self, input: NodeId, axis: usize, ) -> Result<NodeId, AutogradError>

Squeeze in graph (preserves backward path).

Source

pub fn cat( &mut self, inputs: &[NodeId], axis: usize, ) -> Result<NodeId, AutogradError>

Concatenates multiple nodes along axis.

Source

pub fn select( &mut self, input: NodeId, axis: usize, index: usize, ) -> Result<NodeId, AutogradError>

Selects a single index along axis, reducing that dimension.

Source

pub fn narrow( &mut self, input: NodeId, axis: usize, start: usize, length: usize, ) -> Result<NodeId, AutogradError>

Narrows (slices) a node along axis from start for length elements.

Source

pub fn gather( &mut self, input: NodeId, axis: usize, index: NodeId, ) -> Result<NodeId, AutogradError>

Gathers elements along axis using an index tensor (from another node).

For each position in the index tensor, retrieves the value from input at the index along axis.

Source

pub fn scatter_add( &mut self, input: NodeId, index: NodeId, src: NodeId, axis: usize, ) -> Result<NodeId, AutogradError>

Scatter-add operation: scatters src values into input at index positions along axis.

Forward: output = input.scatter_add(axis, index, src)

Source

pub fn pad( &mut self, input: NodeId, padding: &[usize], value: f32, ) -> Result<NodeId, AutogradError>

Pads the tensor with a constant value along each dimension.

padding is a flat array of [before_0, after_0, before_1, after_1, ...] pairs per dim.

Source

pub fn repeat( &mut self, input: NodeId, repeats: &[usize], ) -> Result<NodeId, AutogradError>

Repeats the tensor along each dimension.

Source

pub fn sum(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Reduces one node to scalar sum.

Source

pub fn mean(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Reduces one node to scalar mean.

Source

pub fn conv2d_nhwc( &mut self, input: NodeId, weight: NodeId, bias: Option<NodeId>, stride_h: usize, stride_w: usize, ) -> Result<NodeId, AutogradError>

NHWC 2-D convolution forward. input shape [N,H,W,C_in], weight shape [KH,KW,C_in,C_out], optional bias shape [C_out].

Source

pub fn max_pool2d_nhwc( &mut self, input: NodeId, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize, ) -> Result<NodeId, AutogradError>

NHWC max-pooling forward with argmax tracking for backward.

Source

pub fn avg_pool2d_nhwc( &mut self, input: NodeId, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize, ) -> Result<NodeId, AutogradError>

NHWC average-pooling forward.

Source

pub fn batch_norm2d_nhwc( &mut self, input: NodeId, gamma: NodeId, beta: NodeId, running_mean: NodeId, running_var: NodeId, epsilon: f32, ) -> Result<NodeId, AutogradError>

NHWC batch-normalization forward (inference mode: uses running stats). gamma/beta/running_mean/running_var must be rank-1 of size C.

Source

pub fn layer_norm( &mut self, input: NodeId, gamma: NodeId, beta: NodeId, epsilon: f32, ) -> Result<NodeId, AutogradError>

Layer normalization over the last dimension.

Input can be any rank; normalization is applied over the last axis. gamma and beta must have shape [last_dim].

Source

pub fn group_norm( &mut self, input: NodeId, gamma: NodeId, beta: NodeId, num_groups: usize, epsilon: f32, ) -> Result<NodeId, AutogradError>

Group normalization on NHWC input [N, H, W, C].

gamma and beta must have shape [C]. num_groups must divide C.

Source

pub fn flatten(&mut self, input: NodeId) -> Result<NodeId, AutogradError>

Flatten rank-4 NHWC tensor [N,H,W,C] to rank-2 [N, H*W*C].

Source

pub fn sum_axis( &mut self, input: NodeId, axis: usize, ) -> Result<NodeId, AutogradError>

Reduces one node by summing along a single axis (removing that dimension).

Source

pub fn mean_axis( &mut self, input: NodeId, axis: usize, ) -> Result<NodeId, AutogradError>

Reduces one node by averaging along a single axis (removing that dimension).

Source

pub fn depthwise_conv2d_nhwc( &mut self, input: NodeId, weight: NodeId, bias: Option<NodeId>, stride_h: usize, stride_w: usize, ) -> Result<NodeId, AutogradError>

NHWC depthwise 2-D convolution forward. input shape [N,H,W,C], weight shape [KH,KW,C,1], optional bias shape [C].

Source

pub fn scatter( &mut self, input: NodeId, indices: NodeId, src: NodeId, ) -> Result<NodeId, AutogradError>

Scatter: write values from src into input at row positions given by indices. input shape: [N, D], indices shape: [M], src shape: [M, D]. Result: input with rows at indices replaced by src rows.

Source

pub fn embedding_lookup( &mut self, weight: NodeId, indices: NodeId, ) -> Result<NodeId, AutogradError>

Embedding lookup: gather rows from weight matrix at given indices. weight shape: [vocab_size, embed_dim], indices shape: [seq_len]. Result shape: [seq_len, embed_dim].

Source

pub fn conv1d_nlc( &mut self, input: NodeId, weight: NodeId, bias: Option<NodeId>, stride: usize, ) -> Result<NodeId, AutogradError>

NLC 1-D convolution forward. input shape [N,L,C_in], weight shape [K,C_in,C_out], optional bias shape [C_out].

Source

pub fn conv3d_ndhwc( &mut self, input: NodeId, weight: NodeId, bias: Option<NodeId>, stride_d: usize, stride_h: usize, stride_w: usize, ) -> Result<NodeId, AutogradError>

NDHWC 3-D convolution forward (no padding). input shape [N,D,H,W,C_in], weight shape [KD,KH,KW,C_in,C_out], optional bias shape [C_out].

Source

pub fn scaled_dot_product_attention( &mut self, query: NodeId, key: NodeId, value: NodeId, ) -> Result<NodeId, AutogradError>

Scaled dot-product attention forward. query shape [seq_q, d_k], key shape [seq_k, d_k], value shape [seq_k, d_v]. Returns [seq_q, d_v].

Source

pub fn conv_transpose2d_nhwc( &mut self, input: NodeId, weight: NodeId, bias: Option<NodeId>, stride_h: usize, stride_w: usize, ) -> Result<NodeId, AutogradError>

NHWC transposed 2-D convolution forward. input shape [N,H,W,C_in], weight shape [KH,KW,C_out,C_in], optional bias shape [C_out]. Output shape: [N, (H-1)*stride_h + KH, (W-1)*stride_w + KW, C_out].

Source

pub fn adaptive_avg_pool2d_nhwc( &mut self, input: NodeId, out_h: usize, out_w: usize, ) -> Result<NodeId, AutogradError>

NHWC adaptive average pool 2d forward. input shape [N,H,W,C], output shape [N,out_h,out_w,C].

Source

pub fn adaptive_max_pool2d_nhwc( &mut self, input: NodeId, out_h: usize, out_w: usize, ) -> Result<NodeId, AutogradError>

NHWC adaptive max pool 2d forward with argmax tracking for backward.

Source

pub fn instance_norm_nhwc( &mut self, input: NodeId, gamma: NodeId, beta: NodeId, epsilon: f32, ) -> Result<NodeId, AutogradError>

Instance normalization (NHWC) forward. Normalizes per (N,C) pair across H*W spatial dimensions. gamma and beta must have shape [C].

Source

pub fn prelu( &mut self, input: NodeId, alpha: NodeId, ) -> Result<NodeId, AutogradError>

PReLU activation forward. alpha is a parameter node with shape [C] or [1]. For NHWC inputs, channels are the last dimension.

Source

pub fn pixel_shuffle( &mut self, input: NodeId, upscale_factor: usize, ) -> Result<NodeId, AutogradError>

Pixel shuffle forward: rearranges [N, H, W, Cr^2] -> [N, Hr, W*r, C].

Source

pub fn upsample_nearest( &mut self, input: NodeId, scale_factor: usize, ) -> Result<NodeId, AutogradError>

Nearest-neighbor upsample forward: [N, H, W, C] -> [N, Hr, Wr, C].

Source

pub fn rnn_forward( &mut self, input: NodeId, w_ih: NodeId, w_hh: NodeId, bias: NodeId, ) -> Result<NodeId, AutogradError>

RNN forward pass through all timesteps (for BPTT). input: [seq_len, input_size], w_ih: [input_size, hidden_size], w_hh: [hidden_size, hidden_size], bias: [hidden_size]. Returns output [seq_len, hidden_size].

Source

pub fn lstm_forward( &mut self, input: NodeId, w_ih: NodeId, w_hh: NodeId, bias: NodeId, ) -> Result<NodeId, AutogradError>

LSTM forward pass through all timesteps (for BPTT). input: [seq_len, input_size], w_ih: [input_size, 4*hidden_size], w_hh: [hidden_size, 4*hidden_size], bias: [4*hidden_size]. Returns output [seq_len, hidden_size].

Source

pub fn gru_forward( &mut self, input: NodeId, w_ih: NodeId, w_hh: NodeId, bias_ih: NodeId, bias_hh: NodeId, ) -> Result<NodeId, AutogradError>

GRU forward pass through all timesteps (for BPTT). input: [seq_len, input_size], w_ih: [input_size, 3*hidden_size], w_hh: [hidden_size, 3*hidden_size], bias_ih: [3*hidden_size], bias_hh: [3*hidden_size]. Returns output [seq_len, hidden_size].

Source

pub fn deformable_conv2d_nhwc( &mut self, input: NodeId, weight: NodeId, offsets: NodeId, bias: Option<NodeId>, stride: usize, padding: usize, ) -> Result<NodeId, AutogradError>

Deformable conv2d NHWC forward. input: [N,H,W,C_in], weight: [KH,KW,C_in,C_out], offsets: [N,OH,OW,KH*KW*2].

Source

pub fn clip_grad_norm(&mut self, param_nodes: &[NodeId], max_norm: f32) -> f32

Clips gradients by global L2 norm. Returns the original norm before clipping. If total_norm > max_norm, scales all gradients by max_norm / total_norm.

Source

pub fn clip_grad_value(&mut self, param_nodes: &[NodeId], max_value: f32)

Clips each gradient element to [-max_value, max_value].

Trait Implementations§

Source§

impl Default for Graph

Source§

fn default() -> Self

Returns the “default value” for a type. Read more

Auto Trait Implementations§

§

impl Freeze for Graph

§

impl !RefUnwindSafe for Graph

§

impl !Send for Graph

§

impl !Sync for Graph

§

impl Unpin for Graph

§

impl UnsafeUnpin for Graph

§

impl !UnwindSafe for Graph

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.