1use crate::error::OptimizeResult;
8use crate::result::OptimizeResults;
9use scirs2_core::error::CoreResult as Result;
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
11use scirs2_core::random::Rng;
12use scirs2_core::simd_ops::SimdUnifiedOps;
13
14#[derive(Debug, Clone, Copy)]
16pub enum MemristorModel {
17 LinearIonicDrift,
19 NonlinearIonicDrift,
21 SimmonsTunnelBarrier,
23 TeamModel,
25 BiolekModel,
27}
28
29#[derive(Debug, Clone)]
31pub struct MemristorParameters {
32 pub length: f64,
34 pub mobility: f64,
36 pub r_on: f64,
38 pub r_off: f64,
40 pub initial_x: f64,
42 pub temp_coeff: f64,
44 pub variability: f64,
46 pub p_coeff: f64,
48 pub q_coeff: f64,
49}
50
51impl Default for MemristorParameters {
52 fn default() -> Self {
53 Self {
54 length: 10e-9, mobility: 1e-10, r_on: 100.0, r_off: 16000.0, initial_x: 0.5, temp_coeff: 0.001, variability: 0.05, p_coeff: 10.0, q_coeff: 10.0, }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct Memristor {
70 pub resistance: f64,
72 pub state: f64,
74 pub params: MemristorParameters,
76 pub model: MemristorModel,
78 pub temperature: f64,
80 pub variability_factor: f64,
82 pub voltage_history: Vec<f64>,
84 pub max_history: usize,
86}
87
88impl Memristor {
89 pub fn new(params: MemristorParameters, model: MemristorModel) -> Self {
91 let initial_resistance =
92 params.r_on + (params.r_off - params.r_on) * (1.0 - params.initial_x);
93 let variability_factor = if params.variability > 0.0 {
94 1.0 + (scirs2_core::random::rng().random::<f64>() - 0.5) * 2.0 * params.variability
95 } else {
96 1.0
97 };
98
99 Self {
100 resistance: initial_resistance * variability_factor,
101 state: params.initial_x,
102 params,
103 model,
104 temperature: 300.0, variability_factor,
106 voltage_history: Vec::new(),
107 max_history: 10,
108 }
109 }
110
111 pub fn update(&mut self, voltage: f64, dt: f64) {
113 self.voltage_history.push(voltage);
115 if self.voltage_history.len() > self.max_history {
116 self.voltage_history.remove(0);
117 }
118
119 let temp_factor = 1.0 + self.params.temp_coeff * (self.temperature - 300.0);
121 let effective_mobility = self.params.mobility * temp_factor;
122
123 match self.model {
124 MemristorModel::LinearIonicDrift => {
125 self.update_linear_drift(voltage, dt, effective_mobility);
126 }
127 MemristorModel::NonlinearIonicDrift => {
128 self.update_nonlinear_drift(voltage, dt, effective_mobility);
129 }
130 MemristorModel::SimmonsTunnelBarrier => {
131 self.update_simmons_model(voltage, dt);
132 }
133 MemristorModel::TeamModel => {
134 self.update_team_model(voltage, dt);
135 }
136 MemristorModel::BiolekModel => {
137 self.update_biolek_model(voltage, dt);
138 }
139 }
140
141 self.update_resistance();
143 }
144
145 fn update_linear_drift(&mut self, voltage: f64, dt: f64, mobility: f64) {
147 let dx_dt = (mobility * self.params.r_on) / self.params.length.powi(2) * voltage;
148 self.state += dx_dt * dt;
149 self.state = self.state.max(0.0).min(1.0);
150 }
151
152 fn update_nonlinear_drift(&mut self, voltage: f64, dt: f64, mobility: f64) {
154 let window = if voltage > 0.0 {
156 self.state * (1.0 - self.state).powf(self.params.p_coeff)
157 } else {
158 self.state.powf(self.params.p_coeff) * (1.0 - self.state)
159 };
160
161 let dx_dt = (mobility * self.params.r_on) / self.params.length.powi(2) * voltage * window;
162 self.state += dx_dt * dt;
163 self.state = self.state.max(0.0).min(1.0);
164 }
165
166 fn update_simmons_model(&mut self, voltage: f64, dt: f64) {
168 let _beta = 0.8; let v_th = 0.16; if voltage.abs() > v_th {
172 let sign = voltage.signum();
173 let exp_term = (-(voltage.abs() - v_th) / 0.3).exp();
174 let dx_dt = sign * 10e-15 * (1.0 - exp_term);
175
176 self.state += dx_dt * dt / self.params.length;
177 self.state = self.state.max(0.0).min(1.0);
178 }
179 }
180
181 fn update_team_model(&mut self, voltage: f64, dt: f64) {
183 let v_on = 0.3; let v_off = -0.5; let k_on = 8e-13; let k_off = 8e-13; if voltage > v_on {
189 let dx_dt = k_on * ((voltage / v_on) - 1.0).exp();
190 self.state += dx_dt * dt;
191 } else if voltage < v_off {
192 let dx_dt = -k_off * ((voltage.abs() / v_off.abs()) - 1.0).exp();
193 self.state += dx_dt * dt;
194 }
195
196 self.state = self.state.max(0.0).min(1.0);
197 }
198
199 fn update_biolek_model(&mut self, voltage: f64, dt: f64) {
201 let v_th = 1.0; if voltage.abs() > v_th {
204 let window = if voltage > 0.0 {
205 1.0 - (2.0 * self.state - 1.0).powi(2 * self.params.p_coeff as i32)
206 } else {
207 1.0 - (2.0 * self.state - 1.0).powi(2 * self.params.q_coeff as i32)
208 };
209
210 let dx_dt = voltage * window * 1e-12;
211 self.state += dx_dt * dt;
212 self.state = self.state.max(0.0).min(1.0);
213 }
214 }
215
216 fn update_resistance(&mut self) {
218 let base_resistance =
220 self.params.r_on * self.state + self.params.r_off * (1.0 - self.state);
221 self.resistance = base_resistance * self.variability_factor;
222 }
223
224 pub fn conductance(&self) -> f64 {
226 1.0 / self.resistance
227 }
228
229 pub fn set_temperature(&mut self, temperature: f64) {
231 self.temperature = temperature;
232 }
233
234 pub fn power_dissipation(&self, voltage: f64) -> f64 {
236 voltage.powi(2) / self.resistance
237 }
238
239 pub fn reset(&mut self) {
241 self.state = self.params.initial_x;
242 self.voltage_history.clear();
243 self.update_resistance();
244 }
245}
246
247#[derive(Debug, Clone)]
249pub struct MemristiveCrossbar {
250 pub memristors: Vec<Vec<Memristor>>,
252 pub rows: usize,
254 pub cols: usize,
255 pub row_resistance: Array1<f64>,
257 pub col_resistance: Array1<f64>,
258 pub v_max: f64,
260 pub v_min: f64,
261 pub fault_map: Array2<bool>,
263 pub use_sneak_compensation: bool,
265 pub stats: CrossbarStats,
267}
268
269#[derive(Debug, Clone)]
271pub struct CrossbarStats {
272 pub operations: usize,
274 pub power_consumption: f64,
276 pub avg_read_time_ns: f64,
278 pub avg_write_time_ns: f64,
280 pub faulty_devices: usize,
282}
283
284impl Default for CrossbarStats {
285 fn default() -> Self {
286 Self {
287 operations: 0,
288 power_consumption: 0.0,
289 avg_read_time_ns: 1.0,
290 avg_write_time_ns: 10.0,
291 faulty_devices: 0,
292 }
293 }
294}
295
296impl MemristiveCrossbar {
297 pub fn new(
299 rows: usize,
300 cols: usize,
301 params: MemristorParameters,
302 model: MemristorModel,
303 ) -> Self {
304 let mut memristors = Vec::with_capacity(rows);
305 let mut fault_map = Array2::from_elem((rows, cols), false);
306 let mut faulty_count = 0;
307
308 for i in 0..rows {
309 let mut row = Vec::with_capacity(cols);
310 for j in 0..cols {
311 let mut memristor = Memristor::new(params.clone(), model);
312
313 if scirs2_core::random::rng().random::<f64>() < 0.01 {
315 fault_map[[i, j]] = true;
316 faulty_count += 1;
317 if scirs2_core::random::rng().random::<bool>() {
319 memristor.resistance = params.r_off * 10.0; } else {
321 memristor.resistance = params.r_on * 0.1; }
323 }
324
325 row.push(memristor);
326 }
327 memristors.push(row);
328 }
329
330 let wire_r_per_cell = 1.0; let row_resistance = Array1::from_shape_fn(rows, |i| wire_r_per_cell * (i + 1) as f64);
333 let col_resistance = Array1::from_shape_fn(cols, |j| wire_r_per_cell * (j + 1) as f64);
334
335 let mut stats = CrossbarStats::default();
336 stats.faulty_devices = faulty_count;
337
338 Self {
339 memristors,
340 rows,
341 cols,
342 row_resistance,
343 col_resistance,
344 v_max: 1.5, v_min: -1.5, fault_map,
347 use_sneak_compensation: true,
348 stats,
349 }
350 }
351
352 pub fn multiply(&mut self, input: &ArrayView1<f64>) -> Array1<f64> {
354 let start_time = std::time::Instant::now();
355 let mut output = Array1::zeros(self.rows);
356
357 if input.len() >= 4 && self.rows >= 4 {
359 self.multiply_simd(input, &mut output);
360 } else {
361 self.multiply_scalar(input, &mut output);
362 }
363
364 if self.use_sneak_compensation {
366 self.compensate_sneak_paths(&mut output, input);
367 }
368
369 self.stats.operations += 1;
371 self.stats.power_consumption += self.calculate_read_power(input);
372 let elapsed = start_time.elapsed().as_nanos() as f64;
373 self.stats.avg_read_time_ns =
374 (self.stats.avg_read_time_ns * (self.stats.operations - 1) as f64 + elapsed)
375 / self.stats.operations as f64;
376
377 output
378 }
379
380 fn multiply_simd(&self, input: &ArrayView1<f64>, output: &mut Array1<f64>) {
382 for i in 0..self.rows {
383 let mut sum = 0.0;
384 let conductances: Vec<f64> = (0..self.cols)
385 .map(|j| {
386 if self.fault_map[[i, j]] {
387 0.0
388 } else {
389 self.memristors[i][j].conductance()
390 }
391 })
392 .collect();
393
394 if conductances.len() >= input.len() {
396 let g_slice = &conductances[..input.len()];
397 let g_array = Array1::from(g_slice.to_vec());
398 sum = SimdUnifiedOps::simd_dot(&g_array.view(), input);
399 }
400
401 output[i] = sum;
402 }
403 }
404
405 fn multiply_scalar(&self, input: &ArrayView1<f64>, output: &mut Array1<f64>) {
407 for i in 0..self.rows {
408 for j in 0..self.cols.min(input.len()) {
409 if !self.fault_map[[i, j]] {
410 let conductance = self.memristors[i][j].conductance();
411 output[i] += input[j] * conductance;
412 }
413 }
414 }
415 }
416
417 fn compensate_sneak_paths(&self, output: &mut Array1<f64>, input: &ArrayView1<f64>) {
419 let _avg_conductance = self.calculate_average_conductance();
422 let sneak_compensation_factor = 0.95; for i in 0..output.len() {
425 output[i] *= sneak_compensation_factor;
426 }
427 }
428
429 fn calculate_average_conductance(&self) -> f64 {
431 let mut sum = 0.0;
432 let mut count = 0;
433
434 for i in 0..self.rows {
435 for j in 0..self.cols {
436 if !self.fault_map[[i, j]] {
437 sum += self.memristors[i][j].conductance();
438 count += 1;
439 }
440 }
441 }
442
443 if count > 0 {
444 sum / count as f64
445 } else {
446 0.0
447 }
448 }
449
450 fn calculate_read_power(&self, input: &ArrayView1<f64>) -> f64 {
452 let mut power = 0.0;
453 let read_voltage = 0.1; for i in 0..self.rows {
456 for j in 0..self.cols.min(input.len()) {
457 if !self.fault_map[[i, j]] && input[j].abs() > 1e-10 {
458 power += self.memristors[i][j].power_dissipation(read_voltage * input[j]);
459 }
460 }
461 }
462
463 power
464 }
465
466 pub fn update(
468 &mut self,
469 input: &ArrayView1<f64>,
470 target: &ArrayView1<f64>,
471 learning_rate: f64,
472 ) -> Result<()> {
473 let start_time = std::time::Instant::now();
474
475 let current_output = self.multiply(input);
477 let error = target - ¤t_output;
478
479 for i in 0..self.rows.min(error.len()) {
481 for j in 0..self.cols.min(input.len()) {
482 if !self.fault_map[[i, j]] {
483 let desired_delta_g = learning_rate * error[i] * input[j];
485
486 let programming_voltage =
488 self.conductance_change_to_voltage(desired_delta_g, i, j);
489
490 let limited_voltage = programming_voltage.max(self.v_min).min(self.v_max);
492
493 let dt = 1e-6; self.memristors[i][j].update(limited_voltage, dt);
496 }
497 }
498 }
499
500 let elapsed = start_time.elapsed().as_nanos() as f64;
502 self.stats.avg_write_time_ns =
503 (self.stats.avg_write_time_ns * self.stats.operations as f64 + elapsed)
504 / (self.stats.operations + 1) as f64;
505
506 Ok(())
507 }
508
509 fn conductance_change_to_voltage(&self, delta_g: f64, row: usize, col: usize) -> f64 {
511 let current_g = self.memristors[row][col].conductance();
513 let relative_change = delta_g / (current_g + 1e-12);
514
515 if relative_change > 0.0 {
517 0.5 * relative_change.ln().max(-3.0) } else {
519 -0.5 * (-relative_change).ln().max(-3.0) }
521 }
522
523 pub fn refresh(&mut self) -> Result<()> {
525 for i in 0..self.rows {
526 for j in 0..self.cols {
527 if !self.fault_map[[i, j]] {
528 let _target_conductance = self.memristors[i][j].conductance();
530
531 let refresh_voltage = 0.1; self.memristors[i][j].update(refresh_voltage, 1e-7);
534 }
535 }
536 }
537 Ok(())
538 }
539
540 pub fn get_stats(&self) -> &CrossbarStats {
542 &self.stats
543 }
544
545 pub fn reset(&mut self) {
547 for i in 0..self.rows {
548 for j in 0..self.cols {
549 if !self.fault_map[[i, j]] {
550 self.memristors[i][j].reset();
551 }
552 }
553 }
554 self.stats = CrossbarStats::default();
555 self.stats.faulty_devices = self.fault_map.iter().filter(|&&x| x).count();
556 }
557
558 pub fn get_conductance_matrix(&self) -> Array2<f64> {
560 let mut conductances = Array2::zeros((self.rows, self.cols));
561 for i in 0..self.rows {
562 for j in 0..self.cols {
563 conductances[[i, j]] = if self.fault_map[[i, j]] {
564 0.0
565 } else {
566 self.memristors[i][j].conductance()
567 };
568 }
569 }
570 conductances
571 }
572}
573
574#[derive(Debug, Clone)]
577pub struct MemristiveOptimizer {
578 pub crossbar: MemristiveCrossbar,
580 pub parameters: Array1<f64>,
582 pub best_parameters: Array1<f64>,
584 pub best_objective: f64,
586 pub learning_rate: f64,
588 pub momentum: f64,
590 pub momentum_buffer: Array1<f64>,
592 pub nit: usize,
594}
595
596impl MemristiveOptimizer {
597 pub fn new(
599 initial_params: Array1<f64>,
600 learning_rate: f64,
601 momentum: f64,
602 memristor_params: MemristorParameters,
603 model: MemristorModel,
604 ) -> Self {
605 let n = initial_params.len();
606 let crossbar_size = (n as f64).sqrt().ceil() as usize;
607 let crossbar =
608 MemristiveCrossbar::new(crossbar_size, crossbar_size, memristor_params, model);
609
610 Self {
611 crossbar,
612 parameters: initial_params.clone(),
613 best_parameters: initial_params.clone(),
614 best_objective: f64::INFINITY,
615 learning_rate,
616 momentum,
617 momentum_buffer: Array1::zeros(n),
618 nit: 0,
619 }
620 }
621
622 pub fn optimize<F>(
624 &mut self,
625 objective: F,
626 max_nit: usize,
627 ) -> OptimizeResult<OptimizeResults<f64>>
628 where
629 F: Fn(&ArrayView1<f64>) -> f64,
630 {
631 let mut convergence_history = Vec::new();
632
633 for iter in 0..max_nit {
634 let current_obj = objective(&self.parameters.view());
636 convergence_history.push(current_obj);
637
638 if current_obj < self.best_objective {
640 self.best_objective = current_obj;
641 self.best_parameters = self.parameters.clone();
642 }
643
644 let gradient = self.compute_finite_diff_gradient(&objective)?;
646
647 let crossbar_input = self.encode_gradient(&gradient);
649
650 let crossbar_output = self.crossbar.multiply(&crossbar_input.view());
652
653 let decoded_update = self.decode_update(&crossbar_output);
655
656 self.apply_momentum_update(&decoded_update)?;
658
659 self.update_crossbar_weights(&gradient, current_obj)?;
661
662 if self.check_convergence(&convergence_history) {
664 break;
665 }
666
667 self.nit += 1;
668
669 if iter % 100 == 0 {
671 self.crossbar.refresh()?;
672 }
673 }
674
675 Ok(OptimizeResults::<f64> {
676 x: self.best_parameters.clone(),
677 fun: self.best_objective,
678 success: self.best_objective < 1e-6,
679 nit: self.nit,
680 message: "Memristive optimization completed".to_string(),
681 jac: None,
682 hess: None,
683 constr: None,
684 nfev: self.nit,
685 njev: 0,
686 nhev: 0,
687 maxcv: 0,
688 status: 0,
689 })
690 }
691
692 fn compute_finite_diff_gradient<F>(&self, objective: &F) -> Result<Array1<f64>>
694 where
695 F: Fn(&ArrayView1<f64>) -> f64,
696 {
697 let n = self.parameters.len();
698 let mut gradient = Array1::zeros(n);
699 let h = 1e-6;
700 let f0 = objective(&self.parameters.view());
701
702 for i in 0..n {
703 let mut params_plus = self.parameters.clone();
704 params_plus[i] += h;
705 let f_plus = objective(¶ms_plus.view());
706 gradient[i] = (f_plus - f0) / h;
707 }
708
709 Ok(gradient)
710 }
711
712 fn encode_gradient(&self, gradient: &Array1<f64>) -> Array1<f64> {
714 let crossbar_size = self.crossbar.cols;
715 let mut encoded = Array1::zeros(crossbar_size);
716
717 let max_grad = gradient.mapv(|x| x.abs()).fold(0.0, |a, &b| f64::max(a, b));
719 if max_grad > 0.0 {
720 for i in 0..crossbar_size.min(gradient.len()) {
721 encoded[i] = gradient[i] / max_grad;
722 }
723 }
724
725 encoded
726 }
727
728 fn decode_update(&self, crossbar_output: &Array1<f64>) -> Array1<f64> {
730 let n = self.parameters.len();
731 let mut update = Array1::zeros(n);
732
733 for i in 0..n.min(crossbar_output.len()) {
735 update[i] = crossbar_output[i] * self.learning_rate;
736 }
737
738 update
739 }
740
741 fn apply_momentum_update(&mut self, update: &Array1<f64>) -> Result<()> {
743 self.momentum_buffer =
745 &(self.momentum * &self.momentum_buffer) + &((1.0 - self.momentum) * update);
746
747 self.parameters = &self.parameters - &self.momentum_buffer;
749
750 Ok(())
751 }
752
753 fn update_crossbar_weights(
755 &mut self,
756 gradient: &Array1<f64>,
757 objective_value: f64,
758 ) -> Result<()> {
759 let performance_factor = (-objective_value / 10.0).exp(); let encoded_gradient = self.encode_gradient(gradient);
763 let target_output = &encoded_gradient * performance_factor;
764
765 self.crossbar
766 .update(&encoded_gradient.view(), &target_output.view(), 0.01)?;
767
768 Ok(())
769 }
770
771 fn check_convergence(&self, history: &[f64]) -> bool {
773 if history.len() < 10 {
774 return false;
775 }
776
777 let recent = &history[history.len() - 5..];
778 let variance = recent
779 .iter()
780 .fold(0.0, |acc, &x| acc + (x - recent[0]).powi(2))
781 / recent.len() as f64;
782
783 variance < 1e-12
784 }
785}
786
787#[allow(dead_code)]
789pub fn memristive_gradient_descent<F>(
790 objective: F,
791 initial_params: &ArrayView1<f64>,
792 learning_rate: f64,
793 max_nit: usize,
794) -> Result<Array1<f64>>
795where
796 F: Fn(&ArrayView1<f64>) -> f64,
797{
798 let params = MemristorParameters::default();
799 let model = MemristorModel::NonlinearIonicDrift;
800
801 let mut optimizer = MemristiveOptimizer::new(
802 initial_params.to_owned(),
803 learning_rate,
804 0.9, params,
806 model,
807 );
808
809 let result = optimizer.optimize(objective, max_nit)?;
810 Ok(result.x)
811}
812
813#[allow(dead_code)]
815pub fn advanced_memristive_optimization<F>(
816 objective: F,
817 initial_params: &ArrayView1<f64>,
818 learning_rate: f64,
819 max_nit: usize,
820 memristor_params: MemristorParameters,
821 model: MemristorModel,
822) -> OptimizeResult<OptimizeResults<f64>>
823where
824 F: Fn(&ArrayView1<f64>) -> f64,
825{
826 let mut optimizer = MemristiveOptimizer::new(
827 initial_params.to_owned(),
828 learning_rate,
829 0.9,
830 memristor_params,
831 model,
832 );
833
834 optimizer.optimize(objective, max_nit)
835}
836
837#[allow(dead_code)]
839pub fn memristive_neural_optimizer<F>(
840 objective: F,
841 initial_weights: &ArrayView2<f64>,
842 learning_rate: f64,
843 max_nit: usize,
844) -> Result<Array2<f64>>
845where
846 F: Fn(&ArrayView2<f64>) -> f64,
847{
848 let (rows, cols) = initial_weights.dim();
849 let params = MemristorParameters::default();
850 let mut crossbar = MemristiveCrossbar::new(rows, cols, params, MemristorModel::TeamModel);
851
852 for i in 0..rows {
854 for j in 0..cols {
855 let _target_conductance = initial_weights[[i, j]].abs() * 1e-3; let voltage = if initial_weights[[i, j]] > 0.0 {
857 1.0
858 } else {
859 -1.0
860 };
861 crossbar.memristors[i][j].update(voltage, 1e-3);
862 }
863 }
864
865 for _iter in 0..max_nit {
866 let current_weights = crossbar.get_conductance_matrix();
868 let objective_value = objective(¤t_weights.view());
869
870 let mut weight_gradients = Array2::zeros((rows, cols));
872 let h = 1e-6;
873
874 for i in 0..rows {
875 for j in 0..cols {
876 let mut perturbed_weights = current_weights.clone();
877 perturbed_weights[[i, j]] += h;
878 let f_plus = objective(&perturbed_weights.view());
879 weight_gradients[[i, j]] = (f_plus - objective_value) / h;
880 }
881 }
882
883 for i in 0..rows {
885 let row_input = weight_gradients.row(i).to_owned();
886 let target = Array1::zeros(cols); crossbar
888 .update(&row_input.view(), &target.view(), learning_rate)
889 .ok();
890 }
891 }
892
893 Ok(crossbar.get_conductance_matrix())
894}
895
896#[cfg(test)]
897mod tests {
898 use super::*;
899
900 #[test]
901 fn test_memristor_models() {
902 let params = MemristorParameters::default();
903 let mut memristor = Memristor::new(params, MemristorModel::NonlinearIonicDrift);
904
905 let initial_resistance = memristor.resistance;
906 memristor.update(1.0, 1e-3);
907
908 assert!(memristor.resistance != initial_resistance);
910 }
911
912 #[test]
913 fn test_crossbar_operations() {
914 let params = MemristorParameters::default();
915 let mut crossbar = MemristiveCrossbar::new(3, 3, params, MemristorModel::LinearIonicDrift);
916
917 let input = Array1::from(vec![1.0, 0.5, 0.0]);
918 let output = crossbar.multiply(&input.view());
919
920 assert_eq!(output.len(), 3);
921 assert!(output.iter().all(|&x| x.is_finite()));
922 }
923
924 #[test]
925 fn test_memristive_optimization() {
926 let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
927 let initial = Array1::from(vec![1.0, 1.0]);
928
929 let result = memristive_gradient_descent(objective, &initial.view(), 0.1, 100);
930 assert!(result.is_ok());
931
932 let final_params = result.unwrap();
933 let final_obj = objective(&final_params.view());
934 let initial_obj = objective(&initial.view());
935
936 assert!(final_obj < initial_obj);
938 }
939
940 #[test]
941 fn test_crossbar_with_faults() {
942 let params = MemristorParameters::default();
943 let mut crossbar = MemristiveCrossbar::new(3, 3, params, MemristorModel::TeamModel);
945
946 let _faulty_count = crossbar.fault_map.iter().filter(|&&x| x).count();
948 let input = Array1::ones(3);
951 let output = crossbar.multiply(&input.view());
952 assert!(output.iter().all(|&x| x.is_finite()));
953 }
954
955 #[test]
956 fn test_temperature_effects() {
957 let params = MemristorParameters::default();
958 let mut memristor = Memristor::new(params, MemristorModel::NonlinearIonicDrift);
959
960 let resistance_at_300k = memristor.resistance;
961
962 memristor.set_temperature(350.0); memristor.update(1.0, 1e-3);
964 let resistance_at_350k = memristor.resistance;
965
966 assert!(resistance_at_350k != resistance_at_300k);
968 }
969}