rustgym/leetcode/
_307_range_sum_query_mutable.rs

1struct BitTree {
2    tree: Vec<i32>,
3    data: Vec<i32>,
4    n: usize,
5}
6
7impl BitTree {
8    fn new(n: usize) -> Self {
9        let tree = vec![0; n];
10        let data = vec![0; n];
11        let n = n;
12        BitTree { tree, data, n }
13    }
14
15    fn get(&self, i: usize) -> i32 {
16        self.data[i]
17    }
18
19    fn sum(&self, i: usize) -> i32 {
20        let mut res = 0;
21        let down_iter = std::iter::successors(Some(i), |&i| {
22            let j = i & (i + 1);
23            if j > 0 {
24                Some(j - 1)
25            } else {
26                None
27            }
28        });
29        for j in down_iter {
30            res += self.tree[j];
31        }
32        res
33    }
34
35    fn add(&mut self, i: usize, v: i32) {
36        self.data[i] += v;
37        let n = self.n;
38        let up_iter = std::iter::successors(Some(i), |&i| {
39            let j = i | (i + 1);
40            if j < n {
41                Some(j)
42            } else {
43                None
44            }
45        });
46        for j in up_iter {
47            self.tree[j] += v;
48        }
49    }
50}
51
52struct NumArray {
53    bit_tree: BitTree,
54}
55
56impl NumArray {
57    fn new(nums: Vec<i32>) -> Self {
58        let n = nums.len();
59        let mut bit_tree = BitTree::new(n);
60        for i in 0..n {
61            bit_tree.add(i, nums[i]);
62        }
63        NumArray { bit_tree }
64    }
65
66    fn update(&mut self, i: i32, val: i32) {
67        let i = i as usize;
68        self.bit_tree.add(i as usize, val - self.bit_tree.get(i))
69    }
70
71    fn sum_range(&self, i: i32, j: i32) -> i32 {
72        let i = i as usize;
73        let j = j as usize;
74        if i > 0 {
75            self.bit_tree.sum(j) - self.bit_tree.sum(i - 1)
76        } else {
77            self.bit_tree.sum(j)
78        }
79    }
80}
81
82#[test]
83fn test() {
84    let nums = vec![1, 3, 5];
85    let mut obj = NumArray::new(nums);
86    assert_eq!(obj.sum_range(0, 2), 9);
87    obj.update(1, 2);
88    assert_eq!(obj.sum_range(0, 2), 8);
89}