Skip to main content

teaql_runtime/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Mutex;
6use std::time::{Duration, SystemTime};
7
8use teaql_core::{EntityDescriptor, Record, UpdateCommand, Value};
9use teaql_sql::{CompiledQuery, DatabaseKind, SqlDialect};
10
11use crate::{
12    CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, GraphNode,
13    InternalIdGenerator, Language, MetadataStore, ObjectLocation, RepositoryBehavior,
14    RepositoryBehaviorRegistry, RepositoryRegistry, RuntimeError, local_id_generator,
15    translate_check_result,
16};
17use crate::{EntityRoot, QueryExecutor, RepositoryError};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SqlLogOperation {
21    Select,
22    Insert,
23    Update,
24    Delete,
25    Recover,
26}
27
28impl SqlLogOperation {
29    pub fn is_select(self) -> bool {
30        matches!(self, Self::Select)
31    }
32
33    pub fn is_mutation(self) -> bool {
34        !self.is_select()
35    }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub struct SqlLogOptions {
40    pub select: bool,
41    pub mutation: bool,
42}
43
44impl SqlLogOptions {
45    pub fn select_only() -> Self {
46        Self {
47            select: true,
48            mutation: false,
49        }
50    }
51
52    pub fn mutation_only() -> Self {
53        Self {
54            select: false,
55            mutation: true,
56        }
57    }
58
59    pub fn all() -> Self {
60        Self {
61            select: true,
62            mutation: true,
63        }
64    }
65
66    pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
67        if operation.is_select() {
68            self.select
69        } else {
70            self.mutation
71        }
72    }
73}
74
75#[derive(Debug, Clone, PartialEq)]
76pub struct SqlLogEntry {
77    pub operation: SqlLogOperation,
78    pub sql: String,
79    pub params: Vec<Value>,
80    pub debug_sql: String,
81    pub pretty_sql: String,
82    pub started_at: SystemTime,
83    pub ended_at: SystemTime,
84    pub elapsed: Duration,
85    pub result_count: Option<usize>,
86    pub result_type: Option<String>,
87    pub affected_rows: Option<u64>,
88    pub result_summary: String,
89}
90
91pub trait SchemaProvider: Send + Sync {
92    fn ensure_schema<'a>(
93        &'a self,
94        ctx: &'a UserContext,
95    ) -> Pin<Box<dyn Future<Output = Result<(), RuntimeError>> + Send + 'a>>;
96}
97
98#[derive(Default)]
99pub struct UserContext {
100    pub(crate) metadata: Option<Box<dyn MetadataStore>>,
101    pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
102    pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
103    pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
104    pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
105    pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
106    schema_provider: Option<Box<dyn SchemaProvider>>,
107    language: Language,
108    typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
109    named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
110    locals: BTreeMap<String, Value>,
111    pub(crate) initial_graphs: Vec<GraphNode>,
112    entity_root: EntityRoot,
113    sql_log_options: SqlLogOptions,
114    sql_log_entries: Mutex<Vec<SqlLogEntry>>,
115}
116
117impl UserContext {
118    pub fn new() -> Self {
119        Self::default()
120    }
121
122    pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
123        module.apply_to(&mut self);
124        self
125    }
126
127    pub fn entity_root(&self) -> EntityRoot {
128        self.entity_root.clone()
129    }
130
131    pub fn initial_graphs(&self) -> &[GraphNode] {
132        &self.initial_graphs
133    }
134
135    pub fn set_initial_graphs(&mut self, graphs: Vec<GraphNode>) {
136        self.initial_graphs = graphs;
137    }
138
139    pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
140        self.metadata = Some(Box::new(metadata));
141        self
142    }
143
144    pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
145        self.metadata = Some(Box::new(metadata));
146    }
147
148    pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
149        self.repository_registry = Some(Box::new(registry));
150        self
151    }
152
153    pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
154        self.repository_registry = Some(Box::new(registry));
155    }
156
157    pub fn with_repository_behavior_registry(
158        mut self,
159        registry: impl RepositoryBehaviorRegistry + 'static,
160    ) -> Self {
161        self.repository_behavior_registry = Some(Box::new(registry));
162        self
163    }
164
165    pub fn set_repository_behavior_registry(
166        &mut self,
167        registry: impl RepositoryBehaviorRegistry + 'static,
168    ) {
169        self.repository_behavior_registry = Some(Box::new(registry));
170    }
171
172    pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
173        self.checker_registry = Some(Box::new(registry));
174        self
175    }
176
177    pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
178        self.checker_registry = Some(Box::new(registry));
179    }
180
181    pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
182        self.event_sink = Some(Box::new(sink));
183        self
184    }
185
186    pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
187        self.event_sink = Some(Box::new(sink));
188    }
189
190    pub fn with_internal_id_generator(
191        mut self,
192        generator: impl InternalIdGenerator + 'static,
193    ) -> Self {
194        self.internal_id_generator = Some(Box::new(generator));
195        self
196    }
197
198    pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
199        self.internal_id_generator = Some(Box::new(generator));
200    }
201
202    pub fn with_schema_provider(mut self, provider: impl SchemaProvider + 'static) -> Self {
203        self.schema_provider = Some(Box::new(provider));
204        self
205    }
206
207    pub fn set_schema_provider(&mut self, provider: impl SchemaProvider + 'static) {
208        self.schema_provider = Some(Box::new(provider));
209    }
210
211    pub async fn ensure_schema(&self) -> Result<(), RuntimeError> {
212        let provider = self
213            .schema_provider
214            .as_ref()
215            .ok_or_else(|| RuntimeError::Schema("missing schema provider".to_owned()))?;
216        provider.ensure_schema(self).await
217    }
218
219    pub fn with_language(mut self, language: Language) -> Self {
220        self.language = language;
221        self
222    }
223
224    pub fn set_language(&mut self, language: Language) {
225        self.language = language;
226    }
227
228    pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
229        self.sql_log_options = options;
230        self
231    }
232
233    pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
234        self.sql_log_options = options;
235    }
236
237    pub fn enable_select_sql_log(&mut self) {
238        self.sql_log_options.select = true;
239    }
240
241    pub fn enable_mutation_sql_log(&mut self) {
242        self.sql_log_options.mutation = true;
243    }
244
245    pub fn enable_all_sql_log(&mut self) {
246        self.sql_log_options = SqlLogOptions::all();
247    }
248
249    pub fn disable_sql_log(&mut self) {
250        self.sql_log_options = SqlLogOptions::default();
251        self.clear_sql_logs();
252    }
253
254    pub fn sql_log_options(&self) -> SqlLogOptions {
255        self.sql_log_options
256    }
257
258    pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
259        self.sql_log_entries
260            .lock()
261            .map(|entries| entries.clone())
262            .unwrap_or_default()
263    }
264
265    pub fn clear_sql_logs(&self) {
266        if let Ok(mut entries) = self.sql_log_entries.lock() {
267            entries.clear();
268        }
269    }
270
271    pub(crate) fn record_sql_log(
272        &self,
273        operation: SqlLogOperation,
274        query: &CompiledQuery,
275        database_kind: DatabaseKind,
276        started_at: SystemTime,
277        ended_at: SystemTime,
278        elapsed: Duration,
279        result_count: Option<usize>,
280        result_type: Option<String>,
281        affected_rows: Option<u64>,
282    ) {
283        if !self.sql_log_options.enabled_for(operation) {
284            return;
285        }
286        let debug_sql = query.debug_sql(database_kind);
287        if let Ok(mut entries) = self.sql_log_entries.lock() {
288            entries.push(SqlLogEntry {
289                operation,
290                sql: query.sql.clone(),
291                params: query.params.clone(),
292                pretty_sql: pretty_sql(&debug_sql),
293                debug_sql,
294                started_at,
295                ended_at,
296                elapsed,
297                result_summary: sql_result_summary(
298                    operation,
299                    result_count,
300                    result_type.as_deref(),
301                    affected_rows,
302                ),
303                result_count,
304                result_type,
305                affected_rows,
306            });
307        }
308    }
309
310    pub fn language(&self) -> Language {
311        self.language
312    }
313
314    pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
315        let Some(language) = Language::from_code(code) else {
316            return Err(RuntimeError::Language(format!(
317                "unsupported language code: {code}"
318            )));
319        };
320        self.language = language;
321        Ok(())
322    }
323
324    pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
325        self.internal_id_generator
326            .as_ref()
327            .map(|generator| generator.generate_id(entity))
328            .transpose()
329    }
330
331    pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
332        match self.generate_id(entity)? {
333            Some(id) => Ok(id),
334            None => local_id_generator().generate_id(entity),
335        }
336    }
337
338    pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
339        self.metadata
340            .as_ref()
341            .and_then(|metadata| metadata.entity(name))
342    }
343
344    pub fn all_entities(&self) -> Vec<&EntityDescriptor> {
345        self.metadata
346            .as_ref()
347            .map(|metadata| metadata.all_entities())
348            .unwrap_or_default()
349    }
350
351    pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
352        self.entity(name)
353            .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
354    }
355
356    pub fn insert_resource<T>(&mut self, resource: T)
357    where
358        T: Send + Sync + 'static,
359    {
360        self.typed_resources
361            .insert(TypeId::of::<T>(), Box::new(resource));
362    }
363
364    pub fn get_resource<T>(&self) -> Option<&T>
365    where
366        T: Send + Sync + 'static,
367    {
368        self.typed_resources
369            .get(&TypeId::of::<T>())
370            .and_then(|value| value.downcast_ref::<T>())
371    }
372
373    pub fn require_resource<T>(&self) -> Result<&T, ContextError>
374    where
375        T: Send + Sync + 'static,
376    {
377        self.get_resource::<T>()
378            .ok_or(ContextError::MissingTypedResource(
379                std::any::type_name::<T>(),
380            ))
381    }
382
383    pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
384    where
385        T: Send + Sync + 'static,
386    {
387        self.named_resources.insert(name.into(), Box::new(resource));
388    }
389
390    pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
391    where
392        T: Send + Sync + 'static,
393    {
394        self.named_resources
395            .get(name)
396            .and_then(|value| value.downcast_ref::<T>())
397    }
398
399    pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
400    where
401        T: Send + Sync + 'static,
402    {
403        self.get_named_resource::<T>(name)
404            .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
405    }
406
407    pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
408        self.locals.insert(key.into(), value.into());
409    }
410
411    pub fn local(&self, key: &str) -> Option<&Value> {
412        self.locals.get(key)
413    }
414
415    pub fn remove_local(&mut self, key: &str) -> Option<Value> {
416        self.locals.remove(key)
417    }
418
419    pub fn has_repository(&self, entity: &str) -> bool {
420        let in_registry = self
421            .repository_registry
422            .as_ref()
423            .map(|registry| registry.contains(entity))
424            .unwrap_or(false);
425        in_registry || self.entity(entity).is_some()
426    }
427
428    pub fn repository_behavior(
429        &self,
430        entity: &str,
431    ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
432        self.repository_behavior_registry
433            .as_ref()
434            .and_then(|registry| registry.behavior(entity))
435    }
436
437    pub fn has_checker(&self, entity: &str) -> bool {
438        self.checker_registry
439            .as_ref()
440            .and_then(|registry| registry.checker(entity))
441            .is_some()
442    }
443
444    pub fn check_and_fix_record(
445        &self,
446        entity: &str,
447        record: &mut Record,
448    ) -> Result<(), RuntimeError> {
449        self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
450    }
451
452    pub fn check_and_fix_record_at(
453        &self,
454        entity: &str,
455        record: &mut Record,
456        location: &ObjectLocation,
457    ) -> Result<(), RuntimeError> {
458        let Some(checker) = self
459            .checker_registry
460            .as_ref()
461            .and_then(|registry| registry.checker(entity))
462        else {
463            return Ok(());
464        };
465        let mut results = CheckResults::new();
466        checker.check_and_fix(self, record, location, &mut results);
467        if results.is_empty() {
468            Ok(())
469        } else {
470            self.translate_check_results(&mut results);
471            Err(RuntimeError::Check(results))
472        }
473    }
474
475    pub fn translate_check_results(&self, results: &mut CheckResults) {
476        for result in results {
477            result.message = Some(translate_check_result(self.language, result));
478        }
479    }
480
481    pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
482        let Some(sink) = self.event_sink.as_ref() else {
483            return Ok(());
484        };
485        sink.on_event(self, &event)
486    }
487
488    pub fn commit_changes<D, E>(&self) -> Result<(), RepositoryError<E::Error>>
489    where
490        D: SqlDialect + Send + Sync + 'static,
491        E: QueryExecutor + Send + Sync + 'static,
492    {
493        let dialect = self.require_resource::<D>().map_err(|err| {
494            RepositoryError::Runtime(RuntimeError::Graph(format!(
495                "cannot commit changes without dialect: {err}"
496            )))
497        })?;
498        let executor = self.require_resource::<E>().map_err(|err| {
499            RepositoryError::Runtime(RuntimeError::Graph(format!(
500                "cannot commit changes without executor: {err}"
501            )))
502        })?;
503        let change_set = self.entity_root.current_change_set();
504
505        for (key, changes) in change_set.changes() {
506            if changes.is_empty() {
507                continue;
508            }
509            let entity = self
510                .require_entity(&key.entity)
511                .map_err(RepositoryError::Runtime)?;
512            let mut command = UpdateCommand::new(&key.entity, key.id.clone());
513            for (field, value) in changes {
514                command = command.value(field.clone(), value.clone());
515            }
516            let query = dialect
517                .compile_update(entity, &command)
518                .map_err(RuntimeError::from)
519                .map_err(RepositoryError::Runtime)?;
520            executor
521                .execute(&query)
522                .map_err(RepositoryError::Executor)?;
523        }
524
525        self.entity_root.clear_current_change_set();
526        Ok(())
527    }
528}
529
530fn sql_result_summary(
531    operation: SqlLogOperation,
532    result_count: Option<usize>,
533    result_type: Option<&str>,
534    affected_rows: Option<u64>,
535) -> String {
536    match operation {
537        SqlLogOperation::Select => {
538            let count = result_count.unwrap_or(0);
539            match result_type {
540                Some(result_type) => format!("{count} x {result_type}"),
541                None => format!("{count} rows"),
542            }
543        }
544        _ => format!("{} rows affected", affected_rows.unwrap_or(0)),
545    }
546}
547
548fn pretty_sql(sql: &str) -> String {
549    let mut pretty = sql.to_owned();
550    for keyword in [
551        " FROM ",
552        " WHERE ",
553        " GROUP BY ",
554        " HAVING ",
555        " ORDER BY ",
556        " LIMIT ",
557        " OFFSET ",
558        " RETURNING ",
559    ] {
560        pretty = pretty.replace(keyword, &format!("\n{}", keyword.trim_start()));
561    }
562    pretty
563        .replace(" AND ", "\n  AND ")
564        .replace(" OR ", "\n  OR ")
565}