Skip to main content

teaql_runtime/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::sync::Mutex;
4
5use teaql_core::{EntityDescriptor, Record, Value};
6use teaql_sql::{CompiledQuery, DatabaseKind};
7
8use crate::{
9    CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, GraphNode,
10    InternalIdGenerator, Language, MetadataStore, ObjectLocation, RepositoryBehavior,
11    RepositoryBehaviorRegistry, RepositoryRegistry, RuntimeError, local_id_generator,
12    translate_check_result,
13};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum SqlLogOperation {
17    Select,
18    Insert,
19    Update,
20    Delete,
21    Recover,
22}
23
24impl SqlLogOperation {
25    pub fn is_select(self) -> bool {
26        matches!(self, Self::Select)
27    }
28
29    pub fn is_mutation(self) -> bool {
30        !self.is_select()
31    }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
35pub struct SqlLogOptions {
36    pub select: bool,
37    pub mutation: bool,
38}
39
40impl SqlLogOptions {
41    pub fn select_only() -> Self {
42        Self {
43            select: true,
44            mutation: false,
45        }
46    }
47
48    pub fn mutation_only() -> Self {
49        Self {
50            select: false,
51            mutation: true,
52        }
53    }
54
55    pub fn all() -> Self {
56        Self {
57            select: true,
58            mutation: true,
59        }
60    }
61
62    pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
63        if operation.is_select() {
64            self.select
65        } else {
66            self.mutation
67        }
68    }
69}
70
71#[derive(Debug, Clone, PartialEq)]
72pub struct SqlLogEntry {
73    pub operation: SqlLogOperation,
74    pub sql: String,
75    pub params: Vec<Value>,
76    pub debug_sql: String,
77}
78
79#[derive(Default)]
80pub struct UserContext {
81    pub(crate) metadata: Option<Box<dyn MetadataStore>>,
82    pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
83    pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
84    pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
85    pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
86    pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
87    language: Language,
88    typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
89    named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
90    locals: BTreeMap<String, Value>,
91    pub(crate) initial_graphs: Vec<GraphNode>,
92    sql_log_options: SqlLogOptions,
93    sql_log_entries: Mutex<Vec<SqlLogEntry>>,
94}
95
96impl UserContext {
97    pub fn new() -> Self {
98        Self::default()
99    }
100
101    pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
102        module.apply_to(&mut self);
103        self
104    }
105
106    pub fn initial_graphs(&self) -> &[GraphNode] {
107        &self.initial_graphs
108    }
109
110    pub fn set_initial_graphs(&mut self, graphs: Vec<GraphNode>) {
111        self.initial_graphs = graphs;
112    }
113
114    pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
115        self.metadata = Some(Box::new(metadata));
116        self
117    }
118
119    pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
120        self.metadata = Some(Box::new(metadata));
121    }
122
123    pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
124        self.repository_registry = Some(Box::new(registry));
125        self
126    }
127
128    pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
129        self.repository_registry = Some(Box::new(registry));
130    }
131
132    pub fn with_repository_behavior_registry(
133        mut self,
134        registry: impl RepositoryBehaviorRegistry + 'static,
135    ) -> Self {
136        self.repository_behavior_registry = Some(Box::new(registry));
137        self
138    }
139
140    pub fn set_repository_behavior_registry(
141        &mut self,
142        registry: impl RepositoryBehaviorRegistry + 'static,
143    ) {
144        self.repository_behavior_registry = Some(Box::new(registry));
145    }
146
147    pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
148        self.checker_registry = Some(Box::new(registry));
149        self
150    }
151
152    pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
153        self.checker_registry = Some(Box::new(registry));
154    }
155
156    pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
157        self.event_sink = Some(Box::new(sink));
158        self
159    }
160
161    pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
162        self.event_sink = Some(Box::new(sink));
163    }
164
165    pub fn with_internal_id_generator(
166        mut self,
167        generator: impl InternalIdGenerator + 'static,
168    ) -> Self {
169        self.internal_id_generator = Some(Box::new(generator));
170        self
171    }
172
173    pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
174        self.internal_id_generator = Some(Box::new(generator));
175    }
176
177    pub fn with_language(mut self, language: Language) -> Self {
178        self.language = language;
179        self
180    }
181
182    pub fn set_language(&mut self, language: Language) {
183        self.language = language;
184    }
185
186    pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
187        self.sql_log_options = options;
188        self
189    }
190
191    pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
192        self.sql_log_options = options;
193    }
194
195    pub fn enable_select_sql_log(&mut self) {
196        self.sql_log_options.select = true;
197    }
198
199    pub fn enable_mutation_sql_log(&mut self) {
200        self.sql_log_options.mutation = true;
201    }
202
203    pub fn enable_all_sql_log(&mut self) {
204        self.sql_log_options = SqlLogOptions::all();
205    }
206
207    pub fn disable_sql_log(&mut self) {
208        self.sql_log_options = SqlLogOptions::default();
209        self.clear_sql_logs();
210    }
211
212    pub fn sql_log_options(&self) -> SqlLogOptions {
213        self.sql_log_options
214    }
215
216    pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
217        self.sql_log_entries
218            .lock()
219            .map(|entries| entries.clone())
220            .unwrap_or_default()
221    }
222
223    pub fn clear_sql_logs(&self) {
224        if let Ok(mut entries) = self.sql_log_entries.lock() {
225            entries.clear();
226        }
227    }
228
229    pub(crate) fn record_sql_log(
230        &self,
231        operation: SqlLogOperation,
232        query: &CompiledQuery,
233        database_kind: DatabaseKind,
234    ) {
235        if !self.sql_log_options.enabled_for(operation) {
236            return;
237        }
238        if let Ok(mut entries) = self.sql_log_entries.lock() {
239            entries.push(SqlLogEntry {
240                operation,
241                sql: query.sql.clone(),
242                params: query.params.clone(),
243                debug_sql: query.debug_sql(database_kind),
244            });
245        }
246    }
247
248    pub fn language(&self) -> Language {
249        self.language
250    }
251
252    pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
253        let Some(language) = Language::from_code(code) else {
254            return Err(RuntimeError::Language(format!(
255                "unsupported language code: {code}"
256            )));
257        };
258        self.language = language;
259        Ok(())
260    }
261
262    pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
263        self.internal_id_generator
264            .as_ref()
265            .map(|generator| generator.generate_id(entity))
266            .transpose()
267    }
268
269    pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
270        match self.generate_id(entity)? {
271            Some(id) => Ok(id),
272            None => local_id_generator().generate_id(entity),
273        }
274    }
275
276    pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
277        self.metadata
278            .as_ref()
279            .and_then(|metadata| metadata.entity(name))
280    }
281
282    pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
283        self.entity(name)
284            .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
285    }
286
287    pub fn insert_resource<T>(&mut self, resource: T)
288    where
289        T: Send + Sync + 'static,
290    {
291        self.typed_resources
292            .insert(TypeId::of::<T>(), Box::new(resource));
293    }
294
295    pub fn get_resource<T>(&self) -> Option<&T>
296    where
297        T: Send + Sync + 'static,
298    {
299        self.typed_resources
300            .get(&TypeId::of::<T>())
301            .and_then(|value| value.downcast_ref::<T>())
302    }
303
304    pub fn require_resource<T>(&self) -> Result<&T, ContextError>
305    where
306        T: Send + Sync + 'static,
307    {
308        self.get_resource::<T>()
309            .ok_or(ContextError::MissingTypedResource(
310                std::any::type_name::<T>(),
311            ))
312    }
313
314    pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
315    where
316        T: Send + Sync + 'static,
317    {
318        self.named_resources.insert(name.into(), Box::new(resource));
319    }
320
321    pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
322    where
323        T: Send + Sync + 'static,
324    {
325        self.named_resources
326            .get(name)
327            .and_then(|value| value.downcast_ref::<T>())
328    }
329
330    pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
331    where
332        T: Send + Sync + 'static,
333    {
334        self.get_named_resource::<T>(name)
335            .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
336    }
337
338    pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
339        self.locals.insert(key.into(), value.into());
340    }
341
342    pub fn local(&self, key: &str) -> Option<&Value> {
343        self.locals.get(key)
344    }
345
346    pub fn remove_local(&mut self, key: &str) -> Option<Value> {
347        self.locals.remove(key)
348    }
349
350    pub fn has_repository(&self, entity: &str) -> bool {
351        let in_registry = self
352            .repository_registry
353            .as_ref()
354            .map(|registry| registry.contains(entity))
355            .unwrap_or(false);
356        in_registry || self.entity(entity).is_some()
357    }
358
359    pub fn repository_behavior(
360        &self,
361        entity: &str,
362    ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
363        self.repository_behavior_registry
364            .as_ref()
365            .and_then(|registry| registry.behavior(entity))
366    }
367
368    pub fn has_checker(&self, entity: &str) -> bool {
369        self.checker_registry
370            .as_ref()
371            .and_then(|registry| registry.checker(entity))
372            .is_some()
373    }
374
375    pub fn check_and_fix_record(
376        &self,
377        entity: &str,
378        record: &mut Record,
379    ) -> Result<(), RuntimeError> {
380        self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
381    }
382
383    pub fn check_and_fix_record_at(
384        &self,
385        entity: &str,
386        record: &mut Record,
387        location: &ObjectLocation,
388    ) -> Result<(), RuntimeError> {
389        let Some(checker) = self
390            .checker_registry
391            .as_ref()
392            .and_then(|registry| registry.checker(entity))
393        else {
394            return Ok(());
395        };
396        let mut results = CheckResults::new();
397        checker.check_and_fix(self, record, location, &mut results);
398        if results.is_empty() {
399            Ok(())
400        } else {
401            self.translate_check_results(&mut results);
402            Err(RuntimeError::Check(results))
403        }
404    }
405
406    pub fn translate_check_results(&self, results: &mut CheckResults) {
407        for result in results {
408            result.message = Some(translate_check_result(self.language, result));
409        }
410    }
411
412    pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
413        let Some(sink) = self.event_sink.as_ref() else {
414            return Ok(());
415        };
416        sink.on_event(self, &event)
417    }
418}