1use crate::error::ValidationError;
22use crate::types::{CsrMatrix, SolverResult};
23
24pub const MAX_NODES: usize = 10_000_000;
30
31pub const MAX_EDGES: usize = 100_000_000;
33
34pub const MAX_DIM: usize = 65_536;
36
37pub const MAX_ITERATIONS: usize = 1_000_000;
39
40pub const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
42
43pub fn validate_csr_matrix(matrix: &CsrMatrix<f32>) -> Result<(), ValidationError> {
76 if matrix.rows > MAX_NODES || matrix.cols > MAX_NODES {
78 return Err(ValidationError::MatrixTooLarge {
79 rows: matrix.rows,
80 cols: matrix.cols,
81 max_dim: MAX_NODES,
82 });
83 }
84
85 let nnz = matrix.values.len();
87 if nnz > MAX_EDGES {
88 return Err(ValidationError::DimensionMismatch(format!(
89 "nnz {} exceeds maximum allowed {}",
90 nnz, MAX_EDGES,
91 )));
92 }
93
94 let expected_row_ptr_len = matrix.rows + 1;
96 if matrix.row_ptr.len() != expected_row_ptr_len {
97 return Err(ValidationError::DimensionMismatch(format!(
98 "row_ptr length {} does not equal rows + 1 = {}",
99 matrix.row_ptr.len(),
100 expected_row_ptr_len,
101 )));
102 }
103
104 for i in 1..matrix.row_ptr.len() {
106 if matrix.row_ptr[i] < matrix.row_ptr[i - 1] {
107 return Err(ValidationError::NonMonotonicRowPtrs { position: i });
108 }
109 }
110
111 if matrix.row_ptr[0] != 0 {
113 return Err(ValidationError::DimensionMismatch(format!(
114 "row_ptr[0] = {} (expected 0)",
115 matrix.row_ptr[0],
116 )));
117 }
118 let expected_nnz = matrix.row_ptr[matrix.rows];
119 if expected_nnz != nnz {
120 return Err(ValidationError::DimensionMismatch(format!(
121 "values length {} does not match row_ptr[rows] = {}",
122 nnz, expected_nnz,
123 )));
124 }
125
126 if matrix.col_indices.len() != nnz {
128 return Err(ValidationError::DimensionMismatch(format!(
129 "col_indices length {} does not match values length {}",
130 matrix.col_indices.len(),
131 nnz,
132 )));
133 }
134
135 for row in 0..matrix.rows {
137 let start = matrix.row_ptr[row];
138 let end = matrix.row_ptr[row + 1];
139
140 let mut prev_col: Option<usize> = None;
141 for idx in start..end {
142 let col = matrix.col_indices[idx];
143 if col >= matrix.cols {
144 return Err(ValidationError::IndexOutOfBounds {
145 index: col as u32,
146 row,
147 cols: matrix.cols,
148 });
149 }
150
151 let val = matrix.values[idx];
152 if !val.is_finite() {
153 return Err(ValidationError::NonFiniteValue(format!(
154 "matrix[{}, {}] = {}",
155 row, col, val,
156 )));
157 }
158
159 if let Some(pc) = prev_col {
161 if col < pc {
162 tracing::warn!(
163 row = row,
164 "column indices not sorted within row (col {} follows {}); \
165 performance may be degraded",
166 col,
167 pc,
168 );
169 }
170 }
171 prev_col = Some(col);
172 }
173 }
174
175 Ok(())
176}
177
178pub fn validate_rhs(rhs: &[f32], expected_len: usize) -> Result<(), ValidationError> {
195 if rhs.len() != expected_len {
197 return Err(ValidationError::DimensionMismatch(format!(
198 "rhs length {} does not match expected {}",
199 rhs.len(),
200 expected_len,
201 )));
202 }
203
204 let mut all_zero = true;
206 for (i, &v) in rhs.iter().enumerate() {
207 if !v.is_finite() {
208 return Err(ValidationError::NonFiniteValue(format!(
209 "rhs[{}] = {}",
210 i, v,
211 )));
212 }
213 if v != 0.0 {
214 all_zero = false;
215 }
216 }
217
218 if all_zero && !rhs.is_empty() {
219 tracing::warn!("rhs vector is all zeros; solution will be trivially zero");
220 }
221
222 Ok(())
223}
224
225pub fn validate_rhs_vector(rhs: &[f32], expected_len: usize) -> Result<(), ValidationError> {
230 validate_rhs(rhs, expected_len)
231}
232
233pub fn validate_params(tolerance: f64, max_iterations: usize) -> Result<(), ValidationError> {
249 if !tolerance.is_finite() || tolerance <= 0.0 || tolerance > 1.0 {
250 return Err(ValidationError::ParameterOutOfRange {
251 name: "tolerance".into(),
252 value: format!("{tolerance:.2e}"),
253 expected: "(0.0, 1.0]".into(),
254 });
255 }
256
257 if max_iterations == 0 || max_iterations > MAX_ITERATIONS {
258 return Err(ValidationError::ParameterOutOfRange {
259 name: "max_iterations".into(),
260 value: max_iterations.to_string(),
261 expected: format!("[1, {}]", MAX_ITERATIONS),
262 });
263 }
264
265 Ok(())
266}
267
268pub fn validate_solver_input(
282 matrix: &CsrMatrix<f32>,
283 rhs: &[f32],
284 tolerance: f64,
285) -> Result<(), ValidationError> {
286 validate_csr_matrix(matrix)?;
287 validate_rhs(rhs, matrix.rows)?;
288
289 if matrix.rows != matrix.cols {
291 return Err(ValidationError::DimensionMismatch(format!(
292 "solver requires a square matrix but got {}x{}",
293 matrix.rows, matrix.cols,
294 )));
295 }
296
297 if !tolerance.is_finite() || tolerance <= 0.0 {
299 return Err(ValidationError::ParameterOutOfRange {
300 name: "tolerance".into(),
301 value: tolerance.to_string(),
302 expected: "finite positive value".into(),
303 });
304 }
305
306 Ok(())
307}
308
309pub fn validate_output(result: &SolverResult) -> Result<(), ValidationError> {
326 for (i, &v) in result.solution.iter().enumerate() {
328 if !v.is_finite() {
329 return Err(ValidationError::NonFiniteValue(format!(
330 "solution[{}] = {}",
331 i, v,
332 )));
333 }
334 }
335
336 if !result.residual_norm.is_finite() {
338 return Err(ValidationError::NonFiniteValue(format!(
339 "residual_norm = {}",
340 result.residual_norm,
341 )));
342 }
343
344 if result.iterations == 0 {
346 return Err(ValidationError::ParameterOutOfRange {
347 name: "iterations".into(),
348 value: "0".into(),
349 expected: ">= 1".into(),
350 });
351 }
352
353 Ok(())
354}
355
356pub fn validate_body_size(size: usize) -> Result<(), ValidationError> {
368 if size > MAX_BODY_SIZE {
369 return Err(ValidationError::ParameterOutOfRange {
370 name: "body_size".into(),
371 value: format!("{} bytes", size),
372 expected: format!("<= {} bytes (10 MiB)", MAX_BODY_SIZE),
373 });
374 }
375 Ok(())
376}
377
378#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::types::{Algorithm, ConvergenceInfo, CsrMatrix, SolverResult};
386 use std::time::Duration;
387
388 fn make_identity(n: usize) -> CsrMatrix<f32> {
389 let mut row_ptr = vec![0usize; n + 1];
390 let mut col_indices = Vec::with_capacity(n);
391 let mut values = Vec::with_capacity(n);
392 for i in 0..n {
393 row_ptr[i + 1] = i + 1;
394 col_indices.push(i);
395 values.push(1.0);
396 }
397 CsrMatrix {
398 values,
399 col_indices,
400 row_ptr,
401 rows: n,
402 cols: n,
403 }
404 }
405
406 #[test]
409 fn valid_identity() {
410 let mat = make_identity(4);
411 assert!(validate_csr_matrix(&mat).is_ok());
412 }
413
414 #[test]
415 fn valid_empty_matrix() {
416 let m = CsrMatrix {
417 row_ptr: vec![0],
418 col_indices: vec![],
419 values: vec![],
420 rows: 0,
421 cols: 0,
422 };
423 assert!(validate_csr_matrix(&m).is_ok());
424 }
425
426 #[test]
427 fn valid_from_coo() {
428 let m = CsrMatrix::<f32>::from_coo(
429 3,
430 3,
431 vec![
432 (0, 0, 2.0),
433 (0, 1, -0.5),
434 (1, 0, -0.5),
435 (1, 1, 2.0),
436 (1, 2, -0.5),
437 (2, 1, -0.5),
438 (2, 2, 2.0),
439 ],
440 );
441 assert!(validate_csr_matrix(&m).is_ok());
442 }
443
444 #[test]
445 fn rejects_too_large_matrix() {
446 let m = CsrMatrix {
447 row_ptr: vec![0, 0],
448 col_indices: vec![],
449 values: vec![],
450 rows: MAX_NODES + 1,
451 cols: 1,
452 };
453 assert!(matches!(
454 validate_csr_matrix(&m),
455 Err(ValidationError::MatrixTooLarge { .. })
456 ));
457 }
458
459 #[test]
460 fn rejects_wrong_row_ptr_length() {
461 let m = CsrMatrix {
462 row_ptr: vec![0, 1],
463 col_indices: vec![0],
464 values: vec![1.0],
465 rows: 3,
466 cols: 3,
467 };
468 assert!(matches!(
469 validate_csr_matrix(&m),
470 Err(ValidationError::DimensionMismatch(_))
471 ));
472 }
473
474 #[test]
475 fn non_monotonic_row_ptr() {
476 let mut mat = make_identity(4);
477 mat.row_ptr[2] = 0; let err = validate_csr_matrix(&mat).unwrap_err();
479 assert!(matches!(err, ValidationError::NonMonotonicRowPtrs { .. }));
480 }
481
482 #[test]
483 fn rejects_row_ptr_not_starting_at_zero() {
484 let m = CsrMatrix {
485 row_ptr: vec![1, 2],
486 col_indices: vec![0],
487 values: vec![1.0],
488 rows: 1,
489 cols: 1,
490 };
491 match validate_csr_matrix(&m) {
492 Err(ValidationError::DimensionMismatch(msg)) => {
493 assert!(msg.contains("row_ptr[0]"), "msg: {msg}");
494 }
495 other => panic!("expected DimensionMismatch for row_ptr[0], got {other:?}"),
496 }
497 }
498
499 #[test]
500 fn col_index_out_of_bounds() {
501 let mut mat = make_identity(4);
502 mat.col_indices[1] = 99;
503 let err = validate_csr_matrix(&mat).unwrap_err();
504 assert!(matches!(err, ValidationError::IndexOutOfBounds { .. }));
505 }
506
507 #[test]
508 fn nan_value_rejected() {
509 let mut mat = make_identity(4);
510 mat.values[0] = f32::NAN;
511 let err = validate_csr_matrix(&mat).unwrap_err();
512 assert!(matches!(err, ValidationError::NonFiniteValue(_)));
513 }
514
515 #[test]
516 fn inf_value_rejected() {
517 let mut mat = make_identity(4);
518 mat.values[0] = f32::INFINITY;
519 let err = validate_csr_matrix(&mat).unwrap_err();
520 assert!(matches!(err, ValidationError::NonFiniteValue(_)));
521 }
522
523 #[test]
526 fn valid_rhs() {
527 assert!(validate_rhs(&[1.0, 2.0, 3.0], 3).is_ok());
528 }
529
530 #[test]
531 fn rhs_dimension_mismatch() {
532 let err = validate_rhs(&[1.0, 2.0], 3).unwrap_err();
533 assert!(matches!(err, ValidationError::DimensionMismatch(_)));
534 }
535
536 #[test]
537 fn rhs_nan_rejected() {
538 let err = validate_rhs(&[1.0, f32::NAN, 3.0], 3).unwrap_err();
539 assert!(matches!(err, ValidationError::NonFiniteValue(_)));
540 }
541
542 #[test]
543 fn rhs_inf_rejected() {
544 let err = validate_rhs(&[1.0, f32::NEG_INFINITY, 3.0], 3).unwrap_err();
545 assert!(matches!(err, ValidationError::NonFiniteValue(_)));
546 }
547
548 #[test]
549 fn warns_on_all_zero_rhs() {
550 assert!(validate_rhs(&[0.0, 0.0, 0.0], 3).is_ok());
553 }
554
555 #[test]
558 fn rhs_vector_alias_works() {
559 assert!(validate_rhs_vector(&[1.0, 2.0], 2).is_ok());
560 assert!(validate_rhs_vector(&[1.0, 2.0], 3).is_err());
561 }
562
563 #[test]
566 fn valid_params() {
567 assert!(validate_params(1e-8, 500).is_ok());
568 assert!(validate_params(1.0, 1).is_ok());
569 }
570
571 #[test]
572 fn rejects_zero_tolerance() {
573 match validate_params(0.0, 100) {
574 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
575 assert_eq!(name, "tolerance");
576 }
577 other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
578 }
579 }
580
581 #[test]
582 fn rejects_negative_tolerance() {
583 match validate_params(-1e-6, 100) {
584 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
585 assert_eq!(name, "tolerance");
586 }
587 other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
588 }
589 }
590
591 #[test]
592 fn rejects_tolerance_above_one() {
593 match validate_params(1.5, 100) {
594 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
595 assert_eq!(name, "tolerance");
596 }
597 other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
598 }
599 }
600
601 #[test]
602 fn rejects_nan_tolerance() {
603 match validate_params(f64::NAN, 100) {
604 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
605 assert_eq!(name, "tolerance");
606 }
607 other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
608 }
609 }
610
611 #[test]
612 fn rejects_zero_iterations() {
613 match validate_params(1e-6, 0) {
614 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
615 assert_eq!(name, "max_iterations");
616 }
617 other => panic!("expected ParameterOutOfRange for max_iterations, got {other:?}"),
618 }
619 }
620
621 #[test]
622 fn rejects_excessive_iterations() {
623 match validate_params(1e-6, MAX_ITERATIONS + 1) {
624 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
625 assert_eq!(name, "max_iterations");
626 }
627 other => panic!("expected ParameterOutOfRange for max_iterations, got {other:?}"),
628 }
629 }
630
631 #[test]
634 fn full_input_validation() {
635 let mat = make_identity(3);
636 let rhs = vec![1.0f32, 2.0, 3.0];
637 assert!(validate_solver_input(&mat, &rhs, 1e-6).is_ok());
638 }
639
640 #[test]
641 fn non_square_rejected() {
642 let mat = CsrMatrix {
643 values: vec![],
644 col_indices: vec![],
645 row_ptr: vec![0, 0, 0],
646 rows: 2,
647 cols: 3,
648 };
649 let rhs = vec![1.0f32, 2.0];
650 let err = validate_solver_input(&mat, &rhs, 1e-6).unwrap_err();
651 assert!(matches!(err, ValidationError::DimensionMismatch(_)));
652 }
653
654 #[test]
655 fn invalid_tolerance_rejected() {
656 let mat = make_identity(2);
657 let rhs = vec![1.0f32, 2.0];
658 assert!(validate_solver_input(&mat, &rhs, -1.0).is_err());
659 assert!(validate_solver_input(&mat, &rhs, 0.0).is_err());
660 assert!(validate_solver_input(&mat, &rhs, f64::NAN).is_err());
661 }
662
663 #[test]
666 fn valid_output() {
667 let result = SolverResult {
668 solution: vec![1.0, 2.0, 3.0],
669 iterations: 10,
670 residual_norm: 1e-8,
671 wall_time: Duration::from_millis(5),
672 convergence_history: vec![ConvergenceInfo {
673 iteration: 0,
674 residual_norm: 1.0,
675 }],
676 algorithm: Algorithm::Neumann,
677 };
678 assert!(validate_output(&result).is_ok());
679 }
680
681 #[test]
682 fn rejects_nan_in_solution() {
683 let result = SolverResult {
684 solution: vec![1.0, f32::NAN, 3.0],
685 iterations: 1,
686 residual_norm: 1e-8,
687 wall_time: Duration::from_millis(1),
688 convergence_history: vec![],
689 algorithm: Algorithm::Neumann,
690 };
691 match validate_output(&result) {
692 Err(ValidationError::NonFiniteValue(ref msg)) => {
693 assert!(msg.contains("solution"), "msg: {msg}");
694 }
695 other => panic!("expected NonFiniteValue for solution, got {other:?}"),
696 }
697 }
698
699 #[test]
700 fn rejects_inf_in_solution() {
701 let result = SolverResult {
702 solution: vec![f32::INFINITY],
703 iterations: 1,
704 residual_norm: 1e-8,
705 wall_time: Duration::from_millis(1),
706 convergence_history: vec![],
707 algorithm: Algorithm::Neumann,
708 };
709 match validate_output(&result) {
710 Err(ValidationError::NonFiniteValue(ref msg)) => {
711 assert!(msg.contains("solution"), "msg: {msg}");
712 }
713 other => panic!("expected NonFiniteValue for solution, got {other:?}"),
714 }
715 }
716
717 #[test]
718 fn rejects_nan_residual() {
719 let result = SolverResult {
720 solution: vec![1.0],
721 iterations: 1,
722 residual_norm: f64::NAN,
723 wall_time: Duration::from_millis(1),
724 convergence_history: vec![],
725 algorithm: Algorithm::Neumann,
726 };
727 match validate_output(&result) {
728 Err(ValidationError::NonFiniteValue(ref msg)) => {
729 assert!(msg.contains("residual"), "msg: {msg}");
730 }
731 other => panic!("expected NonFiniteValue for residual, got {other:?}"),
732 }
733 }
734
735 #[test]
736 fn rejects_inf_residual() {
737 let result = SolverResult {
738 solution: vec![1.0],
739 iterations: 1,
740 residual_norm: f64::INFINITY,
741 wall_time: Duration::from_millis(1),
742 convergence_history: vec![],
743 algorithm: Algorithm::Neumann,
744 };
745 assert!(matches!(
746 validate_output(&result),
747 Err(ValidationError::NonFiniteValue(_))
748 ));
749 }
750
751 #[test]
752 fn rejects_zero_iterations_in_output() {
753 let result = SolverResult {
754 solution: vec![1.0],
755 iterations: 0,
756 residual_norm: 1e-8,
757 wall_time: Duration::from_millis(1),
758 convergence_history: vec![],
759 algorithm: Algorithm::Neumann,
760 };
761 match validate_output(&result) {
762 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
763 assert_eq!(name, "iterations");
764 }
765 other => panic!("expected ParameterOutOfRange, got {other:?}"),
766 }
767 }
768
769 #[test]
772 fn valid_body_size() {
773 assert!(validate_body_size(1024).is_ok());
774 assert!(validate_body_size(MAX_BODY_SIZE).is_ok());
775 }
776
777 #[test]
778 fn rejects_oversized_body() {
779 match validate_body_size(MAX_BODY_SIZE + 1) {
780 Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
781 assert_eq!(name, "body_size");
782 }
783 other => panic!("expected ParameterOutOfRange, got {other:?}"),
784 }
785 }
786}