1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::future::Future;
4
5use std::pin::Pin;
6use std::sync::Mutex;
7use std::time::{Duration, SystemTime};
8
9use teaql_core::{EntityDescriptor, Record, UpdateCommand, Value};
10use teaql_sql::{CompiledQuery, DatabaseKind, SqlDialect};
11
12use crate::{
13 CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, GraphNode,
14 InternalIdGenerator, Language, MetadataStore, ObjectLocation, RepositoryBehavior,
15 RepositoryBehaviorRegistry, RepositoryRegistry, RequestPolicy, RuntimeError,
16 local_id_generator, translate_check_result,
17};
18use crate::{EntityRoot, QueryExecutor, RepositoryError};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SqlLogOperation {
22 Select,
23 Insert,
24 Update,
25 Delete,
26 Recover,
27}
28
29impl SqlLogOperation {
30 pub fn is_select(self) -> bool {
31 matches!(self, Self::Select)
32 }
33
34 pub fn is_mutation(self) -> bool {
35 !self.is_select()
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub struct SqlLogOptions {
41 pub select: bool,
42 pub mutation: bool,
43}
44
45impl SqlLogOptions {
46 pub fn disabled() -> Self {
47 Self {
48 select: false,
49 mutation: false,
50 }
51 }
52
53 pub fn select_only() -> Self {
54 Self {
55 select: true,
56 mutation: false,
57 }
58 }
59
60 pub fn mutation_only() -> Self {
61 Self {
62 select: false,
63 mutation: true,
64 }
65 }
66
67 pub fn all() -> Self {
68 Self {
69 select: true,
70 mutation: true,
71 }
72 }
73
74 pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
75 if operation.is_select() {
76 self.select
77 } else {
78 self.mutation
79 }
80 }
81}
82
83#[derive(Debug, Clone, PartialEq)]
84pub struct SqlLogEntry {
85 pub operation: SqlLogOperation,
86 pub sql: String,
87 pub params: Vec<Value>,
88 pub debug_sql: String,
89 pub pretty_sql: String,
90 pub started_at: SystemTime,
91 pub ended_at: SystemTime,
92 pub elapsed: Duration,
93 pub result_count: Option<usize>,
94 pub result_type: Option<String>,
95 pub affected_rows: Option<u64>,
96 pub result_summary: String,
97 pub user_identifier: Option<String>,
98 pub comment: Option<String>,
99}
100
101#[derive(Debug, Clone)]
102pub struct TuiLogEntry {
103 pub timestamp: SystemTime,
104 pub line: String,
105}
106
107#[derive(Clone, Default)]
108pub struct TuiLogBuffer {
109 pub entries: std::sync::Arc<Mutex<Vec<TuiLogEntry>>>,
110}
111
112pub trait SchemaProvider: Send + Sync {
113 fn ensure_schema<'a>(
114 &'a self,
115 ctx: &'a UserContext,
116 ) -> Pin<Box<dyn Future<Output = Result<(), RuntimeError>> + Send + 'a>>;
117}
118
119pub struct UserContext {
120 pub(crate) metadata: Option<Box<dyn MetadataStore>>,
121 pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
122 pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
123 pub(crate) request_policy: Option<Box<dyn RequestPolicy>>,
124 pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
125 pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
126 pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
127 schema_provider: Option<Box<dyn SchemaProvider>>,
128 language: Language,
129 typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
130 named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
131 locals: BTreeMap<String, Value>,
132 pub(crate) initial_graphs: Vec<GraphNode>,
133 entity_root: EntityRoot,
134 sql_log_options: SqlLogOptions,
135 sql_log_entries: Mutex<Vec<SqlLogEntry>>,
136 user_identifier: Option<String>,
137 pub(crate) comment_stack: Mutex<Vec<String>>,
138}
139
140impl Default for UserContext {
141 fn default() -> Self {
142 let pid = std::process::id();
143 let thread_id_str = format!("{:?}", std::thread::current().id());
144 let numeric_thread_id = thread_id_str
145 .strip_prefix("ThreadId(")
146 .and_then(|s| s.strip_suffix(")"))
147 .unwrap_or(&thread_id_str);
148 let os_user = std::env::var("USER")
149 .or_else(|_| std::env::var("USERNAME"))
150 .unwrap_or_else(|_| "main".to_owned());
151 let user_id = format!("{os_user}@pid-{pid}.tid-{numeric_thread_id}");
152 Self {
153 metadata: None,
154 repository_registry: None,
155 repository_behavior_registry: None,
156 request_policy: None,
157 checker_registry: None,
158 event_sink: None,
159 internal_id_generator: None,
160 schema_provider: None,
161 language: Language::default(),
162 typed_resources: HashMap::new(),
163 named_resources: BTreeMap::new(),
164 locals: BTreeMap::new(),
165 initial_graphs: Vec::new(),
166 entity_root: EntityRoot::default(),
167 sql_log_options: SqlLogOptions::all(),
168 sql_log_entries: Mutex::new(Vec::new()),
169 user_identifier: Some(user_id),
170 comment_stack: Mutex::new(Vec::new()),
171 }
172 }
173}
174
175impl UserContext {
176 pub fn new() -> Self {
177 Self::default()
178 }
179
180 pub fn user_identifier(&self) -> Option<&str> {
181 self.user_identifier.as_deref()
182 }
183
184 pub fn set_user_identifier(&mut self, user_identifier: impl Into<String>) {
185 self.user_identifier = Some(user_identifier.into());
186 }
187
188 pub fn with_user_identifier(mut self, user_identifier: impl Into<String>) -> Self {
189 self.user_identifier = Some(user_identifier.into());
190 self
191 }
192
193 pub fn set_user_identifier_option(&mut self, user_identifier: Option<String>) {
194 self.user_identifier = user_identifier;
195 }
196
197 pub fn with_user_identifier_option(mut self, user_identifier: Option<String>) -> Self {
198 self.user_identifier = user_identifier;
199 self
200 }
201
202 pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
203 module.apply_to(&mut self);
204 self
205 }
206
207 pub fn entity_root(&self) -> EntityRoot {
208 self.entity_root.clone()
209 }
210
211 pub fn initial_graphs(&self) -> &[GraphNode] {
212 &self.initial_graphs
213 }
214
215 pub fn set_initial_graphs(&mut self, graphs: Vec<GraphNode>) {
216 self.initial_graphs = graphs;
217 }
218
219 pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
220 self.metadata = Some(Box::new(metadata));
221 self
222 }
223
224 pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
225 self.metadata = Some(Box::new(metadata));
226 }
227
228 pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
229 self.repository_registry = Some(Box::new(registry));
230 self
231 }
232
233 pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
234 self.repository_registry = Some(Box::new(registry));
235 }
236
237 pub fn with_repository_behavior_registry(
238 mut self,
239 registry: impl RepositoryBehaviorRegistry + 'static,
240 ) -> Self {
241 self.repository_behavior_registry = Some(Box::new(registry));
242 self
243 }
244
245 pub fn set_repository_behavior_registry(
246 &mut self,
247 registry: impl RepositoryBehaviorRegistry + 'static,
248 ) {
249 self.repository_behavior_registry = Some(Box::new(registry));
250 }
251
252 pub fn with_request_policy(mut self, policy: impl RequestPolicy + 'static) -> Self {
253 self.request_policy = Some(Box::new(policy));
254 self
255 }
256
257 pub fn set_request_policy(&mut self, policy: impl RequestPolicy + 'static) {
258 self.request_policy = Some(Box::new(policy));
259 }
260
261 pub fn clear_request_policy(&mut self) {
262 self.request_policy = None;
263 }
264
265 pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
266 self.checker_registry = Some(Box::new(registry));
267 self
268 }
269
270 pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
271 self.checker_registry = Some(Box::new(registry));
272 }
273
274 pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
275 self.event_sink = Some(Box::new(sink));
276 self
277 }
278
279 pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
280 self.event_sink = Some(Box::new(sink));
281 }
282
283 pub fn with_internal_id_generator(
284 mut self,
285 generator: impl InternalIdGenerator + 'static,
286 ) -> Self {
287 self.internal_id_generator = Some(Box::new(generator));
288 self
289 }
290
291 pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
292 self.internal_id_generator = Some(Box::new(generator));
293 }
294
295 pub fn with_schema_provider(mut self, provider: impl SchemaProvider + 'static) -> Self {
296 self.schema_provider = Some(Box::new(provider));
297 self
298 }
299
300 pub fn set_schema_provider(&mut self, provider: impl SchemaProvider + 'static) {
301 self.schema_provider = Some(Box::new(provider));
302 }
303
304 pub async fn ensure_schema(&self) -> Result<(), RuntimeError> {
305 let provider = self
306 .schema_provider
307 .as_ref()
308 .ok_or_else(|| RuntimeError::Schema("missing schema provider".to_owned()))?;
309 provider.ensure_schema(self).await
310 }
311
312 pub fn with_language(mut self, language: Language) -> Self {
313 self.language = language;
314 self
315 }
316
317 pub fn set_language(&mut self, language: Language) {
318 self.language = language;
319 }
320
321 pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
322 self.sql_log_options = options;
323 self
324 }
325
326 pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
327 self.sql_log_options = options;
328 }
329
330 pub fn enable_select_sql_log(&mut self) {
331 self.sql_log_options.select = true;
332 }
333
334 pub fn enable_mutation_sql_log(&mut self) {
335 self.sql_log_options.mutation = true;
336 }
337
338 pub fn enable_all_sql_log(&mut self) {
339 self.sql_log_options = SqlLogOptions::all();
340 }
341
342 pub fn disable_sql_log(&mut self) {
343 self.sql_log_options = SqlLogOptions::disabled();
344 self.clear_sql_logs();
345 }
346
347 pub fn sql_log_options(&self) -> SqlLogOptions {
348 self.sql_log_options
349 }
350
351 pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
352 self.sql_log_entries
353 .lock()
354 .map(|entries| entries.clone())
355 .unwrap_or_default()
356 }
357
358 pub fn clear_sql_logs(&self) {
359 if let Ok(mut entries) = self.sql_log_entries.lock() {
360 entries.clear();
361 }
362 }
363
364 pub(crate) fn record_sql_log(
365 &self,
366 operation: SqlLogOperation,
367 query: &CompiledQuery,
368 database_kind: DatabaseKind,
369 started_at: SystemTime,
370 ended_at: SystemTime,
371 elapsed: Duration,
372 result_count: Option<usize>,
373 result_type: Option<String>,
374 affected_rows: Option<u64>,
375 comment: Option<String>,
376 ) {
377 if !self.sql_log_options.enabled_for(operation) {
378 return;
379 }
380 let debug_sql = query.debug_sql(database_kind);
381 let result_summary = sql_result_summary(
382 operation,
383 result_count,
384 result_type.as_deref(),
385 affected_rows,
386 &debug_sql,
387 );
388
389 let stack_comment = self.comment_stack.lock().ok().and_then(|stack| {
391 if stack.is_empty() {
392 None
393 } else {
394 Some(stack.join("->"))
395 }
396 });
397 let final_comment = stack_comment.or(comment);
398
399 if let Ok(mut file) = std::fs::OpenOptions::new()
401 .create(true)
402 .append(true)
403 .open("app.log")
404 {
405 use std::io::Write;
406 let local_time: chrono::DateTime<chrono::Local> = started_at.into();
407 let timestamp_str = local_time.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
408 let user_id_str = self.user_identifier.as_deref().unwrap_or("");
409 let comment_part = if let Some(ref c) = final_comment {
410 format!(" - [{c}]")
411 } else {
412 "".to_owned()
413 };
414 let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
415 let log_line = format!(
416 "{timestamp_str}-[{user_id_str}]--DEBUG - SqlLogEntry{comment_part} - [{result_summary}] {} (took {:.3}ms)\n",
417 debug_sql, elapsed_ms
418 );
419 let _ = file.write_all(log_line.as_bytes());
420 }
421
422 if let Ok(mut entries) = self.sql_log_entries.lock() {
423 entries.push(SqlLogEntry {
424 operation,
425 sql: query.sql.clone(),
426 params: query.params.clone(),
427 pretty_sql: pretty_sql(&debug_sql),
428 debug_sql: debug_sql.clone(),
429 started_at,
430 ended_at,
431 elapsed,
432 result_summary: result_summary.clone(),
433 result_count,
434 result_type,
435 affected_rows,
436 user_identifier: self.user_identifier.clone(),
437 comment: final_comment.clone(),
438 });
439 }
440
441 if let Some(buf) = self.get_resource::<TuiLogBuffer>() {
442 let local_time: chrono::DateTime<chrono::Local> = started_at.into();
443 let timestamp_str = local_time.format("%H:%M:%S%.3f").to_string();
444 let user_id_str = self.user_identifier.as_deref().unwrap_or("");
445 let comment_part = if let Some(ref c) = final_comment {
446 format!(" - [{c}]")
447 } else {
448 "".to_owned()
449 };
450 let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
451 let single_line_sql = debug_sql
452 .lines()
453 .map(|line| line.trim())
454 .filter(|line| !line.is_empty())
455 .collect::<Vec<_>>()
456 .join(" ");
457 let log_line = format!(
458 "[{}]-[{}]-[DEBUG]-SqlLogEntry{} - [{}] {} (took {:.3}ms)",
459 timestamp_str, user_id_str, comment_part, result_summary, single_line_sql, elapsed_ms
460 );
461 if let Ok(mut entries) = buf.entries.lock() {
462 entries.push(TuiLogEntry {
463 timestamp: started_at,
464 line: log_line,
465 });
466 }
467 }
468 }
469
470 pub fn language(&self) -> Language {
471 self.language
472 }
473
474 pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
475 let Some(language) = Language::from_code(code) else {
476 return Err(RuntimeError::Language(format!(
477 "unsupported language code: {code}"
478 )));
479 };
480 self.language = language;
481 Ok(())
482 }
483
484 pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
485 self.internal_id_generator
486 .as_ref()
487 .map(|generator| generator.generate_id(entity))
488 .transpose()
489 }
490
491 pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
492 match self.generate_id(entity)? {
493 Some(id) => Ok(id),
494 None => local_id_generator().generate_id(entity),
495 }
496 }
497
498 pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
499 self.metadata
500 .as_ref()
501 .and_then(|metadata| metadata.entity(name))
502 }
503
504 pub fn all_entities(&self) -> Vec<&EntityDescriptor> {
505 self.metadata
506 .as_ref()
507 .map(|metadata| metadata.all_entities())
508 .unwrap_or_default()
509 }
510
511 pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
512 self.entity(name)
513 .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
514 }
515
516 pub fn insert_resource<T>(&mut self, resource: T)
517 where
518 T: Send + Sync + 'static,
519 {
520 self.typed_resources
521 .insert(TypeId::of::<T>(), Box::new(resource));
522 }
523
524 pub fn get_resource<T>(&self) -> Option<&T>
525 where
526 T: Send + Sync + 'static,
527 {
528 self.typed_resources
529 .get(&TypeId::of::<T>())
530 .and_then(|value| value.downcast_ref::<T>())
531 }
532
533 pub fn require_resource<T>(&self) -> Result<&T, ContextError>
534 where
535 T: Send + Sync + 'static,
536 {
537 self.get_resource::<T>()
538 .ok_or(ContextError::MissingTypedResource(
539 std::any::type_name::<T>(),
540 ))
541 }
542
543 pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
544 where
545 T: Send + Sync + 'static,
546 {
547 self.named_resources.insert(name.into(), Box::new(resource));
548 }
549
550 pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
551 where
552 T: Send + Sync + 'static,
553 {
554 self.named_resources
555 .get(name)
556 .and_then(|value| value.downcast_ref::<T>())
557 }
558
559 pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
560 where
561 T: Send + Sync + 'static,
562 {
563 self.get_named_resource::<T>(name)
564 .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
565 }
566
567 pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
568 self.locals.insert(key.into(), value.into());
569 }
570
571 pub fn local(&self, key: &str) -> Option<&Value> {
572 self.locals.get(key)
573 }
574
575 pub fn remove_local(&mut self, key: &str) -> Option<Value> {
576 self.locals.remove(key)
577 }
578
579 pub fn has_repository(&self, entity: &str) -> bool {
580 let in_registry = self
581 .repository_registry
582 .as_ref()
583 .map(|registry| registry.contains(entity))
584 .unwrap_or(false);
585 in_registry || self.entity(entity).is_some()
586 }
587
588 pub fn repository_behavior(
589 &self,
590 entity: &str,
591 ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
592 self.repository_behavior_registry
593 .as_ref()
594 .and_then(|registry| registry.behavior(entity))
595 }
596
597 pub fn has_checker(&self, entity: &str) -> bool {
598 self.checker_registry
599 .as_ref()
600 .and_then(|registry| registry.checker(entity))
601 .is_some()
602 }
603
604 pub fn check_and_fix_record(
605 &self,
606 entity: &str,
607 record: &mut Record,
608 ) -> Result<(), RuntimeError> {
609 self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
610 }
611
612 pub fn check_and_fix_record_at(
613 &self,
614 entity: &str,
615 record: &mut Record,
616 location: &ObjectLocation,
617 ) -> Result<(), RuntimeError> {
618 let Some(checker) = self
619 .checker_registry
620 .as_ref()
621 .and_then(|registry| registry.checker(entity))
622 else {
623 return Ok(());
624 };
625 let mut results = CheckResults::new();
626 checker.check_and_fix(self, record, location, &mut results);
627 if results.is_empty() {
628 Ok(())
629 } else {
630 self.translate_check_results(&mut results);
631 Err(RuntimeError::Check(results))
632 }
633 }
634
635 pub fn translate_check_results(&self, results: &mut CheckResults) {
636 for result in results {
637 result.message = Some(translate_check_result(self.language, result));
638 }
639 }
640
641 pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
642 let Some(sink) = self.event_sink.as_ref() else {
643 return Ok(());
644 };
645 sink.on_event(self, &event)
646 }
647
648 pub fn commit_changes<D, E>(&self) -> Result<(), RepositoryError<E::Error>>
649 where
650 D: SqlDialect + Send + Sync + 'static,
651 E: QueryExecutor + Send + Sync + 'static,
652 {
653 let dialect = self.require_resource::<D>().map_err(|err| {
654 RepositoryError::Runtime(RuntimeError::Graph(format!(
655 "cannot commit changes without dialect: {err}"
656 )))
657 })?;
658 let executor = self.require_resource::<E>().map_err(|err| {
659 RepositoryError::Runtime(RuntimeError::Graph(format!(
660 "cannot commit changes without executor: {err}"
661 )))
662 })?;
663 let change_set = self.entity_root.current_change_set();
664
665 for (key, changes) in change_set.changes() {
666 if changes.is_empty() {
667 continue;
668 }
669 let entity = self
670 .require_entity(&key.entity)
671 .map_err(RepositoryError::Runtime)?;
672 let mut command = UpdateCommand::new(&key.entity, key.id.clone());
673 for (field, value) in changes {
674 command = command.value(field.clone(), value.clone());
675 }
676 let query = dialect
677 .compile_update(entity, &command)
678 .map_err(RuntimeError::from)
679 .map_err(RepositoryError::Runtime)?;
680 executor
681 .execute(&query)
682 .map_err(RepositoryError::Executor)?;
683 }
684
685 self.entity_root.clear_current_change_set();
686 Ok(())
687 }
688}
689
690fn extract_id_from_sql(sql: &str) -> Option<String> {
691 let sql_lower = sql.to_lowercase();
692 let where_idx = sql_lower.find("where")?;
693 let where_clause = &sql_lower[where_idx + 5..];
694
695 let bytes = where_clause.as_bytes();
696 let mut i = 0;
697 while i < bytes.len() {
698 if i + 1 < bytes.len() && &bytes[i..i+2] == b"id" {
699 let prev_ok = if i == 0 {
701 true
702 } else {
703 let prev_char = bytes[i - 1] as char;
704 !prev_char.is_ascii_alphanumeric() && prev_char != '_' && prev_char != '.'
705 };
706 let next_ok = if i + 2 == bytes.len() {
708 true
709 } else {
710 let next_char = bytes[i + 2] as char;
711 !next_char.is_ascii_alphanumeric() && next_char != '_'
712 };
713
714 if prev_ok && next_ok {
715 let mut j = i + 2;
718 while j < bytes.len() && (bytes[j] as char).is_whitespace() {
719 j += 1;
720 }
721 if j < bytes.len() && bytes[j] == b'=' {
722 j += 1;
723 while j < bytes.len() && (bytes[j] as char).is_whitespace() {
724 j += 1;
725 }
726 let mut val_str = String::new();
728 if j < bytes.len() && bytes[j] == b'\'' {
729 j += 1; while j < bytes.len() && bytes[j] != b'\'' {
731 val_str.push(bytes[j] as char);
732 j += 1;
733 }
734 return Some(val_str);
735 } else {
736 while j < bytes.len() {
737 let c = bytes[j] as char;
738 if c.is_ascii_alphanumeric() || c == '_' || c == '-' {
739 val_str.push(c);
740 j += 1;
741 } else {
742 break;
743 }
744 }
745 if !val_str.is_empty() {
746 return Some(val_str);
747 }
748 }
749 }
750 }
751 }
752 i += 1;
753 }
754 None
755}
756
757fn sql_result_summary(
758 operation: SqlLogOperation,
759 result_count: Option<usize>,
760 result_type: Option<&str>,
761 affected_rows: Option<u64>,
762 debug_sql: &str,
763) -> String {
764 match operation {
765 SqlLogOperation::Select => {
766 let count = result_count.unwrap_or(0);
767 if count == 0 {
768 "MISS".to_owned()
769 } else if count > 1 {
770 match result_type {
771 Some(result_type) => format!("{count}*{result_type}"),
772 None => format!("{count}*rows"),
773 }
774 } else {
775 match result_type {
776 Some(result_type) => {
777 if let Some(id) = extract_id_from_sql(debug_sql) {
778 format!("{result_type}({id})")
779 } else {
780 result_type.to_owned()
781 }
782 }
783 None => "row".to_owned(),
784 }
785 }
786 }
787 _ => {
788 let affected = affected_rows.unwrap_or(0);
789 format!("{affected} UPDATED")
790 }
791 }
792}
793
794fn pretty_sql(sql: &str) -> String {
795 let mut pretty = sql.to_owned();
796 for keyword in [
797 " FROM ",
798 " WHERE ",
799 " GROUP BY ",
800 " HAVING ",
801 " ORDER BY ",
802 " LIMIT ",
803 " OFFSET ",
804 " RETURNING ",
805 ] {
806 pretty = pretty.replace(keyword, &format!("\n{}", keyword.trim_start()));
807 }
808 pretty
809 .replace(" AND ", "\n AND ")
810 .replace(" OR ", "\n OR ")
811}
812
813pub struct QueryCommentGuard<'a> {
814 context: &'a UserContext,
815 has_pushed: bool,
816}
817
818impl<'a> QueryCommentGuard<'a> {
819 pub fn new(context: &'a UserContext, comment: Option<String>) -> Self {
820 let mut has_pushed = false;
821 if let Some(comment) = comment {
822 if !comment.is_empty() {
823 if let Ok(mut stack) = context.comment_stack.lock() {
824 if stack.last() != Some(&comment) {
825 stack.push(comment);
826 has_pushed = true;
827 }
828 }
829 }
830 }
831 Self { context, has_pushed }
832 }
833}
834
835impl<'a> Drop for QueryCommentGuard<'a> {
836 fn drop(&mut self) {
837 if self.has_pushed {
838 if let Ok(mut stack) = self.context.comment_stack.lock() {
839 stack.pop();
840 }
841 }
842 }
843}