scirs2_optimize/combinatorial/
knapsack.rs1use crate::error::OptimizeError;
8
9pub type KnapsackResult<T> = Result<T, OptimizeError>;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub struct KnapsackItem {
17 pub weight: u64,
19 pub value: u64,
21}
22
23pub fn knapsack_dp(items: &[KnapsackItem], capacity: u64) -> KnapsackResult<(u64, Vec<usize>)> {
29 let n = items.len();
30 let w = capacity as usize;
31
32 let table_size = (n + 1).saturating_mul(w + 1);
37 if table_size > 500_000_000 {
38 return Err(OptimizeError::InvalidInput(format!(
39 "DP table size {table_size} exceeds 500M; use branch-and-bound for large capacities"
40 )));
41 }
42
43 let mut dp = vec![0u64; (n + 1) * (w + 1)];
45
46 for i in 1..=n {
47 let iw = items[i - 1].weight as usize;
48 let iv = items[i - 1].value;
49 for c in 0..=w {
50 let without = dp[(i - 1) * (w + 1) + c];
51 let with_item = if iw <= c {
52 dp[(i - 1) * (w + 1) + c - iw].saturating_add(iv)
53 } else {
54 0
55 };
56 dp[i * (w + 1) + c] = without.max(with_item);
57 }
58 }
59
60 let mut selected = Vec::new();
62 let mut remaining = w;
63 for i in (1..=n).rev() {
64 if dp[i * (w + 1) + remaining] != dp[(i - 1) * (w + 1) + remaining] {
65 selected.push(i - 1);
66 let iw = items[i - 1].weight as usize;
67 remaining = remaining.saturating_sub(iw);
68 }
69 }
70 selected.reverse();
71
72 let total = dp[n * (w + 1) + w];
73 Ok((total, selected))
74}
75
76pub fn fractional_knapsack(items: &[KnapsackItem], capacity: u64) -> f64 {
82 if capacity == 0 || items.is_empty() {
83 return 0.0;
84 }
85
86 let mut indexed: Vec<(usize, f64)> = items
88 .iter()
89 .enumerate()
90 .map(|(i, it)| {
91 let ratio = if it.weight == 0 {
92 f64::INFINITY
93 } else {
94 it.value as f64 / it.weight as f64
95 };
96 (i, ratio)
97 })
98 .collect();
99 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
100
101 let mut remaining = capacity as f64;
102 let mut total_value = 0.0;
103
104 for (idx, _ratio) in &indexed {
105 let item = &items[*idx];
106 if item.weight == 0 {
107 total_value += item.value as f64;
108 continue;
109 }
110 let take = (remaining / item.weight as f64).min(1.0);
111 total_value += take * item.value as f64;
112 remaining -= take * item.weight as f64;
113 if remaining <= 0.0 {
114 break;
115 }
116 }
117
118 total_value
119}
120
121pub fn knapsack_greedy(items: &[KnapsackItem], capacity: u64) -> (u64, Vec<usize>) {
128 if capacity == 0 || items.is_empty() {
129 return (0, vec![]);
130 }
131
132 let mut indexed: Vec<(usize, f64)> = items
133 .iter()
134 .enumerate()
135 .map(|(i, it)| {
136 let ratio = if it.weight == 0 {
137 f64::INFINITY
138 } else {
139 it.value as f64 / it.weight as f64
140 };
141 (i, ratio)
142 })
143 .collect();
144 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
145
146 let mut remaining = capacity;
147 let mut total = 0u64;
148 let mut selected = Vec::new();
149
150 for (idx, _) in &indexed {
151 let item = &items[*idx];
152 if item.weight <= remaining {
153 selected.push(*idx);
154 remaining -= item.weight;
155 total += item.value;
156 }
157 }
158
159 selected.sort_unstable();
160 (total, selected)
161}
162
163#[derive(Debug, Clone)]
167struct BbNode {
168 level: usize,
169 value: u64,
170 weight: u64,
171 bound: f64,
172 taken: Vec<bool>,
173}
174
175fn lp_bound(
177 items: &[KnapsackItem],
178 sorted_indices: &[usize],
179 level: usize,
180 value: u64,
181 weight: u64,
182 capacity: u64,
183) -> f64 {
184 if weight > capacity {
185 return 0.0;
186 }
187 let mut remaining = (capacity - weight) as f64;
188 let mut bound = value as f64;
189
190 for &idx in sorted_indices.iter().skip(level) {
191 let item = &items[idx];
192 if item.weight as f64 <= remaining {
193 bound += item.value as f64;
194 remaining -= item.weight as f64;
195 } else {
196 if item.weight > 0 {
198 bound += remaining * (item.value as f64 / item.weight as f64);
199 }
200 break;
201 }
202 }
203 bound
204}
205
206pub fn knapsack_branch_bound(
210 items: &[KnapsackItem],
211 capacity: u64,
212) -> KnapsackResult<(u64, Vec<usize>)> {
213 let n = items.len();
214 if n == 0 || capacity == 0 {
215 return Ok((0, vec![]));
216 }
217
218 let mut sorted_indices: Vec<usize> = (0..n).collect();
220 sorted_indices.sort_by(|&a, &b| {
221 let ra = if items[a].weight == 0 {
222 f64::INFINITY
223 } else {
224 items[a].value as f64 / items[a].weight as f64
225 };
226 let rb = if items[b].weight == 0 {
227 f64::INFINITY
228 } else {
229 items[b].value as f64 / items[b].weight as f64
230 };
231 rb.partial_cmp(&ra).unwrap_or(std::cmp::Ordering::Equal)
232 });
233
234 let mut best_value = 0u64;
235 let mut best_taken = vec![false; n];
236
237 {
239 let (gv, gi) = knapsack_greedy(items, capacity);
240 best_value = gv;
241 for idx in gi {
242 best_taken[idx] = true;
243 }
244 }
245
246 let root = BbNode {
248 level: 0,
249 value: 0,
250 weight: 0,
251 bound: lp_bound(items, &sorted_indices, 0, 0, 0, capacity),
252 taken: vec![false; n],
253 };
254
255 let mut stack: Vec<BbNode> = vec![root];
256
257 while let Some(node) = stack.pop() {
258 if node.level == n {
259 if node.value > best_value {
260 best_value = node.value;
261 best_taken = node.taken.clone();
262 }
263 continue;
264 }
265
266 if node.bound <= best_value as f64 {
267 continue;
268 }
269
270 let item_idx = sorted_indices[node.level];
271 let item = &items[item_idx];
272
273 if node.weight + item.weight <= capacity {
275 let mut taken_with = node.taken.clone();
276 taken_with[item_idx] = true;
277 let new_value = node.value + item.value;
278 let new_weight = node.weight + item.weight;
279 let new_bound = lp_bound(
280 items,
281 &sorted_indices,
282 node.level + 1,
283 new_value,
284 new_weight,
285 capacity,
286 );
287 if new_bound > best_value as f64 {
288 stack.push(BbNode {
289 level: node.level + 1,
290 value: new_value,
291 weight: new_weight,
292 bound: new_bound,
293 taken: taken_with,
294 });
295 }
296 }
297
298 let excl_bound = lp_bound(
300 items,
301 &sorted_indices,
302 node.level + 1,
303 node.value,
304 node.weight,
305 capacity,
306 );
307 if excl_bound > best_value as f64 {
308 stack.push(BbNode {
309 level: node.level + 1,
310 value: node.value,
311 weight: node.weight,
312 bound: excl_bound,
313 taken: node.taken.clone(),
314 });
315 }
316 }
317
318 let selected: Vec<usize> = (0..n).filter(|&i| best_taken[i]).collect();
319 Ok((best_value, selected))
320}
321
322#[derive(Debug, Clone)]
326pub struct MultiKnapsackItem {
327 pub weights: Vec<u64>,
329 pub value: u64,
331}
332
333pub fn multi_knapsack_greedy(
340 items: &[MultiKnapsackItem],
341 capacities: &[u64],
342) -> KnapsackResult<(u64, Vec<usize>)> {
343 let n = items.len();
344 let d = capacities.len();
345 if n == 0 || d == 0 {
346 return Ok((0, vec![]));
347 }
348
349 for (i, item) in items.iter().enumerate() {
351 if item.weights.len() != d {
352 return Err(OptimizeError::InvalidInput(format!(
353 "Item {i} has {} weight dimensions but capacities has {d}",
354 item.weights.len()
355 )));
356 }
357 }
358
359 let mut indexed: Vec<(usize, f64)> = items
361 .iter()
362 .enumerate()
363 .map(|(i, it)| {
364 let norm_sq: f64 = it.weights.iter().map(|&w| (w as f64).powi(2)).sum();
365 let norm = norm_sq.sqrt().max(1e-12);
366 (i, it.value as f64 / norm)
367 })
368 .collect();
369 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
370
371 let mut remaining = capacities.to_vec();
372 let mut selected = vec![false; n];
373 let mut total = 0u64;
374
375 for (idx, _) in &indexed {
376 let item = &items[*idx];
377 if item
378 .weights
379 .iter()
380 .enumerate()
381 .all(|(dim, &w)| w <= remaining[dim])
382 {
383 selected[*idx] = true;
384 total += item.value;
385 for (dim, &w) in item.weights.iter().enumerate() {
386 remaining[dim] -= w;
387 }
388 }
389 }
390
391 let mut improved = true;
393 while improved {
394 improved = false;
395 for out_idx in 0..n {
396 if !selected[out_idx] {
397 continue;
398 }
399 for in_idx in 0..n {
400 if selected[in_idx] {
401 continue;
402 }
403 let delta_v = items[in_idx].value as i64 - items[out_idx].value as i64;
405 if delta_v <= 0 {
406 continue;
407 }
408 let feasible = items[in_idx].weights.iter().enumerate().all(|(dim, &w)| {
409 let freed = items[out_idx].weights[dim];
410 freed + remaining[dim] >= w
411 });
412 if feasible {
413 for dim in 0..d {
414 remaining[dim] += items[out_idx].weights[dim];
415 remaining[dim] -= items[in_idx].weights[dim];
416 }
417 total = total
418 .saturating_sub(items[out_idx].value)
419 .saturating_add(items[in_idx].value);
420 selected[out_idx] = false;
421 selected[in_idx] = true;
422 improved = true;
423 break;
424 }
425 }
426 if improved {
427 break;
428 }
429 }
430 }
431
432 let result: Vec<usize> = (0..n).filter(|&i| selected[i]).collect();
433 Ok((total, result))
434}
435
436#[cfg(test)]
440mod tests {
441 use super::*;
442
443 fn classic_items() -> Vec<KnapsackItem> {
444 vec![
446 KnapsackItem {
447 weight: 2,
448 value: 3,
449 },
450 KnapsackItem {
451 weight: 3,
452 value: 4,
453 },
454 KnapsackItem {
455 weight: 2,
456 value: 5,
457 },
458 KnapsackItem {
459 weight: 3,
460 value: 6,
461 },
462 ]
463 }
464
465 #[test]
466 fn test_dp_classic() {
467 let items = classic_items();
468 let (val, sel) = knapsack_dp(&items, 5).expect("unexpected None or Err");
469 assert_eq!(val, 11, "expected value 11, got {val}");
471 let total_weight: u64 = sel.iter().map(|&i| items[i].weight).sum();
472 assert!(total_weight <= 5);
473 let total_val: u64 = sel.iter().map(|&i| items[i].value).sum();
474 assert_eq!(total_val, val);
475 }
476
477 #[test]
478 fn test_dp_empty() {
479 let (val, sel) = knapsack_dp(&[], 10).expect("unexpected None or Err");
480 assert_eq!(val, 0);
481 assert!(sel.is_empty());
482 }
483
484 #[test]
485 fn test_dp_zero_capacity() {
486 let items = classic_items();
487 let (val, sel) = knapsack_dp(&items, 0).expect("unexpected None or Err");
488 assert_eq!(val, 0);
489 assert!(sel.is_empty());
490 }
491
492 #[test]
493 fn test_fractional_knapsack() {
494 let items = classic_items();
495 let val = fractional_knapsack(&items, 5);
496 assert!(val >= 9.0 - 1e-9);
498 }
499
500 #[test]
501 fn test_greedy_knapsack() {
502 let items = classic_items();
503 let (val, sel) = knapsack_greedy(&items, 5);
504 assert!(val > 0);
505 let total_weight: u64 = sel.iter().map(|&i| items[i].weight).sum();
506 assert!(total_weight <= 5);
507 }
508
509 #[test]
510 fn test_branch_bound_classic() {
511 let items = classic_items();
512 let (val, sel) = knapsack_branch_bound(&items, 5).expect("unexpected None or Err");
513 assert_eq!(val, 11);
515 let total_weight: u64 = sel.iter().map(|&i| items[i].weight).sum();
516 assert!(total_weight <= 5);
517 }
518
519 #[test]
520 fn test_bb_equals_dp() {
521 let items = vec![
522 KnapsackItem {
523 weight: 1,
524 value: 6,
525 },
526 KnapsackItem {
527 weight: 2,
528 value: 10,
529 },
530 KnapsackItem {
531 weight: 3,
532 value: 12,
533 },
534 ];
535 let cap = 5;
536 let (dp_val, _) = knapsack_dp(&items, cap).expect("unexpected None or Err");
537 let (bb_val, _) = knapsack_branch_bound(&items, cap).expect("unexpected None or Err");
538 assert_eq!(dp_val, bb_val, "DP and B&B should agree");
539 }
540
541 #[test]
542 fn test_multi_knapsack() {
543 let items = vec![
544 MultiKnapsackItem {
545 weights: vec![2, 1],
546 value: 5,
547 },
548 MultiKnapsackItem {
549 weights: vec![1, 2],
550 value: 5,
551 },
552 MultiKnapsackItem {
553 weights: vec![3, 3],
554 value: 8,
555 },
556 ];
557 let caps = vec![4, 4];
558 let (val, sel) = multi_knapsack_greedy(&items, &caps).expect("unexpected None or Err");
559 assert!(val > 0);
560 for dim in 0..2 {
562 let used: u64 = sel.iter().map(|&i| items[i].weights[dim]).sum();
563 assert!(used <= caps[dim]);
564 }
565 }
566
567 #[test]
568 fn test_fractional_zero_capacity() {
569 let items = classic_items();
570 assert_eq!(fractional_knapsack(&items, 0), 0.0);
571 }
572
573 #[test]
574 fn test_all_items_fit() {
575 let items = classic_items();
576 let total_weight: u64 = items.iter().map(|i| i.weight).sum();
577 let total_value: u64 = items.iter().map(|i| i.value).sum();
578 let (val, sel) = knapsack_dp(&items, total_weight + 100).expect("unexpected None or Err");
579 assert_eq!(val, total_value);
580 assert_eq!(sel.len(), items.len());
581 }
582}