Skip to main content

teaql_runtime/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Mutex;
6use std::time::{Duration, SystemTime};
7
8use teaql_core::{EntityDescriptor, Record, UpdateCommand, Value};
9use teaql_sql::{CompiledQuery, DatabaseKind, SqlDialect};
10
11use crate::{
12    CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, GraphNode,
13    InternalIdGenerator, Language, MetadataStore, ObjectLocation, RepositoryBehavior,
14    RepositoryBehaviorRegistry, RepositoryRegistry, RequestPolicy, RuntimeError,
15    local_id_generator, translate_check_result,
16};
17use crate::{EntityRoot, QueryExecutor, RepositoryError};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SqlLogOperation {
21    Select,
22    Insert,
23    Update,
24    Delete,
25    Recover,
26}
27
28impl SqlLogOperation {
29    pub fn is_select(self) -> bool {
30        matches!(self, Self::Select)
31    }
32
33    pub fn is_mutation(self) -> bool {
34        !self.is_select()
35    }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub struct SqlLogOptions {
40    pub select: bool,
41    pub mutation: bool,
42}
43
44impl SqlLogOptions {
45    pub fn disabled() -> Self {
46        Self {
47            select: false,
48            mutation: false,
49        }
50    }
51
52    pub fn select_only() -> Self {
53        Self {
54            select: true,
55            mutation: false,
56        }
57    }
58
59    pub fn mutation_only() -> Self {
60        Self {
61            select: false,
62            mutation: true,
63        }
64    }
65
66    pub fn all() -> Self {
67        Self {
68            select: true,
69            mutation: true,
70        }
71    }
72
73    pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
74        if operation.is_select() {
75            self.select
76        } else {
77            self.mutation
78        }
79    }
80}
81
82#[derive(Debug, Clone, PartialEq)]
83pub struct SqlLogEntry {
84    pub operation: SqlLogOperation,
85    pub sql: String,
86    pub params: Vec<Value>,
87    pub debug_sql: String,
88    pub pretty_sql: String,
89    pub started_at: SystemTime,
90    pub ended_at: SystemTime,
91    pub elapsed: Duration,
92    pub result_count: Option<usize>,
93    pub result_type: Option<String>,
94    pub affected_rows: Option<u64>,
95    pub result_summary: String,
96}
97
98pub trait SchemaProvider: Send + Sync {
99    fn ensure_schema<'a>(
100        &'a self,
101        ctx: &'a UserContext,
102    ) -> Pin<Box<dyn Future<Output = Result<(), RuntimeError>> + Send + 'a>>;
103}
104
105pub struct UserContext {
106    pub(crate) metadata: Option<Box<dyn MetadataStore>>,
107    pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
108    pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
109    pub(crate) request_policy: Option<Box<dyn RequestPolicy>>,
110    pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
111    pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
112    pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
113    schema_provider: Option<Box<dyn SchemaProvider>>,
114    language: Language,
115    typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
116    named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
117    locals: BTreeMap<String, Value>,
118    pub(crate) initial_graphs: Vec<GraphNode>,
119    entity_root: EntityRoot,
120    sql_log_options: SqlLogOptions,
121    sql_log_entries: Mutex<Vec<SqlLogEntry>>,
122}
123
124impl Default for UserContext {
125    fn default() -> Self {
126        Self {
127            metadata: None,
128            repository_registry: None,
129            repository_behavior_registry: None,
130            request_policy: None,
131            checker_registry: None,
132            event_sink: None,
133            internal_id_generator: None,
134            schema_provider: None,
135            language: Language::default(),
136            typed_resources: HashMap::new(),
137            named_resources: BTreeMap::new(),
138            locals: BTreeMap::new(),
139            initial_graphs: Vec::new(),
140            entity_root: EntityRoot::default(),
141            sql_log_options: SqlLogOptions::all(),
142            sql_log_entries: Mutex::new(Vec::new()),
143        }
144    }
145}
146
147impl UserContext {
148    pub fn new() -> Self {
149        Self::default()
150    }
151
152    pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
153        module.apply_to(&mut self);
154        self
155    }
156
157    pub fn entity_root(&self) -> EntityRoot {
158        self.entity_root.clone()
159    }
160
161    pub fn initial_graphs(&self) -> &[GraphNode] {
162        &self.initial_graphs
163    }
164
165    pub fn set_initial_graphs(&mut self, graphs: Vec<GraphNode>) {
166        self.initial_graphs = graphs;
167    }
168
169    pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
170        self.metadata = Some(Box::new(metadata));
171        self
172    }
173
174    pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
175        self.metadata = Some(Box::new(metadata));
176    }
177
178    pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
179        self.repository_registry = Some(Box::new(registry));
180        self
181    }
182
183    pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
184        self.repository_registry = Some(Box::new(registry));
185    }
186
187    pub fn with_repository_behavior_registry(
188        mut self,
189        registry: impl RepositoryBehaviorRegistry + 'static,
190    ) -> Self {
191        self.repository_behavior_registry = Some(Box::new(registry));
192        self
193    }
194
195    pub fn set_repository_behavior_registry(
196        &mut self,
197        registry: impl RepositoryBehaviorRegistry + 'static,
198    ) {
199        self.repository_behavior_registry = Some(Box::new(registry));
200    }
201
202    pub fn with_request_policy(mut self, policy: impl RequestPolicy + 'static) -> Self {
203        self.request_policy = Some(Box::new(policy));
204        self
205    }
206
207    pub fn set_request_policy(&mut self, policy: impl RequestPolicy + 'static) {
208        self.request_policy = Some(Box::new(policy));
209    }
210
211    pub fn clear_request_policy(&mut self) {
212        self.request_policy = None;
213    }
214
215    pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
216        self.checker_registry = Some(Box::new(registry));
217        self
218    }
219
220    pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
221        self.checker_registry = Some(Box::new(registry));
222    }
223
224    pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
225        self.event_sink = Some(Box::new(sink));
226        self
227    }
228
229    pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
230        self.event_sink = Some(Box::new(sink));
231    }
232
233    pub fn with_internal_id_generator(
234        mut self,
235        generator: impl InternalIdGenerator + 'static,
236    ) -> Self {
237        self.internal_id_generator = Some(Box::new(generator));
238        self
239    }
240
241    pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
242        self.internal_id_generator = Some(Box::new(generator));
243    }
244
245    pub fn with_schema_provider(mut self, provider: impl SchemaProvider + 'static) -> Self {
246        self.schema_provider = Some(Box::new(provider));
247        self
248    }
249
250    pub fn set_schema_provider(&mut self, provider: impl SchemaProvider + 'static) {
251        self.schema_provider = Some(Box::new(provider));
252    }
253
254    pub async fn ensure_schema(&self) -> Result<(), RuntimeError> {
255        let provider = self
256            .schema_provider
257            .as_ref()
258            .ok_or_else(|| RuntimeError::Schema("missing schema provider".to_owned()))?;
259        provider.ensure_schema(self).await
260    }
261
262    pub fn with_language(mut self, language: Language) -> Self {
263        self.language = language;
264        self
265    }
266
267    pub fn set_language(&mut self, language: Language) {
268        self.language = language;
269    }
270
271    pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
272        self.sql_log_options = options;
273        self
274    }
275
276    pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
277        self.sql_log_options = options;
278    }
279
280    pub fn enable_select_sql_log(&mut self) {
281        self.sql_log_options.select = true;
282    }
283
284    pub fn enable_mutation_sql_log(&mut self) {
285        self.sql_log_options.mutation = true;
286    }
287
288    pub fn enable_all_sql_log(&mut self) {
289        self.sql_log_options = SqlLogOptions::all();
290    }
291
292    pub fn disable_sql_log(&mut self) {
293        self.sql_log_options = SqlLogOptions::disabled();
294        self.clear_sql_logs();
295    }
296
297    pub fn sql_log_options(&self) -> SqlLogOptions {
298        self.sql_log_options
299    }
300
301    pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
302        self.sql_log_entries
303            .lock()
304            .map(|entries| entries.clone())
305            .unwrap_or_default()
306    }
307
308    pub fn clear_sql_logs(&self) {
309        if let Ok(mut entries) = self.sql_log_entries.lock() {
310            entries.clear();
311        }
312    }
313
314    pub(crate) fn record_sql_log(
315        &self,
316        operation: SqlLogOperation,
317        query: &CompiledQuery,
318        database_kind: DatabaseKind,
319        started_at: SystemTime,
320        ended_at: SystemTime,
321        elapsed: Duration,
322        result_count: Option<usize>,
323        result_type: Option<String>,
324        affected_rows: Option<u64>,
325    ) {
326        if !self.sql_log_options.enabled_for(operation) {
327            return;
328        }
329        let debug_sql = query.debug_sql(database_kind);
330        if let Ok(mut entries) = self.sql_log_entries.lock() {
331            entries.push(SqlLogEntry {
332                operation,
333                sql: query.sql.clone(),
334                params: query.params.clone(),
335                pretty_sql: pretty_sql(&debug_sql),
336                debug_sql,
337                started_at,
338                ended_at,
339                elapsed,
340                result_summary: sql_result_summary(
341                    operation,
342                    result_count,
343                    result_type.as_deref(),
344                    affected_rows,
345                ),
346                result_count,
347                result_type,
348                affected_rows,
349            });
350        }
351    }
352
353    pub fn language(&self) -> Language {
354        self.language
355    }
356
357    pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
358        let Some(language) = Language::from_code(code) else {
359            return Err(RuntimeError::Language(format!(
360                "unsupported language code: {code}"
361            )));
362        };
363        self.language = language;
364        Ok(())
365    }
366
367    pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
368        self.internal_id_generator
369            .as_ref()
370            .map(|generator| generator.generate_id(entity))
371            .transpose()
372    }
373
374    pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
375        match self.generate_id(entity)? {
376            Some(id) => Ok(id),
377            None => local_id_generator().generate_id(entity),
378        }
379    }
380
381    pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
382        self.metadata
383            .as_ref()
384            .and_then(|metadata| metadata.entity(name))
385    }
386
387    pub fn all_entities(&self) -> Vec<&EntityDescriptor> {
388        self.metadata
389            .as_ref()
390            .map(|metadata| metadata.all_entities())
391            .unwrap_or_default()
392    }
393
394    pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
395        self.entity(name)
396            .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
397    }
398
399    pub fn insert_resource<T>(&mut self, resource: T)
400    where
401        T: Send + Sync + 'static,
402    {
403        self.typed_resources
404            .insert(TypeId::of::<T>(), Box::new(resource));
405    }
406
407    pub fn get_resource<T>(&self) -> Option<&T>
408    where
409        T: Send + Sync + 'static,
410    {
411        self.typed_resources
412            .get(&TypeId::of::<T>())
413            .and_then(|value| value.downcast_ref::<T>())
414    }
415
416    pub fn require_resource<T>(&self) -> Result<&T, ContextError>
417    where
418        T: Send + Sync + 'static,
419    {
420        self.get_resource::<T>()
421            .ok_or(ContextError::MissingTypedResource(
422                std::any::type_name::<T>(),
423            ))
424    }
425
426    pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
427    where
428        T: Send + Sync + 'static,
429    {
430        self.named_resources.insert(name.into(), Box::new(resource));
431    }
432
433    pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
434    where
435        T: Send + Sync + 'static,
436    {
437        self.named_resources
438            .get(name)
439            .and_then(|value| value.downcast_ref::<T>())
440    }
441
442    pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
443    where
444        T: Send + Sync + 'static,
445    {
446        self.get_named_resource::<T>(name)
447            .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
448    }
449
450    pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
451        self.locals.insert(key.into(), value.into());
452    }
453
454    pub fn local(&self, key: &str) -> Option<&Value> {
455        self.locals.get(key)
456    }
457
458    pub fn remove_local(&mut self, key: &str) -> Option<Value> {
459        self.locals.remove(key)
460    }
461
462    pub fn has_repository(&self, entity: &str) -> bool {
463        let in_registry = self
464            .repository_registry
465            .as_ref()
466            .map(|registry| registry.contains(entity))
467            .unwrap_or(false);
468        in_registry || self.entity(entity).is_some()
469    }
470
471    pub fn repository_behavior(
472        &self,
473        entity: &str,
474    ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
475        self.repository_behavior_registry
476            .as_ref()
477            .and_then(|registry| registry.behavior(entity))
478    }
479
480    pub fn has_checker(&self, entity: &str) -> bool {
481        self.checker_registry
482            .as_ref()
483            .and_then(|registry| registry.checker(entity))
484            .is_some()
485    }
486
487    pub fn check_and_fix_record(
488        &self,
489        entity: &str,
490        record: &mut Record,
491    ) -> Result<(), RuntimeError> {
492        self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
493    }
494
495    pub fn check_and_fix_record_at(
496        &self,
497        entity: &str,
498        record: &mut Record,
499        location: &ObjectLocation,
500    ) -> Result<(), RuntimeError> {
501        let Some(checker) = self
502            .checker_registry
503            .as_ref()
504            .and_then(|registry| registry.checker(entity))
505        else {
506            return Ok(());
507        };
508        let mut results = CheckResults::new();
509        checker.check_and_fix(self, record, location, &mut results);
510        if results.is_empty() {
511            Ok(())
512        } else {
513            self.translate_check_results(&mut results);
514            Err(RuntimeError::Check(results))
515        }
516    }
517
518    pub fn translate_check_results(&self, results: &mut CheckResults) {
519        for result in results {
520            result.message = Some(translate_check_result(self.language, result));
521        }
522    }
523
524    pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
525        let Some(sink) = self.event_sink.as_ref() else {
526            return Ok(());
527        };
528        sink.on_event(self, &event)
529    }
530
531    pub fn commit_changes<D, E>(&self) -> Result<(), RepositoryError<E::Error>>
532    where
533        D: SqlDialect + Send + Sync + 'static,
534        E: QueryExecutor + Send + Sync + 'static,
535    {
536        let dialect = self.require_resource::<D>().map_err(|err| {
537            RepositoryError::Runtime(RuntimeError::Graph(format!(
538                "cannot commit changes without dialect: {err}"
539            )))
540        })?;
541        let executor = self.require_resource::<E>().map_err(|err| {
542            RepositoryError::Runtime(RuntimeError::Graph(format!(
543                "cannot commit changes without executor: {err}"
544            )))
545        })?;
546        let change_set = self.entity_root.current_change_set();
547
548        for (key, changes) in change_set.changes() {
549            if changes.is_empty() {
550                continue;
551            }
552            let entity = self
553                .require_entity(&key.entity)
554                .map_err(RepositoryError::Runtime)?;
555            let mut command = UpdateCommand::new(&key.entity, key.id.clone());
556            for (field, value) in changes {
557                command = command.value(field.clone(), value.clone());
558            }
559            let query = dialect
560                .compile_update(entity, &command)
561                .map_err(RuntimeError::from)
562                .map_err(RepositoryError::Runtime)?;
563            executor
564                .execute(&query)
565                .map_err(RepositoryError::Executor)?;
566        }
567
568        self.entity_root.clear_current_change_set();
569        Ok(())
570    }
571}
572
573fn sql_result_summary(
574    operation: SqlLogOperation,
575    result_count: Option<usize>,
576    result_type: Option<&str>,
577    affected_rows: Option<u64>,
578) -> String {
579    match operation {
580        SqlLogOperation::Select => {
581            let count = result_count.unwrap_or(0);
582            match result_type {
583                Some(result_type) => format!("{count} x {result_type}"),
584                None => format!("{count} rows"),
585            }
586        }
587        _ => format!("{} rows affected", affected_rows.unwrap_or(0)),
588    }
589}
590
591fn pretty_sql(sql: &str) -> String {
592    let mut pretty = sql.to_owned();
593    for keyword in [
594        " FROM ",
595        " WHERE ",
596        " GROUP BY ",
597        " HAVING ",
598        " ORDER BY ",
599        " LIMIT ",
600        " OFFSET ",
601        " RETURNING ",
602    ] {
603        pretty = pretty.replace(keyword, &format!("\n{}", keyword.trim_start()));
604    }
605    pretty
606        .replace(" AND ", "\n  AND ")
607        .replace(" OR ", "\n  OR ")
608}