Skip to main content

trustformers_core/
traits.rs

1//! Core traits defining the fundamental abstractions of TrustformeRS.
2//!
3//! This module contains the essential traits that form the foundation of the TrustformeRS
4//! transformer library. These traits define the interfaces for models, layers, configuration,
5//! tokenization, optimization, and parameter initialization.
6//!
7//! # Overview
8//!
9//! The traits in this module establish a consistent API across all transformer implementations:
10//!
11//! - [`Model`]: The main trait for transformer models with forward pass and loading capabilities
12//! - [`Layer`]: Building blocks for neural network architectures
13//! - [`Config`]: Configuration management for models and components
14//! - [`WeightReader`]: Interface for loading pretrained model weights
15//! - [`Tokenizer`]: Text tokenization and encoding/decoding
16//! - [`Optimizer`]: Parameter optimization algorithms
17//! - [`ParameterInit`]: Weight initialization strategies
18//!
19//! # Examples
20//!
21//! ```no_run
22//! use trustformers_core::traits::{Model, Config};
23//! use trustformers_core::tensor::Tensor;
24//!
25//! // Example model implementation
26//! struct MyModel {
27//!     config: MyConfig,
28//!     // ... model layers
29//! }
30//!
31//! impl Model for MyModel {
32//!     type Config = MyConfig;
33//!     type Input = Tensor;
34//!     type Output = Tensor;
35//!
36//!     fn forward(&self, input: Self::Input) -> Result<Self::Output> {
37//!         // Model forward pass implementation
38//!         Ok(input)
39//!     }
40//!
41//!     fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
42//!         // Load pretrained weights
43//!         Ok(())
44//!     }
45//!
46//!     fn get_config(&self) -> &Self::Config {
47//!         &self.config
48//!     }
49//! }
50//! ```
51
52use crate::errors::Result;
53use crate::tensor::Tensor;
54use serde::{Deserialize, Serialize};
55use std::io::Read;
56
57/// The main trait for transformer models.
58///
59/// This trait defines the interface that all transformer models must implement,
60/// providing a consistent API for forward passes, weight loading, and configuration access.
61///
62/// # Type Parameters
63///
64/// - `Config`: The configuration type for this model, must implement [`Config`]
65/// - `Input`: The input type for the model's forward pass
66/// - `Output`: The output type produced by the model
67///
68/// # Thread Safety
69///
70/// Models must be `Send + Sync` to support multi-threaded inference and training.
71///
72/// # Example
73///
74/// ```no_run
75/// use trustformers_core::traits::{Model, Config};
76/// use trustformers_core::tensor::Tensor;
77/// use trustformers_core::error::Result;
78/// use std::io::Read;
79/// use serde::{Deserialize, Serialize};
80///
81/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
82/// #[derive(Deserialize, Serialize)]
83/// struct BertConfig {
84///     hidden_size: usize,
85///     num_attention_heads: usize,
86///     // ... other config fields
87/// }
88///
89/// impl Config for BertConfig {
90///     fn architecture(&self) -> &'static str {
91///         "bert"
92///     }
93/// }
94///
95/// struct BertModel {
96///     config: BertConfig,
97///     // ... model layers
98/// }
99///
100/// impl Model for BertModel {
101///     type Config = BertConfig;
102///     type Input = Tensor;
103///     type Output = Tensor;
104///
105///     fn forward(&self, input: Self::Input) -> Result<Self::Output> {
106///         // BERT forward pass implementation
107///         Ok(input)
108///     }
109///
110///     fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
111///         // Load BERT weights from reader
112///         Ok(())
113///     }
114///
115///     fn get_config(&self) -> &Self::Config {
116///         &self.config
117///     }
118/// }
119/// # Ok(())
120/// # }
121/// ```
122pub trait Model: Send + Sync {
123    type Config: Config;
124    type Input;
125    type Output;
126
127    /// Performs a forward pass through the model.
128    ///
129    /// # Arguments
130    ///
131    /// * `input` - The input data for the model
132    ///
133    /// # Returns
134    ///
135    /// Returns `Ok(output)` on success, or an error if the forward pass fails.
136    ///
137    /// # Errors
138    ///
139    /// May return errors for:
140    /// - Invalid input dimensions
141    /// - Numerical computation errors
142    /// - Out of memory conditions
143    fn forward(&self, input: Self::Input) -> Result<Self::Output>;
144
145    /// Loads pretrained weights into the model.
146    ///
147    /// This method reads model weights from a reader (typically a file or network stream)
148    /// and updates the model's parameters accordingly.
149    ///
150    /// # Arguments
151    ///
152    /// * `reader` - A reader providing access to the pretrained weight data
153    ///
154    /// # Returns
155    ///
156    /// Returns `Ok(())` on successful loading, or an error if loading fails.
157    ///
158    /// # Errors
159    ///
160    /// May return errors for:
161    /// - IO errors while reading
162    /// - Incompatible weight formats
163    /// - Mismatched tensor shapes
164    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()>;
165
166    /// Returns a reference to the model's configuration.
167    ///
168    /// # Returns
169    ///
170    /// A reference to the model's configuration object.
171    fn get_config(&self) -> &Self::Config;
172
173    /// Returns the total number of parameters in the model.
174    ///
175    /// This method calculates the total number of trainable parameters
176    /// across all layers in the model. It's useful for compression metrics,
177    /// model size analysis, and memory usage calculations.
178    ///
179    /// # Returns
180    ///
181    /// The total number of parameters as a `usize`.
182    ///
183    /// # Example
184    ///
185    /// ```no_run
186    /// use trustformers_core::traits::Model;
187    ///
188    /// fn analyze_model_size<M: Model>(model: &M) {
189    ///     let params = model.num_parameters();
190    ///     let memory_mb = params * 4 / (1024 * 1024); // Assuming f32 weights
191    ///     println!("Model has {} parameters ({} MB)", params, memory_mb);
192    /// }
193    /// ```
194    fn num_parameters(&self) -> usize;
195}
196
197/// A building block for neural network architectures.
198///
199/// The `Layer` trait represents a single computational unit in a neural network,
200/// such as a linear transformation, attention mechanism, or normalization layer.
201/// Layers can be composed together to build complete models.
202///
203/// # Type Parameters
204///
205/// - `Input`: The input type accepted by this layer
206/// - `Output`: The output type produced by this layer
207///
208/// # Thread Safety
209///
210/// Layers must be `Send + Sync` to support parallel computation.
211///
212/// # Example
213///
214/// ```no_run
215/// use trustformers_core::traits::Layer;
216/// use trustformers_core::tensor::Tensor;
217/// use trustformers_core::error::Result;
218///
219/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
220/// struct LinearLayer {
221///     weight: Tensor,
222///     bias: Option<Tensor>,
223/// }
224///
225/// impl Layer for LinearLayer {
226///     type Input = Tensor;
227///     type Output = Tensor;
228///
229///     fn forward(&self, input: Self::Input) -> Result<Self::Output> {
230///         // Compute linear transformation: y = xW^T + b
231///         let output = input.matmul(&self.weight.transpose()?)?;
232///         if let Some(bias) = &self.bias {
233///             output.add(bias)
234///         } else {
235///             Ok(output)
236///         }
237///     }
238/// }
239/// # Ok(())
240/// # }
241/// ```
242pub trait Layer: Send + Sync {
243    type Input;
244    type Output;
245
246    /// Performs the forward computation of this layer.
247    ///
248    /// # Arguments
249    ///
250    /// * `input` - The input data to process
251    ///
252    /// # Returns
253    ///
254    /// Returns `Ok(output)` containing the layer's output, or an error if computation fails.
255    ///
256    /// # Errors
257    ///
258    /// May return errors for:
259    /// - Invalid input dimensions
260    /// - Numerical errors during computation
261    /// - Resource allocation failures
262    fn forward(&self, input: Self::Input) -> Result<Self::Output>;
263}
264
265/// Configuration trait for models and components.
266///
267/// This trait provides a standardized interface for configuration objects
268/// that can be serialized, deserialized, and validated. All model configurations
269/// must implement this trait to ensure compatibility with the TrustformeRS ecosystem.
270///
271/// # Requirements
272///
273/// Implementing types must be serializable and deserializable using serde.
274///
275/// # Example
276///
277/// ```no_run
278/// use trustformers_core::traits::Config;
279/// use trustformers_core::error::Result;
280/// use serde::{Deserialize, Serialize};
281///
282/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
283/// #[derive(Debug, Clone, Deserialize, Serialize)]
284/// struct GPT2Config {
285///     vocab_size: usize,
286///     hidden_size: usize,
287///     num_layers: usize,
288///     num_heads: usize,
289/// }
290///
291/// impl Config for GPT2Config {
292///     fn validate(&self) -> Result<()> {
293///         if self.hidden_size % self.num_heads != 0 {
294///             return Err(anyhow::anyhow!(
295///                 "hidden_size must be divisible by num_heads"
296///             ));
297///         }
298///         Ok(())
299///     }
300///
301///     fn architecture(&self) -> &'static str {
302///         "gpt2"
303///     }
304/// }
305/// # Ok(())
306/// # }
307/// ```
308pub trait Config: for<'de> Deserialize<'de> + Serialize {
309    /// Validates the configuration for correctness.
310    ///
311    /// This method should check that all configuration parameters are valid
312    /// and compatible with each other. The default implementation accepts
313    /// all configurations as valid.
314    ///
315    /// # Returns
316    ///
317    /// Returns `Ok(())` if the configuration is valid, or an error describing
318    /// the validation failure.
319    ///
320    /// # Example
321    ///
322    /// Common validations include:
323    /// - Checking that dimensions are compatible
324    /// - Verifying that values are within acceptable ranges
325    /// - Ensuring required fields are properly set
326    fn validate(&self) -> Result<()> {
327        Ok(())
328    }
329
330    /// Returns the architecture name for this configuration.
331    ///
332    /// This should return a static string identifying the model architecture,
333    /// such as "bert", "gpt2", "t5", etc. This is used for model registration
334    /// and automatic model selection.
335    ///
336    /// # Returns
337    ///
338    /// A static string slice containing the architecture name.
339    fn architecture(&self) -> &'static str;
340}
341
342/// Interface for reading model weights from various sources.
343///
344/// `WeightReader` provides an abstraction over different weight storage formats,
345/// allowing models to load pretrained parameters from files, network sources,
346/// or other storage backends.
347///
348/// # Supported Formats
349///
350/// Implementations may support various formats including:
351/// - SafeTensors (.safetensors)
352/// - PyTorch checkpoints (.pt, .bin)
353/// - NumPy arrays (.npz)
354/// - Custom formats
355///
356/// # Example
357///
358/// ```no_run
359/// use trustformers_core::traits::WeightReader;
360/// use trustformers_core::tensor::Tensor;
361/// use trustformers_core::error::Result;
362///
363/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
364/// struct SafeTensorsReader {
365///     // ... implementation details
366/// }
367///
368/// impl WeightReader for SafeTensorsReader {
369///     fn read_tensor(&mut self, name: &str) -> Result<Tensor> {
370///         // Read tensor from SafeTensors file
371///         // ...
372///         Tensor::zeros(&[768, 768])
373///     }
374///
375///     fn list_tensors(&self) -> Vec<String> {
376///         vec![
377///             "bert.embeddings.word_embeddings.weight".to_string(),
378///             "bert.encoder.layer.0.attention.self.query.weight".to_string(),
379///             // ... more tensor names
380///         ]
381///     }
382/// }
383/// # Ok(())
384/// # }
385/// ```
386pub trait WeightReader {
387    /// Reads a tensor by name from the weight source.
388    ///
389    /// # Arguments
390    ///
391    /// * `name` - The name/key of the tensor to read (e.g., "encoder.layer.0.weight")
392    ///
393    /// # Returns
394    ///
395    /// Returns `Ok(tensor)` containing the requested tensor, or an error if the
396    /// tensor cannot be found or loaded.
397    ///
398    /// # Errors
399    ///
400    /// May return errors for:
401    /// - Tensor not found with the given name
402    /// - IO errors while reading
403    /// - Corrupted or invalid tensor data
404    /// - Unsupported tensor format
405    fn read_tensor(&mut self, name: &str) -> Result<Tensor>;
406
407    /// Lists all available tensor names in the weight source.
408    ///
409    /// This method is useful for debugging and for discovering the structure
410    /// of saved model weights.
411    ///
412    /// # Returns
413    ///
414    /// A vector containing the names of all available tensors.
415    fn list_tensors(&self) -> Vec<String>;
416}
417
418/// Text tokenization interface for transformer models.
419///
420/// The `Tokenizer` trait provides methods for converting between text and token IDs,
421/// which is essential for preparing input data for transformer models. Implementations
422/// may use various tokenization algorithms such as WordPiece, BPE, or SentencePiece.
423///
424/// # Thread Safety
425///
426/// Tokenizers must be `Send + Sync` to support concurrent tokenization.
427///
428/// # Example
429///
430/// ```no_run
431/// use trustformers_core::traits::{Tokenizer, TokenizedInput};
432/// use trustformers_core::error::Result;
433///
434/// struct BertTokenizer {
435///     vocab: std::collections::HashMap<String, u32>,
436///     // ... other fields
437/// }
438///
439/// impl Tokenizer for BertTokenizer {
440///     fn encode(&self, text: &str) -> Result<TokenizedInput> {
441///         // Tokenize text into subwords
442///         let tokens = vec![101, 2023, 2003, 1037, 3231, 102]; // [CLS] this is a test [SEP]
443///         Ok(TokenizedInput {
444///             input_ids: tokens,
445///             attention_mask: vec![1; 6],
446///             token_type_ids: Some(vec![0; 6]),
447///         })
448///     }
449///
450///     fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
451///         // Encode two texts for tasks like question answering
452///         let tokens1 = self.encode(text)?;
453///         let tokens2 = self.encode(text2)?;
454///         // Combine tokens with separator
455///         let mut combined_tokens = tokens1.token_ids;
456///         combined_tokens.extend_from_slice(&tokens2.token_ids);
457///         Ok(TokenizedInput {
458///             token_ids: combined_tokens,
459///             attention_mask: vec![1; combined_tokens.len()],
460///             token_type_ids: Some(vec![0; tokens1.token_ids.len()].into_iter()
461///                 .chain(vec![1; tokens2.token_ids.len()]).collect()),
462///         })
463///     }
464///
465///     fn decode(&self, ids: &[u32]) -> Result<String> {
466///         // Convert token IDs back to text
467///         Ok("this is a test".to_string())
468///     }
469///
470///     fn vocab_size(&self) -> usize {
471///         30522 // BERT base vocabulary size
472///     }
473/// }
474/// ```
475pub trait Tokenizer: Send + Sync {
476    /// Encodes a single text string into tokens.
477    ///
478    /// # Arguments
479    ///
480    /// * `text` - The input text to tokenize
481    ///
482    /// # Returns
483    ///
484    /// Returns a `TokenizedInput` containing:
485    /// - `input_ids`: The token IDs
486    /// - `attention_mask`: Binary mask indicating real vs padding tokens
487    /// - `token_type_ids`: Optional segment IDs for models like BERT
488    ///
489    /// # Errors
490    ///
491    /// May return errors for:
492    /// - Invalid UTF-8 sequences
493    /// - Text exceeding maximum length
494    /// - Unknown tokens that cannot be handled
495    fn encode(&self, text: &str) -> Result<TokenizedInput>;
496
497    /// Encodes a pair of texts for sequence-pair tasks.
498    ///
499    /// This method is used for tasks that require two input sequences,
500    /// such as question answering, textual entailment, or sequence classification.
501    ///
502    /// # Arguments
503    ///
504    /// * `text` - The first text sequence
505    /// * `text2` - The second text sequence
506    ///
507    /// # Returns
508    ///
509    /// Returns a `TokenizedInput` with both sequences encoded and separated
510    /// by appropriate special tokens (e.g., [SEP] for BERT).
511    ///
512    /// # Errors
513    ///
514    /// May return errors for the same reasons as `encode()`.
515    fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput>;
516
517    /// Decodes token IDs back into text.
518    ///
519    /// # Arguments
520    ///
521    /// * `ids` - The token IDs to decode
522    ///
523    /// # Returns
524    ///
525    /// Returns the decoded text string. Special tokens may be included
526    /// or excluded depending on the implementation.
527    ///
528    /// # Errors
529    ///
530    /// May return errors for:
531    /// - Invalid token IDs
532    /// - Decoding errors
533    fn decode(&self, ids: &[u32]) -> Result<String>;
534
535    /// Returns the size of the tokenizer's vocabulary.
536    ///
537    /// # Returns
538    ///
539    /// The total number of tokens in the vocabulary.
540    fn vocab_size(&self) -> usize;
541
542    /// Returns a copy of the vocabulary as a mapping from tokens to IDs.
543    ///
544    /// # Returns
545    ///
546    /// A HashMap containing the vocabulary mapping.
547    fn get_vocab(&self) -> std::collections::HashMap<String, u32>;
548
549    /// Converts a token string to its corresponding ID.
550    ///
551    /// # Arguments
552    ///
553    /// * `token` - The token string to convert
554    ///
555    /// # Returns
556    ///
557    /// The token ID if the token exists in the vocabulary, None otherwise.
558    fn token_to_id(&self, token: &str) -> Option<u32>;
559
560    /// Converts a token ID to its corresponding token string.
561    ///
562    /// # Arguments
563    ///
564    /// * `id` - The token ID to convert
565    ///
566    /// # Returns
567    ///
568    /// The token string if the ID exists in the vocabulary, None otherwise.
569    fn id_to_token(&self, id: u32) -> Option<String>;
570}
571
572/// Represents tokenized input ready for model consumption.
573///
574/// `TokenizedInput` contains all the necessary components for feeding
575/// text data into a transformer model after tokenization.
576///
577/// # Fields
578///
579/// * `input_ids` - The token IDs representing the input text
580/// * `attention_mask` - Binary mask (0 or 1) indicating which tokens are real vs padding
581/// * `token_type_ids` - Optional segment IDs for models that use them (e.g., BERT)
582///
583/// # Example
584///
585/// ```no_run
586/// use trustformers_core::traits::TokenizedInput;
587///
588/// let input = TokenizedInput {
589///     input_ids: vec![101, 2023, 2003, 1037, 3231, 102], // [CLS] this is a test [SEP]
590///     attention_mask: vec![1, 1, 1, 1, 1, 1], // All tokens are real (not padding)
591///     token_type_ids: Some(vec![0, 0, 0, 0, 0, 0]), // All tokens from first segment
592/// };
593/// ```
594#[derive(Debug, Clone, Default)]
595pub struct TokenizedInput {
596    /// Token IDs representing the encoded text.
597    /// These correspond to entries in the tokenizer's vocabulary.
598    pub input_ids: Vec<u32>,
599
600    /// Binary attention mask indicating real tokens (1) vs padding tokens (0).
601    /// This prevents the model from attending to padding tokens.
602    pub attention_mask: Vec<u8>,
603
604    /// Optional token type IDs for distinguishing between different segments.
605    /// Used by models like BERT for tasks involving multiple sequences.
606    /// Typically 0 for the first sequence and 1 for the second sequence.
607    pub token_type_ids: Option<Vec<u32>>,
608
609    /// Optional special tokens mask indicating special tokens (1) vs regular tokens (0).
610    /// Used to identify tokens like [CLS], [SEP], [PAD] etc.
611    pub special_tokens_mask: Option<Vec<u8>>,
612
613    /// Optional offset mapping showing character positions of tokens in original text.
614    /// Each tuple contains (start_pos, end_pos) character offsets.
615    pub offset_mapping: Option<Vec<(usize, usize)>>,
616
617    /// Optional overflowing tokens when text exceeds max length.
618    /// Contains tokens that were truncated from the input.
619    pub overflowing_tokens: Option<Vec<u32>>,
620}
621
622impl TokenizedInput {
623    /// Create a new TokenizedInput with minimal required fields
624    pub fn new(input_ids: Vec<u32>, attention_mask: Vec<u8>) -> Self {
625        Self {
626            input_ids,
627            attention_mask,
628            token_type_ids: None,
629            special_tokens_mask: None,
630            offset_mapping: None,
631            overflowing_tokens: None,
632        }
633    }
634
635    /// Create a new TokenizedInput with token type IDs
636    pub fn with_token_type_ids(
637        input_ids: Vec<u32>,
638        attention_mask: Vec<u8>,
639        token_type_ids: Option<Vec<u32>>,
640    ) -> Self {
641        Self {
642            input_ids,
643            attention_mask,
644            token_type_ids,
645            special_tokens_mask: None,
646            offset_mapping: None,
647            overflowing_tokens: None,
648        }
649    }
650}
651
652/// Parameter optimization algorithms for training neural networks.
653///
654/// The `Optimizer` trait defines the interface for gradient-based optimization
655/// algorithms such as SGD, Adam, AdamW, etc. Optimizers update model parameters
656/// based on computed gradients to minimize the loss function.
657///
658/// # Thread Safety
659///
660/// Optimizers must be `Send + Sync` to support distributed training.
661///
662/// # Example
663///
664/// ```no_run
665/// use trustformers_core::traits::Optimizer;
666/// use trustformers_core::tensor::Tensor;
667/// use trustformers_core::error::Result;
668///
669/// struct SGD {
670///     learning_rate: f32,
671///     momentum: f32,
672///     velocity: std::collections::HashMap<String, Tensor>,
673/// }
674///
675/// impl Optimizer for SGD {
676///     fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
677///         // SGD with momentum: v = momentum * v - lr * grad
678///         // parameter += v
679///         Ok(())
680///     }
681///
682///     fn zero_grad(&mut self) {
683///         // Clear accumulated gradients
684///     }
685///
686///     fn step(&mut self) {
687///         // Apply updates to all parameters
688///     }
689///
690///     fn get_lr(&self) -> f32 {
691///         self.learning_rate
692///     }
693///
694///     fn set_lr(&mut self, lr: f32) {
695///         self.learning_rate = lr;
696///     }
697/// }
698/// ```
699pub trait Optimizer: Send + Sync {
700    /// Updates a parameter based on its gradient.
701    ///
702    /// # Arguments
703    ///
704    /// * `parameter` - The parameter tensor to update
705    /// * `grad` - The gradient tensor for this parameter
706    ///
707    /// # Returns
708    ///
709    /// Returns `Ok(())` on successful update, or an error if the update fails.
710    ///
711    /// # Errors
712    ///
713    /// May return errors for:
714    /// - Mismatched tensor shapes
715    /// - Numerical errors (e.g., NaN or Inf values)
716    /// - Memory allocation failures
717    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()>;
718
719    /// Clears all accumulated gradients.
720    ///
721    /// This should be called before each backward pass to ensure
722    /// gradients don't accumulate across batches (unless gradient
723    /// accumulation is intentionally being used).
724    fn zero_grad(&mut self);
725
726    /// Performs a single optimization step.
727    ///
728    /// This method applies all pending parameter updates. It should be
729    /// called after gradients have been computed for all parameters.
730    fn step(&mut self);
731
732    /// Gets the current learning rate.
733    ///
734    /// # Returns
735    ///
736    /// The current learning rate value.
737    fn get_lr(&self) -> f32;
738
739    /// Sets a new learning rate.
740    ///
741    /// # Arguments
742    ///
743    /// * `lr` - The new learning rate value
744    ///
745    /// # Note
746    ///
747    /// This is useful for implementing learning rate schedules.
748    fn set_lr(&mut self, lr: f32);
749
750    /// Accumulates gradients for gradient accumulation.
751    ///
752    /// This method is used when training with gradient accumulation,
753    /// where gradients from multiple batches are accumulated before
754    /// performing an update step.
755    ///
756    /// # Arguments
757    ///
758    /// * `parameter` - The parameter tensor
759    /// * `grad` - The gradient to accumulate
760    ///
761    /// # Returns
762    ///
763    /// Returns `Ok(())` on success, or an error if accumulation fails.
764    ///
765    /// # Default Implementation
766    ///
767    /// The default implementation simply calls `update()`. Override this
768    /// method for optimizers that need special gradient accumulation logic.
769    fn accumulate_grad(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
770        // Default implementation: just store the gradient for later use
771        self.update(parameter, grad)
772    }
773
774    /// Applies accumulated gradients after gradient accumulation.
775    ///
776    /// This method should be called after accumulating gradients from
777    /// multiple batches to apply the averaged update.
778    ///
779    /// # Arguments
780    ///
781    /// * `accumulation_steps` - The number of accumulation steps performed
782    ///
783    /// # Returns
784    ///
785    /// Returns `Ok(())` on success, or an error if application fails.
786    ///
787    /// # Default Implementation
788    ///
789    /// The default implementation is a no-op. Override this method for
790    /// optimizers that implement gradient accumulation.
791    fn apply_accumulated_grads(&mut self, accumulation_steps: usize) -> Result<()> {
792        // Default implementation: no-op, override if needed
793        let _ = accumulation_steps;
794        Ok(())
795    }
796}
797
798/// Weight initialization strategies for neural network parameters.
799///
800/// The `ParameterInit` trait provides various initialization methods that help
801/// ensure proper gradient flow and training stability. Different initialization
802/// strategies are optimal for different activation functions and architectures.
803///
804/// # Example
805///
806/// ```no_run
807/// use trustformers_core::traits::ParameterInit;
808/// use trustformers_core::tensor::Tensor;
809///
810/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
811/// let mut weight = Tensor::zeros(&[768, 768])?;
812///
813/// // Initialize with Xavier/Glorot uniform for tanh activations
814/// weight.xavier_uniform();
815///
816/// // Or use Kaiming/He initialization for ReLU activations
817/// weight.kaiming_normal("fan_in", "relu");
818/// # Ok(())
819/// # }
820/// ```
821pub trait ParameterInit {
822    /// Initializes the tensor with values from a normal distribution.
823    ///
824    /// # Arguments
825    ///
826    /// * `mean` - The mean of the normal distribution
827    /// * `std` - The standard deviation of the normal distribution
828    ///
829    /// # Example
830    ///
831    /// ```no_run
832    /// # use trustformers_core::tensor::Tensor;
833    /// # use trustformers_core::traits::ParameterInit;
834    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
835    /// let mut tensor = Tensor::zeros(&[100, 100])?;
836    /// tensor.normal(0.0, 0.02); // Common for transformer embeddings
837    /// # Ok(())
838    /// # }
839    /// ```
840    fn normal(&mut self, mean: f32, std: f32);
841
842    /// Initializes the tensor with values from a uniform distribution.
843    ///
844    /// # Arguments
845    ///
846    /// * `min` - The minimum value (inclusive)
847    /// * `max` - The maximum value (exclusive)
848    ///
849    /// # Example
850    ///
851    /// ```no_run
852    /// # use trustformers_core::tensor::Tensor;
853    /// # use trustformers_core::traits::ParameterInit;
854    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
855    /// let mut tensor = Tensor::zeros(&[100, 100])?;
856    /// tensor.uniform(-0.1, 0.1);
857    /// # Ok(())
858    /// # }
859    /// ```
860    fn uniform(&mut self, min: f32, max: f32);
861
862    /// Xavier/Glorot uniform initialization.
863    ///
864    /// Initializes weights to maintain variance across layers, optimal for
865    /// tanh and sigmoid activations. The range is [-x, x] where
866    /// x = sqrt(6 / (fan_in + fan_out)).
867    ///
868    /// # References
869    ///
870    /// Glorot & Bengio (2010): "Understanding the difficulty of training
871    /// deep feedforward neural networks"
872    fn xavier_uniform(&mut self);
873
874    /// Xavier/Glorot normal initialization.
875    ///
876    /// Similar to `xavier_uniform` but uses a normal distribution with
877    /// std = sqrt(2 / (fan_in + fan_out)).
878    fn xavier_normal(&mut self);
879
880    /// Kaiming/He uniform initialization.
881    ///
882    /// Designed for ReLU and similar activations. Maintains variance when
883    /// half of the neurons are zeroed out by ReLU.
884    ///
885    /// # Arguments
886    ///
887    /// * `mode` - Either "fan_in" or "fan_out", determines which dimension to use
888    /// * `nonlinearity` - The activation function ("relu", "leaky_relu", "linear")
889    ///
890    /// # Example
891    ///
892    /// ```no_run
893    /// # use trustformers_core::tensor::Tensor;
894    /// # use trustformers_core::traits::ParameterInit;
895    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
896    /// let mut conv_weight = Tensor::zeros(&[64, 32, 3, 3])?;
897    /// conv_weight.kaiming_uniform("fan_in", "relu");
898    /// # Ok(())
899    /// # }
900    /// ```
901    ///
902    /// # References
903    ///
904    /// He et al. (2015): "Delving Deep into Rectifiers: Surpassing
905    /// Human-Level Performance on ImageNet Classification"
906    fn kaiming_uniform(&mut self, mode: &str, nonlinearity: &str);
907
908    /// Kaiming/He normal initialization.
909    ///
910    /// Similar to `kaiming_uniform` but uses a normal distribution.
911    /// Generally preferred over uniform for deeper networks.
912    ///
913    /// # Arguments
914    ///
915    /// * `mode` - Either "fan_in" or "fan_out"
916    /// * `nonlinearity` - The activation function ("relu", "leaky_relu", "linear")
917    fn kaiming_normal(&mut self, mode: &str, nonlinearity: &str);
918}