1mod activation;
2mod attention;
3mod conv;
4mod linear;
5mod misc;
6mod norm;
7mod pool;
8mod recurrent;
9
10use yscv_autograd::{Graph, NodeId};
11use yscv_tensor::Tensor;
12
13use super::lora::LoraLinear;
14use crate::ModelError;
15
16pub use activation::{
18 GELULayer, LeakyReLULayer, MishLayer, PReLULayer, ReLULayer, SiLULayer, SigmoidLayer, TanhLayer,
19};
20pub use attention::{
21 EmbeddingLayer, FeedForwardLayer, MultiHeadAttentionLayer, TransformerEncoderLayer,
22};
23pub use conv::{
24 Conv1dLayer, Conv2dLayer, Conv3dLayer, ConvTranspose2dLayer, DeformableConv2dLayer,
25 DepthwiseConv2dLayer, SeparableConv2dLayer,
26};
27pub use linear::LinearLayer;
28pub use misc::{
29 DropoutLayer, FlattenLayer, MaskHead, PixelShuffleLayer, ResidualBlock, SoftmaxLayer,
30 UpsampleLayer,
31};
32pub use norm::{BatchNorm2dLayer, GroupNormLayer, InstanceNormLayer, LayerNormLayer};
33pub use pool::{
34 AdaptiveAvgPool2dLayer, AdaptiveMaxPool2dLayer, AvgPool2dLayer, GlobalAvgPool2dLayer,
35 MaxPool2dLayer,
36};
37pub use recurrent::{GruLayer, LstmLayer, RnnLayer};
38
39#[derive(Debug, Clone)]
40pub enum ModelLayer {
41 Linear(LinearLayer),
42 ReLU(ReLULayer),
43 LeakyReLU(LeakyReLULayer),
44 Sigmoid(SigmoidLayer),
45 Tanh(TanhLayer),
46 Dropout(DropoutLayer),
47 Conv2d(Conv2dLayer),
48 BatchNorm2d(BatchNorm2dLayer),
49 MaxPool2d(MaxPool2dLayer),
50 AvgPool2d(AvgPool2dLayer),
51 GlobalAvgPool2d(GlobalAvgPool2dLayer),
52 Flatten(FlattenLayer),
53 Softmax(SoftmaxLayer),
54 Embedding(EmbeddingLayer),
55 LayerNorm(LayerNormLayer),
56 GroupNorm(GroupNormLayer),
57 DepthwiseConv2d(DepthwiseConv2dLayer),
58 SeparableConv2d(SeparableConv2dLayer),
59 LoraLinear(LoraLinear),
60 Conv1d(Conv1dLayer),
61 Conv3d(Conv3dLayer),
62 ConvTranspose2d(ConvTranspose2dLayer),
63 AdaptiveAvgPool2d(AdaptiveAvgPool2dLayer),
64 AdaptiveMaxPool2d(AdaptiveMaxPool2dLayer),
65 InstanceNorm(InstanceNormLayer),
66 PixelShuffle(PixelShuffleLayer),
67 Upsample(UpsampleLayer),
68 GELU(GELULayer),
69 SiLU(SiLULayer),
70 Mish(MishLayer),
71 PReLU(PReLULayer),
72 ResidualBlock(ResidualBlock),
73 Rnn(RnnLayer),
74 Lstm(LstmLayer),
75 Gru(GruLayer),
76 MultiHeadAttention(MultiHeadAttentionLayer),
77 TransformerEncoder(TransformerEncoderLayer),
78 FeedForward(FeedForwardLayer),
79 DeformableConv2d(DeformableConv2dLayer),
80}
81
82impl ModelLayer {
83 pub(crate) fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
84 match self {
85 Self::Linear(layer) => layer.forward(graph, input),
86 Self::ReLU(layer) => layer.forward(graph, input),
87 Self::LeakyReLU(layer) => layer.forward(graph, input),
88 Self::Sigmoid(layer) => layer.forward(graph, input),
89 Self::Tanh(layer) => layer.forward(graph, input),
90 Self::Dropout(layer) => layer.forward(graph, input),
91 Self::Flatten(layer) => layer.forward(graph, input),
92 Self::Conv2d(layer) => layer.forward(graph, input),
93 Self::BatchNorm2d(layer) => layer.forward(graph, input),
94 Self::MaxPool2d(layer) => layer.forward(graph, input),
95 Self::AvgPool2d(layer) => layer.forward(graph, input),
96 Self::GlobalAvgPool2d(layer) => layer.forward(graph, input),
97 Self::Embedding(layer) => layer.forward(graph, input),
98 Self::DepthwiseConv2d(layer) => layer.forward(graph, input),
99 Self::SeparableConv2d(layer) => layer.forward(graph, input),
100 Self::LoraLinear(layer) => layer.forward(graph, input),
101 Self::GELU(layer) => layer.forward(graph, input),
102 Self::SiLU(layer) => layer.forward(graph, input),
103 Self::Mish(layer) => layer.forward(graph, input),
104 Self::LayerNorm(layer) => layer.forward(graph, input),
105 Self::GroupNorm(layer) => layer.forward(graph, input),
106 Self::Conv1d(layer) => layer.forward(graph, input),
107 Self::Conv3d(layer) => layer.forward(graph, input),
108 Self::MultiHeadAttention(layer) => layer.forward(graph, input),
109 Self::ConvTranspose2d(layer) => layer.forward(graph, input),
110 Self::AdaptiveAvgPool2d(layer) => layer.forward(graph, input),
111 Self::AdaptiveMaxPool2d(layer) => layer.forward(graph, input),
112 Self::InstanceNorm(layer) => layer.forward(graph, input),
113 Self::PReLU(layer) => layer.forward(graph, input),
114 Self::Softmax(layer) => layer.forward(graph, input),
115 Self::PixelShuffle(layer) => layer.forward(graph, input),
116 Self::Upsample(layer) => layer.forward(graph, input),
117 Self::ResidualBlock(layer) => layer.forward(graph, input),
118 Self::Rnn(layer) => layer.forward(graph, input),
119 Self::Lstm(layer) => layer.forward(graph, input),
120 Self::Gru(layer) => layer.forward(graph, input),
121 Self::TransformerEncoder(layer) => layer.forward(graph, input),
122 Self::FeedForward(layer) => layer.forward(graph, input),
123 Self::DeformableConv2d(layer) => layer.forward(graph, input),
124 }
125 }
126
127 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
128 match self {
129 Self::Conv2d(layer) => layer.forward_inference(input),
130 Self::BatchNorm2d(layer) => layer.forward_inference(input),
131 Self::MaxPool2d(layer) => layer.forward_inference(input),
132 Self::AvgPool2d(layer) => layer.forward_inference(input),
133 Self::GlobalAvgPool2d(layer) => layer.forward_inference(input),
134 Self::Flatten(layer) => layer.forward_inference(input),
135 Self::Softmax(layer) => layer.forward_inference(input),
136 Self::DepthwiseConv2d(layer) => layer.forward_inference(input),
137 Self::SeparableConv2d(layer) => layer.forward_inference(input),
138 Self::Conv1d(layer) => layer.forward_inference(input),
139 Self::Conv3d(layer) => layer.forward_inference(input),
140 Self::ConvTranspose2d(layer) => layer.forward_inference(input),
141 Self::AdaptiveAvgPool2d(layer) => layer.forward_inference(input),
142 Self::AdaptiveMaxPool2d(layer) => layer.forward_inference(input),
143 Self::InstanceNorm(layer) => layer.forward_inference(input),
144 Self::PixelShuffle(layer) => layer.forward_inference(input),
145 Self::Upsample(layer) => layer.forward_inference(input),
146 Self::GELU(layer) => layer.forward_inference(input),
147 Self::SiLU(layer) => layer.forward_inference(input),
148 Self::Mish(layer) => layer.forward_inference(input),
149 Self::PReLU(layer) => layer.forward_inference(input),
150 Self::ResidualBlock(layer) => layer.forward_inference(input),
151 Self::Rnn(layer) => layer.forward_inference(input),
152 Self::Lstm(layer) => layer.forward_inference(input),
153 Self::Gru(layer) => layer.forward_inference(input),
154 Self::MultiHeadAttention(layer) => layer.forward_inference(input),
155 Self::TransformerEncoder(layer) => layer.forward_inference(input),
156 Self::FeedForward(layer) => layer.forward_inference(input),
157 Self::DeformableConv2d(layer) => layer.forward_inference(input),
158 Self::Linear(layer) => layer.forward_inference(input),
159 Self::ReLU(_)
160 | Self::LeakyReLU(_)
161 | Self::Sigmoid(_)
162 | Self::Tanh(_)
163 | Self::Dropout(_)
164 | Self::Embedding(_)
165 | Self::LayerNorm(_)
166 | Self::GroupNorm(_)
167 | Self::LoraLinear(_) => Err(ModelError::GraphOnlyLayer),
168 }
169 }
170
171 pub fn supports_graph_forward(&self) -> bool {
172 if let Self::Upsample(u) = self {
174 !u.is_bilinear()
175 } else {
176 true
177 }
178 }
179
180 pub fn supports_inference_forward(&self) -> bool {
181 matches!(
182 self,
183 Self::Conv2d(_)
184 | Self::BatchNorm2d(_)
185 | Self::MaxPool2d(_)
186 | Self::AvgPool2d(_)
187 | Self::GlobalAvgPool2d(_)
188 | Self::Flatten(_)
189 | Self::Softmax(_)
190 | Self::DepthwiseConv2d(_)
191 | Self::SeparableConv2d(_)
192 | Self::Conv1d(_)
193 | Self::Conv3d(_)
194 | Self::ConvTranspose2d(_)
195 | Self::AdaptiveAvgPool2d(_)
196 | Self::AdaptiveMaxPool2d(_)
197 | Self::InstanceNorm(_)
198 | Self::PixelShuffle(_)
199 | Self::Upsample(_)
200 | Self::GELU(_)
201 | Self::SiLU(_)
202 | Self::Mish(_)
203 | Self::PReLU(_)
204 | Self::ResidualBlock(_)
205 | Self::Rnn(_)
206 | Self::Lstm(_)
207 | Self::Gru(_)
208 | Self::MultiHeadAttention(_)
209 | Self::TransformerEncoder(_)
210 | Self::FeedForward(_)
211 | Self::DeformableConv2d(_)
212 )
213 }
214}