1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Mutex;
6
7use teaql_core::{EntityDescriptor, Record, UpdateCommand, Value};
8use teaql_sql::{CompiledQuery, DatabaseKind, SqlDialect};
9
10use crate::{
11 CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, GraphNode,
12 InternalIdGenerator, Language, MetadataStore, ObjectLocation, RepositoryBehavior,
13 RepositoryBehaviorRegistry, RepositoryRegistry, RuntimeError, local_id_generator,
14 translate_check_result,
15};
16use crate::{EntityRoot, QueryExecutor, RepositoryError};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum SqlLogOperation {
20 Select,
21 Insert,
22 Update,
23 Delete,
24 Recover,
25}
26
27impl SqlLogOperation {
28 pub fn is_select(self) -> bool {
29 matches!(self, Self::Select)
30 }
31
32 pub fn is_mutation(self) -> bool {
33 !self.is_select()
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
38pub struct SqlLogOptions {
39 pub select: bool,
40 pub mutation: bool,
41}
42
43impl SqlLogOptions {
44 pub fn select_only() -> Self {
45 Self {
46 select: true,
47 mutation: false,
48 }
49 }
50
51 pub fn mutation_only() -> Self {
52 Self {
53 select: false,
54 mutation: true,
55 }
56 }
57
58 pub fn all() -> Self {
59 Self {
60 select: true,
61 mutation: true,
62 }
63 }
64
65 pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
66 if operation.is_select() {
67 self.select
68 } else {
69 self.mutation
70 }
71 }
72}
73
74#[derive(Debug, Clone, PartialEq)]
75pub struct SqlLogEntry {
76 pub operation: SqlLogOperation,
77 pub sql: String,
78 pub params: Vec<Value>,
79 pub debug_sql: String,
80}
81
82pub trait SchemaProvider: Send + Sync {
83 fn ensure_schema<'a>(
84 &'a self,
85 ctx: &'a UserContext,
86 ) -> Pin<Box<dyn Future<Output = Result<(), RuntimeError>> + Send + 'a>>;
87}
88
89#[derive(Default)]
90pub struct UserContext {
91 pub(crate) metadata: Option<Box<dyn MetadataStore>>,
92 pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
93 pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
94 pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
95 pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
96 pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
97 schema_provider: Option<Box<dyn SchemaProvider>>,
98 language: Language,
99 typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
100 named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
101 locals: BTreeMap<String, Value>,
102 pub(crate) initial_graphs: Vec<GraphNode>,
103 entity_root: EntityRoot,
104 sql_log_options: SqlLogOptions,
105 sql_log_entries: Mutex<Vec<SqlLogEntry>>,
106}
107
108impl UserContext {
109 pub fn new() -> Self {
110 Self::default()
111 }
112
113 pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
114 module.apply_to(&mut self);
115 self
116 }
117
118 pub fn entity_root(&self) -> EntityRoot {
119 self.entity_root.clone()
120 }
121
122 pub fn initial_graphs(&self) -> &[GraphNode] {
123 &self.initial_graphs
124 }
125
126 pub fn set_initial_graphs(&mut self, graphs: Vec<GraphNode>) {
127 self.initial_graphs = graphs;
128 }
129
130 pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
131 self.metadata = Some(Box::new(metadata));
132 self
133 }
134
135 pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
136 self.metadata = Some(Box::new(metadata));
137 }
138
139 pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
140 self.repository_registry = Some(Box::new(registry));
141 self
142 }
143
144 pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
145 self.repository_registry = Some(Box::new(registry));
146 }
147
148 pub fn with_repository_behavior_registry(
149 mut self,
150 registry: impl RepositoryBehaviorRegistry + 'static,
151 ) -> Self {
152 self.repository_behavior_registry = Some(Box::new(registry));
153 self
154 }
155
156 pub fn set_repository_behavior_registry(
157 &mut self,
158 registry: impl RepositoryBehaviorRegistry + 'static,
159 ) {
160 self.repository_behavior_registry = Some(Box::new(registry));
161 }
162
163 pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
164 self.checker_registry = Some(Box::new(registry));
165 self
166 }
167
168 pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
169 self.checker_registry = Some(Box::new(registry));
170 }
171
172 pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
173 self.event_sink = Some(Box::new(sink));
174 self
175 }
176
177 pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
178 self.event_sink = Some(Box::new(sink));
179 }
180
181 pub fn with_internal_id_generator(
182 mut self,
183 generator: impl InternalIdGenerator + 'static,
184 ) -> Self {
185 self.internal_id_generator = Some(Box::new(generator));
186 self
187 }
188
189 pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
190 self.internal_id_generator = Some(Box::new(generator));
191 }
192
193 pub fn with_schema_provider(mut self, provider: impl SchemaProvider + 'static) -> Self {
194 self.schema_provider = Some(Box::new(provider));
195 self
196 }
197
198 pub fn set_schema_provider(&mut self, provider: impl SchemaProvider + 'static) {
199 self.schema_provider = Some(Box::new(provider));
200 }
201
202 pub async fn ensure_schema(&self) -> Result<(), RuntimeError> {
203 let provider = self
204 .schema_provider
205 .as_ref()
206 .ok_or_else(|| RuntimeError::Schema("missing schema provider".to_owned()))?;
207 provider.ensure_schema(self).await
208 }
209
210 pub fn with_language(mut self, language: Language) -> Self {
211 self.language = language;
212 self
213 }
214
215 pub fn set_language(&mut self, language: Language) {
216 self.language = language;
217 }
218
219 pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
220 self.sql_log_options = options;
221 self
222 }
223
224 pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
225 self.sql_log_options = options;
226 }
227
228 pub fn enable_select_sql_log(&mut self) {
229 self.sql_log_options.select = true;
230 }
231
232 pub fn enable_mutation_sql_log(&mut self) {
233 self.sql_log_options.mutation = true;
234 }
235
236 pub fn enable_all_sql_log(&mut self) {
237 self.sql_log_options = SqlLogOptions::all();
238 }
239
240 pub fn disable_sql_log(&mut self) {
241 self.sql_log_options = SqlLogOptions::default();
242 self.clear_sql_logs();
243 }
244
245 pub fn sql_log_options(&self) -> SqlLogOptions {
246 self.sql_log_options
247 }
248
249 pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
250 self.sql_log_entries
251 .lock()
252 .map(|entries| entries.clone())
253 .unwrap_or_default()
254 }
255
256 pub fn clear_sql_logs(&self) {
257 if let Ok(mut entries) = self.sql_log_entries.lock() {
258 entries.clear();
259 }
260 }
261
262 pub(crate) fn record_sql_log(
263 &self,
264 operation: SqlLogOperation,
265 query: &CompiledQuery,
266 database_kind: DatabaseKind,
267 ) {
268 if !self.sql_log_options.enabled_for(operation) {
269 return;
270 }
271 if let Ok(mut entries) = self.sql_log_entries.lock() {
272 entries.push(SqlLogEntry {
273 operation,
274 sql: query.sql.clone(),
275 params: query.params.clone(),
276 debug_sql: query.debug_sql(database_kind),
277 });
278 }
279 }
280
281 pub fn language(&self) -> Language {
282 self.language
283 }
284
285 pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
286 let Some(language) = Language::from_code(code) else {
287 return Err(RuntimeError::Language(format!(
288 "unsupported language code: {code}"
289 )));
290 };
291 self.language = language;
292 Ok(())
293 }
294
295 pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
296 self.internal_id_generator
297 .as_ref()
298 .map(|generator| generator.generate_id(entity))
299 .transpose()
300 }
301
302 pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
303 match self.generate_id(entity)? {
304 Some(id) => Ok(id),
305 None => local_id_generator().generate_id(entity),
306 }
307 }
308
309 pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
310 self.metadata
311 .as_ref()
312 .and_then(|metadata| metadata.entity(name))
313 }
314
315 pub fn all_entities(&self) -> Vec<&EntityDescriptor> {
316 self.metadata
317 .as_ref()
318 .map(|metadata| metadata.all_entities())
319 .unwrap_or_default()
320 }
321
322 pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
323 self.entity(name)
324 .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
325 }
326
327 pub fn insert_resource<T>(&mut self, resource: T)
328 where
329 T: Send + Sync + 'static,
330 {
331 self.typed_resources
332 .insert(TypeId::of::<T>(), Box::new(resource));
333 }
334
335 pub fn get_resource<T>(&self) -> Option<&T>
336 where
337 T: Send + Sync + 'static,
338 {
339 self.typed_resources
340 .get(&TypeId::of::<T>())
341 .and_then(|value| value.downcast_ref::<T>())
342 }
343
344 pub fn require_resource<T>(&self) -> Result<&T, ContextError>
345 where
346 T: Send + Sync + 'static,
347 {
348 self.get_resource::<T>()
349 .ok_or(ContextError::MissingTypedResource(
350 std::any::type_name::<T>(),
351 ))
352 }
353
354 pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
355 where
356 T: Send + Sync + 'static,
357 {
358 self.named_resources.insert(name.into(), Box::new(resource));
359 }
360
361 pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
362 where
363 T: Send + Sync + 'static,
364 {
365 self.named_resources
366 .get(name)
367 .and_then(|value| value.downcast_ref::<T>())
368 }
369
370 pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
371 where
372 T: Send + Sync + 'static,
373 {
374 self.get_named_resource::<T>(name)
375 .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
376 }
377
378 pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
379 self.locals.insert(key.into(), value.into());
380 }
381
382 pub fn local(&self, key: &str) -> Option<&Value> {
383 self.locals.get(key)
384 }
385
386 pub fn remove_local(&mut self, key: &str) -> Option<Value> {
387 self.locals.remove(key)
388 }
389
390 pub fn has_repository(&self, entity: &str) -> bool {
391 let in_registry = self
392 .repository_registry
393 .as_ref()
394 .map(|registry| registry.contains(entity))
395 .unwrap_or(false);
396 in_registry || self.entity(entity).is_some()
397 }
398
399 pub fn repository_behavior(
400 &self,
401 entity: &str,
402 ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
403 self.repository_behavior_registry
404 .as_ref()
405 .and_then(|registry| registry.behavior(entity))
406 }
407
408 pub fn has_checker(&self, entity: &str) -> bool {
409 self.checker_registry
410 .as_ref()
411 .and_then(|registry| registry.checker(entity))
412 .is_some()
413 }
414
415 pub fn check_and_fix_record(
416 &self,
417 entity: &str,
418 record: &mut Record,
419 ) -> Result<(), RuntimeError> {
420 self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
421 }
422
423 pub fn check_and_fix_record_at(
424 &self,
425 entity: &str,
426 record: &mut Record,
427 location: &ObjectLocation,
428 ) -> Result<(), RuntimeError> {
429 let Some(checker) = self
430 .checker_registry
431 .as_ref()
432 .and_then(|registry| registry.checker(entity))
433 else {
434 return Ok(());
435 };
436 let mut results = CheckResults::new();
437 checker.check_and_fix(self, record, location, &mut results);
438 if results.is_empty() {
439 Ok(())
440 } else {
441 self.translate_check_results(&mut results);
442 Err(RuntimeError::Check(results))
443 }
444 }
445
446 pub fn translate_check_results(&self, results: &mut CheckResults) {
447 for result in results {
448 result.message = Some(translate_check_result(self.language, result));
449 }
450 }
451
452 pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
453 let Some(sink) = self.event_sink.as_ref() else {
454 return Ok(());
455 };
456 sink.on_event(self, &event)
457 }
458
459 pub fn commit_changes<D, E>(&self) -> Result<(), RepositoryError<E::Error>>
460 where
461 D: SqlDialect + Send + Sync + 'static,
462 E: QueryExecutor + Send + Sync + 'static,
463 {
464 let dialect = self.require_resource::<D>().map_err(|err| {
465 RepositoryError::Runtime(RuntimeError::Graph(format!(
466 "cannot commit changes without dialect: {err}"
467 )))
468 })?;
469 let executor = self.require_resource::<E>().map_err(|err| {
470 RepositoryError::Runtime(RuntimeError::Graph(format!(
471 "cannot commit changes without executor: {err}"
472 )))
473 })?;
474 let change_set = self.entity_root.current_change_set();
475
476 for (key, changes) in change_set.changes() {
477 if changes.is_empty() {
478 continue;
479 }
480 let entity = self
481 .require_entity(&key.entity)
482 .map_err(RepositoryError::Runtime)?;
483 let mut command = UpdateCommand::new(&key.entity, key.id.clone());
484 for (field, value) in changes {
485 command = command.value(field.clone(), value.clone());
486 }
487 let query = dialect
488 .compile_update(entity, &command)
489 .map_err(RuntimeError::from)
490 .map_err(RepositoryError::Runtime)?;
491 executor
492 .execute(&query)
493 .map_err(RepositoryError::Executor)?;
494 }
495
496 self.entity_root.clear_current_change_set();
497 Ok(())
498 }
499}