1use crate::error::{OptimizeError, OptimizeResult};
26
27pub fn knapsack_dp(
62 values: &[f64],
63 weights: &[f64],
64 capacity: f64,
65) -> OptimizeResult<(f64, Vec<bool>)> {
66 let n = values.len();
67 if weights.len() != n {
68 return Err(OptimizeError::InvalidInput(
69 "values and weights must have the same length".to_string(),
70 ));
71 }
72 for &v in values {
73 if v < 0.0 {
74 return Err(OptimizeError::InvalidInput(
75 "all values must be non-negative".to_string(),
76 ));
77 }
78 }
79 for &w in weights {
80 if w < 0.0 {
81 return Err(OptimizeError::InvalidInput(
82 "all weights must be non-negative".to_string(),
83 ));
84 }
85 }
86 if capacity < 0.0 {
87 return Err(OptimizeError::InvalidInput(
88 "capacity must be non-negative".to_string(),
89 ));
90 }
91 if n == 0 {
92 return Ok((0.0, Vec::new()));
93 }
94
95 let scale = {
103 let mut best_scale = 1.0_f64;
105 for &w in weights.iter().chain(std::iter::once(&capacity)) {
106 if w <= 0.0 {
107 continue;
108 }
109 let mut s = 1.0_f64;
111 for _ in 0..6 {
112 if ((w * s).round() - w * s).abs() < 1e-9 {
113 break;
114 }
115 s *= 10.0;
116 }
117 if s > best_scale {
118 best_scale = s;
119 }
120 }
121 let max_scale = if capacity > 0.0 {
123 (20_000_000.0 / capacity).max(1.0)
124 } else {
125 1e6
126 };
127 best_scale.min(max_scale)
128 };
129
130 let scaled_cap = (capacity * scale).round() as usize;
131 let scaled_w: Vec<usize> = weights
132 .iter()
133 .map(|&w| (w * scale).round() as usize)
134 .collect();
135
136 let cols = scaled_cap + 1;
139 let mut prev = vec![0.0_f64; cols];
140 let mut curr = vec![0.0_f64; cols];
141
142 let use_backtrack = (n as u64) * (cols as u64) <= 40_000_000;
145
146 let mut keep: Vec<Vec<bool>> = if use_backtrack {
147 vec![vec![false; cols]; n]
148 } else {
149 Vec::new()
150 };
151
152 for i in 0..n {
153 for w in 0..cols {
154 let sw = scaled_w[i];
155 if sw > w {
156 curr[w] = prev[w];
157 if use_backtrack {
158 keep[i][w] = false;
159 }
160 } else {
161 let take = prev[w - sw] + values[i];
162 if take > prev[w] {
163 curr[w] = take;
164 if use_backtrack {
165 keep[i][w] = true;
166 }
167 } else {
168 curr[w] = prev[w];
169 if use_backtrack {
170 keep[i][w] = false;
171 }
172 }
173 }
174 }
175 std::mem::swap(&mut prev, &mut curr);
176 }
177
178 let opt_val = prev[scaled_cap];
179
180 let selection = if use_backtrack {
182 let mut sel = vec![false; n];
183 let mut w = scaled_cap;
184 for i in (0..n).rev() {
185 if keep[i][w] {
186 sel[i] = true;
187 w -= scaled_w[i];
188 }
189 }
190 sel
191 } else {
192 greedy_selection_fallback(values, weights, capacity)
194 };
195
196 Ok((opt_val, selection))
197}
198
199fn greedy_selection_fallback(values: &[f64], weights: &[f64], capacity: f64) -> Vec<bool> {
201 let n = values.len();
202 let mut order: Vec<usize> = (0..n).collect();
203 order.sort_by(|&a, &b| {
204 let ra = if weights[a] > 0.0 {
205 values[a] / weights[a]
206 } else {
207 f64::INFINITY
208 };
209 let rb = if weights[b] > 0.0 {
210 values[b] / weights[b]
211 } else {
212 f64::INFINITY
213 };
214 rb.partial_cmp(&ra).unwrap_or(std::cmp::Ordering::Equal)
215 });
216 let mut sel = vec![false; n];
217 let mut rem = capacity;
218 for i in order {
219 if weights[i] <= rem {
220 sel[i] = true;
221 rem -= weights[i];
222 }
223 }
224 sel
225}
226
227pub fn fractional_knapsack(values: &[f64], weights: &[f64], capacity: f64) -> OptimizeResult<f64> {
256 let n = values.len();
257 if weights.len() != n {
258 return Err(OptimizeError::InvalidInput(
259 "values and weights must have the same length".to_string(),
260 ));
261 }
262 for &v in values {
263 if v < 0.0 {
264 return Err(OptimizeError::InvalidInput(
265 "all values must be non-negative".to_string(),
266 ));
267 }
268 }
269 for &w in weights {
270 if w < 0.0 {
271 return Err(OptimizeError::InvalidInput(
272 "all weights must be non-negative".to_string(),
273 ));
274 }
275 }
276 if capacity < 0.0 {
277 return Err(OptimizeError::InvalidInput(
278 "capacity must be non-negative".to_string(),
279 ));
280 }
281 if n == 0 {
282 return Ok(0.0);
283 }
284
285 let mut order: Vec<usize> = (0..n).collect();
287 order.sort_by(|&a, &b| {
288 let ra = if weights[a] > 0.0 {
289 values[a] / weights[a]
290 } else {
291 f64::INFINITY
292 };
293 let rb = if weights[b] > 0.0 {
294 values[b] / weights[b]
295 } else {
296 f64::INFINITY
297 };
298 rb.partial_cmp(&ra).unwrap_or(std::cmp::Ordering::Equal)
299 });
300
301 let mut total_value = 0.0;
302 let mut remaining = capacity;
303
304 for i in order {
305 if remaining <= 0.0 {
306 break;
307 }
308 let w = weights[i];
309 if w <= 0.0 {
310 total_value += values[i];
312 } else if w <= remaining {
313 total_value += values[i];
314 remaining -= w;
315 } else {
316 total_value += values[i] * (remaining / w);
318 remaining = 0.0;
319 }
320 }
321
322 Ok(total_value)
323}
324
325pub fn bounded_knapsack(
366 values: &[f64],
367 weights: &[f64],
368 counts: &[usize],
369 capacity: f64,
370) -> OptimizeResult<(f64, Vec<usize>)> {
371 let n = values.len();
372 if weights.len() != n || counts.len() != n {
373 return Err(OptimizeError::InvalidInput(
374 "values, weights, and counts must have the same length".to_string(),
375 ));
376 }
377 for &v in values {
378 if v < 0.0 {
379 return Err(OptimizeError::InvalidInput(
380 "all values must be non-negative".to_string(),
381 ));
382 }
383 }
384 for &w in weights {
385 if w < 0.0 {
386 return Err(OptimizeError::InvalidInput(
387 "all weights must be non-negative".to_string(),
388 ));
389 }
390 }
391 if capacity < 0.0 {
392 return Err(OptimizeError::InvalidInput(
393 "capacity must be non-negative".to_string(),
394 ));
395 }
396 if n == 0 {
397 return Ok((0.0, Vec::new()));
398 }
399
400 let mut virtual_values: Vec<f64> = Vec::new();
403 let mut virtual_weights: Vec<f64> = Vec::new();
404 let mut virtual_orig: Vec<usize> = Vec::new();
405 let mut virtual_mult: Vec<usize> = Vec::new();
406
407 for i in 0..n {
408 let mut remaining = counts[i];
409 let mut k = 1usize;
410 while remaining > 0 {
411 let take = k.min(remaining);
412 virtual_values.push(values[i] * take as f64);
413 virtual_weights.push(weights[i] * take as f64);
414 virtual_orig.push(i);
415 virtual_mult.push(take);
416 remaining -= take;
417 k *= 2;
418 }
419 }
420
421 let (opt_val, sel01) = knapsack_dp(&virtual_values, &virtual_weights, capacity)?;
423
424 let mut selection = vec![0usize; n];
426 for (k, selected) in sel01.iter().enumerate() {
427 if *selected {
428 selection[virtual_orig[k]] += virtual_mult[k];
429 }
430 }
431
432 for i in 0..n {
434 if selection[i] > counts[i] {
435 selection[i] = counts[i];
436 }
437 }
438
439 Ok((opt_val, selection))
440}
441
442pub fn multi_dimensional_knapsack(
481 values: &[f64],
482 weights: &[Vec<f64>],
483 capacities: &[f64],
484) -> OptimizeResult<f64> {
485 let n = values.len();
486 let d = weights.len();
487
488 if d == 0 {
489 return Ok(values.iter().sum());
491 }
492 for dim in 0..d {
493 if weights[dim].len() != n {
494 return Err(OptimizeError::InvalidInput(format!(
495 "weights[{}] has length {}, expected {}",
496 dim,
497 weights[dim].len(),
498 n
499 )));
500 }
501 }
502 if capacities.len() != d {
503 return Err(OptimizeError::InvalidInput(format!(
504 "capacities length {} != number of weight dimensions {}",
505 capacities.len(),
506 d
507 )));
508 }
509 for &v in values {
510 if v < 0.0 {
511 return Err(OptimizeError::InvalidInput(
512 "all values must be non-negative".to_string(),
513 ));
514 }
515 }
516 for (dim, cap) in capacities.iter().enumerate() {
517 if *cap < 0.0 {
518 return Err(OptimizeError::InvalidInput(format!(
519 "capacity[{}] must be non-negative",
520 dim
521 )));
522 }
523 }
524 if n == 0 {
525 return Ok(0.0);
526 }
527
528 let best = Mkp::new(n, d, values, weights, capacities).branch_and_bound();
530 Ok(best)
531}
532
533struct Mkp<'a> {
535 n: usize,
536 d: usize,
537 values: &'a [f64],
538 weights: &'a [Vec<f64>],
539 capacities: &'a [f64],
540 order: Vec<usize>,
542}
543
544impl<'a> Mkp<'a> {
545 fn new(
546 n: usize,
547 d: usize,
548 values: &'a [f64],
549 weights: &'a [Vec<f64>],
550 capacities: &'a [f64],
551 ) -> Self {
552 let mut order: Vec<usize> = (0..n).collect();
553 order.sort_by(|&a, &b| {
555 let sum_a: f64 = (0..d).map(|k| weights[k][a]).sum::<f64>();
556 let sum_b: f64 = (0..d).map(|k| weights[k][b]).sum::<f64>();
557 let ra = if sum_a > 0.0 {
558 values[a] / sum_a
559 } else {
560 f64::INFINITY
561 };
562 let rb = if sum_b > 0.0 {
563 values[b] / sum_b
564 } else {
565 f64::INFINITY
566 };
567 rb.partial_cmp(&ra).unwrap_or(std::cmp::Ordering::Equal)
568 });
569 Mkp {
570 n,
571 d,
572 values,
573 weights,
574 capacities,
575 order,
576 }
577 }
578
579 fn lp_upper_bound(&self, sel: &[Option<bool>], rem_cap: &[f64]) -> f64 {
581 let mut curr_val: f64 = sel
583 .iter()
584 .enumerate()
585 .filter_map(|(i, s)| s.map(|take| if take { self.values[i] } else { 0.0 }))
586 .sum();
587
588 let mut remaining: Vec<f64> = rem_cap.to_vec();
589
590 for &item in &self.order {
592 if sel[item].is_some() {
593 continue;
594 }
595 let mut max_frac = 1.0_f64;
597 for dim in 0..self.d {
598 let w = self.weights[dim][item];
599 if w > 0.0 {
600 let frac = remaining[dim] / w;
601 if frac < max_frac {
602 max_frac = frac;
603 }
604 }
605 }
606 if max_frac <= 0.0 {
607 continue;
608 }
609 let frac = max_frac.min(1.0);
610 curr_val += self.values[item] * frac;
611 if frac >= 1.0 {
612 for dim in 0..self.d {
613 remaining[dim] -= self.weights[dim][item];
614 }
615 }
616 }
617 curr_val
618 }
619
620 fn is_feasible_add(&self, item: usize, rem_cap: &[f64]) -> bool {
622 for dim in 0..self.d {
623 if self.weights[dim][item] > rem_cap[dim] + 1e-10 {
624 return false;
625 }
626 }
627 true
628 }
629
630 fn branch_and_bound(&self) -> f64 {
631 struct State {
634 sel: Vec<Option<bool>>,
635 rem_cap: Vec<f64>,
636 curr_val: f64,
637 next_idx: usize,
639 }
640
641 let init_sel = vec![None; self.n];
642 let init_rem = self.capacities.to_vec();
643 let init_ub = self.lp_upper_bound(&init_sel, &init_rem);
644
645 let mut best = 0.0_f64;
646 let mut stack: Vec<State> = Vec::new();
647 stack.push(State {
648 sel: init_sel,
649 rem_cap: init_rem,
650 curr_val: 0.0,
651 next_idx: 0,
652 });
653
654 let max_nodes = 200_000usize;
656 let mut nodes = 0usize;
657
658 while let Some(state) = stack.pop() {
659 nodes += 1;
660 if nodes > max_nodes {
661 break;
662 }
663
664 let ub = self.lp_upper_bound(&state.sel, &state.rem_cap);
666 if ub <= best + 1e-9 {
667 continue; }
669
670 if state.next_idx >= self.order.len() {
671 if state.curr_val > best {
673 best = state.curr_val;
674 }
675 continue;
676 }
677
678 let item = self.order[state.next_idx];
679
680 {
682 let mut new_sel = state.sel.clone();
683 new_sel[item] = Some(false);
684 stack.push(State {
685 sel: new_sel,
686 rem_cap: state.rem_cap.clone(),
687 curr_val: state.curr_val,
688 next_idx: state.next_idx + 1,
689 });
690 }
691
692 if self.is_feasible_add(item, &state.rem_cap) {
694 let mut new_sel = state.sel.clone();
695 new_sel[item] = Some(true);
696 let mut new_rem = state.rem_cap.clone();
697 for dim in 0..self.d {
698 new_rem[dim] -= self.weights[dim][item];
699 }
700 let new_val = state.curr_val + self.values[item];
701 if new_val > best {
702 best = new_val;
703 }
704 stack.push(State {
705 sel: new_sel,
706 rem_cap: new_rem,
707 curr_val: new_val,
708 next_idx: state.next_idx + 1,
709 });
710 }
711 }
712
713 let greedy = self.greedy_solution();
715 best.max(greedy)
716 }
717
718 fn greedy_solution(&self) -> f64 {
720 let mut rem: Vec<f64> = self.capacities.to_vec();
721 let mut val = 0.0_f64;
722 for &item in &self.order {
723 if self.is_feasible_add(item, &rem) {
724 val += self.values[item];
725 for dim in 0..self.d {
726 rem[dim] -= self.weights[dim][item];
727 }
728 }
729 }
730 val
731 }
732}
733
734#[cfg(test)]
739mod tests {
740 use super::*;
741 use approx::assert_abs_diff_eq;
742
743 #[test]
746 fn test_knapsack_dp_basic() {
747 let values = vec![4.0, 3.0, 5.0, 2.0, 6.0];
748 let weights = vec![2.0, 3.0, 4.0, 1.0, 5.0];
749 let (val, sel) = knapsack_dp(&values, &weights, 8.0).expect("unexpected None or Err");
750 assert_abs_diff_eq!(val, 12.0, epsilon = 1e-6);
751 let total_w: f64 = weights
753 .iter()
754 .zip(sel.iter())
755 .map(|(&w, &s)| if s { w } else { 0.0 })
756 .sum();
757 let total_v: f64 = values
758 .iter()
759 .zip(sel.iter())
760 .map(|(&v, &s)| if s { v } else { 0.0 })
761 .sum();
762 assert!(total_w <= 8.0 + 1e-9, "weight {} > capacity 8", total_w);
763 assert_abs_diff_eq!(total_v, val, epsilon = 1e-6);
764 }
765
766 #[test]
767 fn test_knapsack_dp_empty() {
768 let (val, sel) = knapsack_dp(&[], &[], 10.0).expect("unexpected None or Err");
769 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-12);
770 assert!(sel.is_empty());
771 }
772
773 #[test]
774 fn test_knapsack_dp_none_fit() {
775 let values = vec![10.0, 20.0];
777 let weights = vec![5.0, 8.0];
778 let (val, _sel) = knapsack_dp(&values, &weights, 3.0).expect("unexpected None or Err");
779 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-6);
780 }
781
782 #[test]
783 fn test_knapsack_dp_all_fit() {
784 let values = vec![1.0, 2.0, 3.0];
785 let weights = vec![1.0, 1.0, 1.0];
786 let (val, sel) = knapsack_dp(&values, &weights, 10.0).expect("unexpected None or Err");
787 assert_abs_diff_eq!(val, 6.0, epsilon = 1e-6);
788 assert_eq!(sel, vec![true, true, true]);
789 }
790
791 #[test]
792 fn test_knapsack_dp_error_negative_value() {
793 let result = knapsack_dp(&[-1.0, 2.0], &[1.0, 1.0], 5.0);
794 assert!(result.is_err());
795 }
796
797 #[test]
798 fn test_knapsack_dp_error_length_mismatch() {
799 let result = knapsack_dp(&[1.0, 2.0, 3.0], &[1.0, 2.0], 5.0);
800 assert!(result.is_err());
801 }
802
803 #[test]
804 fn test_knapsack_dp_integer_weights() {
805 let values = vec![10.0, 40.0, 30.0, 50.0];
808 let weights = vec![5.0, 4.0, 6.0, 3.0];
809 let (val, sel) = knapsack_dp(&values, &weights, 5.0).expect("unexpected None or Err");
810 assert_abs_diff_eq!(val, 50.0, epsilon = 1e-6);
816 let total_w: f64 = weights
817 .iter()
818 .zip(sel.iter())
819 .map(|(&w, &s)| if s { w } else { 0.0 })
820 .sum();
821 assert!(total_w <= 5.0 + 1e-9);
822 }
823
824 #[test]
827 fn test_fractional_knapsack_basic() {
828 let values = vec![60.0, 100.0, 120.0];
832 let weights = vec![10.0, 20.0, 30.0];
833 let val = fractional_knapsack(&values, &weights, 50.0).expect("failed to create val");
834 assert_abs_diff_eq!(val, 240.0, epsilon = 1e-6);
835 }
836
837 #[test]
838 fn test_fractional_knapsack_empty() {
839 let val = fractional_knapsack(&[], &[], 100.0).expect("failed to create val");
840 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-12);
841 }
842
843 #[test]
844 fn test_fractional_knapsack_exact_fit() {
845 let values = vec![10.0, 20.0];
846 let weights = vec![5.0, 10.0];
847 let val = fractional_knapsack(&values, &weights, 15.0).expect("failed to create val");
848 assert_abs_diff_eq!(val, 30.0, epsilon = 1e-9);
849 }
850
851 #[test]
852 fn test_fractional_knapsack_zero_weight_item() {
853 let values = vec![100.0, 5.0];
855 let weights = vec![0.0, 1.0];
856 let val = fractional_knapsack(&values, &weights, 1.0).expect("failed to create val");
857 assert_abs_diff_eq!(val, 105.0, epsilon = 1e-9);
858 }
859
860 #[test]
861 fn test_fractional_knapsack_error_negative_capacity() {
862 let result = fractional_knapsack(&[1.0], &[1.0], -1.0);
863 assert!(result.is_err());
864 }
865
866 #[test]
869 fn test_bounded_knapsack_basic() {
870 let values = vec![3.0, 4.0, 5.0];
871 let weights = vec![1.0, 2.0, 3.0];
872 let counts = vec![4usize, 3, 2];
873 let (val, sel) =
874 bounded_knapsack(&values, &weights, &counts, 7.0).expect("unexpected None or Err");
875 let total_w: f64 = weights
877 .iter()
878 .zip(sel.iter())
879 .map(|(&w, &c)| w * c as f64)
880 .sum();
881 assert!(
882 total_w <= 7.0 + 1e-9,
883 "weight {} exceeds capacity 7",
884 total_w
885 );
886 for i in 0..3 {
888 assert!(
889 sel[i] <= counts[i],
890 "sel[{}]={} > counts[{}]={}",
891 i,
892 sel[i],
893 i,
894 counts[i]
895 );
896 }
897 assert!(val >= 17.0 - 1e-6, "val={} should be >= 17", val);
900 }
901
902 #[test]
903 fn test_bounded_knapsack_unit_counts() {
904 let values = vec![4.0, 3.0, 5.0, 2.0];
906 let weights = vec![2.0, 3.0, 4.0, 1.0];
907 let counts = vec![1usize; 4];
908 let (val_b, _) =
909 bounded_knapsack(&values, &weights, &counts, 6.0).expect("unexpected None or Err");
910 let (val_dp, _) = knapsack_dp(&values, &weights, 6.0).expect("unexpected None or Err");
911 assert_abs_diff_eq!(val_b, val_dp, epsilon = 1e-6);
912 }
913
914 #[test]
915 fn test_bounded_knapsack_error_mismatch() {
916 let result = bounded_knapsack(&[1.0, 2.0], &[1.0], &[1, 1], 5.0);
917 assert!(result.is_err());
918 }
919
920 #[test]
923 fn test_multi_dimensional_knapsack_1d() {
924 let values = vec![4.0, 3.0, 5.0, 2.0, 6.0];
926 let weights = vec![vec![2.0, 3.0, 4.0, 1.0, 5.0]];
927 let caps = vec![8.0];
928 let val_md =
929 multi_dimensional_knapsack(&values, &weights, &caps).expect("failed to create val_md");
930 let (val_dp, _) = knapsack_dp(&values, &weights[0], 8.0).expect("unexpected None or Err");
931 assert_abs_diff_eq!(val_md, val_dp, epsilon = 1e-6);
932 }
933
934 #[test]
935 fn test_multi_dimensional_knapsack_2d() {
936 let values = vec![10.0, 6.0, 5.0];
937 let weights = vec![vec![2.0, 3.0, 1.0], vec![4.0, 1.0, 2.0]];
938 let caps = vec![5.0, 6.0];
939 let val =
940 multi_dimensional_knapsack(&values, &weights, &caps).expect("failed to create val");
941 assert!(val >= 15.0 - 1e-6, "val={} should be >= 15", val);
942 }
943
944 #[test]
945 fn test_multi_dimensional_knapsack_no_dims() {
946 let values = vec![1.0, 2.0, 3.0];
948 let weights: Vec<Vec<f64>> = Vec::new();
949 let caps: Vec<f64> = Vec::new();
950 let val =
951 multi_dimensional_knapsack(&values, &weights, &caps).expect("failed to create val");
952 assert_abs_diff_eq!(val, 6.0, epsilon = 1e-9);
953 }
954
955 #[test]
956 fn test_multi_dimensional_knapsack_empty_items() {
957 let weights: Vec<Vec<f64>> = vec![vec![]];
958 let caps = vec![5.0];
959 let val = multi_dimensional_knapsack(&[], &weights, &caps).expect("failed to create val");
960 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-12);
961 }
962
963 #[test]
964 fn test_multi_dimensional_knapsack_error_dim_mismatch() {
965 let values = vec![1.0, 2.0];
966 let weights = vec![vec![1.0]]; let caps = vec![5.0];
968 let result = multi_dimensional_knapsack(&values, &weights, &caps);
969 assert!(result.is_err());
970 }
971}