tensorlogic_scirs_backend/
memory_pool.rs1use crate::Scirs2Tensor;
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::collections::HashMap;
9
10#[derive(Default)]
12pub struct TensorPool {
13 available: HashMap<Vec<usize>, Vec<Scirs2Tensor>>,
15 pub(crate) allocations: usize,
17 pub(crate) reuses: usize,
18}
19
20impl TensorPool {
21 pub fn new() -> Self {
23 TensorPool {
24 available: HashMap::new(),
25 allocations: 0,
26 reuses: 0,
27 }
28 }
29
30 pub fn get(&mut self, shape: &[usize]) -> Scirs2Tensor {
32 let shape_key = shape.to_vec();
33
34 if let Some(tensors) = self.available.get_mut(&shape_key) {
36 if let Some(tensor) = tensors.pop() {
37 self.reuses += 1;
38 let mut tensor = tensor;
40 tensor.fill(0.0);
41 return tensor;
42 }
43 }
44
45 self.allocations += 1;
47 ArrayD::zeros(IxDyn(shape))
48 }
49
50 pub fn get_ones(&mut self, shape: &[usize]) -> Scirs2Tensor {
52 let mut tensor = self.get(shape);
53 tensor.fill(1.0);
54 tensor
55 }
56
57 pub fn return_tensor(&mut self, tensor: Scirs2Tensor) {
59 let shape = tensor.shape().to_vec();
60 self.available.entry(shape).or_default().push(tensor);
61 }
62
63 pub fn clear(&mut self) {
65 self.available.clear();
66 self.allocations = 0;
67 self.reuses = 0;
68 }
69
70 pub fn available_count(&self) -> usize {
72 self.available.values().map(|v| v.len()).sum()
73 }
74
75 pub fn stats(&self) -> PoolStats {
77 PoolStats {
78 allocations: self.allocations,
79 reuses: self.reuses,
80 available: self.available_count(),
81 reuse_rate: if self.allocations + self.reuses > 0 {
82 self.reuses as f64 / (self.allocations + self.reuses) as f64
83 } else {
84 0.0
85 },
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct PoolStats {
93 pub allocations: usize,
95 pub reuses: usize,
97 pub available: usize,
99 pub reuse_rate: f64,
101}
102
103impl std::fmt::Display for PoolStats {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 write!(
106 f,
107 "PoolStats {{ allocations: {}, reuses: {}, available: {}, reuse_rate: {:.2}% }}",
108 self.allocations,
109 self.reuses,
110 self.available,
111 self.reuse_rate * 100.0
112 )
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn test_pool_basic() {
122 let mut pool = TensorPool::new();
123
124 let t1 = pool.get(&[2, 3]);
126 assert_eq!(t1.shape(), &[2, 3]);
127 assert_eq!(pool.allocations, 1);
128 assert_eq!(pool.reuses, 0);
129
130 pool.return_tensor(t1);
132 assert_eq!(pool.available_count(), 1);
133
134 let t2 = pool.get(&[2, 3]);
136 assert_eq!(t2.shape(), &[2, 3]);
137 assert_eq!(pool.allocations, 1);
138 assert_eq!(pool.reuses, 1);
139 }
140
141 #[test]
142 fn test_pool_different_shapes() {
143 let mut pool = TensorPool::new();
144
145 let t1 = pool.get(&[2, 3]);
146 let t2 = pool.get(&[4, 5]);
147
148 pool.return_tensor(t1);
149 pool.return_tensor(t2);
150
151 assert_eq!(pool.available_count(), 2);
152
153 let t3 = pool.get(&[2, 3]);
155 assert_eq!(t3.shape(), &[2, 3]);
156 assert_eq!(pool.reuses, 1);
157
158 let t4 = pool.get(&[4, 5]);
160 assert_eq!(t4.shape(), &[4, 5]);
161 assert_eq!(pool.reuses, 2);
162 }
163
164 #[test]
165 fn test_pool_stats() {
166 let mut pool = TensorPool::new();
167
168 let t1 = pool.get(&[2, 2]);
170 let t2 = pool.get(&[2, 2]);
171 let t3 = pool.get(&[2, 2]);
172
173 pool.return_tensor(t1);
174 pool.return_tensor(t2);
175 pool.return_tensor(t3);
176
177 let _t4 = pool.get(&[2, 2]);
179 let _t5 = pool.get(&[2, 2]);
180
181 let stats = pool.stats();
182 assert_eq!(stats.allocations, 3);
183 assert_eq!(stats.reuses, 2);
184 assert_eq!(stats.available, 1);
185 assert!((stats.reuse_rate - 0.4).abs() < 1e-6); }
187
188 #[test]
189 fn test_get_ones() {
190 let mut pool = TensorPool::new();
191 let t = pool.get_ones(&[2, 2]);
192
193 assert_eq!(t.shape(), &[2, 2]);
194 assert_eq!(t[[0, 0]], 1.0);
195 assert_eq!(t[[1, 1]], 1.0);
196 }
197
198 #[test]
199 fn test_pool_clear() {
200 let mut pool = TensorPool::new();
201
202 let t1 = pool.get(&[2, 2]);
203 pool.return_tensor(t1);
204
205 assert_eq!(pool.available_count(), 1);
206
207 pool.clear();
208 assert_eq!(pool.available_count(), 0);
209 assert_eq!(pool.allocations, 0);
210 assert_eq!(pool.reuses, 0);
211 }
212
213 #[test]
214 fn test_pool_zeroing() {
215 let mut pool = TensorPool::new();
216
217 let mut t1 = pool.get(&[2, 2]);
219 t1.fill(5.0);
220 pool.return_tensor(t1);
221
222 let t2 = pool.get(&[2, 2]);
224 assert_eq!(t2[[0, 0]], 0.0);
225 assert_eq!(t2[[1, 1]], 0.0);
226 }
227}