Skip to main content

teaql_runtime/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::future::Future;
4
5use std::pin::Pin;
6use std::sync::Mutex;
7use std::time::{Duration, SystemTime};
8
9use teaql_core::{EntityDescriptor, Record, UpdateCommand, Value};
10use teaql_sql::{CompiledQuery, DatabaseKind, SqlDialect};
11
12use crate::{
13    CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, GraphNode,
14    InternalIdGenerator, Language, MetadataStore, ObjectLocation, RepositoryBehavior,
15    RepositoryBehaviorRegistry, RepositoryRegistry, RequestPolicy, RuntimeError,
16    local_id_generator, translate_check_result,
17};
18use crate::{EntityRoot, QueryExecutor, RepositoryError};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SqlLogOperation {
22    Select,
23    Insert,
24    Update,
25    Delete,
26    Recover,
27}
28
29impl SqlLogOperation {
30    pub fn is_select(self) -> bool {
31        matches!(self, Self::Select)
32    }
33
34    pub fn is_mutation(self) -> bool {
35        !self.is_select()
36    }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub struct SqlLogOptions {
41    pub select: bool,
42    pub mutation: bool,
43}
44
45impl SqlLogOptions {
46    pub fn disabled() -> Self {
47        Self {
48            select: false,
49            mutation: false,
50        }
51    }
52
53    pub fn select_only() -> Self {
54        Self {
55            select: true,
56            mutation: false,
57        }
58    }
59
60    pub fn mutation_only() -> Self {
61        Self {
62            select: false,
63            mutation: true,
64        }
65    }
66
67    pub fn all() -> Self {
68        Self {
69            select: true,
70            mutation: true,
71        }
72    }
73
74    pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
75        if operation.is_select() {
76            self.select
77        } else {
78            self.mutation
79        }
80    }
81}
82
83#[derive(Debug, Clone, PartialEq)]
84pub struct SqlLogEntry {
85    pub operation: SqlLogOperation,
86    pub sql: String,
87    pub params: Vec<Value>,
88    pub debug_sql: String,
89    pub pretty_sql: String,
90    pub started_at: SystemTime,
91    pub ended_at: SystemTime,
92    pub elapsed: Duration,
93    pub result_count: Option<usize>,
94    pub result_type: Option<String>,
95    pub affected_rows: Option<u64>,
96    pub result_summary: String,
97    pub user_identifier: Option<String>,
98    pub comment: Option<String>,
99}
100
101pub trait SchemaProvider: Send + Sync {
102    fn ensure_schema<'a>(
103        &'a self,
104        ctx: &'a UserContext,
105    ) -> Pin<Box<dyn Future<Output = Result<(), RuntimeError>> + Send + 'a>>;
106}
107
108pub struct UserContext {
109    pub(crate) metadata: Option<Box<dyn MetadataStore>>,
110    pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
111    pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
112    pub(crate) request_policy: Option<Box<dyn RequestPolicy>>,
113    pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
114    pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
115    pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
116    schema_provider: Option<Box<dyn SchemaProvider>>,
117    language: Language,
118    typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
119    named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
120    locals: BTreeMap<String, Value>,
121    pub(crate) initial_graphs: Vec<GraphNode>,
122    entity_root: EntityRoot,
123    sql_log_options: SqlLogOptions,
124    sql_log_entries: Mutex<Vec<SqlLogEntry>>,
125    user_identifier: Option<String>,
126    pub(crate) comment_stack: Mutex<Vec<String>>,
127}
128
129impl Default for UserContext {
130    fn default() -> Self {
131        let pid = std::process::id();
132        let thread_id_str = format!("{:?}", std::thread::current().id());
133        let numeric_thread_id = thread_id_str
134            .strip_prefix("ThreadId(")
135            .and_then(|s| s.strip_suffix(")"))
136            .unwrap_or(&thread_id_str);
137        let os_user = std::env::var("USER")
138            .or_else(|_| std::env::var("USERNAME"))
139            .unwrap_or_else(|_| "main".to_owned());
140        let user_id = format!("{os_user}@pid-{pid}.tid-{numeric_thread_id}");
141        Self {
142            metadata: None,
143            repository_registry: None,
144            repository_behavior_registry: None,
145            request_policy: None,
146            checker_registry: None,
147            event_sink: None,
148            internal_id_generator: None,
149            schema_provider: None,
150            language: Language::default(),
151            typed_resources: HashMap::new(),
152            named_resources: BTreeMap::new(),
153            locals: BTreeMap::new(),
154            initial_graphs: Vec::new(),
155            entity_root: EntityRoot::default(),
156            sql_log_options: SqlLogOptions::all(),
157            sql_log_entries: Mutex::new(Vec::new()),
158            user_identifier: Some(user_id),
159            comment_stack: Mutex::new(Vec::new()),
160        }
161    }
162}
163
164impl UserContext {
165    pub fn new() -> Self {
166        Self::default()
167    }
168
169    pub fn user_identifier(&self) -> Option<&str> {
170        self.user_identifier.as_deref()
171    }
172
173    pub fn set_user_identifier(&mut self, user_identifier: impl Into<String>) {
174        self.user_identifier = Some(user_identifier.into());
175    }
176
177    pub fn with_user_identifier(mut self, user_identifier: impl Into<String>) -> Self {
178        self.user_identifier = Some(user_identifier.into());
179        self
180    }
181
182    pub fn set_user_identifier_option(&mut self, user_identifier: Option<String>) {
183        self.user_identifier = user_identifier;
184    }
185
186    pub fn with_user_identifier_option(mut self, user_identifier: Option<String>) -> Self {
187        self.user_identifier = user_identifier;
188        self
189    }
190
191    pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
192        module.apply_to(&mut self);
193        self
194    }
195
196    pub fn entity_root(&self) -> EntityRoot {
197        self.entity_root.clone()
198    }
199
200    pub fn initial_graphs(&self) -> &[GraphNode] {
201        &self.initial_graphs
202    }
203
204    pub fn set_initial_graphs(&mut self, graphs: Vec<GraphNode>) {
205        self.initial_graphs = graphs;
206    }
207
208    pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
209        self.metadata = Some(Box::new(metadata));
210        self
211    }
212
213    pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
214        self.metadata = Some(Box::new(metadata));
215    }
216
217    pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
218        self.repository_registry = Some(Box::new(registry));
219        self
220    }
221
222    pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
223        self.repository_registry = Some(Box::new(registry));
224    }
225
226    pub fn with_repository_behavior_registry(
227        mut self,
228        registry: impl RepositoryBehaviorRegistry + 'static,
229    ) -> Self {
230        self.repository_behavior_registry = Some(Box::new(registry));
231        self
232    }
233
234    pub fn set_repository_behavior_registry(
235        &mut self,
236        registry: impl RepositoryBehaviorRegistry + 'static,
237    ) {
238        self.repository_behavior_registry = Some(Box::new(registry));
239    }
240
241    pub fn with_request_policy(mut self, policy: impl RequestPolicy + 'static) -> Self {
242        self.request_policy = Some(Box::new(policy));
243        self
244    }
245
246    pub fn set_request_policy(&mut self, policy: impl RequestPolicy + 'static) {
247        self.request_policy = Some(Box::new(policy));
248    }
249
250    pub fn clear_request_policy(&mut self) {
251        self.request_policy = None;
252    }
253
254    pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
255        self.checker_registry = Some(Box::new(registry));
256        self
257    }
258
259    pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
260        self.checker_registry = Some(Box::new(registry));
261    }
262
263    pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
264        self.event_sink = Some(Box::new(sink));
265        self
266    }
267
268    pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
269        self.event_sink = Some(Box::new(sink));
270    }
271
272    pub fn with_internal_id_generator(
273        mut self,
274        generator: impl InternalIdGenerator + 'static,
275    ) -> Self {
276        self.internal_id_generator = Some(Box::new(generator));
277        self
278    }
279
280    pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
281        self.internal_id_generator = Some(Box::new(generator));
282    }
283
284    pub fn with_schema_provider(mut self, provider: impl SchemaProvider + 'static) -> Self {
285        self.schema_provider = Some(Box::new(provider));
286        self
287    }
288
289    pub fn set_schema_provider(&mut self, provider: impl SchemaProvider + 'static) {
290        self.schema_provider = Some(Box::new(provider));
291    }
292
293    pub async fn ensure_schema(&self) -> Result<(), RuntimeError> {
294        let provider = self
295            .schema_provider
296            .as_ref()
297            .ok_or_else(|| RuntimeError::Schema("missing schema provider".to_owned()))?;
298        provider.ensure_schema(self).await
299    }
300
301    pub fn with_language(mut self, language: Language) -> Self {
302        self.language = language;
303        self
304    }
305
306    pub fn set_language(&mut self, language: Language) {
307        self.language = language;
308    }
309
310    pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
311        self.sql_log_options = options;
312        self
313    }
314
315    pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
316        self.sql_log_options = options;
317    }
318
319    pub fn enable_select_sql_log(&mut self) {
320        self.sql_log_options.select = true;
321    }
322
323    pub fn enable_mutation_sql_log(&mut self) {
324        self.sql_log_options.mutation = true;
325    }
326
327    pub fn enable_all_sql_log(&mut self) {
328        self.sql_log_options = SqlLogOptions::all();
329    }
330
331    pub fn disable_sql_log(&mut self) {
332        self.sql_log_options = SqlLogOptions::disabled();
333        self.clear_sql_logs();
334    }
335
336    pub fn sql_log_options(&self) -> SqlLogOptions {
337        self.sql_log_options
338    }
339
340    pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
341        self.sql_log_entries
342            .lock()
343            .map(|entries| entries.clone())
344            .unwrap_or_default()
345    }
346
347    pub fn clear_sql_logs(&self) {
348        if let Ok(mut entries) = self.sql_log_entries.lock() {
349            entries.clear();
350        }
351    }
352
353    pub(crate) fn record_sql_log(
354        &self,
355        operation: SqlLogOperation,
356        query: &CompiledQuery,
357        database_kind: DatabaseKind,
358        started_at: SystemTime,
359        ended_at: SystemTime,
360        elapsed: Duration,
361        result_count: Option<usize>,
362        result_type: Option<String>,
363        affected_rows: Option<u64>,
364        comment: Option<String>,
365    ) {
366        if !self.sql_log_options.enabled_for(operation) {
367            return;
368        }
369        let debug_sql = query.debug_sql(database_kind);
370        let result_summary = sql_result_summary(
371            operation,
372            result_count,
373            result_type.as_deref(),
374            affected_rows,
375        );
376
377        // Resolve final comment from stack or parameter
378        let stack_comment = self.comment_stack.lock().ok().and_then(|stack| {
379            if stack.is_empty() {
380                None
381            } else {
382                Some(stack.join("->"))
383            }
384        });
385        let final_comment = stack_comment.or(comment);
386
387        // Append log line to app.log file
388        if let Ok(mut file) = std::fs::OpenOptions::new()
389            .create(true)
390            .append(true)
391            .open("app.log")
392        {
393            use std::io::Write;
394            let local_time: chrono::DateTime<chrono::Local> = started_at.into();
395            let timestamp_str = local_time.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
396            let user_id_str = self.user_identifier.as_deref().unwrap_or("");
397            let comment_part = if let Some(ref c) = final_comment {
398                format!(" - [{c}]")
399            } else {
400                "".to_owned()
401            };
402            let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
403            let log_line = format!(
404                "{timestamp_str}-[{user_id_str}]--DEBUG - SqlLogEntry{comment_part} - [{result_summary}] {} (took {:.3}ms)\n",
405                debug_sql, elapsed_ms
406            );
407            let _ = file.write_all(log_line.as_bytes());
408        }
409
410        if let Ok(mut entries) = self.sql_log_entries.lock() {
411            entries.push(SqlLogEntry {
412                operation,
413                sql: query.sql.clone(),
414                params: query.params.clone(),
415                pretty_sql: pretty_sql(&debug_sql),
416                debug_sql,
417                started_at,
418                ended_at,
419                elapsed,
420                result_summary,
421                result_count,
422                result_type,
423                affected_rows,
424                user_identifier: self.user_identifier.clone(),
425                comment: final_comment,
426            });
427        }
428    }
429
430    pub fn language(&self) -> Language {
431        self.language
432    }
433
434    pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
435        let Some(language) = Language::from_code(code) else {
436            return Err(RuntimeError::Language(format!(
437                "unsupported language code: {code}"
438            )));
439        };
440        self.language = language;
441        Ok(())
442    }
443
444    pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
445        self.internal_id_generator
446            .as_ref()
447            .map(|generator| generator.generate_id(entity))
448            .transpose()
449    }
450
451    pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
452        match self.generate_id(entity)? {
453            Some(id) => Ok(id),
454            None => local_id_generator().generate_id(entity),
455        }
456    }
457
458    pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
459        self.metadata
460            .as_ref()
461            .and_then(|metadata| metadata.entity(name))
462    }
463
464    pub fn all_entities(&self) -> Vec<&EntityDescriptor> {
465        self.metadata
466            .as_ref()
467            .map(|metadata| metadata.all_entities())
468            .unwrap_or_default()
469    }
470
471    pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
472        self.entity(name)
473            .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
474    }
475
476    pub fn insert_resource<T>(&mut self, resource: T)
477    where
478        T: Send + Sync + 'static,
479    {
480        self.typed_resources
481            .insert(TypeId::of::<T>(), Box::new(resource));
482    }
483
484    pub fn get_resource<T>(&self) -> Option<&T>
485    where
486        T: Send + Sync + 'static,
487    {
488        self.typed_resources
489            .get(&TypeId::of::<T>())
490            .and_then(|value| value.downcast_ref::<T>())
491    }
492
493    pub fn require_resource<T>(&self) -> Result<&T, ContextError>
494    where
495        T: Send + Sync + 'static,
496    {
497        self.get_resource::<T>()
498            .ok_or(ContextError::MissingTypedResource(
499                std::any::type_name::<T>(),
500            ))
501    }
502
503    pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
504    where
505        T: Send + Sync + 'static,
506    {
507        self.named_resources.insert(name.into(), Box::new(resource));
508    }
509
510    pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
511    where
512        T: Send + Sync + 'static,
513    {
514        self.named_resources
515            .get(name)
516            .and_then(|value| value.downcast_ref::<T>())
517    }
518
519    pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
520    where
521        T: Send + Sync + 'static,
522    {
523        self.get_named_resource::<T>(name)
524            .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
525    }
526
527    pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
528        self.locals.insert(key.into(), value.into());
529    }
530
531    pub fn local(&self, key: &str) -> Option<&Value> {
532        self.locals.get(key)
533    }
534
535    pub fn remove_local(&mut self, key: &str) -> Option<Value> {
536        self.locals.remove(key)
537    }
538
539    pub fn has_repository(&self, entity: &str) -> bool {
540        let in_registry = self
541            .repository_registry
542            .as_ref()
543            .map(|registry| registry.contains(entity))
544            .unwrap_or(false);
545        in_registry || self.entity(entity).is_some()
546    }
547
548    pub fn repository_behavior(
549        &self,
550        entity: &str,
551    ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
552        self.repository_behavior_registry
553            .as_ref()
554            .and_then(|registry| registry.behavior(entity))
555    }
556
557    pub fn has_checker(&self, entity: &str) -> bool {
558        self.checker_registry
559            .as_ref()
560            .and_then(|registry| registry.checker(entity))
561            .is_some()
562    }
563
564    pub fn check_and_fix_record(
565        &self,
566        entity: &str,
567        record: &mut Record,
568    ) -> Result<(), RuntimeError> {
569        self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
570    }
571
572    pub fn check_and_fix_record_at(
573        &self,
574        entity: &str,
575        record: &mut Record,
576        location: &ObjectLocation,
577    ) -> Result<(), RuntimeError> {
578        let Some(checker) = self
579            .checker_registry
580            .as_ref()
581            .and_then(|registry| registry.checker(entity))
582        else {
583            return Ok(());
584        };
585        let mut results = CheckResults::new();
586        checker.check_and_fix(self, record, location, &mut results);
587        if results.is_empty() {
588            Ok(())
589        } else {
590            self.translate_check_results(&mut results);
591            Err(RuntimeError::Check(results))
592        }
593    }
594
595    pub fn translate_check_results(&self, results: &mut CheckResults) {
596        for result in results {
597            result.message = Some(translate_check_result(self.language, result));
598        }
599    }
600
601    pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
602        let Some(sink) = self.event_sink.as_ref() else {
603            return Ok(());
604        };
605        sink.on_event(self, &event)
606    }
607
608    pub fn commit_changes<D, E>(&self) -> Result<(), RepositoryError<E::Error>>
609    where
610        D: SqlDialect + Send + Sync + 'static,
611        E: QueryExecutor + Send + Sync + 'static,
612    {
613        let dialect = self.require_resource::<D>().map_err(|err| {
614            RepositoryError::Runtime(RuntimeError::Graph(format!(
615                "cannot commit changes without dialect: {err}"
616            )))
617        })?;
618        let executor = self.require_resource::<E>().map_err(|err| {
619            RepositoryError::Runtime(RuntimeError::Graph(format!(
620                "cannot commit changes without executor: {err}"
621            )))
622        })?;
623        let change_set = self.entity_root.current_change_set();
624
625        for (key, changes) in change_set.changes() {
626            if changes.is_empty() {
627                continue;
628            }
629            let entity = self
630                .require_entity(&key.entity)
631                .map_err(RepositoryError::Runtime)?;
632            let mut command = UpdateCommand::new(&key.entity, key.id.clone());
633            for (field, value) in changes {
634                command = command.value(field.clone(), value.clone());
635            }
636            let query = dialect
637                .compile_update(entity, &command)
638                .map_err(RuntimeError::from)
639                .map_err(RepositoryError::Runtime)?;
640            executor
641                .execute(&query)
642                .map_err(RepositoryError::Executor)?;
643        }
644
645        self.entity_root.clear_current_change_set();
646        Ok(())
647    }
648}
649
650fn sql_result_summary(
651    operation: SqlLogOperation,
652    result_count: Option<usize>,
653    result_type: Option<&str>,
654    affected_rows: Option<u64>,
655) -> String {
656    match operation {
657        SqlLogOperation::Select => {
658            let count = result_count.unwrap_or(0);
659            if count == 0 {
660                "NO ROWS".to_owned()
661            } else if count > 1 {
662                match result_type {
663                    Some(result_type) => format!("{count}*{result_type}"),
664                    None => format!("{count}*rows"),
665                }
666            } else {
667                match result_type {
668                    Some(result_type) => result_type.to_owned(),
669                    None => "row".to_owned(),
670                }
671            }
672        }
673        _ => {
674            let affected = affected_rows.unwrap_or(0);
675            format!("{affected} UPDATED")
676        }
677    }
678}
679
680fn pretty_sql(sql: &str) -> String {
681    let mut pretty = sql.to_owned();
682    for keyword in [
683        " FROM ",
684        " WHERE ",
685        " GROUP BY ",
686        " HAVING ",
687        " ORDER BY ",
688        " LIMIT ",
689        " OFFSET ",
690        " RETURNING ",
691    ] {
692        pretty = pretty.replace(keyword, &format!("\n{}", keyword.trim_start()));
693    }
694    pretty
695        .replace(" AND ", "\n  AND ")
696        .replace(" OR ", "\n  OR ")
697}
698
699pub struct QueryCommentGuard<'a> {
700    context: &'a UserContext,
701    has_pushed: bool,
702}
703
704impl<'a> QueryCommentGuard<'a> {
705    pub fn new(context: &'a UserContext, comment: Option<String>) -> Self {
706        let mut has_pushed = false;
707        if let Some(comment) = comment {
708            if !comment.is_empty() {
709                if let Ok(mut stack) = context.comment_stack.lock() {
710                    if stack.last() != Some(&comment) {
711                        stack.push(comment);
712                        has_pushed = true;
713                    }
714                }
715            }
716        }
717        Self { context, has_pushed }
718    }
719}
720
721impl<'a> Drop for QueryCommentGuard<'a> {
722    fn drop(&mut self) {
723        if self.has_pushed {
724            if let Ok(mut stack) = self.context.comment_stack.lock() {
725                stack.pop();
726            }
727        }
728    }
729}