1use crate::{
2 client::{prepare_command, Client, PreparedCommand},
3 commands::{GraphCache, GraphValue, GraphValueArraySeed},
4 resp::{
5 cmd, CollectionResponse, Command, CommandArgs, KeyValueCollectionResponse,
6 PrimitiveResponse, RespBuf, RespDeserializer, SingleArg, ToArgs,
7 },
8 Error, Future, Result,
9};
10use serde::{
11 de::{self, DeserializeOwned, DeserializeSeed, Visitor},
12 Deserialize, Deserializer,
13};
14use smallvec::SmallVec;
15use std::{collections::HashMap, fmt, future, str::FromStr};
16
17pub trait GraphCommands<'a> {
22 #[must_use]
34 fn graph_config_get<N, V, R>(self, name: impl SingleArg) -> PreparedCommand<'a, Self, R>
35 where
36 Self: Sized,
37 N: PrimitiveResponse,
38 V: PrimitiveResponse,
39 R: KeyValueCollectionResponse<N, V>,
40 {
41 prepare_command(self, cmd("GRAPH.CONFIG").arg("GET").arg(name))
42 }
43
44 #[must_use]
57 fn graph_config_set(
58 self,
59 name: impl SingleArg,
60 value: impl SingleArg,
61 ) -> PreparedCommand<'a, Self, ()>
62 where
63 Self: Sized,
64 {
65 prepare_command(self, cmd("GRAPH.CONFIG").arg("SET").arg(name).arg(value))
66 }
67
68 #[must_use]
76 fn graph_delete(self, graph: impl SingleArg) -> PreparedCommand<'a, Self, String>
77 where
78 Self: Sized,
79 {
80 prepare_command(self, cmd("GRAPH.DELETE").arg(graph))
81 }
82
83 #[must_use]
97 fn graph_explain<R: PrimitiveResponse + DeserializeOwned, RR: CollectionResponse<R>>(
98 self,
99 graph: impl SingleArg,
100 query: impl SingleArg,
101 ) -> PreparedCommand<'a, Self, RR>
102 where
103 Self: Sized,
104 {
105 prepare_command(self, cmd("GRAPH.EXPLAIN").arg(graph).arg(query))
106 }
107
108 #[must_use]
116 fn graph_list<R: PrimitiveResponse + DeserializeOwned, RR: CollectionResponse<R>>(
117 self,
118 ) -> PreparedCommand<'a, Self, RR>
119 where
120 Self: Sized,
121 {
122 prepare_command(self, cmd("GRAPH.LIST"))
123 }
124
125 #[must_use]
138 fn graph_profile<R: PrimitiveResponse + DeserializeOwned, RR: CollectionResponse<R>>(
139 self,
140 graph: impl SingleArg,
141 query: impl SingleArg,
142 options: GraphQueryOptions,
143 ) -> PreparedCommand<'a, Self, RR>
144 where
145 Self: Sized,
146 {
147 prepare_command(self, cmd("GRAPH.LIST").arg(graph).arg(query).arg(options))
148 }
149
150 #[must_use]
164 fn graph_query(
165 self,
166 graph: impl SingleArg,
167 query: impl SingleArg,
168 options: GraphQueryOptions,
169 ) -> PreparedCommand<'a, Self, GraphResultSet>
170 where
171 Self: Sized,
172 {
173 prepare_command(
174 self,
175 cmd("GRAPH.QUERY")
176 .arg(graph)
177 .arg(query)
178 .arg(options)
179 .arg("--compact"),
180 )
181 .custom_converter(Box::new(GraphResultSet::custom_conversion))
182 }
183
184 #[must_use]
197 fn graph_ro_query(
198 self,
199 graph: impl SingleArg,
200 query: impl SingleArg,
201 options: GraphQueryOptions,
202 ) -> PreparedCommand<'a, Self, GraphResultSet>
203 where
204 Self: Sized,
205 {
206 prepare_command(
207 self,
208 cmd("GRAPH.RO_QUERY")
209 .arg(graph)
210 .arg(query)
211 .arg(options)
212 .arg("--compact"),
213 )
214 .custom_converter(Box::new(GraphResultSet::custom_conversion))
215 }
216
217 #[must_use]
228 fn graph_slowlog<R: CollectionResponse<GraphSlowlogResult>>(
229 self,
230 graph: impl SingleArg,
231 ) -> PreparedCommand<'a, Self, R>
232 where
233 Self: Sized,
234 {
235 prepare_command(self, cmd("GRAPH.SLOWLOG").arg(graph))
236 }
237}
238
239#[derive(Default)]
241pub struct GraphQueryOptions {
242 command_args: CommandArgs,
243}
244
245impl GraphQueryOptions {
246 #[must_use]
248 pub fn timeout(timeout: u64) -> Self {
249 Self {
250 command_args: CommandArgs::default().arg("TIMEOUT").arg(timeout).build(),
251 }
252 }
253}
254
255impl ToArgs for GraphQueryOptions {
256 fn write_args(&self, args: &mut CommandArgs) {
257 args.arg(&self.command_args);
258 }
259}
260
261#[derive(Debug, Deserialize)]
263pub struct GraphResultSet {
264 pub header: GraphHeader,
265 pub rows: Vec<GraphResultRow>,
266 pub statistics: GraphQueryStatistics,
267}
268
269impl GraphResultSet {
270 pub(crate) fn custom_conversion(
271 resp_buffer: RespBuf,
272 command: Command,
273 client: &Client,
274 ) -> Future<Self> {
275 let Some(graph_name) = command.args.iter().next() else {
276 return Box::pin(future::ready(Err(Error::Client("Cannot parse graph command".to_owned()))));
277 };
278
279 let Ok(graph_name) = std::str::from_utf8(graph_name) else {
280 return Box::pin(future::ready(Err(Error::Client("Cannot parse graph command".to_owned()))));
281 };
282
283 let graph_name = graph_name.to_owned();
284
285 Box::pin(async move {
286 let cache_key = format!("graph:{graph_name}");
287 let (cache_hit, num_node_labels, num_prop_keys, num_rel_types) = {
288 let client_state = client.get_client_state();
289 match client_state.get_state::<GraphCache>(&cache_key)? {
290 Some(cache) => {
291 let mut deserializer = RespDeserializer::new(&resp_buffer);
292 if cache.check_for_result(&mut deserializer)? {
293 (true, 0, 0, 0)
294 } else {
295 (
296 false,
297 cache.node_labels.len(),
298 cache.property_keys.len(),
299 cache.relationship_types.len(),
300 )
301 }
302 }
303 None => {
304 let cache = GraphCache::default();
305 let mut deserializer = RespDeserializer::new(&resp_buffer);
306 if cache.check_for_result(&mut deserializer)? {
307 (true, 0, 0, 0)
308 } else {
309 (false, 0, 0, 0)
310 }
311 }
312 }
313 };
314
315 if !cache_hit {
316 let (node_labels, prop_keys, rel_types) = Self::load_missing_ids(
317 &graph_name,
318 client,
319 num_node_labels,
320 num_prop_keys,
321 num_rel_types,
322 )
323 .await?;
324
325 let mut client_state = client.get_client_state_mut();
326 let cache = client_state.get_state_mut::<GraphCache>(&cache_key)?;
327
328 cache.update(
329 num_node_labels,
330 num_prop_keys,
331 num_rel_types,
332 node_labels,
333 prop_keys,
334 rel_types,
335 );
336
337 log::debug!("cache updated: {cache:?}");
338 } else if num_node_labels == 0 && num_prop_keys == 0 && num_rel_types == 0 {
339 let mut client_state = client.get_client_state_mut();
341 client_state.get_state_mut::<GraphCache>(&cache_key)?;
342
343 log::debug!("graph cache created");
344 }
345
346 let mut deserializer = RespDeserializer::new(&resp_buffer);
347 Self::deserialize(&mut deserializer, client, &cache_key)
348 })
349 }
350
351 fn deserialize<'de, D>(
352 deserializer: D,
353 client: &Client,
354 cache_key: &str,
355 ) -> std::result::Result<GraphResultSet, D::Error>
356 where
357 D: Deserializer<'de>,
358 {
359 struct GraphResultSetVisitor<'a, 'b> {
360 client: &'a Client,
361 cache_key: &'b str,
362 }
363
364 impl<'a, 'b, 'de> Visitor<'de> for GraphResultSetVisitor<'a, 'b> {
365 type Value = GraphResultSet;
366
367 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
368 formatter.write_str("GraphResultSet")
369 }
370
371 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
372 where
373 A: de::SeqAccess<'de>,
374 {
375 let Some(size) = seq.size_hint() else {
376 return Err(de::Error::custom("size hint is mandatory for GraphResultSet"));
377 };
378
379 if size == 1 {
380 let Some(statistics) = seq.next_element::<GraphQueryStatistics>()? else {
381 return Err(de::Error::invalid_length(0, &"more elements in sequence"));
382 };
383
384 Ok(GraphResultSet {
385 header: Default::default(),
386 rows: Default::default(),
387 statistics,
388 })
389 } else {
390 let Some(header) = seq.next_element::<GraphHeader>()? else {
391 return Err(de::Error::invalid_length(0, &"more elements in sequence"));
392 };
393
394 let client_state = self.client.get_client_state();
395 let Ok(Some(cache)) = client_state.get_state::<GraphCache>(self.cache_key) else {
396 return Err(de::Error::custom("Cannot find graph cache"));
397 };
398
399 let Some(rows) = seq.next_element_seed(GraphResultRowsSeed { cache })? else {
400 return Err(de::Error::invalid_length(1, &"more elements in sequence"));
401 };
402
403 let Some(statistics) = seq.next_element::<GraphQueryStatistics>()? else {
404 return Err(de::Error::invalid_length(2, &"more elements in sequence"));
405 };
406
407 Ok(GraphResultSet {
408 header,
409 rows,
410 statistics,
411 })
412 }
413 }
414 }
415
416 deserializer.deserialize_seq(GraphResultSetVisitor { client, cache_key })
417 }
418
419 async fn load_missing_ids(
420 graph_name: &str,
421 client: &Client,
422 num_node_labels: usize,
423 num_prop_keys: usize,
424 num_rel_types: usize,
425 ) -> Result<(Vec<String>, Vec<String>, Vec<String>)> {
426 let mut pipeline = client.create_pipeline();
427
428 pipeline.queue(cmd("GRAPH.QUERY").arg(graph_name.to_owned()).arg(format!(
430 "CALL db.labels() YIELD label RETURN label SKIP {}",
431 num_node_labels
432 )));
433
434 pipeline.queue(cmd("GRAPH.QUERY").arg(graph_name.to_owned()).arg(format!(
436 "CALL db.propertyKeys() YIELD propertyKey RETURN propertyKey SKIP {}",
437 num_prop_keys
438 )));
439
440 pipeline.queue(cmd("GRAPH.QUERY").arg(graph_name.to_owned()).arg(format!(
442 "CALL db.relationshipTypes() YIELD relationshipType RETURN relationshipType SKIP {}",
443 num_rel_types
444 )));
445
446 let (MappingsResult(node_labels), MappingsResult(prop_keys), MappingsResult(rel_types)) =
447 pipeline
448 .execute::<(MappingsResult, MappingsResult, MappingsResult)>()
449 .await?;
450
451 Ok((node_labels, prop_keys, rel_types))
452 }
453}
454
455struct MappingsResult(Vec<String>);
458
459impl<'de> Deserialize<'de> for MappingsResult {
460 #[inline]
461 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
462 where
463 D: serde::Deserializer<'de>,
464 {
465 struct MappingsSeed;
466
467 impl<'de> DeserializeSeed<'de> for MappingsSeed {
468 type Value = Vec<String>;
469
470 #[inline]
471 fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
472 where
473 D: Deserializer<'de>,
474 {
475 struct MappingSeed;
476
477 impl<'de> DeserializeSeed<'de> for MappingSeed {
478 type Value = String;
479
480 #[inline]
481 fn deserialize<D>(
482 self,
483 deserializer: D,
484 ) -> std::result::Result<Self::Value, D::Error>
485 where
486 D: Deserializer<'de>,
487 {
488 struct MappingVisitor;
489
490 impl<'de> Visitor<'de> for MappingVisitor {
491 type Value = String;
492
493 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
494 formatter.write_str("String")
495 }
496
497 fn visit_seq<A>(
498 self,
499 mut seq: A,
500 ) -> std::result::Result<Self::Value, A::Error>
501 where
502 A: de::SeqAccess<'de>,
503 {
504 let Some(mapping) = seq.next_element::<String>()? else {
505 return Err(de::Error::invalid_length(0, &"more elements in sequence"));
506 };
507
508 Ok(mapping)
509 }
510 }
511
512 deserializer.deserialize_seq(MappingVisitor)
513 }
514 }
515
516 struct MappingsVisitor;
517
518 impl<'de> Visitor<'de> for MappingsVisitor {
519 type Value = Vec<String>;
520
521 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
522 formatter.write_str("Vec<String>")
523 }
524
525 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
526 where
527 A: de::SeqAccess<'de>,
528 {
529 let mut mappings = if let Some(size_hint) = seq.size_hint() {
530 Vec::with_capacity(size_hint)
531 } else {
532 Vec::new()
533 };
534
535 while let Some(mapping) = seq.next_element_seed(MappingSeed)? {
536 mappings.push(mapping);
537 }
538
539 Ok(mappings)
540 }
541 }
542
543 deserializer.deserialize_seq(MappingsVisitor)
544 }
545 }
546
547 struct MappingsResultVisitor;
548
549 impl<'de> Visitor<'de> for MappingsResultVisitor {
550 type Value = MappingsResult;
551
552 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
553 formatter.write_str("MappingsResult")
554 }
555
556 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
557 where
558 A: serde::de::SeqAccess<'de>,
559 {
560 let Some(_header) = seq.next_element::<Vec::<String>>()? else {
561 return Err(de::Error::invalid_length(0, &"more elements in sequence"));
562 };
563
564 let Some(mappings) = seq.next_element_seed(MappingsSeed)? else {
565 return Err(de::Error::invalid_length(1, &"more elements in sequence"));
566 };
567
568 let Some(_stats) = seq.next_element::<Vec::<String>>()? else {
569 return Err(de::Error::invalid_length(2, &"more elements in sequence"));
570 };
571
572 Ok(MappingsResult(mappings))
573 }
574 }
575
576 deserializer.deserialize_seq(MappingsResultVisitor)
577 }
578}
579
580#[derive(Debug, Default)]
582pub struct GraphHeader {
583 pub column_names: Vec<String>,
584}
585
586impl<'de> Deserialize<'de> for GraphHeader {
587 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
588 where
589 D: Deserializer<'de>,
590 {
591 let header = SmallVec::<[(u16, String); 10]>::deserialize(deserializer)?;
592 let column_names = header
593 .into_iter()
594 .map(|(_colmun_type, column_name)| column_name)
595 .collect();
596
597 Ok(Self { column_names })
598 }
599}
600
601#[derive(Debug, Deserialize)]
603pub struct GraphResultRow {
604 pub values: Vec<GraphValue>,
608}
609
610pub struct GraphResultRowSeed<'a> {
611 cache: &'a GraphCache,
612}
613
614impl<'de, 'a> DeserializeSeed<'de> for GraphResultRowSeed<'a> {
615 type Value = GraphResultRow;
616
617 #[inline]
618 fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
619 where
620 D: Deserializer<'de>,
621 {
622 let values = GraphValueArraySeed { cache: self.cache }.deserialize(deserializer)?;
623
624 Ok(GraphResultRow { values })
625 }
626}
627
628struct GraphResultRowsSeed<'a> {
629 cache: &'a GraphCache,
630}
631
632impl<'de, 'a> Visitor<'de> for GraphResultRowsSeed<'a> {
633 type Value = Vec<GraphResultRow>;
634
635 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
636 formatter.write_str("Vec<GraphResultRow>")
637 }
638
639 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
640 where
641 A: de::SeqAccess<'de>,
642 {
643 let mut rows = if let Some(size) = seq.size_hint() {
644 Vec::with_capacity(size)
645 } else {
646 Vec::new()
647 };
648
649 while let Some(row) = seq.next_element_seed(GraphResultRowSeed { cache: self.cache })? {
650 rows.push(row);
651 }
652
653 Ok(rows)
654 }
655}
656
657impl<'de, 'a> DeserializeSeed<'de> for GraphResultRowsSeed<'a> {
658 type Value = Vec<GraphResultRow>;
659
660 #[inline]
661 fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
662 where
663 D: Deserializer<'de>,
664 {
665 deserializer.deserialize_seq(self)
666 }
667}
668
669#[derive(Debug, Default)]
671pub struct GraphQueryStatistics {
672 pub labels_added: usize,
673 pub labels_removed: usize,
674 pub nodes_created: usize,
675 pub nodes_deleted: usize,
676 pub properties_set: usize,
677 pub properties_removed: usize,
678 pub relationships_created: usize,
679 pub relationships_deleted: usize,
680 pub indices_created: usize,
681 pub indices_deleted: usize,
682 pub cached_execution: bool,
683 pub query_internal_execution_time: f64,
684 pub additional_statistics: HashMap<String, String>,
685}
686
687impl<'de> Deserialize<'de> for GraphQueryStatistics {
688 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
689 where
690 D: Deserializer<'de>,
691 {
692 struct GraphQueryStatisticsVisitor;
693
694 impl<'de> Visitor<'de> for GraphQueryStatisticsVisitor {
695 type Value = GraphQueryStatistics;
696
697 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
698 formatter.write_str("GraphQueryStatistics")
699 }
700
701 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
702 where
703 A: de::SeqAccess<'de>,
704 {
705 fn parse<'de, A, F>(value: &str) -> std::result::Result<F, A::Error>
706 where
707 A: de::SeqAccess<'de>,
708 F: FromStr,
709 {
710 match value.parse::<F>() {
711 Ok(value) => Ok(value),
712 Err(_) => Err(de::Error::custom(format!(
713 "Cannot parse GraphQueryStatistics: {value}"
714 ))),
715 }
716 }
717
718 fn parse_query_execution_time<'de, A>(
719 value: &str,
720 ) -> std::result::Result<f64, A::Error>
721 where
722 A: de::SeqAccess<'de>,
723 {
724 let Some((value, _milliseconds))= value.split_once(' ') else {
725 return Err(de::Error::custom("Cannot parse GraphQueryStatistics (query exuction time)"));
726 };
727
728 match value.parse::<f64>() {
729 Ok(value) => Ok(value),
730 Err(_) => Err(de::Error::custom(
731 "Cannot parse GraphQueryStatistics (query exuction time)",
732 )),
733 }
734 }
735
736 let mut stats = GraphQueryStatistics::default();
737
738 while let Some(str) = seq.next_element::<&str>()? {
739 let Some((name, value))= str.split_once(": ") else {
740 return Err(de::Error::custom("Cannot parse GraphQueryStatistics"));
741 };
742
743 match name {
744 "Labels added" => stats.labels_added = parse::<A, _>(value)?,
745 "Labels removed" => stats.labels_removed = parse::<A, _>(value)?,
746 "Nodes created" => stats.nodes_created = parse::<A, _>(value)?,
747 "Nodes deleted:" => stats.nodes_deleted = parse::<A, _>(value)?,
748 "Properties set" => stats.properties_set = parse::<A, _>(value)?,
749 "Properties removed" => stats.properties_removed = parse::<A, _>(value)?,
750 "Relationships created" => {
751 stats.relationships_created = parse::<A, _>(value)?
752 }
753 "Relationships deleted" => {
754 stats.relationships_deleted = parse::<A, _>(value)?
755 }
756 "Indices created" => stats.indices_created = parse::<A, _>(value)?,
757 "Indices deleted" => stats.indices_deleted = parse::<A, _>(value)?,
758 "Cached execution" => stats.cached_execution = parse::<A, u8>(value)? != 0,
759 "Query internal execution time" => {
760 stats.query_internal_execution_time =
761 parse_query_execution_time::<A>(value)?
762 }
763 _ => {
764 stats
765 .additional_statistics
766 .insert(name.to_owned(), value.to_owned());
767 }
768 }
769 }
770
771 Ok(stats)
772 }
773 }
774
775 deserializer.deserialize_seq(GraphQueryStatisticsVisitor)
776 }
777}
778
779#[derive(Debug, Deserialize)]
781pub struct GraphSlowlogResult {
782 pub processing_time: u64,
784 pub issued_command: String,
786 pub issued_query: String,
788 pub execution_duration: f64,
790}