Skip to main content

tenflowers_core/
context.rs

1use crate::device::context::{DeviceContext, DEVICE_MANAGER};
2use crate::{Device, Result};
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6/// Execution context for TenfloweRS operations
7pub struct Context {
8    /// Default device for operations
9    default_device: Device,
10    /// Device contexts cache
11    device_contexts: RwLock<HashMap<Device, Arc<dyn DeviceContext>>>,
12    /// Context attributes
13    attributes: RwLock<HashMap<String, String>>,
14    /// Eager execution mode
15    eager_mode: bool,
16    /// Enable profiling
17    profiling_enabled: bool,
18}
19
20impl Context {
21    /// Create a new execution context
22    pub fn new() -> Result<Self> {
23        Ok(Self {
24            default_device: Device::Cpu,
25            device_contexts: RwLock::new(HashMap::new()),
26            attributes: RwLock::new(HashMap::new()),
27            eager_mode: true,
28            profiling_enabled: false,
29        })
30    }
31
32    /// Create a context with specific device
33    pub fn with_device(device: Device) -> Result<Self> {
34        let mut ctx = Self::new()?;
35        ctx.default_device = device;
36        Ok(ctx)
37    }
38
39    /// Get the default device
40    pub fn default_device(&self) -> Device {
41        self.default_device
42    }
43
44    /// Set the default device
45    pub fn set_default_device(&mut self, device: Device) {
46        self.default_device = device;
47    }
48
49    /// Check if eager execution is enabled
50    pub fn is_eager(&self) -> bool {
51        self.eager_mode
52    }
53
54    /// Set eager execution mode
55    pub fn set_eager_mode(&mut self, eager: bool) {
56        self.eager_mode = eager;
57    }
58
59    /// Enable/disable profiling
60    pub fn set_profiling(&mut self, enabled: bool) {
61        self.profiling_enabled = enabled;
62    }
63
64    /// Get device context
65    pub fn get_device_context(&self, device: &Device) -> Result<Arc<dyn DeviceContext>> {
66        // Check cache first
67        {
68            let contexts = self
69                .device_contexts
70                .read()
71                .expect("read lock should not be poisoned");
72            if let Some(ctx) = contexts.get(device) {
73                return Ok(Arc::clone(ctx));
74            }
75        }
76
77        // Get from global manager
78        let ctx = DEVICE_MANAGER.get_context(device)?;
79
80        // Cache it
81        {
82            let mut contexts = self
83                .device_contexts
84                .write()
85                .expect("write lock should not be poisoned");
86            contexts.insert(*device, Arc::clone(&ctx));
87        }
88
89        Ok(ctx)
90    }
91
92    /// Set a context attribute
93    pub fn set_attribute(&self, key: String, value: String) {
94        let mut attrs = self
95            .attributes
96            .write()
97            .expect("write lock should not be poisoned");
98        attrs.insert(key, value);
99    }
100
101    /// Get a context attribute
102    pub fn get_attribute(&self, key: &str) -> Option<String> {
103        let attrs = self
104            .attributes
105            .read()
106            .expect("read lock should not be poisoned");
107        attrs.get(key).cloned()
108    }
109}
110
111// Global context for eager execution
112lazy_static::lazy_static! {
113    static ref GLOBAL_CONTEXT: RwLock<Option<Arc<Context>>> = RwLock::new(None);
114}
115
116/// Get the current global context
117pub fn get_context() -> Result<Arc<Context>> {
118    let ctx_opt = GLOBAL_CONTEXT
119        .read()
120        .expect("read lock should not be poisoned");
121    if let Some(ctx) = ctx_opt.as_ref() {
122        Ok(Arc::clone(ctx))
123    } else {
124        drop(ctx_opt);
125
126        // Create new context
127        let ctx = Arc::new(Context::new()?);
128        let mut ctx_opt = GLOBAL_CONTEXT
129            .write()
130            .expect("write lock should not be poisoned");
131        *ctx_opt = Some(Arc::clone(&ctx));
132        Ok(ctx)
133    }
134}
135
136/// Set the global context
137pub fn set_context(ctx: Arc<Context>) {
138    let mut ctx_opt = GLOBAL_CONTEXT
139        .write()
140        .expect("write lock should not be poisoned");
141    *ctx_opt = Some(ctx);
142}
143
144/// Context scope for temporary device placement
145pub struct DeviceScope {
146    previous_device: Device,
147    context: Arc<Context>,
148}
149
150impl DeviceScope {
151    /// Create a new device scope
152    pub fn new(device: Device) -> Result<Self> {
153        let ctx = get_context()?;
154        let previous = ctx.default_device();
155
156        // Clone context and modify
157        let mut new_ctx = (*ctx).clone();
158        new_ctx.set_default_device(device);
159        set_context(Arc::new(new_ctx));
160
161        Ok(Self {
162            previous_device: previous,
163            context: ctx,
164        })
165    }
166}
167
168impl Drop for DeviceScope {
169    fn drop(&mut self) {
170        // Restore previous context
171        let mut restored_ctx = (*self.context).clone();
172        restored_ctx.set_default_device(self.previous_device);
173        set_context(Arc::new(restored_ctx));
174    }
175}
176
177// Make Context cloneable for DeviceScope
178impl Clone for Context {
179    fn clone(&self) -> Self {
180        Self {
181            default_device: self.default_device,
182            device_contexts: RwLock::new(HashMap::new()), // Don't clone cache
183            attributes: RwLock::new(
184                self.attributes
185                    .read()
186                    .expect("read lock should not be poisoned")
187                    .clone(),
188            ),
189            eager_mode: self.eager_mode,
190            profiling_enabled: self.profiling_enabled,
191        }
192    }
193}
194
195/// Macro for device scope
196#[macro_export]
197macro_rules! with_device {
198    ($device:expr, $body:block) => {{
199        let _scope = $crate::context::DeviceScope::new($device)?;
200        $body
201    }};
202}