Skip to main content

teaql_runtime/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::future::Future;
4
5use std::pin::Pin;
6use std::sync::Mutex;
7use std::time::{Duration, SystemTime};
8
9use teaql_core::{EntityDescriptor, Record, UpdateCommand, Value};
10use teaql_sql::{CompiledQuery, DatabaseKind, SqlDialect};
11
12use crate::{
13    CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, GraphNode,
14    InternalIdGenerator, Language, MetadataStore, ObjectLocation, RepositoryBehavior,
15    RepositoryBehaviorRegistry, RepositoryRegistry, RequestPolicy, RuntimeError,
16    local_id_generator, translate_check_result,
17};
18use crate::{EntityRoot, QueryExecutor, RepositoryError};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SqlLogOperation {
22    Select,
23    Insert,
24    Update,
25    Delete,
26    Recover,
27}
28
29impl SqlLogOperation {
30    pub fn is_select(self) -> bool {
31        matches!(self, Self::Select)
32    }
33
34    pub fn is_mutation(self) -> bool {
35        !self.is_select()
36    }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub struct SqlLogOptions {
41    pub select: bool,
42    pub mutation: bool,
43}
44
45impl SqlLogOptions {
46    pub fn disabled() -> Self {
47        Self {
48            select: false,
49            mutation: false,
50        }
51    }
52
53    pub fn select_only() -> Self {
54        Self {
55            select: true,
56            mutation: false,
57        }
58    }
59
60    pub fn mutation_only() -> Self {
61        Self {
62            select: false,
63            mutation: true,
64        }
65    }
66
67    pub fn all() -> Self {
68        Self {
69            select: true,
70            mutation: true,
71        }
72    }
73
74    pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
75        if operation.is_select() {
76            self.select
77        } else {
78            self.mutation
79        }
80    }
81}
82
83#[derive(Debug, Clone, PartialEq)]
84pub struct SqlLogEntry {
85    pub operation: SqlLogOperation,
86    pub sql: String,
87    pub params: Vec<Value>,
88    pub debug_sql: String,
89    pub pretty_sql: String,
90    pub started_at: SystemTime,
91    pub ended_at: SystemTime,
92    pub elapsed: Duration,
93    pub result_count: Option<usize>,
94    pub result_type: Option<String>,
95    pub affected_rows: Option<u64>,
96    pub result_summary: String,
97}
98
99#[derive(Debug, Clone, PartialEq)]
100pub struct UnifiedLogEntry {
101    pub timestamp: SystemTime,
102    pub user_identifier: Option<String>,
103    pub trace_chain: Vec<teaql_core::TraceNode>,
104    pub payload: LogPayload,
105}
106
107#[derive(Debug, Clone, PartialEq)]
108pub enum LogPayload {
109    Sql(SqlLogEntry),
110    Info(InfoLogEntry),
111}
112
113#[derive(Debug, Clone, PartialEq)]
114pub struct InfoLogEntry {
115    pub message: String,
116}
117
118#[derive(Clone, Default)]
119pub struct UnifiedLogBuffer {
120    pub entries: std::sync::Arc<Mutex<Vec<UnifiedLogEntry>>>,
121}
122
123pub trait SchemaProvider: Send + Sync {
124    fn ensure_schema<'a>(
125        &'a self,
126        ctx: &'a UserContext,
127    ) -> Pin<Box<dyn Future<Output = Result<(), RuntimeError>> + Send + 'a>>;
128}
129
130pub struct UserContext {
131    pub(crate) metadata: Option<Box<dyn MetadataStore>>,
132    pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
133    pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
134    pub(crate) request_policy: Option<Box<dyn RequestPolicy>>,
135    pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
136    pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
137    pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
138    schema_provider: Option<Box<dyn SchemaProvider>>,
139    language: Language,
140    typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
141    named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
142    locals: BTreeMap<String, Value>,
143    pub(crate) initial_graphs: Vec<GraphNode>,
144    entity_root: EntityRoot,
145    sql_log_options: SqlLogOptions,
146    sql_log_entries: Mutex<Vec<SqlLogEntry>>,
147    user_identifier: Option<String>,
148}
149
150impl Default for UserContext {
151    fn default() -> Self {
152        let pid = std::process::id();
153        let thread_id_str = format!("{:?}", std::thread::current().id());
154        let numeric_thread_id = thread_id_str
155            .strip_prefix("ThreadId(")
156            .and_then(|s| s.strip_suffix(")"))
157            .unwrap_or(&thread_id_str);
158        let os_user = std::env::var("USER")
159            .or_else(|_| std::env::var("USERNAME"))
160            .unwrap_or_else(|_| "main".to_owned());
161        let user_id = format!("{os_user}@pid-{pid}.tid-{numeric_thread_id}");
162        Self {
163            metadata: None,
164            repository_registry: None,
165            repository_behavior_registry: None,
166            request_policy: None,
167            checker_registry: None,
168            event_sink: None,
169            internal_id_generator: None,
170            schema_provider: None,
171            language: Language::default(),
172            typed_resources: HashMap::new(),
173            named_resources: BTreeMap::new(),
174            locals: BTreeMap::new(),
175            initial_graphs: Vec::new(),
176            entity_root: EntityRoot::default(),
177            sql_log_options: SqlLogOptions::all(),
178            sql_log_entries: Mutex::new(Vec::new()),
179            user_identifier: Some(user_id),
180        }
181    }
182}
183
184impl UserContext {
185    pub fn new() -> Self {
186        Self::default()
187    }
188
189    pub fn user_identifier(&self) -> Option<&str> {
190        self.user_identifier.as_deref()
191    }
192
193    pub fn set_user_identifier(&mut self, user_identifier: impl Into<String>) {
194        self.user_identifier = Some(user_identifier.into());
195    }
196
197    pub fn with_user_identifier(mut self, user_identifier: impl Into<String>) -> Self {
198        self.user_identifier = Some(user_identifier.into());
199        self
200    }
201
202    pub fn set_user_identifier_option(&mut self, user_identifier: Option<String>) {
203        self.user_identifier = user_identifier;
204    }
205
206    pub fn with_user_identifier_option(mut self, user_identifier: Option<String>) -> Self {
207        self.user_identifier = user_identifier;
208        self
209    }
210
211    pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
212        module.apply_to(&mut self);
213        self
214    }
215
216    pub fn entity_root(&self) -> EntityRoot {
217        self.entity_root.clone()
218    }
219
220    pub fn initial_graphs(&self) -> &[GraphNode] {
221        &self.initial_graphs
222    }
223
224    pub fn set_initial_graphs(&mut self, graphs: Vec<GraphNode>) {
225        self.initial_graphs = graphs;
226    }
227
228    pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
229        self.metadata = Some(Box::new(metadata));
230        self
231    }
232
233    pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
234        self.metadata = Some(Box::new(metadata));
235    }
236
237    pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
238        self.repository_registry = Some(Box::new(registry));
239        self
240    }
241
242    pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
243        self.repository_registry = Some(Box::new(registry));
244    }
245
246    pub fn with_repository_behavior_registry(
247        mut self,
248        registry: impl RepositoryBehaviorRegistry + 'static,
249    ) -> Self {
250        self.repository_behavior_registry = Some(Box::new(registry));
251        self
252    }
253
254    pub fn set_repository_behavior_registry(
255        &mut self,
256        registry: impl RepositoryBehaviorRegistry + 'static,
257    ) {
258        self.repository_behavior_registry = Some(Box::new(registry));
259    }
260
261    pub fn with_request_policy(mut self, policy: impl RequestPolicy + 'static) -> Self {
262        self.request_policy = Some(Box::new(policy));
263        self
264    }
265
266    pub fn set_request_policy(&mut self, policy: impl RequestPolicy + 'static) {
267        self.request_policy = Some(Box::new(policy));
268    }
269
270    pub fn clear_request_policy(&mut self) {
271        self.request_policy = None;
272    }
273
274    pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
275        self.checker_registry = Some(Box::new(registry));
276        self
277    }
278
279    pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
280        self.checker_registry = Some(Box::new(registry));
281    }
282
283    pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
284        self.event_sink = Some(Box::new(sink));
285        self
286    }
287
288    pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
289        self.event_sink = Some(Box::new(sink));
290    }
291
292    pub fn with_internal_id_generator(
293        mut self,
294        generator: impl InternalIdGenerator + 'static,
295    ) -> Self {
296        self.internal_id_generator = Some(Box::new(generator));
297        self
298    }
299
300    pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
301        self.internal_id_generator = Some(Box::new(generator));
302    }
303
304    pub fn with_schema_provider(mut self, provider: impl SchemaProvider + 'static) -> Self {
305        self.schema_provider = Some(Box::new(provider));
306        self
307    }
308
309    pub fn set_schema_provider(&mut self, provider: impl SchemaProvider + 'static) {
310        self.schema_provider = Some(Box::new(provider));
311    }
312
313    pub async fn ensure_schema(&self) -> Result<(), RuntimeError> {
314        let provider = self
315            .schema_provider
316            .as_ref()
317            .ok_or_else(|| RuntimeError::Schema("missing schema provider".to_owned()))?;
318        provider.ensure_schema(self).await
319    }
320
321    pub fn with_language(mut self, language: Language) -> Self {
322        self.language = language;
323        self
324    }
325
326    pub fn set_language(&mut self, language: Language) {
327        self.language = language;
328    }
329
330    pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
331        self.sql_log_options = options;
332        self
333    }
334
335    pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
336        self.sql_log_options = options;
337    }
338
339    pub fn enable_select_sql_log(&mut self) {
340        self.sql_log_options.select = true;
341    }
342
343    pub fn enable_mutation_sql_log(&mut self) {
344        self.sql_log_options.mutation = true;
345    }
346
347    pub fn enable_all_sql_log(&mut self) {
348        self.sql_log_options = SqlLogOptions::all();
349    }
350
351    pub fn disable_sql_log(&mut self) {
352        self.sql_log_options = SqlLogOptions::disabled();
353        self.clear_sql_logs();
354    }
355
356    pub fn sql_log_options(&self) -> SqlLogOptions {
357        self.sql_log_options
358    }
359
360    pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
361        self.sql_log_entries
362            .lock()
363            .map(|entries| entries.clone())
364            .unwrap_or_default()
365    }
366
367    pub fn clear_sql_logs(&self) {
368        if let Ok(mut entries) = self.sql_log_entries.lock() {
369            entries.clear();
370        }
371    }
372
373    pub(crate) fn record_sql_log(
374        &self,
375        operation: SqlLogOperation,
376        query: &CompiledQuery,
377        database_kind: DatabaseKind,
378        started_at: SystemTime,
379        ended_at: SystemTime,
380        elapsed: Duration,
381        result_count: Option<usize>,
382        result_type: Option<String>,
383        affected_rows: Option<u64>,
384        trace_chain: Vec<teaql_core::TraceNode>,
385    ) {
386        if !self.sql_log_options.enabled_for(operation) {
387            return;
388        }
389        let debug_sql = query.debug_sql(database_kind);
390        let result_summary = sql_result_summary(
391            operation,
392            result_count,
393            result_type.as_deref(),
394            affected_rows,
395            &debug_sql,
396        );
397
398        let sql_log_entry = SqlLogEntry {
399            operation,
400            sql: query.sql.clone(),
401            params: query.params.clone(),
402            pretty_sql: pretty_sql(&debug_sql),
403            debug_sql: debug_sql.clone(),
404            started_at,
405            ended_at,
406            elapsed,
407            result_summary: result_summary.clone(),
408            result_count,
409            result_type,
410            affected_rows,
411        };
412
413        if let Ok(mut entries) = self.sql_log_entries.lock() {
414            // Keep sql_log_entries backwards-compatible for now if needed,
415            // wait, we modified SqlLogEntry. We can just push it directly since we removed comment.
416            // Wait, we need to push a cloned SqlLogEntry since it doesn't have comment.
417            entries.push(sql_log_entry.clone());
418        }
419
420        if let Some(buf) = self.get_resource::<UnifiedLogBuffer>() {
421            if let Ok(mut entries) = buf.entries.lock() {
422                entries.push(UnifiedLogEntry {
423                    timestamp: started_at,
424                    user_identifier: self.user_identifier.clone(),
425                    trace_chain,
426                    payload: LogPayload::Sql(sql_log_entry),
427                });
428            }
429        }
430    }
431
432    pub fn language(&self) -> Language {
433        self.language
434    }
435
436    pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
437        let Some(language) = Language::from_code(code) else {
438            return Err(RuntimeError::Language(format!(
439                "unsupported language code: {code}"
440            )));
441        };
442        self.language = language;
443        Ok(())
444    }
445
446    pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
447        self.internal_id_generator
448            .as_ref()
449            .map(|generator| generator.generate_id(entity))
450            .transpose()
451    }
452
453    pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
454        match self.generate_id(entity)? {
455            Some(id) => Ok(id),
456            None => local_id_generator().generate_id(entity),
457        }
458    }
459
460    pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
461        self.metadata
462            .as_ref()
463            .and_then(|metadata| metadata.entity(name))
464    }
465
466    pub fn all_entities(&self) -> Vec<&EntityDescriptor> {
467        self.metadata
468            .as_ref()
469            .map(|metadata| metadata.all_entities())
470            .unwrap_or_default()
471    }
472
473    pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
474        self.entity(name)
475            .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
476    }
477
478    pub fn insert_resource<T>(&mut self, resource: T)
479    where
480        T: Send + Sync + 'static,
481    {
482        self.typed_resources
483            .insert(TypeId::of::<T>(), Box::new(resource));
484    }
485
486    pub fn get_resource<T>(&self) -> Option<&T>
487    where
488        T: Send + Sync + 'static,
489    {
490        self.typed_resources
491            .get(&TypeId::of::<T>())
492            .and_then(|value| value.downcast_ref::<T>())
493    }
494
495    pub fn require_resource<T>(&self) -> Result<&T, ContextError>
496    where
497        T: Send + Sync + 'static,
498    {
499        self.get_resource::<T>()
500            .ok_or(ContextError::MissingTypedResource(
501                std::any::type_name::<T>(),
502            ))
503    }
504
505    pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
506    where
507        T: Send + Sync + 'static,
508    {
509        self.named_resources.insert(name.into(), Box::new(resource));
510    }
511
512    pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
513    where
514        T: Send + Sync + 'static,
515    {
516        self.named_resources
517            .get(name)
518            .and_then(|value| value.downcast_ref::<T>())
519    }
520
521    pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
522    where
523        T: Send + Sync + 'static,
524    {
525        self.get_named_resource::<T>(name)
526            .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
527    }
528
529    pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
530        self.locals.insert(key.into(), value.into());
531    }
532
533    pub fn local(&self, key: &str) -> Option<&Value> {
534        self.locals.get(key)
535    }
536
537    pub fn remove_local(&mut self, key: &str) -> Option<Value> {
538        self.locals.remove(key)
539    }
540
541    pub fn has_repository(&self, entity: &str) -> bool {
542        let in_registry = self
543            .repository_registry
544            .as_ref()
545            .map(|registry| registry.contains(entity))
546            .unwrap_or(false);
547        in_registry || self.entity(entity).is_some()
548    }
549
550    pub fn repository_behavior(
551        &self,
552        entity: &str,
553    ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
554        self.repository_behavior_registry
555            .as_ref()
556            .and_then(|registry| registry.behavior(entity))
557    }
558
559    pub fn has_checker(&self, entity: &str) -> bool {
560        self.checker_registry
561            .as_ref()
562            .and_then(|registry| registry.checker(entity))
563            .is_some()
564    }
565
566    pub fn check_and_fix_record(
567        &self,
568        entity: &str,
569        record: &mut Record,
570    ) -> Result<(), RuntimeError> {
571        self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
572    }
573
574    pub fn check_and_fix_record_at(
575        &self,
576        entity: &str,
577        record: &mut Record,
578        location: &ObjectLocation,
579    ) -> Result<(), RuntimeError> {
580        let Some(checker) = self
581            .checker_registry
582            .as_ref()
583            .and_then(|registry| registry.checker(entity))
584        else {
585            return Ok(());
586        };
587        let mut results = CheckResults::new();
588        checker.check_and_fix(self, record, location, &mut results);
589        if results.is_empty() {
590            Ok(())
591        } else {
592            self.translate_check_results(&mut results);
593            Err(RuntimeError::Check(results))
594        }
595    }
596
597    pub fn translate_check_results(&self, results: &mut CheckResults) {
598        for result in results {
599            result.message = Some(translate_check_result(self.language, result));
600        }
601    }
602
603    pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
604        let Some(sink) = self.event_sink.as_ref() else {
605            return Ok(());
606        };
607        sink.on_event(self, &event)
608    }
609
610    pub fn commit_changes<D, E>(&self) -> Result<(), RepositoryError<E::Error>>
611    where
612        D: SqlDialect + Send + Sync + 'static,
613        E: QueryExecutor + Send + Sync + 'static,
614    {
615        let dialect = self.require_resource::<D>().map_err(|err| {
616            RepositoryError::Runtime(RuntimeError::Graph(format!(
617                "cannot commit changes without dialect: {err}"
618            )))
619        })?;
620        let executor = self.require_resource::<E>().map_err(|err| {
621            RepositoryError::Runtime(RuntimeError::Graph(format!(
622                "cannot commit changes without executor: {err}"
623            )))
624        })?;
625        let change_set = self.entity_root.current_change_set();
626
627        for (key, changes) in change_set.changes() {
628            if changes.is_empty() {
629                continue;
630            }
631            let entity = self
632                .require_entity(&key.entity)
633                .map_err(RepositoryError::Runtime)?;
634            let mut command = UpdateCommand::new(&key.entity, key.id.clone());
635            for (field, value) in changes {
636                command = command.value(field.clone(), value.clone());
637            }
638            let query = dialect
639                .compile_update(entity, &command)
640                .map_err(RuntimeError::from)
641                .map_err(RepositoryError::Runtime)?;
642            executor
643                .execute(&query)
644                .map_err(RepositoryError::Executor)?;
645        }
646
647        self.entity_root.clear_current_change_set();
648        Ok(())
649    }
650}
651
652fn extract_id_from_sql(sql: &str) -> Option<String> {
653    let sql_lower = sql.to_lowercase();
654    let where_idx = sql_lower.find("where")?;
655    let where_clause = &sql_lower[where_idx + 5..];
656    
657    let bytes = where_clause.as_bytes();
658    let mut i = 0;
659    while i < bytes.len() {
660        if i + 1 < bytes.len() && &bytes[i..i+2] == b"id" {
661            // Check boundary before
662            let prev_ok = if i == 0 {
663                true
664            } else {
665                let prev_char = bytes[i - 1] as char;
666                !prev_char.is_ascii_alphanumeric() && prev_char != '_' && prev_char != '.'
667            };
668            // Check boundary after
669            let next_ok = if i + 2 == bytes.len() {
670                true
671            } else {
672                let next_char = bytes[i + 2] as char;
673                !next_char.is_ascii_alphanumeric() && next_char != '_'
674            };
675            
676            if prev_ok && next_ok {
677                // Found the standalone "id" word!
678                // Now look for "=" after it
679                let mut j = i + 2;
680                while j < bytes.len() && (bytes[j] as char).is_whitespace() {
681                    j += 1;
682                }
683                if j < bytes.len() && bytes[j] == b'=' {
684                    j += 1;
685                    while j < bytes.len() && (bytes[j] as char).is_whitespace() {
686                        j += 1;
687                      }
688                      // Now extract the value
689                      let mut val_str = String::new();
690                      if j < bytes.len() && bytes[j] == b'\'' {
691                          j += 1; // consume single quote
692                          while j < bytes.len() && bytes[j] != b'\'' {
693                              val_str.push(bytes[j] as char);
694                              j += 1;
695                          }
696                          return Some(val_str);
697                      } else {
698                          while j < bytes.len() {
699                              let c = bytes[j] as char;
700                              if c.is_ascii_alphanumeric() || c == '_' || c == '-' {
701                                  val_str.push(c);
702                                  j += 1;
703                              } else {
704                                  break;
705                              }
706                          }
707                          if !val_str.is_empty() {
708                              return Some(val_str);
709                          }
710                      }
711                }
712            }
713        }
714        i += 1;
715    }
716    None
717}
718
719fn sql_result_summary(
720    operation: SqlLogOperation,
721    result_count: Option<usize>,
722    result_type: Option<&str>,
723    affected_rows: Option<u64>,
724    debug_sql: &str,
725) -> String {
726    match operation {
727        SqlLogOperation::Select => {
728            let count = result_count.unwrap_or(0);
729            if count == 0 {
730                "MISS".to_owned()
731            } else if count > 1 {
732                match result_type {
733                    Some(result_type) => format!("{count}*{result_type}"),
734                    None => format!("{count}*rows"),
735                }
736            } else {
737                match result_type {
738                    Some(result_type) => {
739                        if let Some(id) = extract_id_from_sql(debug_sql) {
740                            format!("{result_type}({id})")
741                        } else {
742                            result_type.to_owned()
743                        }
744                    }
745                    None => "row".to_owned(),
746                }
747            }
748        }
749        _ => {
750            let affected = affected_rows.unwrap_or(0);
751            format!("{affected} UPDATED")
752        }
753    }
754}
755
756fn pretty_sql(sql: &str) -> String {
757    let mut pretty = sql.to_owned();
758    for keyword in [
759        " FROM ",
760        " WHERE ",
761        " GROUP BY ",
762        " HAVING ",
763        " ORDER BY ",
764        " LIMIT ",
765        " OFFSET ",
766        " RETURNING ",
767    ] {
768        pretty = pretty.replace(keyword, &format!("\n{}", keyword.trim_start()));
769    }
770    pretty
771        .replace(" AND ", "\n  AND ")
772        .replace(" OR ", "\n  OR ")
773}
774
775