tenflowers_core/
context.rs1use crate::device::context::{DeviceContext, DEVICE_MANAGER};
2use crate::{Device, Result};
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6pub struct Context {
8 default_device: Device,
10 device_contexts: RwLock<HashMap<Device, Arc<dyn DeviceContext>>>,
12 attributes: RwLock<HashMap<String, String>>,
14 eager_mode: bool,
16 profiling_enabled: bool,
18}
19
20impl Context {
21 pub fn new() -> Result<Self> {
23 Ok(Self {
24 default_device: Device::Cpu,
25 device_contexts: RwLock::new(HashMap::new()),
26 attributes: RwLock::new(HashMap::new()),
27 eager_mode: true,
28 profiling_enabled: false,
29 })
30 }
31
32 pub fn with_device(device: Device) -> Result<Self> {
34 let mut ctx = Self::new()?;
35 ctx.default_device = device;
36 Ok(ctx)
37 }
38
39 pub fn default_device(&self) -> Device {
41 self.default_device
42 }
43
44 pub fn set_default_device(&mut self, device: Device) {
46 self.default_device = device;
47 }
48
49 pub fn is_eager(&self) -> bool {
51 self.eager_mode
52 }
53
54 pub fn set_eager_mode(&mut self, eager: bool) {
56 self.eager_mode = eager;
57 }
58
59 pub fn set_profiling(&mut self, enabled: bool) {
61 self.profiling_enabled = enabled;
62 }
63
64 pub fn get_device_context(&self, device: &Device) -> Result<Arc<dyn DeviceContext>> {
66 {
68 let contexts = self
69 .device_contexts
70 .read()
71 .expect("read lock should not be poisoned");
72 if let Some(ctx) = contexts.get(device) {
73 return Ok(Arc::clone(ctx));
74 }
75 }
76
77 let ctx = DEVICE_MANAGER.get_context(device)?;
79
80 {
82 let mut contexts = self
83 .device_contexts
84 .write()
85 .expect("write lock should not be poisoned");
86 contexts.insert(*device, Arc::clone(&ctx));
87 }
88
89 Ok(ctx)
90 }
91
92 pub fn set_attribute(&self, key: String, value: String) {
94 let mut attrs = self
95 .attributes
96 .write()
97 .expect("write lock should not be poisoned");
98 attrs.insert(key, value);
99 }
100
101 pub fn get_attribute(&self, key: &str) -> Option<String> {
103 let attrs = self
104 .attributes
105 .read()
106 .expect("read lock should not be poisoned");
107 attrs.get(key).cloned()
108 }
109}
110
111lazy_static::lazy_static! {
113 static ref GLOBAL_CONTEXT: RwLock<Option<Arc<Context>>> = RwLock::new(None);
114}
115
116pub fn get_context() -> Result<Arc<Context>> {
118 let ctx_opt = GLOBAL_CONTEXT
119 .read()
120 .expect("read lock should not be poisoned");
121 if let Some(ctx) = ctx_opt.as_ref() {
122 Ok(Arc::clone(ctx))
123 } else {
124 drop(ctx_opt);
125
126 let ctx = Arc::new(Context::new()?);
128 let mut ctx_opt = GLOBAL_CONTEXT
129 .write()
130 .expect("write lock should not be poisoned");
131 *ctx_opt = Some(Arc::clone(&ctx));
132 Ok(ctx)
133 }
134}
135
136pub fn set_context(ctx: Arc<Context>) {
138 let mut ctx_opt = GLOBAL_CONTEXT
139 .write()
140 .expect("write lock should not be poisoned");
141 *ctx_opt = Some(ctx);
142}
143
144pub struct DeviceScope {
146 previous_device: Device,
147 context: Arc<Context>,
148}
149
150impl DeviceScope {
151 pub fn new(device: Device) -> Result<Self> {
153 let ctx = get_context()?;
154 let previous = ctx.default_device();
155
156 let mut new_ctx = (*ctx).clone();
158 new_ctx.set_default_device(device);
159 set_context(Arc::new(new_ctx));
160
161 Ok(Self {
162 previous_device: previous,
163 context: ctx,
164 })
165 }
166}
167
168impl Drop for DeviceScope {
169 fn drop(&mut self) {
170 let mut restored_ctx = (*self.context).clone();
172 restored_ctx.set_default_device(self.previous_device);
173 set_context(Arc::new(restored_ctx));
174 }
175}
176
177impl Clone for Context {
179 fn clone(&self) -> Self {
180 Self {
181 default_device: self.default_device,
182 device_contexts: RwLock::new(HashMap::new()), attributes: RwLock::new(
184 self.attributes
185 .read()
186 .expect("read lock should not be poisoned")
187 .clone(),
188 ),
189 eager_mode: self.eager_mode,
190 profiling_enabled: self.profiling_enabled,
191 }
192 }
193}
194
195#[macro_export]
197macro_rules! with_device {
198 ($device:expr, $body:block) => {{
199 let _scope = $crate::context::DeviceScope::new($device)?;
200 $body
201 }};
202}