1use crate::error::{NeuralError, Result};
4use crate::nas::search_space::{Architecture, LayerType};
5use std::collections::HashMap;
6use std::fmt;
7
8pub trait ArchitectureEncoding: Send + Sync + fmt::Display {
10 fn to_vector(&self) -> Vec<f64>;
12 fn from_vector(vec: &[f64]) -> Result<Self>
14 where
15 Self: Sized;
16 fn dimension(&self) -> usize;
18 fn mutate(&self, mutation_rate: f32) -> Result<Box<dyn ArchitectureEncoding>>;
20 fn crossover(&self, other: &dyn ArchitectureEncoding) -> Result<Box<dyn ArchitectureEncoding>>;
22 fn to_architecture(&self) -> Result<Architecture>;
24 fn clone_box(&self) -> Box<dyn ArchitectureEncoding>;
26}
27
28#[derive(Debug, Clone)]
30pub struct NodeType {
31 pub layer_type: LayerType,
32 pub is_input: bool,
33 pub is_output: bool,
34}
35
36#[derive(Debug, Clone)]
38pub struct NodeAttributes {
39 pub name: String,
40 pub operation_type: String,
41 pub parameters: HashMap<String, f64>,
42}
43
44#[derive(Debug, Clone)]
46pub struct GraphEncoding {
47 pub nodes: Vec<NodeType>,
49 pub edges: Vec<Vec<bool>>,
51 pub node_attrs: Vec<NodeAttributes>,
53}
54
55impl GraphEncoding {
56 pub fn new(nodes: Vec<NodeType>, edges: Vec<Vec<bool>>) -> Self {
58 let node_attrs = nodes
59 .iter()
60 .enumerate()
61 .map(|(i, _)| NodeAttributes {
62 name: format!("node_{}", i),
63 operation_type: "default".to_string(),
64 parameters: HashMap::new(),
65 })
66 .collect();
67 Self {
68 nodes,
69 edges,
70 node_attrs,
71 }
72 }
73
74 pub fn random<R: scirs2_core::random::Rng>(
76 rng: &mut scirs2_core::random::prelude::Random<R>,
77 ) -> Result<Self> {
78 let num_nodes = rng.random_range(3..=8);
79 let mut nodes = Vec::with_capacity(num_nodes);
80 nodes.push(NodeType {
81 layer_type: LayerType::Dense(rng.random_range(64..=256)),
82 is_input: true,
83 is_output: false,
84 });
85 for _ in 1..num_nodes - 1 {
86 let layer_type = match rng.random_range(0..5) {
87 0 => LayerType::Dense(rng.random_range(32..=512)),
88 1 => LayerType::Conv2D {
89 filters: rng.random_range(16..=256),
90 kernel_size: (3, 3),
91 stride: (1, 1),
92 },
93 2 => LayerType::Dropout(rng.random_range(10..50) as f32 / 100.0),
94 3 => LayerType::BatchNorm,
95 _ => LayerType::Activation("relu".to_string()),
96 };
97 nodes.push(NodeType {
98 layer_type,
99 is_input: false,
100 is_output: false,
101 });
102 }
103 nodes.push(NodeType {
104 layer_type: LayerType::Dense(rng.random_range(1..=10)),
105 is_input: false,
106 is_output: true,
107 });
108 let mut edges = vec![vec![false; num_nodes]; num_nodes];
109 for i in 0..num_nodes - 1 {
110 edges[i][i + 1] = true;
111 }
112 for (i, row) in edges.iter_mut().enumerate().take(num_nodes) {
113 for cell in row.iter_mut().take(num_nodes).skip(i + 2) {
114 if rng.random_bool(0.2) {
115 *cell = true;
116 }
117 }
118 }
119 Ok(Self::new(nodes, edges))
120 }
121
122 fn compute_complexity_factor(&self) -> f32 {
123 let mut layer_types = std::collections::HashSet::new();
124 for node in &self.nodes {
125 layer_types.insert(std::mem::discriminant(&node.layer_type));
126 }
127 let mut complexity = layer_types.len() as f32 / self.nodes.len().max(1) as f32;
128 let mut connections = 0;
129 for row in &self.edges {
130 connections += row.iter().filter(|&&x| x).count();
131 }
132 let n = self.nodes.len();
133 complexity += connections as f32 / (n * n).max(1) as f32;
134 complexity.min(1.0)
135 }
136
137 fn choose_kernel_size<R: scirs2_core::random::Rng>(
138 &self,
139 rng: &mut scirs2_core::random::prelude::Random<R>,
140 ) -> (usize, usize) {
141 let sizes = [(1usize, 1usize), (3, 3), (5, 5), (7, 7)];
142 let idx = rng.random_range(0..sizes.len());
143 sizes[idx]
144 }
145
146 fn choose_stride<R: scirs2_core::random::Rng>(
147 &self,
148 rng: &mut scirs2_core::random::prelude::Random<R>,
149 ) -> (usize, usize) {
150 let strides = [(1usize, 1usize), (2, 2)];
151 let idx = rng.random_range(0..strides.len());
152 strides[idx]
153 }
154
155 fn choose_random_layer_type<R: scirs2_core::random::Rng>(
156 &self,
157 rng: &mut scirs2_core::random::prelude::Random<R>,
158 ) -> LayerType {
159 let k = self.choose_kernel_size(rng);
160 let s = self.choose_stride(rng);
161 match rng.random_range(0..5) {
162 0 => LayerType::Dense(rng.random_range(32..=512)),
163 1 => LayerType::Conv2D {
164 filters: rng.random_range(16..=256),
165 kernel_size: k,
166 stride: s,
167 },
168 2 => LayerType::Dropout(rng.random_range(10..50) as f32 / 100.0),
169 3 => LayerType::BatchNorm,
170 _ => LayerType::Activation("relu".to_string()),
171 }
172 }
173
174 fn would_disconnect_graph(
175 &self,
176 edges: &[Vec<bool>],
177 from: usize,
178 to: usize,
179 num_nodes: usize,
180 ) -> bool {
181 let mut test_edges = edges.to_vec();
182 test_edges[from][to] = !test_edges[from][to];
183 let mut reachable = vec![false; num_nodes];
184 for (i, node) in self.nodes.iter().enumerate() {
185 if node.is_input {
186 reachable[i] = true;
187 }
188 }
189 let mut changed = true;
190 while changed {
191 changed = false;
192 for i in 0..num_nodes {
193 if reachable[i] {
194 for j in 0..num_nodes {
195 if test_edges[i][j] && !reachable[j] {
196 reachable[j] = true;
197 changed = true;
198 }
199 }
200 }
201 }
202 }
203 for (i, node) in self.nodes.iter().enumerate() {
204 if node.is_output && !reachable[i] {
205 return true;
206 }
207 }
208 false
209 }
210
211 fn add_node<R: scirs2_core::random::Rng>(
212 &self,
213 mutated: &mut GraphEncoding,
214 rng: &mut scirs2_core::random::prelude::Random<R>,
215 ) -> Result<()> {
216 let new_layer_type = self.choose_random_layer_type(rng);
217 let new_node = NodeType {
218 layer_type: new_layer_type,
219 is_input: false,
220 is_output: false,
221 };
222 mutated.nodes.push(new_node);
223 let new_size = mutated.nodes.len();
224 for row in &mut mutated.edges {
225 row.push(false);
226 }
227 mutated.edges.push(vec![false; new_size]);
228 mutated.node_attrs.push(NodeAttributes {
229 name: format!("node_{}", new_size - 1),
230 operation_type: "default".to_string(),
231 parameters: HashMap::new(),
232 });
233 if new_size >= 2 {
234 let from_idx = rng.random_range(0..new_size - 1);
235 mutated.edges[from_idx][new_size - 1] = true;
236 if new_size >= 2 {
237 let to_idx = rng.random_range(0..new_size - 1);
238 if to_idx != new_size - 1 {
239 mutated.edges[new_size - 1][to_idx] = true;
240 }
241 }
242 }
243 Ok(())
244 }
245}
246
247impl ArchitectureEncoding for GraphEncoding {
248 fn to_vector(&self) -> Vec<f64> {
249 let mut vec = Vec::new();
250 vec.push(self.nodes.len() as f64);
251 for node in &self.nodes {
252 vec.push(if node.is_input { 1.0 } else { 0.0 });
253 vec.push(if node.is_output { 1.0 } else { 0.0 });
254 match &node.layer_type {
255 LayerType::Dense(units) => {
256 vec.push(1.0);
257 vec.push(*units as f64);
258 }
259 LayerType::Conv2D { filters, .. } => {
260 vec.push(2.0);
261 vec.push(*filters as f64);
262 }
263 LayerType::Dropout(rate) => {
264 vec.push(3.0);
265 vec.push(*rate as f64);
266 }
267 LayerType::BatchNorm => {
268 vec.push(4.0);
269 vec.push(0.0);
270 }
271 LayerType::Activation(_) => {
272 vec.push(5.0);
273 vec.push(0.0);
274 }
275 _ => {
276 vec.push(0.0);
277 vec.push(0.0);
278 }
279 }
280 }
281 for row in &self.edges {
282 for &edge in row {
283 vec.push(if edge { 1.0 } else { 0.0 });
284 }
285 }
286 vec
287 }
288
289 fn from_vector(vec: &[f64]) -> Result<Self> {
290 if vec.is_empty() {
291 return Err(NeuralError::ConfigError(
292 "Empty vector for GraphEncoding".to_string(),
293 ));
294 }
295 let num_nodes = vec[0] as usize;
296 if num_nodes == 0 {
297 return Err(NeuralError::ConfigError(
298 "GraphEncoding must have at least one node".to_string(),
299 ));
300 }
301 let expected_size = 1 + num_nodes * 4 + num_nodes * num_nodes;
302 if vec.len() < expected_size {
303 return Err(NeuralError::ConfigError(format!(
304 "Vector too short: expected at least {}, got {}",
305 expected_size,
306 vec.len()
307 )));
308 }
309 let mut nodes = Vec::with_capacity(num_nodes);
310 let mut node_attrs = Vec::with_capacity(num_nodes);
311 let mut idx = 1;
312 for i in 0..num_nodes {
313 let is_input = vec[idx] > 0.5;
314 let is_output = vec[idx + 1] > 0.5;
315 let layer_type_code = vec[idx + 2] as i32;
316 let layer_param = vec[idx + 3];
317 let layer_type = match layer_type_code {
318 1 => LayerType::Dense(layer_param as usize),
319 2 => LayerType::Conv2D {
320 filters: layer_param as usize,
321 kernel_size: (3, 3),
322 stride: (1, 1),
323 },
324 3 => LayerType::Dropout(layer_param as f32),
325 4 => LayerType::BatchNorm,
326 5 => LayerType::Activation("relu".to_string()),
327 _ => LayerType::Dense(64),
328 };
329 nodes.push(NodeType {
330 layer_type,
331 is_input,
332 is_output,
333 });
334 node_attrs.push(NodeAttributes {
335 name: format!("node_{}", i),
336 operation_type: "default".to_string(),
337 parameters: HashMap::new(),
338 });
339 idx += 4;
340 }
341 let mut edges = vec![vec![false; num_nodes]; num_nodes];
342 for row in edges.iter_mut().take(num_nodes) {
343 for cell in row.iter_mut().take(num_nodes) {
344 if idx < vec.len() {
345 *cell = vec[idx] > 0.5;
346 idx += 1;
347 }
348 }
349 }
350 Ok(GraphEncoding {
351 nodes,
352 edges,
353 node_attrs,
354 })
355 }
356
357 fn dimension(&self) -> usize {
358 1 + self.nodes.len() * 4 + self.edges.len() * self.edges.len()
359 }
360
361 fn mutate(&self, mutation_rate: f32) -> Result<Box<dyn ArchitectureEncoding>> {
362 use scirs2_core::random::prelude::*;
363 let mut rng = thread_rng();
364 let mut mutated = self.clone();
365 let complexity_factor = self.compute_complexity_factor();
366 let adaptive_rate = mutation_rate * (1.0 + complexity_factor * 0.5);
367 let mutation_type = rng.random_range(0..5);
368 match mutation_type {
369 0 => {
370 for node in &mut mutated.nodes {
372 if !node.is_input && !node.is_output && rng.random_bool(adaptive_rate as f64) {
373 node.layer_type = self.choose_random_layer_type(&mut rng);
374 }
375 }
376 }
377 1 => {
378 for node in &mut mutated.nodes {
380 if !node.is_input && !node.is_output && rng.random_bool(adaptive_rate as f64) {
381 match &mut node.layer_type {
382 LayerType::Dense(ref mut units) => {
383 *units = rng.random_range(32..=512);
384 }
385 LayerType::Conv2D {
386 ref mut filters, ..
387 } => {
388 *filters = rng.random_range(16..=256);
389 }
390 LayerType::Dropout(ref mut rate) => {
391 *rate = rng.random_range(10..50) as f32 / 100.0;
392 }
393 _ => {}
394 }
395 }
396 }
397 }
398 2 => {
399 let num_nodes = mutated.nodes.len();
401 for i in 0..num_nodes {
402 for j in 0..num_nodes {
403 if i != j && rng.random_bool(adaptive_rate as f64) {
404 let would_disconnect =
405 self.would_disconnect_graph(&mutated.edges, i, j, num_nodes);
406 if !would_disconnect {
407 mutated.edges[i][j] = !mutated.edges[i][j];
408 }
409 }
410 }
411 }
412 }
413 3 => {
414 if rng.random_bool(adaptive_rate as f64) && mutated.nodes.len() < 20 {
416 self.add_node(&mut mutated, &mut rng)?;
417 }
418 }
419 _ => {
420 for node in &mut mutated.nodes {
422 if !node.is_input
423 && !node.is_output
424 && rng.random_bool(adaptive_rate as f64 * 0.3)
425 {
426 node.layer_type = self.choose_random_layer_type(&mut rng);
427 }
428 }
429 let num_nodes = mutated.nodes.len();
430 for i in 0..num_nodes {
431 for j in 0..num_nodes {
432 if i != j && rng.random_bool(adaptive_rate as f64 * 0.2) {
433 let would_disconnect =
434 self.would_disconnect_graph(&mutated.edges, i, j, num_nodes);
435 if !would_disconnect {
436 mutated.edges[i][j] = !mutated.edges[i][j];
437 }
438 }
439 }
440 }
441 }
442 }
443 Ok(Box::new(mutated))
444 }
445
446 fn crossover(&self, other: &dyn ArchitectureEncoding) -> Result<Box<dyn ArchitectureEncoding>> {
447 use scirs2_core::random::prelude::*;
448 let mut rng = thread_rng();
449 let self_vec = self.to_vector();
450 let other_vec = other.to_vector();
451 let min_len = self_vec.len().min(other_vec.len());
452 let mut mixed_vec = Vec::with_capacity(self_vec.len().max(other_vec.len()));
453 for i in 0..min_len {
454 if rng.random_bool(0.5) {
455 mixed_vec.push(self_vec[i]);
456 } else {
457 mixed_vec.push(other_vec[i]);
458 }
459 }
460 if self_vec.len() > min_len {
461 mixed_vec.extend_from_slice(&self_vec[min_len..]);
462 } else if other_vec.len() > min_len {
463 mixed_vec.extend_from_slice(&other_vec[min_len..]);
464 }
465 let result = GraphEncoding::from_vector(&mixed_vec)?;
466 Ok(Box::new(result))
467 }
468
469 fn to_architecture(&self) -> Result<Architecture> {
470 let mut layers = Vec::new();
471 let mut connections = Vec::new();
472 for node in &self.nodes {
473 layers.push(node.layer_type.clone());
474 }
475 for (i, row) in self.edges.iter().enumerate() {
476 for (j, &connected) in row.iter().enumerate() {
477 if connected {
478 connections.push((i, j));
479 }
480 }
481 }
482 Architecture::new(layers, connections)
483 }
484
485 fn clone_box(&self) -> Box<dyn ArchitectureEncoding> {
486 Box::new(self.clone())
487 }
488}
489
490impl fmt::Display for GraphEncoding {
491 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
492 writeln!(f, "GraphEncoding:")?;
493 writeln!(f, " Nodes: {}", self.nodes.len())?;
494 for (i, node) in self.nodes.iter().enumerate() {
495 write!(f, " {}: {:?}", i, node.layer_type)?;
496 if node.is_input {
497 write!(f, " [INPUT]")?;
498 }
499 if node.is_output {
500 write!(f, " [OUTPUT]")?;
501 }
502 writeln!(f)?;
503 }
504 writeln!(f, " Edges:")?;
505 for (i, row) in self.edges.iter().enumerate() {
506 write!(f, " {}: ", i)?;
507 for (j, &connected) in row.iter().enumerate() {
508 if connected {
509 write!(f, "{} ", j)?;
510 }
511 }
512 writeln!(f)?;
513 }
514 Ok(())
515 }
516}
517
518#[derive(Debug, Clone)]
520pub struct SequentialEncoding {
521 pub layers: Vec<LayerType>,
522}
523
524impl SequentialEncoding {
525 pub fn new(layers: Vec<LayerType>) -> Self {
526 Self { layers }
527 }
528
529 pub fn random<R: scirs2_core::random::Rng>(
530 rng: &mut scirs2_core::random::prelude::Random<R>,
531 ) -> Result<Self> {
532 let num_layers = rng.random_range(3..=10);
533 let mut layers = Vec::with_capacity(num_layers);
534 layers.push(LayerType::Dense(rng.random_range(64..=512)));
535 for _ in 1..num_layers - 1 {
536 let layer_type = match rng.random_range(0..4) {
537 0 => LayerType::Dense(rng.random_range(32..=512)),
538 1 => LayerType::Dropout(rng.random_range(10..50) as f32 / 100.0),
539 2 => LayerType::BatchNorm,
540 _ => LayerType::Activation("relu".to_string()),
541 };
542 layers.push(layer_type);
543 }
544 layers.push(LayerType::Dense(rng.random_range(1..=10)));
545 Ok(Self { layers })
546 }
547}
548
549impl ArchitectureEncoding for SequentialEncoding {
550 fn to_vector(&self) -> Vec<f64> {
551 let mut vec = Vec::new();
552 vec.push(self.layers.len() as f64);
553 for layer in &self.layers {
554 match layer {
555 LayerType::Dense(units) => {
556 vec.push(1.0);
557 vec.push(*units as f64);
558 vec.push(0.0);
559 }
560 LayerType::Conv2D { filters, .. } => {
561 vec.push(2.0);
562 vec.push(*filters as f64);
563 vec.push(0.0);
564 }
565 LayerType::Dropout(rate) => {
566 vec.push(3.0);
567 vec.push(*rate as f64);
568 vec.push(0.0);
569 }
570 LayerType::BatchNorm => {
571 vec.push(4.0);
572 vec.push(0.0);
573 vec.push(0.0);
574 }
575 LayerType::Activation(_) => {
576 vec.push(5.0);
577 vec.push(0.0);
578 vec.push(0.0);
579 }
580 _ => {
581 vec.push(0.0);
582 vec.push(0.0);
583 vec.push(0.0);
584 }
585 }
586 }
587 vec
588 }
589
590 fn from_vector(vec: &[f64]) -> Result<Self> {
591 if vec.is_empty() {
592 return Err(NeuralError::ConfigError(
593 "Empty vector for SequentialEncoding".to_string(),
594 ));
595 }
596 let num_layers = vec[0] as usize;
597 if num_layers == 0 {
598 return Err(NeuralError::ConfigError(
599 "SequentialEncoding must have at least one layer".to_string(),
600 ));
601 }
602 let expected_size = 1 + num_layers * 3;
603 if vec.len() < expected_size {
604 return Err(NeuralError::ConfigError(format!(
605 "Vector too short: expected {}, got {}",
606 expected_size,
607 vec.len()
608 )));
609 }
610 let mut layers = Vec::with_capacity(num_layers);
611 let mut idx = 1;
612 for _ in 0..num_layers {
613 let layer_type_code = vec[idx] as i32;
614 let param1 = vec[idx + 1];
615 let layer_type = match layer_type_code {
616 1 => LayerType::Dense(param1 as usize),
617 2 => LayerType::Conv2D {
618 filters: param1 as usize,
619 kernel_size: (3, 3),
620 stride: (1, 1),
621 },
622 3 => LayerType::Dropout(param1 as f32),
623 4 => LayerType::BatchNorm,
624 5 => LayerType::Activation("relu".to_string()),
625 _ => LayerType::Dense(64),
626 };
627 layers.push(layer_type);
628 idx += 3;
629 }
630 Ok(Self { layers })
631 }
632
633 fn dimension(&self) -> usize {
634 1 + self.layers.len() * 3
635 }
636
637 fn mutate(&self, mutation_rate: f32) -> Result<Box<dyn ArchitectureEncoding>> {
638 use scirs2_core::random::prelude::*;
639 let mut rng = thread_rng();
640 let mut mutated = self.clone();
641 for layer in &mut mutated.layers {
642 if rng.random_bool(mutation_rate as f64) {
643 match layer {
644 LayerType::Dense(ref mut units) => {
645 *units = rng.random_range(32..=512);
646 }
647 LayerType::Conv2D {
648 ref mut filters, ..
649 } => {
650 *filters = rng.random_range(16..=256);
651 }
652 LayerType::Dropout(ref mut rate) => {
653 *rate = rng.random_range(10..50) as f32 / 100.0;
654 }
655 _ => {}
656 }
657 }
658 }
659 if rng.random_bool(mutation_rate as f64 * 0.1) {
660 if mutated.layers.len() < 15 && rng.random_bool(0.7) {
661 let pos = if mutated.layers.len() > 1 {
662 rng.random_range(1..mutated.layers.len())
663 } else {
664 1
665 };
666 let new_layer = match rng.random_range(0..4) {
667 0 => LayerType::Dense(rng.random_range(32..=512)),
668 1 => LayerType::Dropout(rng.random_range(10..50) as f32 / 100.0),
669 2 => LayerType::BatchNorm,
670 _ => LayerType::Activation("relu".to_string()),
671 };
672 mutated.layers.insert(pos, new_layer);
673 } else if mutated.layers.len() > 3 {
674 let pos = rng.random_range(1..mutated.layers.len() - 1);
675 mutated.layers.remove(pos);
676 }
677 }
678 Ok(Box::new(mutated))
679 }
680
681 fn crossover(&self, other: &dyn ArchitectureEncoding) -> Result<Box<dyn ArchitectureEncoding>> {
682 use scirs2_core::random::prelude::*;
683 let mut rng = thread_rng();
684 let self_vec = self.to_vector();
685 let other_vec = other.to_vector();
686 if self_vec.len() >= 4 && other_vec.len() >= 4 {
687 let min_len = self_vec.len().min(other_vec.len());
688 let crossover_point = rng.random_range(1..min_len);
689 let mut child_vec = Vec::new();
690 child_vec.extend_from_slice(&self_vec[..crossover_point]);
691 child_vec.extend_from_slice(&other_vec[crossover_point..]);
692 if let Ok(result) = SequentialEncoding::from_vector(&child_vec) {
693 return Ok(Box::new(result));
694 }
695 }
696 self.mutate(0.1)
697 }
698
699 fn to_architecture(&self) -> Result<Architecture> {
700 let mut connections = Vec::new();
701 for i in 0..self.layers.len().saturating_sub(1) {
702 connections.push((i, i + 1));
703 }
704 Architecture::new(self.layers.clone(), connections)
705 }
706
707 fn clone_box(&self) -> Box<dyn ArchitectureEncoding> {
708 Box::new(self.clone())
709 }
710}
711
712impl fmt::Display for SequentialEncoding {
713 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
714 writeln!(f, "SequentialEncoding:")?;
715 for (i, layer) in self.layers.iter().enumerate() {
716 writeln!(f, " {}: {:?}", i, layer)?;
717 }
718 Ok(())
719 }
720}
721
722#[cfg(test)]
723mod tests {
724 use super::*;
725 use scirs2_core::random::prelude::*;
726
727 #[test]
728 fn test_graph_encoding() {
729 let nodes = vec![
730 NodeType {
731 layer_type: LayerType::Dense(64),
732 is_input: true,
733 is_output: false,
734 },
735 NodeType {
736 layer_type: LayerType::Dense(32),
737 is_input: false,
738 is_output: false,
739 },
740 NodeType {
741 layer_type: LayerType::Dense(10),
742 is_input: false,
743 is_output: true,
744 },
745 ];
746 let edges = vec![
747 vec![false, true, false],
748 vec![false, false, true],
749 vec![false, false, false],
750 ];
751 let encoding = GraphEncoding::new(nodes, edges);
752 let vector = encoding.to_vector();
753 let decoded = GraphEncoding::from_vector(&vector).expect("decode failed");
754 assert_eq!(vector[0], 3.0);
755 assert_eq!(decoded.nodes.len(), 3);
756 }
757
758 #[test]
759 fn test_sequential_encoding() {
760 let layers = vec![
761 LayerType::Dense(128),
762 LayerType::BatchNorm,
763 LayerType::Activation("relu".to_string()),
764 LayerType::Dropout(0.2),
765 LayerType::Dense(10),
766 ];
767 let encoding = SequentialEncoding::new(layers);
768 let vector = encoding.to_vector();
769 let decoded = SequentialEncoding::from_vector(&vector).expect("decode failed");
770 assert_eq!(vector[0], 5.0);
771 assert_eq!(decoded.layers.len(), 5);
772 }
773
774 #[test]
775 fn test_random_generation() {
776 let mut rng_inst = thread_rng();
777 let seq_encoding =
778 SequentialEncoding::random(&mut rng_inst).expect("random generation failed");
779 assert!(seq_encoding.layers.len() >= 3);
780 assert!(seq_encoding.layers.len() <= 10);
781 }
782}