Skip to main content

tenflowers_neural/model/
mod.rs

1pub mod examples;
2pub mod functional;
3pub mod sequential;
4pub mod subclass;
5pub mod subclass_examples;
6
7pub use functional::{
8    FunctionalModel, FunctionalModelBuilder, Input, LayerOp, Node, NodeId, SharedLayer,
9};
10pub use sequential::Sequential;
11pub use subclass::{helpers, CustomModel, LayerContainer, ModelBase, ModelExt};
12
13use tenflowers_core::{Result, Tensor};
14
15#[cfg(feature = "serialize")]
16use std::path::Path;
17#[cfg(feature = "serialize")]
18use tenflowers_core::TensorError;
19
20#[cfg(feature = "serialize")]
21use serde::{Deserialize, Serialize};
22
23/// Serializable representation of model parameters
24#[cfg(feature = "serialize")]
25#[derive(Serialize, Deserialize)]
26pub struct ModelState {
27    /// Parameter data as flattened vectors
28    pub parameters: Vec<Vec<f32>>,
29    /// Shape information for each parameter
30    pub shapes: Vec<Vec<usize>>,
31    /// Model metadata
32    pub metadata: std::collections::HashMap<String, String>,
33}
34
35/// Core trait for all models
36pub trait Model<T> {
37    fn forward(&self, input: &Tensor<T>) -> Result<Tensor<T>>;
38    fn parameters(&self) -> Vec<&Tensor<T>>;
39    fn parameters_mut(&mut self) -> Vec<&mut Tensor<T>>;
40    fn set_training(&mut self, training: bool);
41    fn zero_grad(&mut self);
42
43    /// Extract intermediate features for knowledge distillation
44    /// Returns None if the model doesn't support feature extraction
45    fn extract_features(&self, input: &Tensor<T>) -> Result<Option<Vec<Tensor<T>>>> {
46        // Default implementation returns None - models can override to provide features
47        let _ = input; // Suppress unused parameter warning
48        Ok(None)
49    }
50
51    /// Provide access to Any for downcasting
52    fn as_any(&self) -> &dyn std::any::Any;
53}
54
55/// Trait for model serialization - separate from Model to maintain dyn compatibility
56#[cfg(feature = "serialize")]
57pub trait ModelSerialization<T> {
58    /// Save model parameters to a file
59    fn save<P: AsRef<Path>>(&self, _path: P) -> Result<()> {
60        Err(TensorError::serialization_error_simple(
61            "Serialization not implemented for this model type".to_string(),
62        ))
63    }
64
65    /// Load model parameters from a file
66    fn load<P: AsRef<Path>>(&mut self, _path: P) -> Result<()> {
67        Err(TensorError::serialization_error_simple(
68            "Deserialization not implemented for this model type".to_string(),
69        ))
70    }
71}
72
73/// Zero the gradient of a tensor parameter
74pub(crate) fn zero_tensor_grad<T>(param: &mut Tensor<T>)
75where
76    T: scirs2_core::num_traits::Zero
77        + Clone
78        + Default
79        + Send
80        + Sync
81        + 'static
82        + bytemuck::Pod
83        + bytemuck::Zeroable,
84{
85    if param.requires_grad() {
86        let zero_grad = create_zero_grad_for_device(param);
87        param.set_grad(zero_grad);
88    }
89}
90
91fn create_zero_grad_for_device<T>(param: &Tensor<T>) -> Option<Tensor<T>>
92where
93    T: scirs2_core::num_traits::Zero
94        + Clone
95        + Default
96        + Send
97        + Sync
98        + 'static
99        + bytemuck::Pod
100        + bytemuck::Zeroable,
101{
102    match param.device() {
103        tenflowers_core::Device::Cpu => Some(Tensor::zeros(param.shape().dims())),
104        #[cfg(feature = "gpu")]
105        tenflowers_core::Device::Gpu(_) => {
106            let cpu_zeros = Tensor::zeros(param.shape().dims());
107            cpu_zeros.to(param.device().clone()).ok() // Fall back to no gradient if transfer fails
108        }
109        #[allow(unreachable_patterns)]
110        _ => {
111            // Handle any other device variants from core (e.g. Rocm)
112            let cpu_zeros = Tensor::zeros(param.shape().dims());
113            cpu_zeros.to(param.device().clone()).ok() // Fall back to no gradient if transfer fails
114        }
115    }
116}