Skip to main content

linear

Function linear 

Source
pub fn linear(input: &TritMatrix, weights: &TritMatrix) -> (TritMatrix, usize)
Expand description

BitNet-style ternary linear layer: output = sparse_matmul(input, W)

input: [batch × in_features] W: [in_features × out_features] (pre-quantized ternary weights) returns: ([batch × out_features], skipped_ops)