Skip to main content

yscv_model/layers/
mod.rs

1mod activation;
2mod attention;
3mod conv;
4mod linear;
5mod misc;
6mod norm;
7mod pool;
8mod recurrent;
9
10use yscv_autograd::{Graph, NodeId};
11use yscv_tensor::Tensor;
12
13use super::lora::LoraLinear;
14use crate::ModelError;
15
16// Re-export all public types from sub-modules.
17pub use activation::{
18    GELULayer, LeakyReLULayer, MishLayer, PReLULayer, ReLULayer, SiLULayer, SigmoidLayer, TanhLayer,
19};
20pub use attention::{
21    EmbeddingLayer, FeedForwardLayer, MultiHeadAttentionLayer, TransformerEncoderLayer,
22};
23pub use conv::{
24    Conv1dLayer, Conv2dLayer, Conv3dLayer, ConvTranspose2dLayer, DeformableConv2dLayer,
25    DepthwiseConv2dLayer, SeparableConv2dLayer,
26};
27pub use linear::LinearLayer;
28pub use misc::{
29    DropoutLayer, FlattenLayer, MaskHead, PixelShuffleLayer, ResidualBlock, SoftmaxLayer,
30    UpsampleLayer,
31};
32pub use norm::{BatchNorm2dLayer, GroupNormLayer, InstanceNormLayer, LayerNormLayer};
33pub use pool::{
34    AdaptiveAvgPool2dLayer, AdaptiveMaxPool2dLayer, AvgPool2dLayer, GlobalAvgPool2dLayer,
35    MaxPool2dLayer,
36};
37pub use recurrent::{GruLayer, LstmLayer, RnnLayer};
38
39#[derive(Debug, Clone)]
40pub enum ModelLayer {
41    Linear(LinearLayer),
42    ReLU(ReLULayer),
43    LeakyReLU(LeakyReLULayer),
44    Sigmoid(SigmoidLayer),
45    Tanh(TanhLayer),
46    Dropout(DropoutLayer),
47    Conv2d(Conv2dLayer),
48    BatchNorm2d(BatchNorm2dLayer),
49    MaxPool2d(MaxPool2dLayer),
50    AvgPool2d(AvgPool2dLayer),
51    GlobalAvgPool2d(GlobalAvgPool2dLayer),
52    Flatten(FlattenLayer),
53    Softmax(SoftmaxLayer),
54    Embedding(EmbeddingLayer),
55    LayerNorm(LayerNormLayer),
56    GroupNorm(GroupNormLayer),
57    DepthwiseConv2d(DepthwiseConv2dLayer),
58    SeparableConv2d(SeparableConv2dLayer),
59    LoraLinear(LoraLinear),
60    Conv1d(Conv1dLayer),
61    Conv3d(Conv3dLayer),
62    ConvTranspose2d(ConvTranspose2dLayer),
63    AdaptiveAvgPool2d(AdaptiveAvgPool2dLayer),
64    AdaptiveMaxPool2d(AdaptiveMaxPool2dLayer),
65    InstanceNorm(InstanceNormLayer),
66    PixelShuffle(PixelShuffleLayer),
67    Upsample(UpsampleLayer),
68    GELU(GELULayer),
69    SiLU(SiLULayer),
70    Mish(MishLayer),
71    PReLU(PReLULayer),
72    ResidualBlock(ResidualBlock),
73    Rnn(RnnLayer),
74    Lstm(LstmLayer),
75    Gru(GruLayer),
76    MultiHeadAttention(MultiHeadAttentionLayer),
77    TransformerEncoder(TransformerEncoderLayer),
78    FeedForward(FeedForwardLayer),
79    DeformableConv2d(DeformableConv2dLayer),
80}
81
82impl ModelLayer {
83    pub(crate) fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
84        match self {
85            Self::Linear(layer) => layer.forward(graph, input),
86            Self::ReLU(layer) => layer.forward(graph, input),
87            Self::LeakyReLU(layer) => layer.forward(graph, input),
88            Self::Sigmoid(layer) => layer.forward(graph, input),
89            Self::Tanh(layer) => layer.forward(graph, input),
90            Self::Dropout(layer) => layer.forward(graph, input),
91            Self::Flatten(layer) => layer.forward(graph, input),
92            Self::Conv2d(layer) => layer.forward(graph, input),
93            Self::BatchNorm2d(layer) => layer.forward(graph, input),
94            Self::MaxPool2d(layer) => layer.forward(graph, input),
95            Self::AvgPool2d(layer) => layer.forward(graph, input),
96            Self::GlobalAvgPool2d(layer) => layer.forward(graph, input),
97            Self::Embedding(layer) => layer.forward(graph, input),
98            Self::DepthwiseConv2d(layer) => layer.forward(graph, input),
99            Self::SeparableConv2d(layer) => layer.forward(graph, input),
100            Self::LoraLinear(layer) => layer.forward(graph, input),
101            Self::GELU(layer) => layer.forward(graph, input),
102            Self::SiLU(layer) => layer.forward(graph, input),
103            Self::Mish(layer) => layer.forward(graph, input),
104            Self::LayerNorm(layer) => layer.forward(graph, input),
105            Self::GroupNorm(layer) => layer.forward(graph, input),
106            Self::Conv1d(layer) => layer.forward(graph, input),
107            Self::Conv3d(layer) => layer.forward(graph, input),
108            Self::MultiHeadAttention(layer) => layer.forward(graph, input),
109            Self::ConvTranspose2d(layer) => layer.forward(graph, input),
110            Self::AdaptiveAvgPool2d(layer) => layer.forward(graph, input),
111            Self::AdaptiveMaxPool2d(layer) => layer.forward(graph, input),
112            Self::InstanceNorm(layer) => layer.forward(graph, input),
113            Self::PReLU(layer) => layer.forward(graph, input),
114            Self::Softmax(layer) => layer.forward(graph, input),
115            Self::PixelShuffle(layer) => layer.forward(graph, input),
116            Self::Upsample(layer) => layer.forward(graph, input),
117            Self::ResidualBlock(layer) => layer.forward(graph, input),
118            Self::Rnn(layer) => layer.forward(graph, input),
119            Self::Lstm(layer) => layer.forward(graph, input),
120            Self::Gru(layer) => layer.forward(graph, input),
121            Self::TransformerEncoder(layer) => layer.forward(graph, input),
122            Self::FeedForward(layer) => layer.forward(graph, input),
123            Self::DeformableConv2d(layer) => layer.forward(graph, input),
124        }
125    }
126
127    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
128        match self {
129            Self::Conv2d(layer) => layer.forward_inference(input),
130            Self::BatchNorm2d(layer) => layer.forward_inference(input),
131            Self::MaxPool2d(layer) => layer.forward_inference(input),
132            Self::AvgPool2d(layer) => layer.forward_inference(input),
133            Self::GlobalAvgPool2d(layer) => layer.forward_inference(input),
134            Self::Flatten(layer) => layer.forward_inference(input),
135            Self::Softmax(layer) => layer.forward_inference(input),
136            Self::DepthwiseConv2d(layer) => layer.forward_inference(input),
137            Self::SeparableConv2d(layer) => layer.forward_inference(input),
138            Self::Conv1d(layer) => layer.forward_inference(input),
139            Self::Conv3d(layer) => layer.forward_inference(input),
140            Self::ConvTranspose2d(layer) => layer.forward_inference(input),
141            Self::AdaptiveAvgPool2d(layer) => layer.forward_inference(input),
142            Self::AdaptiveMaxPool2d(layer) => layer.forward_inference(input),
143            Self::InstanceNorm(layer) => layer.forward_inference(input),
144            Self::PixelShuffle(layer) => layer.forward_inference(input),
145            Self::Upsample(layer) => layer.forward_inference(input),
146            Self::GELU(layer) => layer.forward_inference(input),
147            Self::SiLU(layer) => layer.forward_inference(input),
148            Self::Mish(layer) => layer.forward_inference(input),
149            Self::PReLU(layer) => layer.forward_inference(input),
150            Self::ResidualBlock(layer) => layer.forward_inference(input),
151            Self::Rnn(layer) => layer.forward_inference(input),
152            Self::Lstm(layer) => layer.forward_inference(input),
153            Self::Gru(layer) => layer.forward_inference(input),
154            Self::MultiHeadAttention(layer) => layer.forward_inference(input),
155            Self::TransformerEncoder(layer) => layer.forward_inference(input),
156            Self::FeedForward(layer) => layer.forward_inference(input),
157            Self::DeformableConv2d(layer) => layer.forward_inference(input),
158            Self::Linear(layer) => layer.forward_inference(input),
159            Self::ReLU(_)
160            | Self::LeakyReLU(_)
161            | Self::Sigmoid(_)
162            | Self::Tanh(_)
163            | Self::Dropout(_)
164            | Self::Embedding(_)
165            | Self::LayerNorm(_)
166            | Self::GroupNorm(_)
167            | Self::LoraLinear(_) => Err(ModelError::GraphOnlyLayer),
168        }
169    }
170
171    pub fn supports_graph_forward(&self) -> bool {
172        // Bilinear upsample still requires inference-only mode.
173        if let Self::Upsample(u) = self {
174            !u.is_bilinear()
175        } else {
176            true
177        }
178    }
179
180    pub fn supports_inference_forward(&self) -> bool {
181        matches!(
182            self,
183            Self::Conv2d(_)
184                | Self::BatchNorm2d(_)
185                | Self::MaxPool2d(_)
186                | Self::AvgPool2d(_)
187                | Self::GlobalAvgPool2d(_)
188                | Self::Flatten(_)
189                | Self::Softmax(_)
190                | Self::DepthwiseConv2d(_)
191                | Self::SeparableConv2d(_)
192                | Self::Conv1d(_)
193                | Self::Conv3d(_)
194                | Self::ConvTranspose2d(_)
195                | Self::AdaptiveAvgPool2d(_)
196                | Self::AdaptiveMaxPool2d(_)
197                | Self::InstanceNorm(_)
198                | Self::PixelShuffle(_)
199                | Self::Upsample(_)
200                | Self::GELU(_)
201                | Self::SiLU(_)
202                | Self::Mish(_)
203                | Self::PReLU(_)
204                | Self::ResidualBlock(_)
205                | Self::Rnn(_)
206                | Self::Lstm(_)
207                | Self::Gru(_)
208                | Self::MultiHeadAttention(_)
209                | Self::TransformerEncoder(_)
210                | Self::FeedForward(_)
211                | Self::DeformableConv2d(_)
212        )
213    }
214}