Skip to main content

tensorlogic_infer/
workspace.rs

1//! Workspace management for efficient memory reuse.
2//!
3//! This module provides workspace allocation and management for reducing
4//! memory allocation overhead during inference:
5//! - Pre-allocated memory pools
6//! - Workspace recycling
7//! - Size-based workspace selection
8//! - Automatic workspace expansion
9//! - Memory defragmentation
10//! - Multi-threaded workspace management
11
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, Mutex};
15use thiserror::Error;
16
17/// Workspace management errors.
18#[derive(Error, Debug, Clone, PartialEq)]
19pub enum WorkspaceError {
20    #[error(
21        "Workspace allocation failed: requested {requested} bytes, available {available} bytes"
22    )]
23    AllocationFailed { requested: usize, available: usize },
24
25    #[error("Workspace not found: {0}")]
26    NotFound(String),
27
28    #[error("Invalid workspace size: {0}")]
29    InvalidSize(usize),
30
31    #[error("Workspace limit exceeded: {limit} bytes")]
32    LimitExceeded { limit: usize },
33
34    #[error("Workspace is in use")]
35    InUse,
36}
37
38/// Workspace allocation strategy.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
40pub enum AllocationStrategy {
41    /// Best fit - find smallest workspace that fits
42    BestFit,
43    /// First fit - use first workspace that fits
44    FirstFit,
45    /// Exact fit - only use exact size matches
46    ExactFit,
47    /// Power of 2 - round up to power of 2 sizes
48    PowerOfTwo,
49}
50
51/// Workspace configuration.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct WorkspaceConfig {
54    /// Initial workspace size (bytes)
55    pub initial_size: usize,
56    /// Maximum workspace size (bytes)
57    pub max_size: usize,
58    /// Growth factor when expanding
59    pub growth_factor: f64,
60    /// Allocation strategy
61    pub strategy: AllocationStrategy,
62    /// Enable automatic expansion
63    pub auto_expand: bool,
64    /// Enable defragmentation
65    pub enable_defragmentation: bool,
66    /// Defragmentation threshold (fragmentation ratio)
67    pub defrag_threshold: f64,
68    /// Number of size buckets for pooling
69    pub num_buckets: usize,
70}
71
72impl Default for WorkspaceConfig {
73    fn default() -> Self {
74        Self {
75            initial_size: 1024 * 1024,    // 1 MB
76            max_size: 1024 * 1024 * 1024, // 1 GB
77            growth_factor: 2.0,
78            strategy: AllocationStrategy::BestFit,
79            auto_expand: true,
80            enable_defragmentation: false,
81            defrag_threshold: 0.5,
82            num_buckets: 16,
83        }
84    }
85}
86
87impl WorkspaceConfig {
88    /// Create configuration for large models.
89    pub fn large_model() -> Self {
90        Self {
91            initial_size: 64 * 1024 * 1024,   // 64 MB
92            max_size: 8 * 1024 * 1024 * 1024, // 8 GB
93            growth_factor: 1.5,
94            num_buckets: 32,
95            ..Default::default()
96        }
97    }
98
99    /// Create configuration for small models.
100    pub fn small_model() -> Self {
101        Self {
102            initial_size: 256 * 1024,    // 256 KB
103            max_size: 128 * 1024 * 1024, // 128 MB
104            growth_factor: 2.0,
105            num_buckets: 8,
106            ..Default::default()
107        }
108    }
109
110    /// Create configuration optimized for memory.
111    pub fn memory_optimized() -> Self {
112        Self {
113            initial_size: 512 * 1024,    // 512 KB
114            max_size: 256 * 1024 * 1024, // 256 MB
115            growth_factor: 1.2,
116            enable_defragmentation: true,
117            defrag_threshold: 0.3,
118            ..Default::default()
119        }
120    }
121}
122
123/// A reusable workspace buffer.
124#[derive(Debug, Clone)]
125pub struct Workspace {
126    /// Unique identifier
127    pub id: String,
128    /// Size in bytes
129    pub size: usize,
130    /// Whether currently in use
131    pub in_use: bool,
132    /// Number of times allocated
133    pub allocation_count: usize,
134    /// Total time in use (for statistics)
135    pub total_use_time: std::time::Duration,
136}
137
138impl Workspace {
139    /// Create a new workspace.
140    pub fn new(id: String, size: usize) -> Self {
141        Self {
142            id,
143            size,
144            in_use: false,
145            allocation_count: 0,
146            total_use_time: std::time::Duration::ZERO,
147        }
148    }
149
150    /// Mark as in use.
151    pub fn acquire(&mut self) -> Result<(), WorkspaceError> {
152        if self.in_use {
153            return Err(WorkspaceError::InUse);
154        }
155        self.in_use = true;
156        self.allocation_count += 1;
157        Ok(())
158    }
159
160    /// Mark as available.
161    pub fn release(&mut self) {
162        self.in_use = false;
163    }
164}
165
166/// Workspace pool for managing multiple workspaces.
167pub struct WorkspacePool {
168    config: WorkspaceConfig,
169    workspaces: HashMap<String, Workspace>,
170    free_lists: HashMap<usize, VecDeque<String>>, // Size bucket -> workspace IDs
171    next_id: usize,
172    stats: WorkspaceStats,
173}
174
175impl WorkspacePool {
176    /// Create a new workspace pool.
177    pub fn new(config: WorkspaceConfig) -> Self {
178        let mut pool = Self {
179            config,
180            workspaces: HashMap::new(),
181            free_lists: HashMap::new(),
182            next_id: 0,
183            stats: WorkspaceStats::default(),
184        };
185
186        // Pre-allocate initial workspaces
187        pool.preallocate_workspaces();
188
189        pool
190    }
191
192    /// Pre-allocate workspaces based on configuration.
193    fn preallocate_workspaces(&mut self) {
194        let sizes = self.compute_bucket_sizes();
195        for size in sizes {
196            let _ = self.create_workspace(size);
197        }
198    }
199
200    /// Compute workspace sizes for buckets.
201    fn compute_bucket_sizes(&self) -> Vec<usize> {
202        let mut sizes = Vec::new();
203        let mut size = self.config.initial_size;
204
205        for _ in 0..self.config.num_buckets {
206            sizes.push(size);
207            size = (size as f64 * self.config.growth_factor) as usize;
208            if size > self.config.max_size {
209                break;
210            }
211        }
212
213        sizes
214    }
215
216    /// Create a new workspace.
217    fn create_workspace(&mut self, size: usize) -> String {
218        let id = format!("ws_{}", self.next_id);
219        self.next_id += 1;
220
221        let workspace = Workspace::new(id.clone(), size);
222        self.workspaces.insert(id.clone(), workspace);
223
224        // Add to free list
225        let bucket = self.size_to_bucket(size);
226        self.free_lists
227            .entry(bucket)
228            .or_default()
229            .push_back(id.clone());
230
231        self.stats.total_created += 1;
232        self.stats.current_total_size += size;
233
234        id
235    }
236
237    /// Convert size to bucket size.
238    fn size_to_bucket(&self, size: usize) -> usize {
239        match self.config.strategy {
240            AllocationStrategy::PowerOfTwo => size.next_power_of_two(),
241            _ => {
242                // Find nearest bucket size
243                let sizes = self.compute_bucket_sizes();
244                sizes.iter().find(|&&s| s >= size).copied().unwrap_or(size)
245            }
246        }
247    }
248
249    /// Allocate a workspace of at least the given size.
250    pub fn allocate(&mut self, size: usize) -> Result<String, WorkspaceError> {
251        if size > self.config.max_size {
252            return Err(WorkspaceError::InvalidSize(size));
253        }
254
255        let workspace_id = match self.config.strategy {
256            AllocationStrategy::BestFit => self.find_best_fit(size),
257            AllocationStrategy::FirstFit => self.find_first_fit(size),
258            AllocationStrategy::ExactFit => self.find_exact_fit(size),
259            AllocationStrategy::PowerOfTwo => {
260                let bucket_size = size.next_power_of_two();
261                self.find_first_fit(bucket_size)
262            }
263        };
264
265        match workspace_id {
266            Some(id) => {
267                self.workspaces.get_mut(&id).unwrap().acquire()?;
268                self.stats.total_allocations += 1;
269                Ok(id)
270            }
271            None => {
272                // No suitable workspace found
273                if self.config.auto_expand {
274                    let new_size = self.size_to_bucket(size);
275                    let id = self.create_workspace(new_size);
276                    self.workspaces.get_mut(&id).unwrap().acquire()?;
277                    self.stats.total_allocations += 1;
278                    self.stats.total_expansions += 1;
279                    Ok(id)
280                } else {
281                    Err(WorkspaceError::AllocationFailed {
282                        requested: size,
283                        available: self.max_available_size(),
284                    })
285                }
286            }
287        }
288    }
289
290    /// Release a workspace back to the pool.
291    pub fn release(&mut self, id: &str) -> Result<(), WorkspaceError> {
292        let workspace_size = {
293            let workspace = self
294                .workspaces
295                .get_mut(id)
296                .ok_or_else(|| WorkspaceError::NotFound(id.to_string()))?;
297
298            workspace.release();
299            workspace.size
300        };
301
302        self.stats.total_releases += 1;
303
304        // Add back to free list
305        let bucket = self.size_to_bucket(workspace_size);
306        self.free_lists
307            .entry(bucket)
308            .or_default()
309            .push_back(id.to_string());
310
311        Ok(())
312    }
313
314    /// Find best fit workspace.
315    fn find_best_fit(&mut self, size: usize) -> Option<String> {
316        let mut best_id: Option<String> = None;
317        let mut best_size = usize::MAX;
318
319        for (ws_id, workspace) in &self.workspaces {
320            if !workspace.in_use && workspace.size >= size && workspace.size < best_size {
321                best_id = Some(ws_id.clone());
322                best_size = workspace.size;
323            }
324        }
325
326        if let Some(ref id) = best_id {
327            let bucket = self.size_to_bucket(best_size);
328            if let Some(list) = self.free_lists.get_mut(&bucket) {
329                list.retain(|ws_id| ws_id != id);
330            }
331        }
332
333        best_id
334    }
335
336    /// Find first fit workspace.
337    fn find_first_fit(&mut self, size: usize) -> Option<String> {
338        for (ws_id, workspace) in &self.workspaces {
339            if !workspace.in_use && workspace.size >= size {
340                let id = ws_id.clone();
341                let bucket = self.size_to_bucket(workspace.size);
342                if let Some(list) = self.free_lists.get_mut(&bucket) {
343                    list.retain(|ws_id| ws_id != &id);
344                }
345                return Some(id);
346            }
347        }
348        None
349    }
350
351    /// Find exact fit workspace.
352    fn find_exact_fit(&mut self, size: usize) -> Option<String> {
353        let bucket = self.size_to_bucket(size);
354        if let Some(list) = self.free_lists.get_mut(&bucket) {
355            list.pop_front()
356        } else {
357            None
358        }
359    }
360
361    /// Get maximum available workspace size.
362    fn max_available_size(&self) -> usize {
363        self.workspaces
364            .values()
365            .filter(|ws| !ws.in_use)
366            .map(|ws| ws.size)
367            .max()
368            .unwrap_or(0)
369    }
370
371    /// Get statistics.
372    pub fn stats(&self) -> &WorkspaceStats {
373        &self.stats
374    }
375
376    /// Perform defragmentation if needed.
377    pub fn defragment(&mut self) -> DefragmentationResult {
378        if !self.config.enable_defragmentation {
379            return DefragmentationResult {
380                freed_bytes: 0,
381                merged_workspaces: 0,
382            };
383        }
384
385        let fragmentation_ratio = self.compute_fragmentation_ratio();
386        if fragmentation_ratio < self.config.defrag_threshold {
387            return DefragmentationResult {
388                freed_bytes: 0,
389                merged_workspaces: 0,
390            };
391        }
392
393        // Simple defragmentation: merge adjacent free workspaces
394        // In a real implementation, this would involve memory compaction
395        let freed_bytes = 0;
396        let merged_workspaces = 0;
397
398        // Placeholder for actual defragmentation logic
399        self.stats.total_defragmentations += 1;
400
401        DefragmentationResult {
402            freed_bytes,
403            merged_workspaces,
404        }
405    }
406
407    /// Compute fragmentation ratio.
408    fn compute_fragmentation_ratio(&self) -> f64 {
409        let total_free = self
410            .workspaces
411            .values()
412            .filter(|ws| !ws.in_use)
413            .map(|ws| ws.size)
414            .sum::<usize>();
415
416        let max_free = self.max_available_size();
417
418        if total_free == 0 {
419            0.0
420        } else {
421            1.0 - (max_free as f64 / total_free as f64)
422        }
423    }
424
425    /// Clear all workspaces.
426    pub fn clear(&mut self) {
427        self.workspaces.clear();
428        self.free_lists.clear();
429        self.stats = WorkspaceStats::default();
430        self.preallocate_workspaces();
431    }
432}
433
434/// Thread-safe workspace pool.
435pub struct SharedWorkspacePool {
436    inner: Arc<Mutex<WorkspacePool>>,
437}
438
439impl SharedWorkspacePool {
440    /// Create a new shared workspace pool.
441    pub fn new(config: WorkspaceConfig) -> Self {
442        Self {
443            inner: Arc::new(Mutex::new(WorkspacePool::new(config))),
444        }
445    }
446
447    /// Allocate a workspace.
448    pub fn allocate(&self, size: usize) -> Result<String, WorkspaceError> {
449        self.inner.lock().unwrap().allocate(size)
450    }
451
452    /// Release a workspace.
453    pub fn release(&self, id: &str) -> Result<(), WorkspaceError> {
454        self.inner.lock().unwrap().release(id)
455    }
456
457    /// Get statistics.
458    pub fn stats(&self) -> WorkspaceStats {
459        self.inner.lock().unwrap().stats().clone()
460    }
461
462    /// Perform defragmentation.
463    pub fn defragment(&self) -> DefragmentationResult {
464        self.inner.lock().unwrap().defragment()
465    }
466}
467
468impl Clone for SharedWorkspacePool {
469    fn clone(&self) -> Self {
470        Self {
471            inner: Arc::clone(&self.inner),
472        }
473    }
474}
475
476/// Workspace usage statistics.
477#[derive(Debug, Clone, Default, Serialize, Deserialize)]
478pub struct WorkspaceStats {
479    /// Total workspaces created
480    pub total_created: usize,
481    /// Total allocations
482    pub total_allocations: usize,
483    /// Total releases
484    pub total_releases: usize,
485    /// Total expansions
486    pub total_expansions: usize,
487    /// Total defragmentations
488    pub total_defragmentations: usize,
489    /// Current total size (bytes)
490    pub current_total_size: usize,
491}
492
493impl WorkspaceStats {
494    /// Get hit rate (allocations without expansion).
495    pub fn hit_rate(&self) -> f64 {
496        if self.total_allocations == 0 {
497            0.0
498        } else {
499            1.0 - (self.total_expansions as f64 / self.total_allocations as f64)
500        }
501    }
502
503    /// Get average workspace size.
504    pub fn avg_workspace_size(&self) -> f64 {
505        if self.total_created == 0 {
506            0.0
507        } else {
508            self.current_total_size as f64 / self.total_created as f64
509        }
510    }
511}
512
513/// Defragmentation result.
514#[derive(Debug, Clone, Serialize, Deserialize)]
515pub struct DefragmentationResult {
516    /// Bytes freed
517    pub freed_bytes: usize,
518    /// Number of workspaces merged
519    pub merged_workspaces: usize,
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525
526    #[test]
527    fn test_workspace_creation() {
528        let ws = Workspace::new("test".to_string(), 1024);
529        assert_eq!(ws.size, 1024);
530        assert!(!ws.in_use);
531        assert_eq!(ws.allocation_count, 0);
532    }
533
534    #[test]
535    fn test_workspace_acquire_release() {
536        let mut ws = Workspace::new("test".to_string(), 1024);
537
538        assert!(ws.acquire().is_ok());
539        assert!(ws.in_use);
540        assert_eq!(ws.allocation_count, 1);
541
542        // Cannot acquire twice
543        assert!(ws.acquire().is_err());
544
545        ws.release();
546        assert!(!ws.in_use);
547
548        // Can acquire again
549        assert!(ws.acquire().is_ok());
550        assert_eq!(ws.allocation_count, 2);
551    }
552
553    #[test]
554    fn test_workspace_config() {
555        let config = WorkspaceConfig::large_model();
556        assert!(config.initial_size > WorkspaceConfig::default().initial_size);
557
558        let config = WorkspaceConfig::small_model();
559        assert!(config.max_size < WorkspaceConfig::default().max_size);
560    }
561
562    #[test]
563    fn test_workspace_pool_creation() {
564        let config = WorkspaceConfig::default();
565        let pool = WorkspacePool::new(config);
566
567        assert!(pool.stats().total_created > 0);
568    }
569
570    #[test]
571    fn test_workspace_allocation() {
572        let config = WorkspaceConfig::default();
573        let mut pool = WorkspacePool::new(config);
574
575        let id = pool.allocate(512).unwrap();
576        assert!(!id.is_empty());
577
578        let workspace = pool.workspaces.get(&id).unwrap();
579        assert!(workspace.in_use);
580        assert!(workspace.size >= 512);
581    }
582
583    #[test]
584    fn test_workspace_release() {
585        let config = WorkspaceConfig::default();
586        let mut pool = WorkspacePool::new(config);
587
588        let id = pool.allocate(512).unwrap();
589        assert!(pool.release(&id).is_ok());
590
591        let workspace = pool.workspaces.get(&id).unwrap();
592        assert!(!workspace.in_use);
593    }
594
595    #[test]
596    fn test_allocation_strategies() {
597        // Test different strategies
598        for strategy in [
599            AllocationStrategy::BestFit,
600            AllocationStrategy::FirstFit,
601            AllocationStrategy::ExactFit,
602            AllocationStrategy::PowerOfTwo,
603        ] {
604            let config = WorkspaceConfig {
605                strategy,
606                ..Default::default()
607            };
608            let mut pool = WorkspacePool::new(config);
609
610            let id = pool.allocate(512);
611            assert!(id.is_ok());
612        }
613    }
614
615    #[test]
616    fn test_auto_expansion() {
617        let config = WorkspaceConfig {
618            initial_size: 1024,
619            max_size: 1024 * 1024,
620            auto_expand: true,
621            num_buckets: 2, // Limit pre-allocated buckets
622            ..Default::default()
623        };
624        let mut pool = WorkspacePool::new(config);
625
626        // Record initial expansion count
627        let initial_expansions = pool.stats().total_expansions;
628
629        // Allocate larger than any pre-allocated workspace
630        // With 2 buckets and growth factor 2.0: 1024, 2048
631        // So allocate 5KB which is larger than 2048
632        let id = pool.allocate(5 * 1024);
633        assert!(id.is_ok());
634
635        assert!(pool.stats().total_expansions > initial_expansions);
636    }
637
638    #[test]
639    fn test_allocation_without_expansion() {
640        let config = WorkspaceConfig {
641            initial_size: 1024,
642            max_size: 2048,
643            auto_expand: false,
644            ..Default::default()
645        };
646        let mut pool = WorkspacePool::new(config);
647
648        // Should fail if no suitable workspace
649        // (might succeed if initial workspaces are large enough)
650        let result = pool.allocate(100 * 1024);
651        // Just ensure it doesn't panic
652        let _ = result;
653    }
654
655    #[test]
656    fn test_stats_hit_rate() {
657        let stats = WorkspaceStats {
658            total_allocations: 10,
659            total_expansions: 2,
660            ..Default::default()
661        };
662
663        assert_eq!(stats.hit_rate(), 0.8);
664    }
665
666    #[test]
667    fn test_shared_workspace_pool() {
668        let config = WorkspaceConfig::default();
669        let pool = SharedWorkspacePool::new(config);
670
671        let id = pool.allocate(512).unwrap();
672        assert!(pool.release(&id).is_ok());
673
674        let stats = pool.stats();
675        assert!(stats.total_allocations > 0);
676    }
677
678    #[test]
679    fn test_fragmentation_ratio() {
680        let config = WorkspaceConfig::default();
681        let pool = WorkspacePool::new(config);
682
683        let ratio = pool.compute_fragmentation_ratio();
684        assert!((0.0..=1.0).contains(&ratio));
685    }
686
687    #[test]
688    fn test_defragmentation() {
689        let config = WorkspaceConfig {
690            enable_defragmentation: true,
691            ..Default::default()
692        };
693        let mut pool = WorkspacePool::new(config);
694
695        let result = pool.defragment();
696        // Should not panic
697        assert_eq!(result.freed_bytes, 0); // No actual defrag yet
698    }
699}