Skip to main content

teaql_runtime/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::sync::Mutex;
4
5use teaql_core::{EntityDescriptor, Record, Value};
6use teaql_sql::{CompiledQuery, DatabaseKind};
7
8use crate::{
9    CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, InternalIdGenerator,
10    Language, MetadataStore, ObjectLocation, RepositoryBehavior, RepositoryBehaviorRegistry,
11    RepositoryRegistry, RuntimeError, local_id_generator, translate_check_result,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SqlLogOperation {
16    Select,
17    Insert,
18    Update,
19    Delete,
20    Recover,
21}
22
23impl SqlLogOperation {
24    pub fn is_select(self) -> bool {
25        matches!(self, Self::Select)
26    }
27
28    pub fn is_mutation(self) -> bool {
29        !self.is_select()
30    }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34pub struct SqlLogOptions {
35    pub select: bool,
36    pub mutation: bool,
37}
38
39impl SqlLogOptions {
40    pub fn select_only() -> Self {
41        Self {
42            select: true,
43            mutation: false,
44        }
45    }
46
47    pub fn mutation_only() -> Self {
48        Self {
49            select: false,
50            mutation: true,
51        }
52    }
53
54    pub fn all() -> Self {
55        Self {
56            select: true,
57            mutation: true,
58        }
59    }
60
61    pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
62        if operation.is_select() {
63            self.select
64        } else {
65            self.mutation
66        }
67    }
68}
69
70#[derive(Debug, Clone, PartialEq)]
71pub struct SqlLogEntry {
72    pub operation: SqlLogOperation,
73    pub sql: String,
74    pub params: Vec<Value>,
75    pub debug_sql: String,
76}
77
78#[derive(Default)]
79pub struct UserContext {
80    pub(crate) metadata: Option<Box<dyn MetadataStore>>,
81    pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
82    pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
83    pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
84    pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
85    pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
86    language: Language,
87    typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
88    named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
89    locals: BTreeMap<String, Value>,
90    sql_log_options: SqlLogOptions,
91    sql_log_entries: Mutex<Vec<SqlLogEntry>>,
92}
93
94impl UserContext {
95    pub fn new() -> Self {
96        Self::default()
97    }
98
99    pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
100        module.apply_to(&mut self);
101        self
102    }
103
104    pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
105        self.metadata = Some(Box::new(metadata));
106        self
107    }
108
109    pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
110        self.metadata = Some(Box::new(metadata));
111    }
112
113    pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
114        self.repository_registry = Some(Box::new(registry));
115        self
116    }
117
118    pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
119        self.repository_registry = Some(Box::new(registry));
120    }
121
122    pub fn with_repository_behavior_registry(
123        mut self,
124        registry: impl RepositoryBehaviorRegistry + 'static,
125    ) -> Self {
126        self.repository_behavior_registry = Some(Box::new(registry));
127        self
128    }
129
130    pub fn set_repository_behavior_registry(
131        &mut self,
132        registry: impl RepositoryBehaviorRegistry + 'static,
133    ) {
134        self.repository_behavior_registry = Some(Box::new(registry));
135    }
136
137    pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
138        self.checker_registry = Some(Box::new(registry));
139        self
140    }
141
142    pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
143        self.checker_registry = Some(Box::new(registry));
144    }
145
146    pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
147        self.event_sink = Some(Box::new(sink));
148        self
149    }
150
151    pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
152        self.event_sink = Some(Box::new(sink));
153    }
154
155    pub fn with_internal_id_generator(
156        mut self,
157        generator: impl InternalIdGenerator + 'static,
158    ) -> Self {
159        self.internal_id_generator = Some(Box::new(generator));
160        self
161    }
162
163    pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
164        self.internal_id_generator = Some(Box::new(generator));
165    }
166
167    pub fn with_language(mut self, language: Language) -> Self {
168        self.language = language;
169        self
170    }
171
172    pub fn set_language(&mut self, language: Language) {
173        self.language = language;
174    }
175
176    pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
177        self.sql_log_options = options;
178        self
179    }
180
181    pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
182        self.sql_log_options = options;
183    }
184
185    pub fn enable_select_sql_log(&mut self) {
186        self.sql_log_options.select = true;
187    }
188
189    pub fn enable_mutation_sql_log(&mut self) {
190        self.sql_log_options.mutation = true;
191    }
192
193    pub fn enable_all_sql_log(&mut self) {
194        self.sql_log_options = SqlLogOptions::all();
195    }
196
197    pub fn disable_sql_log(&mut self) {
198        self.sql_log_options = SqlLogOptions::default();
199        self.clear_sql_logs();
200    }
201
202    pub fn sql_log_options(&self) -> SqlLogOptions {
203        self.sql_log_options
204    }
205
206    pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
207        self.sql_log_entries
208            .lock()
209            .map(|entries| entries.clone())
210            .unwrap_or_default()
211    }
212
213    pub fn clear_sql_logs(&self) {
214        if let Ok(mut entries) = self.sql_log_entries.lock() {
215            entries.clear();
216        }
217    }
218
219    pub(crate) fn record_sql_log(
220        &self,
221        operation: SqlLogOperation,
222        query: &CompiledQuery,
223        database_kind: DatabaseKind,
224    ) {
225        if !self.sql_log_options.enabled_for(operation) {
226            return;
227        }
228        if let Ok(mut entries) = self.sql_log_entries.lock() {
229            entries.push(SqlLogEntry {
230                operation,
231                sql: query.sql.clone(),
232                params: query.params.clone(),
233                debug_sql: query.debug_sql(database_kind),
234            });
235        }
236    }
237
238    pub fn language(&self) -> Language {
239        self.language
240    }
241
242    pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
243        let Some(language) = Language::from_code(code) else {
244            return Err(RuntimeError::Language(format!(
245                "unsupported language code: {code}"
246            )));
247        };
248        self.language = language;
249        Ok(())
250    }
251
252    pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
253        self.internal_id_generator
254            .as_ref()
255            .map(|generator| generator.generate_id(entity))
256            .transpose()
257    }
258
259    pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
260        match self.generate_id(entity)? {
261            Some(id) => Ok(id),
262            None => local_id_generator().generate_id(entity),
263        }
264    }
265
266    pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
267        self.metadata
268            .as_ref()
269            .and_then(|metadata| metadata.entity(name))
270    }
271
272    pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
273        self.entity(name)
274            .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
275    }
276
277    pub fn insert_resource<T>(&mut self, resource: T)
278    where
279        T: Send + Sync + 'static,
280    {
281        self.typed_resources
282            .insert(TypeId::of::<T>(), Box::new(resource));
283    }
284
285    pub fn get_resource<T>(&self) -> Option<&T>
286    where
287        T: Send + Sync + 'static,
288    {
289        self.typed_resources
290            .get(&TypeId::of::<T>())
291            .and_then(|value| value.downcast_ref::<T>())
292    }
293
294    pub fn require_resource<T>(&self) -> Result<&T, ContextError>
295    where
296        T: Send + Sync + 'static,
297    {
298        self.get_resource::<T>()
299            .ok_or(ContextError::MissingTypedResource(
300                std::any::type_name::<T>(),
301            ))
302    }
303
304    pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
305    where
306        T: Send + Sync + 'static,
307    {
308        self.named_resources.insert(name.into(), Box::new(resource));
309    }
310
311    pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
312    where
313        T: Send + Sync + 'static,
314    {
315        self.named_resources
316            .get(name)
317            .and_then(|value| value.downcast_ref::<T>())
318    }
319
320    pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
321    where
322        T: Send + Sync + 'static,
323    {
324        self.get_named_resource::<T>(name)
325            .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
326    }
327
328    pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
329        self.locals.insert(key.into(), value.into());
330    }
331
332    pub fn local(&self, key: &str) -> Option<&Value> {
333        self.locals.get(key)
334    }
335
336    pub fn remove_local(&mut self, key: &str) -> Option<Value> {
337        self.locals.remove(key)
338    }
339
340    pub fn has_repository(&self, entity: &str) -> bool {
341        let in_registry = self
342            .repository_registry
343            .as_ref()
344            .map(|registry| registry.contains(entity))
345            .unwrap_or(false);
346        in_registry || self.entity(entity).is_some()
347    }
348
349    pub fn repository_behavior(
350        &self,
351        entity: &str,
352    ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
353        self.repository_behavior_registry
354            .as_ref()
355            .and_then(|registry| registry.behavior(entity))
356    }
357
358    pub fn has_checker(&self, entity: &str) -> bool {
359        self.checker_registry
360            .as_ref()
361            .and_then(|registry| registry.checker(entity))
362            .is_some()
363    }
364
365    pub fn check_and_fix_record(
366        &self,
367        entity: &str,
368        record: &mut Record,
369    ) -> Result<(), RuntimeError> {
370        self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
371    }
372
373    pub fn check_and_fix_record_at(
374        &self,
375        entity: &str,
376        record: &mut Record,
377        location: &ObjectLocation,
378    ) -> Result<(), RuntimeError> {
379        let Some(checker) = self
380            .checker_registry
381            .as_ref()
382            .and_then(|registry| registry.checker(entity))
383        else {
384            return Ok(());
385        };
386        let mut results = CheckResults::new();
387        checker.check_and_fix(self, record, location, &mut results);
388        if results.is_empty() {
389            Ok(())
390        } else {
391            self.translate_check_results(&mut results);
392            Err(RuntimeError::Check(results))
393        }
394    }
395
396    pub fn translate_check_results(&self, results: &mut CheckResults) {
397        for result in results {
398            result.message = Some(translate_check_result(self.language, result));
399        }
400    }
401
402    pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
403        let Some(sink) = self.event_sink.as_ref() else {
404            return Ok(());
405        };
406        sink.on_event(self, &event)
407    }
408}