1use crate::error::ValidationError;
22use crate::types::{CsrMatrix, SolverResult};
23
24pub const MAX_NODES: usize = 10_000_000;
30
31pub const MAX_EDGES: usize = 100_000_000;
33
34pub const MAX_DIM: usize = 65_536;
36
37pub const MAX_ITERATIONS: usize = 1_000_000;
39
40pub const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
42
43pub fn validate_csr_matrix(matrix: &CsrMatrix<f32>) -> Result<(), ValidationError> {
76 if matrix.rows > MAX_NODES || matrix.cols > MAX_NODES {
78 return Err(ValidationError::MatrixTooLarge {
79 rows: matrix.rows,
80 cols: matrix.cols,
81 max_dim: MAX_NODES,
82 });
83 }
84
85 let nnz = matrix.values.len();
87 if nnz > MAX_EDGES {
88 return Err(ValidationError::DimensionMismatch(format!(
89 "nnz {} exceeds maximum allowed {}",
90 nnz, MAX_EDGES,
91 )));
92 }
93
94 let expected_row_ptr_len = matrix.rows + 1;
96 if matrix.row_ptr.len() != expected_row_ptr_len {
97 return Err(ValidationError::DimensionMismatch(format!(
98 "row_ptr length {} does not equal rows + 1 = {}",
99 matrix.row_ptr.len(),
100 expected_row_ptr_len,
101 )));
102 }
103
104 for i in 1..matrix.row_ptr.len() {
106 if matrix.row_ptr[i] < matrix.row_ptr[i - 1] {
107 return Err(ValidationError::NonMonotonicRowPtrs { position: i });
108 }
109 }
110
111 if matrix.row_ptr[0] != 0 {
113 return Err(ValidationError::DimensionMismatch(format!(
114 "row_ptr[0] = {} (expected 0)",
115 matrix.row_ptr[0],
116 )));
117 }
118 let expected_nnz = matrix.row_ptr[matrix.rows];
119 if expected_nnz != nnz {
120 return Err(ValidationError::DimensionMismatch(format!(
121 "values length {} does not match row_ptr[rows] = {}",
122 nnz, expected_nnz,
123 )));
124 }
125
126 if matrix.col_indices.len() != nnz {
128 return Err(ValidationError::DimensionMismatch(format!(
129 "col_indices length {} does not match values length {}",
130 matrix.col_indices.len(),
131 nnz,
132 )));
133 }
134
135 for row in 0..matrix.rows {
137 let start = matrix.row_ptr[row];
138 let end = matrix.row_ptr[row + 1];
139
140 let mut prev_col: Option<usize> = None;
141 for idx in start..end {
142 let col = matrix.col_indices[idx];
143 if col >= matrix.cols {
144 return Err(ValidationError::IndexOutOfBounds {
145 index: col as u32,
146 row,
147 cols: matrix.cols,
148 });
149 }
150
151 let val = matrix.values[idx];
152 if !val.is_finite() {
153 return Err(ValidationError::NonFiniteValue(format!(
154 "matrix[{}, {}] = {}",
155 row, col, val,
156 )));
157 }
158
159 if let Some(pc) = prev_col {
161 if col < pc {
162 tracing::warn!(
163 row = row,
164 "column indices not sorted within row (col {} follows {}); \
165 performance may be degraded",
166 col,
167 pc,
168 );
169 }
170 }
171 prev_col = Some(col);
172 }
173 }
174
175 Ok(())
176}
177
178pub fn validate_rhs(rhs: &[f32], expected_len: usize) -> Result<(), ValidationError> {
195 if rhs.len() != expected_len {
197 return Err(ValidationError::DimensionMismatch(format!(
198 "rhs length {} does not match expected {}",
199 rhs.len(),
200 expected_len,
201 )));
202 }
203
204 let mut all_zero = true;
206 for (i, &v) in rhs.iter().enumerate() {
207 if !v.is_finite() {
208 return Err(ValidationError::NonFiniteValue(format!(
209 "rhs[{}] = {}",
210 i, v,
211 )));
212 }
213 if v != 0.0 {
214 all_zero = false;
215 }
216 }
217
218 if all_zero && !rhs.is_empty() {
219 tracing::warn!("rhs vector is all zeros; solution will be trivially zero");
220 }
221
222 Ok(())
223}
224
225pub fn validate_rhs_vector(rhs: &[f32], expected_len: usize) -> Result<(), ValidationError> {
230 validate_rhs(rhs, expected_len)
231}
232
233pub fn validate_params(tolerance: f64, max_iterations: usize) -> Result<(), ValidationError> {
249 if !tolerance.is_finite() || tolerance <= 0.0 || tolerance > 1.0 {
250 return Err(ValidationError::ParameterOutOfRange {
251 name: "tolerance".into(),
252 value: format!("{tolerance:.2e}"),
253 expected: "(0.0, 1.0]".into(),
254 });
255 }
256
257 if max_iterations == 0 || max_iterations > MAX_ITERATIONS {
258 return Err(ValidationError::ParameterOutOfRange {
259 name: "max_iterations".into(),
260 value: max_iterations.to_string(),
261 expected: format!("[1, {}]", MAX_ITERATIONS),
262 });
263 }
264
265 Ok(())
266}
267
268pub fn validate_solver_input(
282 matrix: &CsrMatrix<f32>,
283 rhs: &[f32],
284 tolerance: f64,
285) -> Result<(), ValidationError> {
286 validate_csr_matrix(matrix)?;
287 validate_rhs(rhs, matrix.rows)?;
288
289 if matrix.rows != matrix.cols {
291 return Err(ValidationError::DimensionMismatch(format!(
292 "solver requires a square matrix but got {}x{}",
293 matrix.rows, matrix.cols,
294 )));
295 }
296
297 if !tolerance.is_finite() || tolerance <= 0.0 {
299 return Err(ValidationError::ParameterOutOfRange {
300 name: "tolerance".into(),
301 value: tolerance.to_string(),
302 expected: "finite positive value".into(),
303 });
304 }
305
306 Ok(())
307}
308
309pub fn validate_output(result: &SolverResult) -> Result<(), ValidationError> {
326 for (i, &v) in result.solution.iter().enumerate() {
328 if !v.is_finite() {
329 return Err(ValidationError::NonFiniteValue(format!(
330 "solution[{}] = {}",
331 i, v,
332 )));
333 }
334 }
335
336 if !result.residual_norm.is_finite() {
338 return Err(ValidationError::NonFiniteValue(format!(
339 "residual_norm = {}",
340 result.residual_norm,
341 )));
342 }
343
344 if result.iterations == 0 {
346 return Err(ValidationError::ParameterOutOfRange {
347 name: "iterations".into(),
348 value: "0".into(),
349 expected: ">= 1".into(),
350 });
351 }
352
353 Ok(())
354}
355
356pub fn validate_body_size(size: usize) -> Result<(), ValidationError> {
368 if size > MAX_BODY_SIZE {
369 return Err(ValidationError::ParameterOutOfRange {
370 name: "body_size".into(),
371 value: format!("{} bytes", size),
372 expected: format!("<= {} bytes (10 MiB)", MAX_BODY_SIZE),
373 });
374 }
375 Ok(())
376}
377
378#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::types::{Algorithm, ConvergenceInfo, CsrMatrix, SolverResult};
386 use std::time::Duration;
387
388 fn make_identity(n: usize) -> CsrMatrix<f32> {
389 let mut row_ptr = vec![0usize; n + 1];
390 let mut col_indices = Vec::with_capacity(n);
391 let mut values = Vec::with_capacity(n);
392 for i in 0..n {
393 row_ptr[i + 1] = i + 1;
394 col_indices.push(i);
395 values.push(1.0);
396 }
397 CsrMatrix {
398 values,
399 col_indices,
400 row_ptr,
401 rows: n,
402 cols: n,
403 }
404 }
405
406 #[test]
409 fn valid_identity() {
410 let mat = make_identity(4);
411 assert!(validate_csr_matrix(&mat).is_ok());
412 }
413
414 #[test]
415 fn valid_empty_matrix() {
416 let m = CsrMatrix {
417 row_ptr: vec![0],
418 col_indices: vec![],
419 values: vec![],
420 rows: 0,
421 cols: 0,
422 };
423 assert!(validate_csr_matrix(&m).is_ok());
424 }
425
426 #[test]
427 fn valid_from_coo() {
428 let m = CsrMatrix::<f32>::from_coo(
429 3,
430 3,
431 vec![
432 (0, 0, 2.0),
433 (0, 1, -0.5),
434 (1, 0, -0.5),
435 (1, 1, 2.0),
436 (1, 2, -0.5),
437 (2, 1, -0.5),
438 (2, 2, 2.0),
439 ],
440 );
441 assert!(validate_csr_matrix(&m).is_ok());
442 }
443
444 #[test]
445 fn rejects_too_large_matrix() {
446 let m = CsrMatrix {
447 row_ptr: vec![0, 0],
448 col_indices: vec![],
449 values: vec![],
450 rows: MAX_NODES + 1,
451 cols: 1,
452 };
453 assert!(matches!(
454 validate_csr_matrix(&m),
455 Err(ValidationError::MatrixTooLarge { .. })
456 ));
457 }
458
459 #[test]
460 fn rejects_wrong_row_ptr_length() {
461 let m = CsrMatrix {
462 row_ptr: vec![0, 1],
463 col_indices: vec![0],
464 values: vec![1.0],
465 rows: 3,
466 cols: 3,
467 };
468 assert!(matches!(
469 validate_csr_matrix(&m),
470 Err(ValidationError::DimensionMismatch(_))
471 ));
472 }
473
474 #[test]
475 fn non_monotonic_row_ptr() {
476 let mut mat = make_identity(4);
477 mat.row_ptr[2] = 0; let err = validate_csr_matrix(&mat).unwrap_err();
479 assert!(matches!(err, ValidationError::NonMonotonicRowPtrs { .. }));
480 }
481
482 #[test]
483 fn rejects_row_ptr_not_starting_at_zero() {
484 let m = CsrMatrix {
485 row_ptr: vec![1, 2],
486 col_indices: vec![0],
487 values: vec![1.0],
488 rows: 1,
489 cols: 1,
490 };
491 match validate_csr_matrix(&m) {
492 Err(ValidationError::DimensionMismatch(msg)) => {
493 assert!(msg.contains("row_ptr[0]"), "msg: {msg}");
494 }
495 other => panic!("expected DimensionMismatch for row_ptr[0], got {other:?}"),
496 }
497 }
498
499 #[test]
500 fn col_index_out_of_bounds() {
501 let mut mat = make_identity(4);
502 mat.col_indices[1] = 99;
503 let err = validate_csr_matrix(&mat).unwrap_err();
504 assert!(matches!(err, ValidationError::IndexOutOfBounds { .. }));
505 }
506
507 #[test]
508 fn nan_value_rejected() {
509 let mut mat = make_identity(4);
510 mat.values[0] = f32::NAN;
511 let err = validate_csr_matrix(&mat).unwrap_err();
512 assert!(matches!(err, ValidationError::NonFiniteValue(_)));
513 }
514
515 #[test]
516 fn inf_value_rejected() {
517 let mut mat = make_identity(4);
518 mat.values[0] = f32::INFINITY;
519 let err = validate_csr_matrix(&mat).unwrap_err();
520 assert!(matches!(err, ValidationError::NonFiniteValue(_)));
521 }
522
523 #[test]
526 fn valid_rhs() {
527 assert!(validate_rhs(&[1.0, 2.0, 3.0], 3).is_ok());
528 }
529
530 #[test]
531 fn rhs_dimension_mismatch() {
532 let err = validate_rhs(&[1.0, 2.0], 3).unwrap_err();
533 assert!(matches!(err, ValidationError::DimensionMismatch(_)));
534 }
535
536 #[test]
537 fn rhs_nan_rejected() {
538 let err = validate_rhs(&[1.0, f32::NAN, 3.0], 3).unwrap_err();
539 assert!(matches!(err, ValidationError::NonFiniteValue(_)));
540 }
541
542 #[test]
543 fn rhs_inf_rejected() {
544 let err = validate_rhs(&[1.0, f32::NEG_INFINITY, 3.0], 3).unwrap_err();
545 assert!(matches!(err, ValidationError::NonFiniteValue(_)));
546 }
547
548 #[test]
549 fn warns_on_all_zero_rhs() {
550 assert!(validate_rhs(&[0.0, 0.0, 0.0], 3).is_ok());
553 }
554
555 #[test]
558 fn rhs_vector_alias_works() {
559 assert!(validate_rhs_vector(&[1.0, 2.0], 2).is_ok());
560 assert!(validate_rhs_vector(&[1.0, 2.0], 3).is_err());
561 }
562
563 #[test]
566 fn valid_params() {
567 assert!(validate_params(1e-8, 500).is_ok());
568 assert!(validate_params(1.0, 1).is_ok());
569 }
570
571 #[test]
572 fn rejects_zero_tolerance() {
573 match validate_params(0.0, 100) {
574 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
575 assert_eq!(name, "tolerance");
576 }
577 other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
578 }
579 }
580
581 #[test]
582 fn rejects_negative_tolerance() {
583 match validate_params(-1e-6, 100) {
584 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
585 assert_eq!(name, "tolerance");
586 }
587 other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
588 }
589 }
590
591 #[test]
592 fn rejects_tolerance_above_one() {
593 match validate_params(1.5, 100) {
594 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
595 assert_eq!(name, "tolerance");
596 }
597 other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
598 }
599 }
600
601 #[test]
602 fn rejects_nan_tolerance() {
603 match validate_params(f64::NAN, 100) {
604 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
605 assert_eq!(name, "tolerance");
606 }
607 other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
608 }
609 }
610
611 #[test]
612 fn rejects_zero_iterations() {
613 match validate_params(1e-6, 0) {
614 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
615 assert_eq!(name, "max_iterations");
616 }
617 other => panic!(
618 "expected ParameterOutOfRange for max_iterations, got {other:?}"
619 ),
620 }
621 }
622
623 #[test]
624 fn rejects_excessive_iterations() {
625 match validate_params(1e-6, MAX_ITERATIONS + 1) {
626 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
627 assert_eq!(name, "max_iterations");
628 }
629 other => panic!(
630 "expected ParameterOutOfRange for max_iterations, got {other:?}"
631 ),
632 }
633 }
634
635 #[test]
638 fn full_input_validation() {
639 let mat = make_identity(3);
640 let rhs = vec![1.0f32, 2.0, 3.0];
641 assert!(validate_solver_input(&mat, &rhs, 1e-6).is_ok());
642 }
643
644 #[test]
645 fn non_square_rejected() {
646 let mat = CsrMatrix {
647 values: vec![],
648 col_indices: vec![],
649 row_ptr: vec![0, 0, 0],
650 rows: 2,
651 cols: 3,
652 };
653 let rhs = vec![1.0f32, 2.0];
654 let err = validate_solver_input(&mat, &rhs, 1e-6).unwrap_err();
655 assert!(matches!(err, ValidationError::DimensionMismatch(_)));
656 }
657
658 #[test]
659 fn invalid_tolerance_rejected() {
660 let mat = make_identity(2);
661 let rhs = vec![1.0f32, 2.0];
662 assert!(validate_solver_input(&mat, &rhs, -1.0).is_err());
663 assert!(validate_solver_input(&mat, &rhs, 0.0).is_err());
664 assert!(validate_solver_input(&mat, &rhs, f64::NAN).is_err());
665 }
666
667 #[test]
670 fn valid_output() {
671 let result = SolverResult {
672 solution: vec![1.0, 2.0, 3.0],
673 iterations: 10,
674 residual_norm: 1e-8,
675 wall_time: Duration::from_millis(5),
676 convergence_history: vec![ConvergenceInfo {
677 iteration: 0,
678 residual_norm: 1.0,
679 }],
680 algorithm: Algorithm::Neumann,
681 };
682 assert!(validate_output(&result).is_ok());
683 }
684
685 #[test]
686 fn rejects_nan_in_solution() {
687 let result = SolverResult {
688 solution: vec![1.0, f32::NAN, 3.0],
689 iterations: 1,
690 residual_norm: 1e-8,
691 wall_time: Duration::from_millis(1),
692 convergence_history: vec![],
693 algorithm: Algorithm::Neumann,
694 };
695 match validate_output(&result) {
696 Err(ValidationError::NonFiniteValue(ref msg)) => {
697 assert!(msg.contains("solution"), "msg: {msg}");
698 }
699 other => panic!("expected NonFiniteValue for solution, got {other:?}"),
700 }
701 }
702
703 #[test]
704 fn rejects_inf_in_solution() {
705 let result = SolverResult {
706 solution: vec![f32::INFINITY],
707 iterations: 1,
708 residual_norm: 1e-8,
709 wall_time: Duration::from_millis(1),
710 convergence_history: vec![],
711 algorithm: Algorithm::Neumann,
712 };
713 match validate_output(&result) {
714 Err(ValidationError::NonFiniteValue(ref msg)) => {
715 assert!(msg.contains("solution"), "msg: {msg}");
716 }
717 other => panic!("expected NonFiniteValue for solution, got {other:?}"),
718 }
719 }
720
721 #[test]
722 fn rejects_nan_residual() {
723 let result = SolverResult {
724 solution: vec![1.0],
725 iterations: 1,
726 residual_norm: f64::NAN,
727 wall_time: Duration::from_millis(1),
728 convergence_history: vec![],
729 algorithm: Algorithm::Neumann,
730 };
731 match validate_output(&result) {
732 Err(ValidationError::NonFiniteValue(ref msg)) => {
733 assert!(msg.contains("residual"), "msg: {msg}");
734 }
735 other => panic!("expected NonFiniteValue for residual, got {other:?}"),
736 }
737 }
738
739 #[test]
740 fn rejects_inf_residual() {
741 let result = SolverResult {
742 solution: vec![1.0],
743 iterations: 1,
744 residual_norm: f64::INFINITY,
745 wall_time: Duration::from_millis(1),
746 convergence_history: vec![],
747 algorithm: Algorithm::Neumann,
748 };
749 assert!(matches!(
750 validate_output(&result),
751 Err(ValidationError::NonFiniteValue(_))
752 ));
753 }
754
755 #[test]
756 fn rejects_zero_iterations_in_output() {
757 let result = SolverResult {
758 solution: vec![1.0],
759 iterations: 0,
760 residual_norm: 1e-8,
761 wall_time: Duration::from_millis(1),
762 convergence_history: vec![],
763 algorithm: Algorithm::Neumann,
764 };
765 match validate_output(&result) {
766 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
767 assert_eq!(name, "iterations");
768 }
769 other => panic!("expected ParameterOutOfRange, got {other:?}"),
770 }
771 }
772
773 #[test]
776 fn valid_body_size() {
777 assert!(validate_body_size(1024).is_ok());
778 assert!(validate_body_size(MAX_BODY_SIZE).is_ok());
779 }
780
781 #[test]
782 fn rejects_oversized_body() {
783 match validate_body_size(MAX_BODY_SIZE + 1) {
784 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
785 assert_eq!(name, "body_size");
786 }
787 other => panic!("expected ParameterOutOfRange, got {other:?}"),
788 }
789 }
790}