Skip to main content

teaql_runtime/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Mutex;
6
7use teaql_core::{EntityDescriptor, Record, UpdateCommand, Value};
8use teaql_sql::{CompiledQuery, DatabaseKind, SqlDialect};
9
10use crate::{
11    CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, GraphNode,
12    InternalIdGenerator, Language, MetadataStore, ObjectLocation, RepositoryBehavior,
13    RepositoryBehaviorRegistry, RepositoryRegistry, RuntimeError, local_id_generator,
14    translate_check_result,
15};
16use crate::{EntityRoot, QueryExecutor, RepositoryError};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum SqlLogOperation {
20    Select,
21    Insert,
22    Update,
23    Delete,
24    Recover,
25}
26
27impl SqlLogOperation {
28    pub fn is_select(self) -> bool {
29        matches!(self, Self::Select)
30    }
31
32    pub fn is_mutation(self) -> bool {
33        !self.is_select()
34    }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
38pub struct SqlLogOptions {
39    pub select: bool,
40    pub mutation: bool,
41}
42
43impl SqlLogOptions {
44    pub fn select_only() -> Self {
45        Self {
46            select: true,
47            mutation: false,
48        }
49    }
50
51    pub fn mutation_only() -> Self {
52        Self {
53            select: false,
54            mutation: true,
55        }
56    }
57
58    pub fn all() -> Self {
59        Self {
60            select: true,
61            mutation: true,
62        }
63    }
64
65    pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
66        if operation.is_select() {
67            self.select
68        } else {
69            self.mutation
70        }
71    }
72}
73
74#[derive(Debug, Clone, PartialEq)]
75pub struct SqlLogEntry {
76    pub operation: SqlLogOperation,
77    pub sql: String,
78    pub params: Vec<Value>,
79    pub debug_sql: String,
80}
81
82pub trait SchemaProvider: Send + Sync {
83    fn ensure_schema<'a>(
84        &'a self,
85        ctx: &'a UserContext,
86    ) -> Pin<Box<dyn Future<Output = Result<(), RuntimeError>> + Send + 'a>>;
87}
88
89#[derive(Default)]
90pub struct UserContext {
91    pub(crate) metadata: Option<Box<dyn MetadataStore>>,
92    pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
93    pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
94    pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
95    pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
96    pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
97    schema_provider: Option<Box<dyn SchemaProvider>>,
98    language: Language,
99    typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
100    named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
101    locals: BTreeMap<String, Value>,
102    pub(crate) initial_graphs: Vec<GraphNode>,
103    entity_root: EntityRoot,
104    sql_log_options: SqlLogOptions,
105    sql_log_entries: Mutex<Vec<SqlLogEntry>>,
106}
107
108impl UserContext {
109    pub fn new() -> Self {
110        Self::default()
111    }
112
113    pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
114        module.apply_to(&mut self);
115        self
116    }
117
118    pub fn entity_root(&self) -> EntityRoot {
119        self.entity_root.clone()
120    }
121
122    pub fn initial_graphs(&self) -> &[GraphNode] {
123        &self.initial_graphs
124    }
125
126    pub fn set_initial_graphs(&mut self, graphs: Vec<GraphNode>) {
127        self.initial_graphs = graphs;
128    }
129
130    pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
131        self.metadata = Some(Box::new(metadata));
132        self
133    }
134
135    pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
136        self.metadata = Some(Box::new(metadata));
137    }
138
139    pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
140        self.repository_registry = Some(Box::new(registry));
141        self
142    }
143
144    pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
145        self.repository_registry = Some(Box::new(registry));
146    }
147
148    pub fn with_repository_behavior_registry(
149        mut self,
150        registry: impl RepositoryBehaviorRegistry + 'static,
151    ) -> Self {
152        self.repository_behavior_registry = Some(Box::new(registry));
153        self
154    }
155
156    pub fn set_repository_behavior_registry(
157        &mut self,
158        registry: impl RepositoryBehaviorRegistry + 'static,
159    ) {
160        self.repository_behavior_registry = Some(Box::new(registry));
161    }
162
163    pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
164        self.checker_registry = Some(Box::new(registry));
165        self
166    }
167
168    pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
169        self.checker_registry = Some(Box::new(registry));
170    }
171
172    pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
173        self.event_sink = Some(Box::new(sink));
174        self
175    }
176
177    pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
178        self.event_sink = Some(Box::new(sink));
179    }
180
181    pub fn with_internal_id_generator(
182        mut self,
183        generator: impl InternalIdGenerator + 'static,
184    ) -> Self {
185        self.internal_id_generator = Some(Box::new(generator));
186        self
187    }
188
189    pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
190        self.internal_id_generator = Some(Box::new(generator));
191    }
192
193    pub fn with_schema_provider(mut self, provider: impl SchemaProvider + 'static) -> Self {
194        self.schema_provider = Some(Box::new(provider));
195        self
196    }
197
198    pub fn set_schema_provider(&mut self, provider: impl SchemaProvider + 'static) {
199        self.schema_provider = Some(Box::new(provider));
200    }
201
202    pub async fn ensure_schema(&self) -> Result<(), RuntimeError> {
203        let provider = self
204            .schema_provider
205            .as_ref()
206            .ok_or_else(|| RuntimeError::Schema("missing schema provider".to_owned()))?;
207        provider.ensure_schema(self).await
208    }
209
210    pub fn with_language(mut self, language: Language) -> Self {
211        self.language = language;
212        self
213    }
214
215    pub fn set_language(&mut self, language: Language) {
216        self.language = language;
217    }
218
219    pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
220        self.sql_log_options = options;
221        self
222    }
223
224    pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
225        self.sql_log_options = options;
226    }
227
228    pub fn enable_select_sql_log(&mut self) {
229        self.sql_log_options.select = true;
230    }
231
232    pub fn enable_mutation_sql_log(&mut self) {
233        self.sql_log_options.mutation = true;
234    }
235
236    pub fn enable_all_sql_log(&mut self) {
237        self.sql_log_options = SqlLogOptions::all();
238    }
239
240    pub fn disable_sql_log(&mut self) {
241        self.sql_log_options = SqlLogOptions::default();
242        self.clear_sql_logs();
243    }
244
245    pub fn sql_log_options(&self) -> SqlLogOptions {
246        self.sql_log_options
247    }
248
249    pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
250        self.sql_log_entries
251            .lock()
252            .map(|entries| entries.clone())
253            .unwrap_or_default()
254    }
255
256    pub fn clear_sql_logs(&self) {
257        if let Ok(mut entries) = self.sql_log_entries.lock() {
258            entries.clear();
259        }
260    }
261
262    pub(crate) fn record_sql_log(
263        &self,
264        operation: SqlLogOperation,
265        query: &CompiledQuery,
266        database_kind: DatabaseKind,
267    ) {
268        if !self.sql_log_options.enabled_for(operation) {
269            return;
270        }
271        if let Ok(mut entries) = self.sql_log_entries.lock() {
272            entries.push(SqlLogEntry {
273                operation,
274                sql: query.sql.clone(),
275                params: query.params.clone(),
276                debug_sql: query.debug_sql(database_kind),
277            });
278        }
279    }
280
281    pub fn language(&self) -> Language {
282        self.language
283    }
284
285    pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
286        let Some(language) = Language::from_code(code) else {
287            return Err(RuntimeError::Language(format!(
288                "unsupported language code: {code}"
289            )));
290        };
291        self.language = language;
292        Ok(())
293    }
294
295    pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
296        self.internal_id_generator
297            .as_ref()
298            .map(|generator| generator.generate_id(entity))
299            .transpose()
300    }
301
302    pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
303        match self.generate_id(entity)? {
304            Some(id) => Ok(id),
305            None => local_id_generator().generate_id(entity),
306        }
307    }
308
309    pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
310        self.metadata
311            .as_ref()
312            .and_then(|metadata| metadata.entity(name))
313    }
314
315    pub fn all_entities(&self) -> Vec<&EntityDescriptor> {
316        self.metadata
317            .as_ref()
318            .map(|metadata| metadata.all_entities())
319            .unwrap_or_default()
320    }
321
322    pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
323        self.entity(name)
324            .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
325    }
326
327    pub fn insert_resource<T>(&mut self, resource: T)
328    where
329        T: Send + Sync + 'static,
330    {
331        self.typed_resources
332            .insert(TypeId::of::<T>(), Box::new(resource));
333    }
334
335    pub fn get_resource<T>(&self) -> Option<&T>
336    where
337        T: Send + Sync + 'static,
338    {
339        self.typed_resources
340            .get(&TypeId::of::<T>())
341            .and_then(|value| value.downcast_ref::<T>())
342    }
343
344    pub fn require_resource<T>(&self) -> Result<&T, ContextError>
345    where
346        T: Send + Sync + 'static,
347    {
348        self.get_resource::<T>()
349            .ok_or(ContextError::MissingTypedResource(
350                std::any::type_name::<T>(),
351            ))
352    }
353
354    pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
355    where
356        T: Send + Sync + 'static,
357    {
358        self.named_resources.insert(name.into(), Box::new(resource));
359    }
360
361    pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
362    where
363        T: Send + Sync + 'static,
364    {
365        self.named_resources
366            .get(name)
367            .and_then(|value| value.downcast_ref::<T>())
368    }
369
370    pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
371    where
372        T: Send + Sync + 'static,
373    {
374        self.get_named_resource::<T>(name)
375            .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
376    }
377
378    pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
379        self.locals.insert(key.into(), value.into());
380    }
381
382    pub fn local(&self, key: &str) -> Option<&Value> {
383        self.locals.get(key)
384    }
385
386    pub fn remove_local(&mut self, key: &str) -> Option<Value> {
387        self.locals.remove(key)
388    }
389
390    pub fn has_repository(&self, entity: &str) -> bool {
391        let in_registry = self
392            .repository_registry
393            .as_ref()
394            .map(|registry| registry.contains(entity))
395            .unwrap_or(false);
396        in_registry || self.entity(entity).is_some()
397    }
398
399    pub fn repository_behavior(
400        &self,
401        entity: &str,
402    ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
403        self.repository_behavior_registry
404            .as_ref()
405            .and_then(|registry| registry.behavior(entity))
406    }
407
408    pub fn has_checker(&self, entity: &str) -> bool {
409        self.checker_registry
410            .as_ref()
411            .and_then(|registry| registry.checker(entity))
412            .is_some()
413    }
414
415    pub fn check_and_fix_record(
416        &self,
417        entity: &str,
418        record: &mut Record,
419    ) -> Result<(), RuntimeError> {
420        self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
421    }
422
423    pub fn check_and_fix_record_at(
424        &self,
425        entity: &str,
426        record: &mut Record,
427        location: &ObjectLocation,
428    ) -> Result<(), RuntimeError> {
429        let Some(checker) = self
430            .checker_registry
431            .as_ref()
432            .and_then(|registry| registry.checker(entity))
433        else {
434            return Ok(());
435        };
436        let mut results = CheckResults::new();
437        checker.check_and_fix(self, record, location, &mut results);
438        if results.is_empty() {
439            Ok(())
440        } else {
441            self.translate_check_results(&mut results);
442            Err(RuntimeError::Check(results))
443        }
444    }
445
446    pub fn translate_check_results(&self, results: &mut CheckResults) {
447        for result in results {
448            result.message = Some(translate_check_result(self.language, result));
449        }
450    }
451
452    pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
453        let Some(sink) = self.event_sink.as_ref() else {
454            return Ok(());
455        };
456        sink.on_event(self, &event)
457    }
458
459    pub fn commit_changes<D, E>(&self) -> Result<(), RepositoryError<E::Error>>
460    where
461        D: SqlDialect + Send + Sync + 'static,
462        E: QueryExecutor + Send + Sync + 'static,
463    {
464        let dialect = self.require_resource::<D>().map_err(|err| {
465            RepositoryError::Runtime(RuntimeError::Graph(format!(
466                "cannot commit changes without dialect: {err}"
467            )))
468        })?;
469        let executor = self.require_resource::<E>().map_err(|err| {
470            RepositoryError::Runtime(RuntimeError::Graph(format!(
471                "cannot commit changes without executor: {err}"
472            )))
473        })?;
474        let change_set = self.entity_root.current_change_set();
475
476        for (key, changes) in change_set.changes() {
477            if changes.is_empty() {
478                continue;
479            }
480            let entity = self
481                .require_entity(&key.entity)
482                .map_err(RepositoryError::Runtime)?;
483            let mut command = UpdateCommand::new(&key.entity, key.id.clone());
484            for (field, value) in changes {
485                command = command.value(field.clone(), value.clone());
486            }
487            let query = dialect
488                .compile_update(entity, &command)
489                .map_err(RuntimeError::from)
490                .map_err(RepositoryError::Runtime)?;
491            executor
492                .execute(&query)
493                .map_err(RepositoryError::Executor)?;
494        }
495
496        self.entity_root.clear_current_change_set();
497        Ok(())
498    }
499}