Skip to main content

tensorlogic_scirs_backend/
memory_pool.rs

1//! Memory pooling for efficient tensor allocation.
2//!
3//! This module provides a memory pool that reuses tensor allocations
4//! to reduce the overhead of creating and destroying arrays during execution.
5
6use crate::Scirs2Tensor;
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::collections::HashMap;
9
10/// Memory pool for reusing tensor allocations
11#[derive(Default)]
12pub struct TensorPool {
13    /// Available tensors grouped by shape
14    available: HashMap<Vec<usize>, Vec<Scirs2Tensor>>,
15    /// Statistics
16    pub(crate) allocations: usize,
17    pub(crate) reuses: usize,
18}
19
20impl TensorPool {
21    /// Create a new empty tensor pool
22    pub fn new() -> Self {
23        TensorPool {
24            available: HashMap::new(),
25            allocations: 0,
26            reuses: 0,
27        }
28    }
29
30    /// Get a tensor with the specified shape, either from the pool or newly allocated
31    pub fn get(&mut self, shape: &[usize]) -> Scirs2Tensor {
32        let shape_key = shape.to_vec();
33
34        // Try to reuse from pool
35        if let Some(tensors) = self.available.get_mut(&shape_key) {
36            if let Some(tensor) = tensors.pop() {
37                self.reuses += 1;
38                // Zero out the tensor before reuse
39                let mut tensor = tensor;
40                tensor.fill(0.0);
41                return tensor;
42            }
43        }
44
45        // Allocate new tensor
46        self.allocations += 1;
47        ArrayD::zeros(IxDyn(shape))
48    }
49
50    /// Get a tensor filled with ones
51    pub fn get_ones(&mut self, shape: &[usize]) -> Scirs2Tensor {
52        let mut tensor = self.get(shape);
53        tensor.fill(1.0);
54        tensor
55    }
56
57    /// Return a tensor to the pool for reuse
58    pub fn return_tensor(&mut self, tensor: Scirs2Tensor) {
59        let shape = tensor.shape().to_vec();
60        self.available.entry(shape).or_default().push(tensor);
61    }
62
63    /// Clear all pooled tensors
64    pub fn clear(&mut self) {
65        self.available.clear();
66        self.allocations = 0;
67        self.reuses = 0;
68    }
69
70    /// Get the number of available tensors in the pool
71    pub fn available_count(&self) -> usize {
72        self.available.values().map(|v| v.len()).sum()
73    }
74
75    /// Get pool statistics
76    pub fn stats(&self) -> PoolStats {
77        PoolStats {
78            allocations: self.allocations,
79            reuses: self.reuses,
80            available: self.available_count(),
81            reuse_rate: if self.allocations + self.reuses > 0 {
82                self.reuses as f64 / (self.allocations + self.reuses) as f64
83            } else {
84                0.0
85            },
86        }
87    }
88}
89
90/// Pool statistics
91#[derive(Debug, Clone)]
92pub struct PoolStats {
93    /// Total number of new allocations
94    pub allocations: usize,
95    /// Total number of reuses from pool
96    pub reuses: usize,
97    /// Number of tensors currently available in pool
98    pub available: usize,
99    /// Reuse rate (0.0 to 1.0)
100    pub reuse_rate: f64,
101}
102
103impl std::fmt::Display for PoolStats {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        write!(
106            f,
107            "PoolStats {{ allocations: {}, reuses: {}, available: {}, reuse_rate: {:.2}% }}",
108            self.allocations,
109            self.reuses,
110            self.available,
111            self.reuse_rate * 100.0
112        )
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_pool_basic() {
122        let mut pool = TensorPool::new();
123
124        // First allocation
125        let t1 = pool.get(&[2, 3]);
126        assert_eq!(t1.shape(), &[2, 3]);
127        assert_eq!(pool.allocations, 1);
128        assert_eq!(pool.reuses, 0);
129
130        // Return to pool
131        pool.return_tensor(t1);
132        assert_eq!(pool.available_count(), 1);
133
134        // Reuse from pool
135        let t2 = pool.get(&[2, 3]);
136        assert_eq!(t2.shape(), &[2, 3]);
137        assert_eq!(pool.allocations, 1);
138        assert_eq!(pool.reuses, 1);
139    }
140
141    #[test]
142    fn test_pool_different_shapes() {
143        let mut pool = TensorPool::new();
144
145        let t1 = pool.get(&[2, 3]);
146        let t2 = pool.get(&[4, 5]);
147
148        pool.return_tensor(t1);
149        pool.return_tensor(t2);
150
151        assert_eq!(pool.available_count(), 2);
152
153        // Get tensor with shape [2, 3] - should reuse
154        let t3 = pool.get(&[2, 3]);
155        assert_eq!(t3.shape(), &[2, 3]);
156        assert_eq!(pool.reuses, 1);
157
158        // Get tensor with shape [4, 5] - should reuse
159        let t4 = pool.get(&[4, 5]);
160        assert_eq!(t4.shape(), &[4, 5]);
161        assert_eq!(pool.reuses, 2);
162    }
163
164    #[test]
165    fn test_pool_stats() {
166        let mut pool = TensorPool::new();
167
168        // Allocate 3 tensors
169        let t1 = pool.get(&[2, 2]);
170        let t2 = pool.get(&[2, 2]);
171        let t3 = pool.get(&[2, 2]);
172
173        pool.return_tensor(t1);
174        pool.return_tensor(t2);
175        pool.return_tensor(t3);
176
177        // Reuse 2 tensors
178        let _t4 = pool.get(&[2, 2]);
179        let _t5 = pool.get(&[2, 2]);
180
181        let stats = pool.stats();
182        assert_eq!(stats.allocations, 3);
183        assert_eq!(stats.reuses, 2);
184        assert_eq!(stats.available, 1);
185        assert!((stats.reuse_rate - 0.4).abs() < 1e-6); // 2/(3+2) = 0.4
186    }
187
188    #[test]
189    fn test_get_ones() {
190        let mut pool = TensorPool::new();
191        let t = pool.get_ones(&[2, 2]);
192
193        assert_eq!(t.shape(), &[2, 2]);
194        assert_eq!(t[[0, 0]], 1.0);
195        assert_eq!(t[[1, 1]], 1.0);
196    }
197
198    #[test]
199    fn test_pool_clear() {
200        let mut pool = TensorPool::new();
201
202        let t1 = pool.get(&[2, 2]);
203        pool.return_tensor(t1);
204
205        assert_eq!(pool.available_count(), 1);
206
207        pool.clear();
208        assert_eq!(pool.available_count(), 0);
209        assert_eq!(pool.allocations, 0);
210        assert_eq!(pool.reuses, 0);
211    }
212
213    #[test]
214    fn test_pool_zeroing() {
215        let mut pool = TensorPool::new();
216
217        // Create tensor and fill with non-zero values
218        let mut t1 = pool.get(&[2, 2]);
219        t1.fill(5.0);
220        pool.return_tensor(t1);
221
222        // Get tensor from pool - should be zeroed
223        let t2 = pool.get(&[2, 2]);
224        assert_eq!(t2[[0, 0]], 0.0);
225        assert_eq!(t2[[1, 1]], 0.0);
226    }
227}