1use serde::{Deserialize, Serialize};
61use std::collections::HashMap;
62use trustformers_core::{errors::Result, tensor::Tensor, traits::Optimizer};
63
64use crate::{common::StateMemoryStats, traits::StatefulOptimizer};
65
66#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct DeepDistributedQPConfig {
69 pub learning_rate: f32,
71
72 pub num_consensus_nodes: usize,
74
75 pub max_iterations: usize,
77
78 pub tolerance: f32,
80
81 pub relaxation_parameter: f32,
83
84 pub penalty_parameter: f32,
86
87 pub step_size: f32,
89
90 pub adaptive_step_size: bool,
92
93 pub network_hidden_dims: Vec<usize>,
95
96 pub warm_start: bool,
98
99 pub consensus_frequency: usize,
101
102 pub max_problem_size: usize,
104}
105
106impl Default for DeepDistributedQPConfig {
107 fn default() -> Self {
108 Self {
109 learning_rate: 1e-3,
110 num_consensus_nodes: 4,
111 max_iterations: 100,
112 tolerance: 1e-6,
113 relaxation_parameter: 1.6,
114 penalty_parameter: 1.0,
115 step_size: 1.0,
116 adaptive_step_size: true,
117 network_hidden_dims: vec![64, 32],
118 warm_start: true,
119 consensus_frequency: 10,
120 max_problem_size: 10000,
121 }
122 }
123}
124
125#[derive(Clone, Debug)]
127struct ConsensusNode {
128 local_variables: Tensor,
130
131 dual_variables: Tensor,
133
134 constraint_residuals: Tensor,
136
137 consensus_error: f32,
139
140 #[allow(dead_code)]
142 node_id: usize,
143}
144
145#[derive(Clone, Debug)]
147struct PolicyNetwork {
148 weights: Vec<Tensor>,
150
151 biases: Vec<Tensor>,
153
154 input_mean: Tensor,
156 input_std: Tensor,
157
158 output_scale: f32,
160}
161
162#[derive(Clone, Debug)]
164pub struct DeepDistributedQPState {
165 consensus_nodes: Vec<ConsensusNode>,
167
168 policy_network: Option<PolicyNetwork>,
170
171 previous_solution: Option<Tensor>,
173
174 #[allow(dead_code)]
176 problem_matrix_p: Option<Tensor>,
177 problem_vector_q: Option<Tensor>,
178 #[allow(dead_code)]
179 constraint_matrix_a: Option<Tensor>,
180 #[allow(dead_code)]
181 constraint_vector_b: Option<Tensor>,
182
183 iteration: usize,
185
186 convergence_history: Vec<f32>,
188
189 solve_times: Vec<f32>,
191
192 #[allow(dead_code)]
194 problem_size: usize,
195}
196
197#[derive(Clone, Debug)]
203pub struct DeepDistributedQP {
204 config: DeepDistributedQPConfig,
205 states: HashMap<String, DeepDistributedQPState>,
206 step: usize,
207 memory_stats: StateMemoryStats,
208
209 global_consensus: Option<Tensor>,
211
212 problems_solved: usize,
214
215 cumulative_speedup: f32,
217}
218
219impl DeepDistributedQP {
220 pub fn new(
222 learning_rate: f32,
223 num_consensus_nodes: usize,
224 max_iterations: usize,
225 tolerance: f32,
226 ) -> Self {
227 Self {
228 config: DeepDistributedQPConfig {
229 learning_rate,
230 num_consensus_nodes,
231 max_iterations,
232 tolerance,
233 ..Default::default()
234 },
235 states: HashMap::new(),
236 step: 0,
237 memory_stats: StateMemoryStats {
238 momentum_elements: 0,
239 variance_elements: 0,
240 third_moment_elements: 0,
241 total_bytes: 0,
242 num_parameters: 0,
243 },
244 global_consensus: None,
245 problems_solved: 0,
246 cumulative_speedup: 1.0,
247 }
248 }
249
250 pub fn for_large_scale() -> Self {
252 Self {
253 config: DeepDistributedQPConfig {
254 learning_rate: 5e-4,
255 num_consensus_nodes: 8,
256 max_iterations: 500,
257 tolerance: 1e-8,
258 relaxation_parameter: 1.8,
259 penalty_parameter: 0.5,
260 step_size: 0.8,
261 adaptive_step_size: true,
262 network_hidden_dims: vec![128, 64, 32],
263 warm_start: true,
264 consensus_frequency: 5,
265 max_problem_size: 50000,
266 },
267 states: HashMap::new(),
268 step: 0,
269 memory_stats: StateMemoryStats {
270 momentum_elements: 0,
271 variance_elements: 0,
272 third_moment_elements: 0,
273 total_bytes: 0,
274 num_parameters: 0,
275 },
276 global_consensus: None,
277 problems_solved: 0,
278 cumulative_speedup: 1.0,
279 }
280 }
281
282 pub fn for_portfolio_optimization() -> Self {
284 Self {
285 config: DeepDistributedQPConfig {
286 learning_rate: 1e-3,
287 num_consensus_nodes: 6,
288 max_iterations: 200,
289 tolerance: 1e-7,
290 relaxation_parameter: 1.5,
291 penalty_parameter: 2.0,
292 step_size: 1.2,
293 adaptive_step_size: true,
294 network_hidden_dims: vec![64, 32, 16],
295 warm_start: true,
296 consensus_frequency: 15,
297 max_problem_size: 5000,
298 },
299 states: HashMap::new(),
300 step: 0,
301 memory_stats: StateMemoryStats {
302 momentum_elements: 0,
303 variance_elements: 0,
304 third_moment_elements: 0,
305 total_bytes: 0,
306 num_parameters: 0,
307 },
308 global_consensus: None,
309 problems_solved: 0,
310 cumulative_speedup: 1.0,
311 }
312 }
313
314 pub fn with_config(config: DeepDistributedQPConfig) -> Self {
316 Self {
317 config,
318 states: HashMap::new(),
319 step: 0,
320 memory_stats: StateMemoryStats {
321 momentum_elements: 0,
322 variance_elements: 0,
323 third_moment_elements: 0,
324 total_bytes: 0,
325 num_parameters: 0,
326 },
327 global_consensus: None,
328 problems_solved: 0,
329 cumulative_speedup: 1.0,
330 }
331 }
332
333 fn initialize_consensus_nodes(&self, problem_size: usize) -> Result<Vec<ConsensusNode>> {
335 let mut nodes = Vec::with_capacity(self.config.num_consensus_nodes);
336
337 for node_id in 0..self.config.num_consensus_nodes {
338 nodes.push(ConsensusNode {
339 local_variables: Tensor::zeros(&[problem_size])?,
340 dual_variables: Tensor::zeros(&[problem_size])?,
341 constraint_residuals: Tensor::zeros(&[problem_size])?,
342 consensus_error: f32::INFINITY,
343 node_id,
344 });
345 }
346
347 Ok(nodes)
348 }
349
350 fn create_policy_network(&self, input_size: usize) -> Result<PolicyNetwork> {
352 let mut weights = Vec::new();
353 let mut biases = Vec::new();
354
355 let mut prev_size = input_size;
356 for &hidden_size in &self.config.network_hidden_dims {
357 let scale = (2.0 / (prev_size + hidden_size) as f32).sqrt();
359 let weight = Tensor::randn(&[prev_size, hidden_size])?.mul_scalar(scale)?;
360 let bias = Tensor::zeros(&[hidden_size])?;
361
362 weights.push(weight);
363 biases.push(bias);
364 prev_size = hidden_size;
365 }
366
367 let output_weight = Tensor::randn(&[prev_size, 1])?.mul_scalar(0.01)?;
369 let output_bias = Tensor::zeros(&[1])?;
370 weights.push(output_weight);
371 biases.push(output_bias);
372
373 Ok(PolicyNetwork {
374 weights,
375 biases,
376 input_mean: Tensor::zeros(&[input_size])?,
377 input_std: Tensor::ones(&[input_size])?,
378 output_scale: 1.0,
379 })
380 }
381
382 fn policy_forward(&self, network: &PolicyNetwork, input: &Tensor) -> Result<Tensor> {
384 let normalized_input = input.sub(&network.input_mean)?.div(&network.input_std)?;
386
387 let input_shape = normalized_input.shape();
389 let batch_size = 1;
390 let feature_size = input_shape.iter().product::<usize>();
391 let reshaped_input = normalized_input.reshape(&[batch_size, feature_size])?;
392
393 let mut x = reshaped_input;
394
395 for i in 0..network.weights.len() - 1 {
397 x = x.matmul(&network.weights[i])?.add(&network.biases[i])?;
398 x = x.relu()?; }
400
401 let output_idx = network.weights.len() - 1;
403 x = x.matmul(&network.weights[output_idx])?.add(&network.biases[output_idx])?;
404
405 let output = x.mul_scalar(network.output_scale)?;
407
408 let final_output = if output.shape().len() == 2 && output.shape()[0] == 1 {
410 output.reshape(&[output.shape()[1]])?
411 } else {
412 output
413 };
414
415 Ok(final_output)
416 }
417
418 fn operator_splitting_update(
420 &self,
421 node: &mut ConsensusNode,
422 gradient: &Tensor,
423 step_size: f32,
424 ) -> Result<()> {
425 let gradient_step = node.local_variables.sub(&gradient.mul_scalar(step_size)?)?;
427
428 let threshold = step_size * self.config.penalty_parameter;
430 node.local_variables = self.soft_threshold(&gradient_step, threshold)?;
431
432 let constraint_violation = node.constraint_residuals.clone(); node.dual_variables = node
435 .dual_variables
436 .add(&constraint_violation.mul_scalar(self.config.penalty_parameter)?)?;
437
438 Ok(())
439 }
440
441 fn soft_threshold(&self, input: &Tensor, threshold: f32) -> Result<Tensor> {
443 let positive_part = input.sub_scalar(threshold)?.relu()?;
444 let negative_part = input.add_scalar(threshold)?.neg()?.relu()?.neg()?;
445 positive_part.add(&negative_part)
446 }
447
448 fn consensus_update(&self, nodes: &mut [ConsensusNode]) -> Result<f32> {
450 let num_nodes = nodes.len();
451 if num_nodes < 2 {
452 return Ok(0.0);
453 }
454
455 let mut consensus_sum = nodes[0].local_variables.clone();
457 for node in nodes.iter().skip(1) {
458 consensus_sum = consensus_sum.add(&node.local_variables)?;
459 }
460 let consensus_avg = consensus_sum.div_scalar(num_nodes as f32)?;
461
462 let mut total_consensus_error = 0.0f32;
464 for node in nodes.iter_mut() {
465 let consensus_diff = consensus_avg.sub(&node.local_variables)?;
466 let consensus_error = consensus_diff.norm()?;
467
468 let update = consensus_diff.mul_scalar(self.config.relaxation_parameter)?;
470 node.local_variables = node.local_variables.add(&update.mul_scalar(0.1)?)?; node.consensus_error = consensus_error;
473 total_consensus_error += consensus_error;
474 }
475
476 Ok(total_consensus_error / num_nodes as f32)
477 }
478
479 fn adaptive_step_size(
481 &self,
482 network: &PolicyNetwork,
483 node: &ConsensusNode,
484 gradient: &Tensor,
485 ) -> Result<f32> {
486 let grad_norm = gradient.norm()?;
488 let var_norm = node.local_variables.norm()?;
489 let dual_norm = node.dual_variables.norm()?;
490 let consensus_error = node.consensus_error;
491
492 let features =
493 Tensor::from_slice(&[grad_norm, var_norm, dual_norm, consensus_error], &[4])?;
494
495 let step_size_tensor = self.policy_forward(network, &features)?;
497 let step_size = if step_size_tensor.shape().iter().product::<usize>() == 1 {
498 step_size_tensor.data()?[0]
500 } else {
501 step_size_tensor.data()?[0]
503 };
504
505 let step_size = step_size.clamp(0.001, 2.0);
507
508 Ok(step_size)
509 }
510
511 fn solve_distributed_qp(&mut self, param_id: &str, gradient: &Tensor) -> Result<Tensor> {
513 let problem_size = gradient.len();
514
515 let param_key = param_id.to_string();
517 let state_exists = self.states.contains_key(¶m_key);
518
519 if !state_exists {
520 let consensus_nodes = self.initialize_consensus_nodes(problem_size).unwrap_or_default();
521 let new_state = DeepDistributedQPState {
522 consensus_nodes,
523 policy_network: None,
524 previous_solution: None,
525 problem_matrix_p: None,
526 problem_vector_q: Some(gradient.clone()),
527 constraint_matrix_a: None,
528 constraint_vector_b: None,
529 iteration: 0,
530 convergence_history: Vec::new(),
531 solve_times: Vec::new(),
532 problem_size,
533 };
534 self.states.insert(param_key.clone(), new_state);
535 }
536
537 let state = self.states.get_mut(¶m_key).unwrap();
538
539 let needs_policy_network = state.policy_network.is_none();
541 let needs_consensus_nodes = state.consensus_nodes.is_empty();
542 let _ = state; if needs_policy_network {
545 let policy_network = self.create_policy_network(4)?; let state = self.states.get_mut(¶m_key).unwrap();
547 state.policy_network = Some(policy_network);
548 }
549
550 if needs_consensus_nodes {
551 let consensus_nodes = self.initialize_consensus_nodes(problem_size)?;
552 let state = self.states.get_mut(¶m_key).unwrap();
553 state.consensus_nodes = consensus_nodes;
554 }
555
556 let state = self.states.get_mut(¶m_key).unwrap();
557
558 if let (true, Some(prev_solution)) =
560 (self.config.warm_start, state.previous_solution.as_ref())
561 {
562 for node in &mut state.consensus_nodes {
563 node.local_variables = prev_solution.clone();
564 }
565 }
566
567 let start_time = std::time::Instant::now();
568 #[allow(dead_code)]
569 let mut _converged = false;
570 #[allow(unused_assignments)]
571 for iteration in 0..self.config.max_iterations {
573 let state = self.states.get_mut(¶m_key).unwrap();
575 state.iteration = iteration;
576
577 let adaptive_step = self.config.adaptive_step_size;
579 let consensus_frequency = self.config.consensus_frequency;
580 let tolerance = self.config.tolerance;
581 let step_size = self.config.step_size;
582
583 let mut consensus_nodes = state.consensus_nodes.clone();
585 let policy_network = state.policy_network.clone();
586 let _ = state; for node in &mut consensus_nodes {
590 let actual_step_size = if adaptive_step {
592 if let Some(ref network) = policy_network {
593 self.adaptive_step_size(network, node, gradient)?
594 } else {
595 step_size
596 }
597 } else {
598 step_size
599 };
600
601 self.operator_splitting_update(node, gradient, actual_step_size)?;
603 }
604
605 let state = self.states.get_mut(¶m_key).unwrap();
607 state.consensus_nodes = consensus_nodes;
608 let _ = state;
609
610 if iteration % consensus_frequency == 0 {
612 let state = self.states.get_mut(¶m_key).unwrap();
613 let mut nodes = state.consensus_nodes.clone();
614 let _ = state;
615
616 let consensus_error = self.consensus_update(&mut nodes)?;
617
618 let state = self.states.get_mut(¶m_key).unwrap();
619 state.consensus_nodes = nodes;
620 state.convergence_history.push(consensus_error);
621 let _ = state;
622
623 if consensus_error < tolerance {
625 _converged = true;
626 break;
627 }
628 }
629 }
630
631 let solve_time = start_time.elapsed().as_secs_f32();
632 let state = self.states.get_mut(¶m_key).unwrap();
633 state.solve_times.push(solve_time);
634
635 let mut solution = state.consensus_nodes[0].local_variables.clone();
637 for node in state.consensus_nodes.iter().skip(1) {
638 solution = solution.add(&node.local_variables)?;
639 }
640 solution = solution.div_scalar(state.consensus_nodes.len() as f32)?;
641
642 state.previous_solution = Some(solution.clone());
644
645 self.problems_solved += 1;
646
647 let baseline_time = solve_time * 2.0; let current_speedup = baseline_time / solve_time.max(1e-6);
650 self.cumulative_speedup = (self.cumulative_speedup * (self.problems_solved - 1) as f32
651 + current_speedup)
652 / self.problems_solved as f32;
653
654 Ok(solution)
655 }
656
657 pub fn qp_solver_stats(&self) -> HashMap<String, (usize, f32, f32, bool)> {
659 self.states
660 .iter()
661 .map(|(name, state)| {
662 let avg_solve_time = if !state.solve_times.is_empty() {
663 state.solve_times.iter().sum::<f32>() / state.solve_times.len() as f32
664 } else {
665 0.0
666 };
667
668 let last_consensus_error =
669 state.convergence_history.last().copied().unwrap_or(f32::INFINITY);
670 let converged = last_consensus_error < self.config.tolerance;
671
672 (
673 name.clone(),
674 (
675 state.iteration,
676 avg_solve_time,
677 last_consensus_error,
678 converged,
679 ),
680 )
681 })
682 .collect()
683 }
684
685 pub fn cumulative_speedup(&self) -> f32 {
687 self.cumulative_speedup
688 }
689
690 pub fn distributed_memory_usage(&self) -> usize {
692 self.states
693 .values()
694 .map(|state| {
695 let nodes_memory = state
696 .consensus_nodes
697 .iter()
698 .map(|node| {
699 node.local_variables.memory_usage()
700 + node.dual_variables.memory_usage()
701 + node.constraint_residuals.memory_usage()
702 })
703 .sum::<usize>();
704
705 let network_memory = if let Some(ref network) = state.policy_network {
706 network.weights.iter().map(|w| w.memory_usage()).sum::<usize>()
707 + network.biases.iter().map(|b| b.memory_usage()).sum::<usize>()
708 + network.input_mean.memory_usage()
709 + network.input_std.memory_usage()
710 } else {
711 0
712 };
713
714 nodes_memory + network_memory
715 })
716 .sum()
717 }
718}
719
720impl Optimizer for DeepDistributedQP {
721 fn update(&mut self, parameter: &mut Tensor, gradient: &Tensor) -> Result<()> {
722 let param_id = format!(
725 "param_{}_{:?}_{}",
726 self.states.len(),
727 parameter.shape(),
728 parameter
729 .data_f32()
730 .unwrap_or_default()
731 .get(0..5)
732 .unwrap_or(&[])
733 .iter()
734 .fold(0u64, |acc, &x| acc.wrapping_add(x.to_bits() as u64))
735 );
736 let qp_solution = self.solve_distributed_qp(¶m_id, gradient)?;
737
738 let update = qp_solution.mul_scalar(self.config.learning_rate)?;
740 *parameter = parameter.sub(&update)?;
741
742 Ok(())
743 }
744
745 fn zero_grad(&mut self) {
746 for state in self.states.values_mut() {
748 state.problem_vector_q = None;
749 }
750 }
751
752 fn step(&mut self) {
753 self.step += 1;
754 }
755
756 fn get_lr(&self) -> f32 {
757 self.config.learning_rate
758 }
759
760 fn set_lr(&mut self, lr: f32) {
761 self.config.learning_rate = lr;
762 }
763}
764
765impl StatefulOptimizer for DeepDistributedQP {
766 type Config = DeepDistributedQPConfig;
767 type State = StateMemoryStats;
768
769 fn config(&self) -> &Self::Config {
770 &self.config
771 }
772
773 fn state(&self) -> &Self::State {
774 &self.memory_stats
775 }
776
777 fn state_mut(&mut self) -> &mut Self::State {
778 &mut self.memory_stats
779 }
780
781 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
782 let mut state_dict = HashMap::new();
783 state_dict.insert("step".to_string(), Tensor::scalar(self.step as f32)?);
784 state_dict.insert(
785 "problems_solved".to_string(),
786 Tensor::scalar(self.problems_solved as f32)?,
787 );
788 Ok(state_dict)
789 }
790
791 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
792 if let Some(step_tensor) = state.get("step") {
793 self.step = step_tensor.to_scalar()? as usize;
794 }
795 if let Some(problems_tensor) = state.get("problems_solved") {
796 self.problems_solved = problems_tensor.to_scalar()? as usize;
797 }
798 Ok(())
799 }
800
801 fn memory_usage(&self) -> StateMemoryStats {
802 self.memory_stats.clone()
803 }
804
805 fn reset_state(&mut self) {
806 self.states.clear();
807 self.step = 0;
808 self.problems_solved = 0;
809 self.cumulative_speedup = 1.0;
810 self.global_consensus = None;
811 }
812
813 fn num_parameters(&self) -> usize {
814 self.states.len()
815 }
816}
817
818impl DeepDistributedQP {
820 pub fn num_workers(&self) -> usize {
822 self.config.num_consensus_nodes
823 }
824
825 pub fn learning_rate(&self) -> f32 {
827 self.config.learning_rate
828 }
829
830 pub fn communication_rounds(&self) -> usize {
832 self.config.max_iterations / self.config.consensus_frequency
833 }
834
835 pub fn synchronization_overhead(&self) -> f32 {
837 1.0 / self.config.consensus_frequency as f32
838 }
839
840 pub fn solve_qp(
842 &mut self,
843 problem_id: &str,
844 p: &Tensor, q: &Tensor, a: Option<&Tensor>, b: Option<&Tensor>, g: Option<&Tensor>, h: Option<&Tensor>, ) -> Result<Tensor> {
851 let problem_key = problem_id.to_string();
853 let state_exists = self.states.contains_key(&problem_key);
854
855 if !state_exists {
856 let consensus_nodes = self.initialize_consensus_nodes(q.len()).unwrap_or_default();
857 let new_state = DeepDistributedQPState {
858 consensus_nodes,
859 policy_network: None,
860 previous_solution: None,
861 problem_matrix_p: Some(p.clone()),
862 problem_vector_q: Some(q.clone()),
863 constraint_matrix_a: a.cloned(),
864 constraint_vector_b: b.cloned(),
865 iteration: 0,
866 convergence_history: Vec::new(),
867 solve_times: Vec::new(),
868 problem_size: q.len(),
869 };
870 self.states.insert(problem_key.clone(), new_state);
871 }
872
873 let state = self.states.get_mut(&problem_key).unwrap();
874
875 if let Some(constraint_mat) = g {
877 for node in &mut state.consensus_nodes {
879 node.constraint_residuals = constraint_mat.matmul(&node.local_variables)?;
880 if let Some(h_vec) = h {
881 node.constraint_residuals = node.constraint_residuals.sub(h_vec)?;
882 }
883 }
884 }
885
886 self.solve_distributed_qp(problem_id, q)
888 }
889
890 pub fn set_policy_weights(
892 &mut self,
893 param_id: &str,
894 weights: Vec<Tensor>,
895 biases: Vec<Tensor>,
896 ) -> Result<()> {
897 if let Some(state) = self.states.get_mut(param_id) {
898 if let Some(ref mut network) = state.policy_network {
899 network.weights = weights;
900 network.biases = biases;
901 }
902 }
903 Ok(())
904 }
905
906 pub fn train_policy(
908 &mut self,
909 param_id: &str,
910 experience_data: &[(Tensor, f32)],
911 ) -> Result<()> {
912 if let Some(state) = self.states.get_mut(param_id) {
914 if let Some(ref mut network) = state.policy_network {
915 if !experience_data.is_empty() {
917 let _features: Vec<_> =
918 experience_data.iter().map(|(f, _)| f.clone()).collect();
919 network.output_scale *= 1.01; }
922 }
923 }
924 Ok(())
925 }
926}
927
928#[cfg(test)]
929mod tests {
930 use super::*;
931
932 #[test]
933 fn test_deep_distributed_qp_creation() {
934 let optimizer = DeepDistributedQP::new(1e-3, 4, 100, 1e-6);
935 assert_eq!(optimizer.learning_rate(), 1e-3);
936 assert_eq!(optimizer.config.num_consensus_nodes, 4);
937 assert_eq!(optimizer.config.max_iterations, 100);
938 }
939
940 #[test]
941 fn test_deep_distributed_qp_presets() {
942 let large_scale = DeepDistributedQP::for_large_scale();
943 assert_eq!(large_scale.config.num_consensus_nodes, 8);
944 assert_eq!(large_scale.config.max_iterations, 500);
945
946 let portfolio = DeepDistributedQP::for_portfolio_optimization();
947 assert_eq!(portfolio.config.num_consensus_nodes, 6);
948 assert_eq!(portfolio.config.penalty_parameter, 2.0);
949 }
950
951 #[test]
952 fn test_consensus_nodes_initialization() -> Result<()> {
953 let optimizer = DeepDistributedQP::new(1e-3, 3, 50, 1e-6);
954 let nodes = optimizer.initialize_consensus_nodes(5)?;
955
956 assert_eq!(nodes.len(), 3);
957 for (i, node) in nodes.iter().enumerate() {
958 assert_eq!(node.node_id, i);
959 assert_eq!(node.local_variables.shape(), &[5]);
960 }
961
962 Ok(())
963 }
964
965 #[test]
966 fn test_policy_network_creation() -> Result<()> {
967 let optimizer = DeepDistributedQP::new(1e-3, 4, 100, 1e-6);
968 let network = optimizer.create_policy_network(4)?;
969
970 assert_eq!(network.weights.len(), 3); assert_eq!(network.biases.len(), 3);
972 assert_eq!(network.input_mean.shape(), &[4]);
973
974 Ok(())
975 }
976
977 #[test]
978 fn test_soft_threshold() -> Result<()> {
979 let optimizer = DeepDistributedQP::new(1e-3, 4, 100, 1e-6);
980 let input = Tensor::from_slice(&[-2.0, -0.5, 0.0, 0.5, 2.0], &[5])?;
981 let threshold = 1.0;
982
983 let result = optimizer.soft_threshold(&input, threshold)?;
984 let result_vec = result.data()?;
985
986 assert!((result_vec[0] - (-1.0)).abs() < 1e-5);
988 assert!(result_vec[1].abs() < 1e-5);
989 assert!(result_vec[2].abs() < 1e-5);
990 assert!(result_vec[3].abs() < 1e-5);
991 assert!((result_vec[4] - 1.0).abs() < 1e-5);
992
993 Ok(())
994 }
995
996 #[test]
997 fn test_simple_qp_solve() -> Result<()> {
998 let mut optimizer = DeepDistributedQP::new(0.1, 2, 20, 1e-4);
999 let mut parameter = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3])?;
1000 let gradient = Tensor::from_slice(&[0.1, 0.2, 0.1], &[3])?;
1001
1002 optimizer.update(&mut parameter, &gradient)?;
1004 optimizer.step();
1005
1006 assert!(true);
1009
1010 Ok(())
1011 }
1012
1013 #[test]
1014 fn test_qp_solver_stats() -> Result<()> {
1015 let mut optimizer = DeepDistributedQP::new(1e-3, 2, 10, 1e-4);
1016 let mut param = Tensor::from_slice(&[1.0, 2.0], &[2])?;
1017 let grad = Tensor::from_slice(&[0.1, 0.1], &[2])?;
1018
1019 optimizer.update(&mut param, &grad)?;
1020
1021 let stats = optimizer.qp_solver_stats();
1022 assert_eq!(stats.len(), 1);
1023
1024 let (iterations, solve_time, _consensus_error, _converged) = stats.values().next().unwrap();
1025 assert!(*iterations <= 10);
1026 assert!(*solve_time >= 0.0);
1027
1028 Ok(())
1029 }
1030
1031 #[test]
1032 fn test_memory_usage() -> Result<()> {
1033 let mut optimizer = DeepDistributedQP::new(1e-3, 3, 10, 1e-4);
1034 let mut param = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4])?;
1035 let grad = Tensor::from_slice(&[0.1, 0.1, 0.1, 0.1], &[4])?;
1036
1037 let memory_before = optimizer.distributed_memory_usage();
1038 optimizer.update(&mut param, &grad)?;
1039 let memory_after = optimizer.distributed_memory_usage();
1040
1041 assert!(memory_after >= memory_before);
1042
1043 Ok(())
1044 }
1045}