Skip to main content

tensorlogic_infer/
workspace.rs

1//! Workspace management for efficient memory reuse.
2//!
3//! This module provides workspace allocation and management for reducing
4//! memory allocation overhead during inference:
5//! - Pre-allocated memory pools
6//! - Workspace recycling
7//! - Size-based workspace selection
8//! - Automatic workspace expansion
9//! - Memory defragmentation
10//! - Multi-threaded workspace management
11
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, Mutex};
15use thiserror::Error;
16
17/// Workspace management errors.
18#[derive(Error, Debug, Clone, PartialEq)]
19pub enum WorkspaceError {
20    #[error(
21        "Workspace allocation failed: requested {requested} bytes, available {available} bytes"
22    )]
23    AllocationFailed { requested: usize, available: usize },
24
25    #[error("Workspace not found: {0}")]
26    NotFound(String),
27
28    #[error("Invalid workspace size: {0}")]
29    InvalidSize(usize),
30
31    #[error("Workspace limit exceeded: {limit} bytes")]
32    LimitExceeded { limit: usize },
33
34    #[error("Workspace is in use")]
35    InUse,
36}
37
38/// Workspace allocation strategy.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
40pub enum AllocationStrategy {
41    /// Best fit - find smallest workspace that fits
42    BestFit,
43    /// First fit - use first workspace that fits
44    FirstFit,
45    /// Exact fit - only use exact size matches
46    ExactFit,
47    /// Power of 2 - round up to power of 2 sizes
48    PowerOfTwo,
49}
50
51/// Workspace configuration.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct WorkspaceConfig {
54    /// Initial workspace size (bytes)
55    pub initial_size: usize,
56    /// Maximum workspace size (bytes)
57    pub max_size: usize,
58    /// Growth factor when expanding
59    pub growth_factor: f64,
60    /// Allocation strategy
61    pub strategy: AllocationStrategy,
62    /// Enable automatic expansion
63    pub auto_expand: bool,
64    /// Enable defragmentation
65    pub enable_defragmentation: bool,
66    /// Defragmentation threshold (fragmentation ratio)
67    pub defrag_threshold: f64,
68    /// Number of size buckets for pooling
69    pub num_buckets: usize,
70}
71
72impl Default for WorkspaceConfig {
73    fn default() -> Self {
74        Self {
75            initial_size: 1024 * 1024,    // 1 MB
76            max_size: 1024 * 1024 * 1024, // 1 GB
77            growth_factor: 2.0,
78            strategy: AllocationStrategy::BestFit,
79            auto_expand: true,
80            enable_defragmentation: false,
81            defrag_threshold: 0.5,
82            num_buckets: 16,
83        }
84    }
85}
86
87impl WorkspaceConfig {
88    /// Create configuration for large models.
89    pub fn large_model() -> Self {
90        Self {
91            initial_size: 64 * 1024 * 1024,   // 64 MB
92            max_size: 8 * 1024 * 1024 * 1024, // 8 GB
93            growth_factor: 1.5,
94            num_buckets: 32,
95            ..Default::default()
96        }
97    }
98
99    /// Create configuration for small models.
100    pub fn small_model() -> Self {
101        Self {
102            initial_size: 256 * 1024,    // 256 KB
103            max_size: 128 * 1024 * 1024, // 128 MB
104            growth_factor: 2.0,
105            num_buckets: 8,
106            ..Default::default()
107        }
108    }
109
110    /// Create configuration optimized for memory.
111    pub fn memory_optimized() -> Self {
112        Self {
113            initial_size: 512 * 1024,    // 512 KB
114            max_size: 256 * 1024 * 1024, // 256 MB
115            growth_factor: 1.2,
116            enable_defragmentation: true,
117            defrag_threshold: 0.3,
118            ..Default::default()
119        }
120    }
121}
122
123/// A reusable workspace buffer.
124#[derive(Debug, Clone)]
125pub struct Workspace {
126    /// Unique identifier
127    pub id: String,
128    /// Size in bytes
129    pub size: usize,
130    /// Whether currently in use
131    pub in_use: bool,
132    /// Number of times allocated
133    pub allocation_count: usize,
134    /// Total time in use (for statistics)
135    pub total_use_time: std::time::Duration,
136}
137
138impl Workspace {
139    /// Create a new workspace.
140    pub fn new(id: String, size: usize) -> Self {
141        Self {
142            id,
143            size,
144            in_use: false,
145            allocation_count: 0,
146            total_use_time: std::time::Duration::ZERO,
147        }
148    }
149
150    /// Mark as in use.
151    pub fn acquire(&mut self) -> Result<(), WorkspaceError> {
152        if self.in_use {
153            return Err(WorkspaceError::InUse);
154        }
155        self.in_use = true;
156        self.allocation_count += 1;
157        Ok(())
158    }
159
160    /// Mark as available.
161    pub fn release(&mut self) {
162        self.in_use = false;
163    }
164}
165
166/// Workspace pool for managing multiple workspaces.
167pub struct WorkspacePool {
168    config: WorkspaceConfig,
169    workspaces: HashMap<String, Workspace>,
170    free_lists: HashMap<usize, VecDeque<String>>, // Size bucket -> workspace IDs
171    next_id: usize,
172    stats: WorkspaceStats,
173}
174
175impl WorkspacePool {
176    /// Create a new workspace pool.
177    pub fn new(config: WorkspaceConfig) -> Self {
178        let mut pool = Self {
179            config,
180            workspaces: HashMap::new(),
181            free_lists: HashMap::new(),
182            next_id: 0,
183            stats: WorkspaceStats::default(),
184        };
185
186        // Pre-allocate initial workspaces
187        pool.preallocate_workspaces();
188
189        pool
190    }
191
192    /// Pre-allocate workspaces based on configuration.
193    fn preallocate_workspaces(&mut self) {
194        let sizes = self.compute_bucket_sizes();
195        for size in sizes {
196            let _ = self.create_workspace(size);
197        }
198    }
199
200    /// Compute workspace sizes for buckets.
201    fn compute_bucket_sizes(&self) -> Vec<usize> {
202        let mut sizes = Vec::new();
203        let mut size = self.config.initial_size;
204
205        for _ in 0..self.config.num_buckets {
206            sizes.push(size);
207            size = (size as f64 * self.config.growth_factor) as usize;
208            if size > self.config.max_size {
209                break;
210            }
211        }
212
213        sizes
214    }
215
216    /// Create a new workspace.
217    fn create_workspace(&mut self, size: usize) -> String {
218        let id = format!("ws_{}", self.next_id);
219        self.next_id += 1;
220
221        let workspace = Workspace::new(id.clone(), size);
222        self.workspaces.insert(id.clone(), workspace);
223
224        // Add to free list
225        let bucket = self.size_to_bucket(size);
226        self.free_lists
227            .entry(bucket)
228            .or_default()
229            .push_back(id.clone());
230
231        self.stats.total_created += 1;
232        self.stats.current_total_size += size;
233
234        id
235    }
236
237    /// Convert size to bucket size.
238    fn size_to_bucket(&self, size: usize) -> usize {
239        match self.config.strategy {
240            AllocationStrategy::PowerOfTwo => size.next_power_of_two(),
241            _ => {
242                // Find nearest bucket size
243                let sizes = self.compute_bucket_sizes();
244                sizes.iter().find(|&&s| s >= size).copied().unwrap_or(size)
245            }
246        }
247    }
248
249    /// Allocate a workspace of at least the given size.
250    pub fn allocate(&mut self, size: usize) -> Result<String, WorkspaceError> {
251        if size > self.config.max_size {
252            return Err(WorkspaceError::InvalidSize(size));
253        }
254
255        let workspace_id = match self.config.strategy {
256            AllocationStrategy::BestFit => self.find_best_fit(size),
257            AllocationStrategy::FirstFit => self.find_first_fit(size),
258            AllocationStrategy::ExactFit => self.find_exact_fit(size),
259            AllocationStrategy::PowerOfTwo => {
260                let bucket_size = size.next_power_of_two();
261                self.find_first_fit(bucket_size)
262            }
263        };
264
265        match workspace_id {
266            Some(id) => {
267                self.workspaces
268                    .get_mut(&id)
269                    .expect("workspace id from find_first_fit or create_workspace is valid")
270                    .acquire()?;
271                self.stats.total_allocations += 1;
272                Ok(id)
273            }
274            None => {
275                // No suitable workspace found
276                if self.config.auto_expand {
277                    let new_size = self.size_to_bucket(size);
278                    let id = self.create_workspace(new_size);
279                    self.workspaces
280                        .get_mut(&id)
281                        .expect("workspace id from create_workspace is valid")
282                        .acquire()?;
283                    self.stats.total_allocations += 1;
284                    self.stats.total_expansions += 1;
285                    Ok(id)
286                } else {
287                    Err(WorkspaceError::AllocationFailed {
288                        requested: size,
289                        available: self.max_available_size(),
290                    })
291                }
292            }
293        }
294    }
295
296    /// Release a workspace back to the pool.
297    pub fn release(&mut self, id: &str) -> Result<(), WorkspaceError> {
298        let workspace_size = {
299            let workspace = self
300                .workspaces
301                .get_mut(id)
302                .ok_or_else(|| WorkspaceError::NotFound(id.to_string()))?;
303
304            workspace.release();
305            workspace.size
306        };
307
308        self.stats.total_releases += 1;
309
310        // Add back to free list
311        let bucket = self.size_to_bucket(workspace_size);
312        self.free_lists
313            .entry(bucket)
314            .or_default()
315            .push_back(id.to_string());
316
317        Ok(())
318    }
319
320    /// Find best fit workspace.
321    fn find_best_fit(&mut self, size: usize) -> Option<String> {
322        let mut best_id: Option<String> = None;
323        let mut best_size = usize::MAX;
324
325        for (ws_id, workspace) in &self.workspaces {
326            if !workspace.in_use && workspace.size >= size && workspace.size < best_size {
327                best_id = Some(ws_id.clone());
328                best_size = workspace.size;
329            }
330        }
331
332        if let Some(ref id) = best_id {
333            let bucket = self.size_to_bucket(best_size);
334            if let Some(list) = self.free_lists.get_mut(&bucket) {
335                list.retain(|ws_id| ws_id != id);
336            }
337        }
338
339        best_id
340    }
341
342    /// Find first fit workspace.
343    fn find_first_fit(&mut self, size: usize) -> Option<String> {
344        for (ws_id, workspace) in &self.workspaces {
345            if !workspace.in_use && workspace.size >= size {
346                let id = ws_id.clone();
347                let bucket = self.size_to_bucket(workspace.size);
348                if let Some(list) = self.free_lists.get_mut(&bucket) {
349                    list.retain(|ws_id| ws_id != &id);
350                }
351                return Some(id);
352            }
353        }
354        None
355    }
356
357    /// Find exact fit workspace.
358    fn find_exact_fit(&mut self, size: usize) -> Option<String> {
359        let bucket = self.size_to_bucket(size);
360        if let Some(list) = self.free_lists.get_mut(&bucket) {
361            list.pop_front()
362        } else {
363            None
364        }
365    }
366
367    /// Get maximum available workspace size.
368    fn max_available_size(&self) -> usize {
369        self.workspaces
370            .values()
371            .filter(|ws| !ws.in_use)
372            .map(|ws| ws.size)
373            .max()
374            .unwrap_or(0)
375    }
376
377    /// Get statistics.
378    pub fn stats(&self) -> &WorkspaceStats {
379        &self.stats
380    }
381
382    /// Perform defragmentation if needed.
383    pub fn defragment(&mut self) -> DefragmentationResult {
384        if !self.config.enable_defragmentation {
385            return DefragmentationResult {
386                freed_bytes: 0,
387                merged_workspaces: 0,
388            };
389        }
390
391        let fragmentation_ratio = self.compute_fragmentation_ratio();
392        if fragmentation_ratio < self.config.defrag_threshold {
393            return DefragmentationResult {
394                freed_bytes: 0,
395                merged_workspaces: 0,
396            };
397        }
398
399        // Defragmentation strategy: pairwise-merge free workspaces from
400        // smallest to largest. This increases the max contiguous free block
401        // without changing total free bytes.
402        let mut free_blocks: Vec<(String, usize)> = self
403            .workspaces
404            .iter()
405            .filter_map(|(id, ws)| {
406                if ws.in_use {
407                    None
408                } else {
409                    Some((id.clone(), ws.size))
410                }
411            })
412            .collect();
413
414        if free_blocks.len() < 2 {
415            self.stats.total_defragmentations += 1;
416            return DefragmentationResult {
417                freed_bytes: 0,
418                merged_workspaces: 0,
419            };
420        }
421
422        free_blocks.sort_by_key(|(_, size)| *size);
423
424        let freed_bytes = 0;
425        let mut merged_workspaces = 0;
426
427        let mut pair_index = 0usize;
428        while pair_index + 1 < free_blocks.len() {
429            let (id_a, size_a) = &free_blocks[pair_index];
430            let (id_b, size_b) = &free_blocks[pair_index + 1];
431            let merged_size = size_a.saturating_add(*size_b);
432
433            // If the merged block would violate workspace limits, skip this pair.
434            if merged_size > self.config.max_size {
435                pair_index += 2;
436                continue;
437            }
438
439            self.remove_from_free_list(id_a, *size_a);
440            self.remove_from_free_list(id_b, *size_b);
441            self.workspaces.remove(id_a);
442            self.workspaces.remove(id_b);
443
444            let merged_id = format!("ws_{}", self.next_id);
445            self.next_id += 1;
446            self.workspaces.insert(
447                merged_id.clone(),
448                Workspace::new(merged_id.clone(), merged_size),
449            );
450
451            let bucket = self.size_to_bucket(merged_size);
452            self.free_lists
453                .entry(bucket)
454                .or_default()
455                .push_back(merged_id);
456
457            merged_workspaces += 1;
458            pair_index += 2;
459        }
460
461        self.stats.total_defragmentations += 1;
462
463        DefragmentationResult {
464            freed_bytes,
465            merged_workspaces,
466        }
467    }
468
469    fn remove_from_free_list(&mut self, id: &str, size: usize) {
470        let bucket = self.size_to_bucket(size);
471        if let Some(list) = self.free_lists.get_mut(&bucket) {
472            list.retain(|ws_id| ws_id != id);
473        }
474    }
475
476    /// Compute fragmentation ratio.
477    fn compute_fragmentation_ratio(&self) -> f64 {
478        let total_free = self
479            .workspaces
480            .values()
481            .filter(|ws| !ws.in_use)
482            .map(|ws| ws.size)
483            .sum::<usize>();
484
485        let max_free = self.max_available_size();
486
487        if total_free == 0 {
488            0.0
489        } else {
490            1.0 - (max_free as f64 / total_free as f64)
491        }
492    }
493
494    /// Clear all workspaces.
495    pub fn clear(&mut self) {
496        self.workspaces.clear();
497        self.free_lists.clear();
498        self.stats = WorkspaceStats::default();
499        self.preallocate_workspaces();
500    }
501}
502
503/// Thread-safe workspace pool.
504pub struct SharedWorkspacePool {
505    inner: Arc<Mutex<WorkspacePool>>,
506}
507
508impl SharedWorkspacePool {
509    /// Create a new shared workspace pool.
510    pub fn new(config: WorkspaceConfig) -> Self {
511        Self {
512            inner: Arc::new(Mutex::new(WorkspacePool::new(config))),
513        }
514    }
515
516    /// Allocate a workspace.
517    pub fn allocate(&self, size: usize) -> Result<String, WorkspaceError> {
518        self.inner
519            .lock()
520            .expect("lock should not be poisoned")
521            .allocate(size)
522    }
523
524    /// Release a workspace.
525    pub fn release(&self, id: &str) -> Result<(), WorkspaceError> {
526        self.inner
527            .lock()
528            .expect("lock should not be poisoned")
529            .release(id)
530    }
531
532    /// Get statistics.
533    pub fn stats(&self) -> WorkspaceStats {
534        self.inner
535            .lock()
536            .expect("lock should not be poisoned")
537            .stats()
538            .clone()
539    }
540
541    /// Perform defragmentation.
542    pub fn defragment(&self) -> DefragmentationResult {
543        self.inner
544            .lock()
545            .expect("lock should not be poisoned")
546            .defragment()
547    }
548}
549
550impl Clone for SharedWorkspacePool {
551    fn clone(&self) -> Self {
552        Self {
553            inner: Arc::clone(&self.inner),
554        }
555    }
556}
557
558/// Workspace usage statistics.
559#[derive(Debug, Clone, Default, Serialize, Deserialize)]
560pub struct WorkspaceStats {
561    /// Total workspaces created
562    pub total_created: usize,
563    /// Total allocations
564    pub total_allocations: usize,
565    /// Total releases
566    pub total_releases: usize,
567    /// Total expansions
568    pub total_expansions: usize,
569    /// Total defragmentations
570    pub total_defragmentations: usize,
571    /// Current total size (bytes)
572    pub current_total_size: usize,
573}
574
575impl WorkspaceStats {
576    /// Get hit rate (allocations without expansion).
577    pub fn hit_rate(&self) -> f64 {
578        if self.total_allocations == 0 {
579            0.0
580        } else {
581            1.0 - (self.total_expansions as f64 / self.total_allocations as f64)
582        }
583    }
584
585    /// Get average workspace size.
586    pub fn avg_workspace_size(&self) -> f64 {
587        if self.total_created == 0 {
588            0.0
589        } else {
590            self.current_total_size as f64 / self.total_created as f64
591        }
592    }
593}
594
595/// Defragmentation result.
596#[derive(Debug, Clone, Serialize, Deserialize)]
597pub struct DefragmentationResult {
598    /// Bytes freed
599    pub freed_bytes: usize,
600    /// Number of workspaces merged
601    pub merged_workspaces: usize,
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607
608    #[test]
609    fn test_workspace_creation() {
610        let ws = Workspace::new("test".to_string(), 1024);
611        assert_eq!(ws.size, 1024);
612        assert!(!ws.in_use);
613        assert_eq!(ws.allocation_count, 0);
614    }
615
616    #[test]
617    fn test_workspace_acquire_release() {
618        let mut ws = Workspace::new("test".to_string(), 1024);
619
620        assert!(ws.acquire().is_ok());
621        assert!(ws.in_use);
622        assert_eq!(ws.allocation_count, 1);
623
624        // Cannot acquire twice
625        assert!(ws.acquire().is_err());
626
627        ws.release();
628        assert!(!ws.in_use);
629
630        // Can acquire again
631        assert!(ws.acquire().is_ok());
632        assert_eq!(ws.allocation_count, 2);
633    }
634
635    #[test]
636    fn test_workspace_config() {
637        let config = WorkspaceConfig::large_model();
638        assert!(config.initial_size > WorkspaceConfig::default().initial_size);
639
640        let config = WorkspaceConfig::small_model();
641        assert!(config.max_size < WorkspaceConfig::default().max_size);
642    }
643
644    #[test]
645    fn test_workspace_pool_creation() {
646        let config = WorkspaceConfig::default();
647        let pool = WorkspacePool::new(config);
648
649        assert!(pool.stats().total_created > 0);
650    }
651
652    #[test]
653    fn test_workspace_allocation() {
654        let config = WorkspaceConfig::default();
655        let mut pool = WorkspacePool::new(config);
656
657        let id = pool.allocate(512).expect("unwrap");
658        assert!(!id.is_empty());
659
660        let workspace = pool.workspaces.get(&id).expect("unwrap");
661        assert!(workspace.in_use);
662        assert!(workspace.size >= 512);
663    }
664
665    #[test]
666    fn test_workspace_release() {
667        let config = WorkspaceConfig::default();
668        let mut pool = WorkspacePool::new(config);
669
670        let id = pool.allocate(512).expect("unwrap");
671        assert!(pool.release(&id).is_ok());
672
673        let workspace = pool.workspaces.get(&id).expect("unwrap");
674        assert!(!workspace.in_use);
675    }
676
677    #[test]
678    fn test_allocation_strategies() {
679        // Test different strategies
680        for strategy in [
681            AllocationStrategy::BestFit,
682            AllocationStrategy::FirstFit,
683            AllocationStrategy::ExactFit,
684            AllocationStrategy::PowerOfTwo,
685        ] {
686            let config = WorkspaceConfig {
687                strategy,
688                ..Default::default()
689            };
690            let mut pool = WorkspacePool::new(config);
691
692            let id = pool.allocate(512);
693            assert!(id.is_ok());
694        }
695    }
696
697    #[test]
698    fn test_auto_expansion() {
699        let config = WorkspaceConfig {
700            initial_size: 1024,
701            max_size: 1024 * 1024,
702            auto_expand: true,
703            num_buckets: 2, // Limit pre-allocated buckets
704            ..Default::default()
705        };
706        let mut pool = WorkspacePool::new(config);
707
708        // Record initial expansion count
709        let initial_expansions = pool.stats().total_expansions;
710
711        // Allocate larger than any pre-allocated workspace
712        // With 2 buckets and growth factor 2.0: 1024, 2048
713        // So allocate 5KB which is larger than 2048
714        let id = pool.allocate(5 * 1024);
715        assert!(id.is_ok());
716
717        assert!(pool.stats().total_expansions > initial_expansions);
718    }
719
720    #[test]
721    fn test_allocation_without_expansion() {
722        let config = WorkspaceConfig {
723            initial_size: 1024,
724            max_size: 2048,
725            auto_expand: false,
726            ..Default::default()
727        };
728        let mut pool = WorkspacePool::new(config);
729
730        // Should fail if no suitable workspace
731        // (might succeed if initial workspaces are large enough)
732        let result = pool.allocate(100 * 1024);
733        // Just ensure it doesn't panic
734        let _ = result;
735    }
736
737    #[test]
738    fn test_stats_hit_rate() {
739        let stats = WorkspaceStats {
740            total_allocations: 10,
741            total_expansions: 2,
742            ..Default::default()
743        };
744
745        assert_eq!(stats.hit_rate(), 0.8);
746    }
747
748    #[test]
749    fn test_shared_workspace_pool() {
750        let config = WorkspaceConfig::default();
751        let pool = SharedWorkspacePool::new(config);
752
753        let id = pool.allocate(512).expect("unwrap");
754        assert!(pool.release(&id).is_ok());
755
756        let stats = pool.stats();
757        assert!(stats.total_allocations > 0);
758    }
759
760    #[test]
761    fn test_fragmentation_ratio() {
762        let config = WorkspaceConfig::default();
763        let pool = WorkspacePool::new(config);
764
765        let ratio = pool.compute_fragmentation_ratio();
766        assert!((0.0..=1.0).contains(&ratio));
767    }
768
769    #[test]
770    fn test_defragmentation() {
771        let config = WorkspaceConfig {
772            enable_defragmentation: true,
773            ..Default::default()
774        };
775        let mut pool = WorkspacePool::new(config);
776
777        let result = pool.defragment();
778        // Defragmentation should be safe and preserve free bytes.
779        assert_eq!(result.freed_bytes, 0);
780    }
781
782    #[test]
783    fn test_defragmentation_merges_when_threshold_met() {
784        let config = WorkspaceConfig {
785            enable_defragmentation: true,
786            defrag_threshold: 0.0,
787            num_buckets: 4,
788            ..Default::default()
789        };
790        let mut pool = WorkspacePool::new(config);
791
792        let before = pool.workspaces.len();
793        let result = pool.defragment();
794        let after = pool.workspaces.len();
795
796        assert_eq!(result.freed_bytes, 0);
797        assert!(result.merged_workspaces > 0);
798        assert!(after < before);
799    }
800}