tenflowers_neural/model/
mod.rs1pub mod examples;
2pub mod functional;
3pub mod sequential;
4pub mod subclass;
5pub mod subclass_examples;
6
7pub use functional::{
8 FunctionalModel, FunctionalModelBuilder, Input, LayerOp, Node, NodeId, SharedLayer,
9};
10pub use sequential::Sequential;
11pub use subclass::{helpers, CustomModel, LayerContainer, ModelBase, ModelExt};
12
13use tenflowers_core::{Result, Tensor};
14
15#[cfg(feature = "serialize")]
16use std::path::Path;
17#[cfg(feature = "serialize")]
18use tenflowers_core::TensorError;
19
20#[cfg(feature = "serialize")]
21use serde::{Deserialize, Serialize};
22
23#[cfg(feature = "serialize")]
25#[derive(Serialize, Deserialize)]
26pub struct ModelState {
27 pub parameters: Vec<Vec<f32>>,
29 pub shapes: Vec<Vec<usize>>,
31 pub metadata: std::collections::HashMap<String, String>,
33}
34
35pub trait Model<T> {
37 fn forward(&self, input: &Tensor<T>) -> Result<Tensor<T>>;
38 fn parameters(&self) -> Vec<&Tensor<T>>;
39 fn parameters_mut(&mut self) -> Vec<&mut Tensor<T>>;
40 fn set_training(&mut self, training: bool);
41 fn zero_grad(&mut self);
42
43 fn extract_features(&self, input: &Tensor<T>) -> Result<Option<Vec<Tensor<T>>>> {
46 let _ = input; Ok(None)
49 }
50
51 fn as_any(&self) -> &dyn std::any::Any;
53}
54
55#[cfg(feature = "serialize")]
57pub trait ModelSerialization<T> {
58 fn save<P: AsRef<Path>>(&self, _path: P) -> Result<()> {
60 Err(TensorError::serialization_error_simple(
61 "Serialization not implemented for this model type".to_string(),
62 ))
63 }
64
65 fn load<P: AsRef<Path>>(&mut self, _path: P) -> Result<()> {
67 Err(TensorError::serialization_error_simple(
68 "Deserialization not implemented for this model type".to_string(),
69 ))
70 }
71}
72
73pub(crate) fn zero_tensor_grad<T>(param: &mut Tensor<T>)
75where
76 T: scirs2_core::num_traits::Zero
77 + Clone
78 + Default
79 + Send
80 + Sync
81 + 'static
82 + bytemuck::Pod
83 + bytemuck::Zeroable,
84{
85 if param.requires_grad() {
86 let zero_grad = create_zero_grad_for_device(param);
87 param.set_grad(zero_grad);
88 }
89}
90
91fn create_zero_grad_for_device<T>(param: &Tensor<T>) -> Option<Tensor<T>>
92where
93 T: scirs2_core::num_traits::Zero
94 + Clone
95 + Default
96 + Send
97 + Sync
98 + 'static
99 + bytemuck::Pod
100 + bytemuck::Zeroable,
101{
102 match param.device() {
103 tenflowers_core::Device::Cpu => Some(Tensor::zeros(param.shape().dims())),
104 #[cfg(feature = "gpu")]
105 tenflowers_core::Device::Gpu(_) => {
106 let cpu_zeros = Tensor::zeros(param.shape().dims());
107 cpu_zeros.to(param.device().clone()).ok() }
109 #[allow(unreachable_patterns)]
110 _ => {
111 let cpu_zeros = Tensor::zeros(param.shape().dims());
113 cpu_zeros.to(param.device().clone()).ok() }
115 }
116}