1use anyhow::Result;
10use arrow_array::{
11 Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
12 UInt64Array, cast::AsArray,
13};
14use arrow_schema::{DataType, Field, Schema};
15use futures::TryStreamExt;
16use lance_graph::{CypherQuery, DirNamespace, GraphConfig};
17use lancedb::{
18 Connection, Table as LanceTable, connect,
19 index::{Index, scalar::FullTextSearchQuery},
20 query::{ExecutableQuery, QueryBase},
21};
22use std::{collections::HashMap, path::Path, sync::Arc};
23
24const ENTITIES_TABLE: &str = "entities";
25const RELATIONS_TABLE: &str = "relations";
26const JOURNALS_TABLE: &str = "journals";
27const CONNECTIONS_MAX: usize = 100;
28
29pub const EMBED_DIM: i32 = 384;
31
32pub struct EntityRow<'a> {
34 pub id: &'a str,
35 pub entity_type: &'a str,
36 pub key: &'a str,
37 pub value: &'a str,
38 pub vector: Vec<f32>,
39}
40
41pub struct RelationRow<'a> {
43 pub source: &'a str,
44 pub relation: &'a str,
45 pub target: &'a str,
46}
47
48pub struct EntityResult {
50 pub id: String,
51 pub entity_type: String,
52 pub key: String,
53 pub value: String,
54 pub created_at: u64,
55}
56
57pub struct RelationResult {
59 pub source: String,
60 pub relation: String,
61 pub target: String,
62 pub created_at: u64,
63}
64
65pub struct JournalResult {
67 pub summary: String,
68 pub agent: String,
69 pub created_at: u64,
70}
71
72pub struct LanceStore {
77 _db: Connection,
78 entities: LanceTable,
79 relations: LanceTable,
80 journals: LanceTable,
81 namespace: Arc<DirNamespace>,
82 graph_config: GraphConfig,
83}
84
85impl LanceStore {
86 pub async fn open<F>(path: impl AsRef<Path>, embed_fn: F) -> Result<Self>
92 where
93 F: Fn(&str) -> Result<Vec<f32>>,
94 {
95 let path = path.as_ref();
96 let db = connect(path.to_str().unwrap_or(".")).execute().await?;
97
98 let mut entities = open_or_create(&db, ENTITIES_TABLE, entity_schema()).await?;
99 let mut relations = open_or_create(&db, RELATIONS_TABLE, relation_schema()).await?;
100 let journals = open_or_create(&db, JOURNALS_TABLE, journal_schema()).await?;
101
102 let schema = entities.schema().await?;
104 let has_agent = schema.fields().iter().any(|f| f.name() == "agent");
105 if has_agent {
106 tracing::info!("detected v1 schema — migrating entities and relations");
107 let (e, r) = migrate_v1_to_v2(&db, &entities, &relations, &embed_fn).await?;
108 entities = e;
109 relations = r;
110 tracing::info!("v1 → v2 migration complete");
111 } else {
112 let has_vector = schema.fields().iter().any(|f| f.name() == "vector");
114 if !has_vector {
115 tracing::info!("detected v2 schema — backfilling entity embeddings");
116 entities = backfill_entity_vectors(&db, &entities, &embed_fn).await?;
117 tracing::info!("entity vector backfill complete");
118 }
119 }
120
121 let namespace = Arc::new(DirNamespace::new(path.to_str().unwrap_or(".")));
122 let graph_config = GraphConfig::builder()
123 .with_node_label(ENTITIES_TABLE, "id")
124 .with_relationship(RELATIONS_TABLE, "source", "target")
125 .build()?;
126
127 let store = Self {
128 _db: db,
129 entities,
130 relations,
131 journals,
132 namespace,
133 graph_config,
134 };
135 store.ensure_entity_indices().await;
136 store.ensure_relation_indices().await;
137 store.ensure_journal_indices().await;
138 Ok(store)
139 }
140
141 pub async fn upsert_entity(&self, row: &EntityRow<'_>) -> Result<()> {
148 let batch = make_entity_batch(row)?;
149 let schema = batch.schema();
150 let batches = RecordBatchIterator::new(std::iter::once(Ok(batch)), schema);
151
152 let mut merge = self.entities.merge_insert(&["id"]);
153 merge
154 .when_matched_update_all(None)
155 .when_not_matched_insert_all();
156 merge.execute(Box::new(batches)).await?;
157 Ok(())
158 }
159
160 pub async fn search_entities(
162 &self,
163 query: &str,
164 entity_type: Option<&str>,
165 limit: usize,
166 ) -> Result<Vec<EntityResult>> {
167 let mut q = self
168 .entities
169 .query()
170 .full_text_search(FullTextSearchQuery::new(query.to_owned()));
171 if let Some(et) = entity_type {
172 q = q.only_if(format!("entity_type = '{}'", escape_sql(et)));
173 }
174 let batches: Vec<RecordBatch> = q.limit(limit).execute().await?.try_collect().await?;
175 batches_to_entities(&batches)
176 }
177
178 pub async fn search_entities_semantic(
180 &self,
181 query_vector: &[f32],
182 entity_type: Option<&str>,
183 limit: usize,
184 ) -> Result<Vec<EntityResult>> {
185 let mut q = self.entities.query().nearest_to(query_vector)?;
186 if let Some(et) = entity_type {
187 q = q.only_if(format!("entity_type = '{}'", escape_sql(et)));
188 }
189 let batches: Vec<RecordBatch> = q.limit(limit).execute().await?.try_collect().await?;
190 batches_to_entities(&batches)
191 }
192
193 pub async fn query_by_type(
195 &self,
196 entity_type: &str,
197 limit: usize,
198 ) -> Result<Vec<EntityResult>> {
199 let filter = format!("entity_type = '{}'", escape_sql(entity_type));
200 let batches: Vec<RecordBatch> = self
201 .entities
202 .query()
203 .only_if(filter)
204 .limit(limit)
205 .execute()
206 .await?
207 .try_collect()
208 .await?;
209 batches_to_entities(&batches)
210 }
211
212 pub async fn find_entity_by_key(&self, key: &str) -> Result<Option<EntityResult>> {
214 let filter = format!("key = '{}'", escape_sql(key));
215 let batches: Vec<RecordBatch> = self
216 .entities
217 .query()
218 .only_if(filter)
219 .limit(1)
220 .execute()
221 .await?
222 .try_collect()
223 .await?;
224 let entities = batches_to_entities(&batches)?;
225 Ok(entities.into_iter().next())
226 }
227
228 pub async fn list_entities(
230 &self,
231 entity_type: Option<&str>,
232 limit: usize,
233 ) -> Result<Vec<EntityResult>> {
234 let mut q = self.entities.query();
235 if let Some(et) = entity_type {
236 q = q.only_if(format!("entity_type = '{}'", escape_sql(et)));
237 }
238 let batches: Vec<RecordBatch> = q.limit(limit).execute().await?.try_collect().await?;
239 batches_to_entities(&batches)
240 }
241
242 pub async fn upsert_relation(&self, row: &RelationRow<'_>) -> Result<()> {
244 let batch = make_relation_batch(row)?;
245 let schema = batch.schema();
246 let batches = RecordBatchIterator::new(std::iter::once(Ok(batch)), schema);
247
248 let mut merge = self
249 .relations
250 .merge_insert(&["source", "relation", "target"]);
251 merge
252 .when_matched_update_all(None)
253 .when_not_matched_insert_all();
254 merge.execute(Box::new(batches)).await?;
255 Ok(())
256 }
257
258 pub async fn find_connections(
260 &self,
261 entity_id: &str,
262 relation: Option<&str>,
263 direction: Direction,
264 limit: usize,
265 ) -> Result<Vec<RelationResult>> {
266 let limit = limit.min(CONNECTIONS_MAX);
267 let cypher = build_connections_cypher(entity_id, relation, direction, limit);
268 let query = CypherQuery::new(&cypher)?.with_config(self.graph_config.clone());
269 let batch = query
270 .execute_with_namespace_arc(Arc::clone(&self.namespace), None)
271 .await?;
272
273 batch_to_relations(&batch)
274 }
275
276 pub async fn list_relations(
278 &self,
279 entity_id: Option<&str>,
280 limit: usize,
281 ) -> Result<Vec<RelationResult>> {
282 let mut q = self.relations.query();
283 if let Some(eid) = entity_id {
284 let escaped = escape_sql(eid);
285 q = q.only_if(format!("source = '{escaped}' OR target = '{escaped}'"));
286 }
287 let batches: Vec<RecordBatch> = q.limit(limit).execute().await?.try_collect().await?;
288 batches_to_relation_results(&batches)
289 }
290
291 async fn ensure_entity_indices(&self) {
293 let idx = [
294 (
295 vec!["key", "value"],
296 Index::FTS(Default::default()),
297 "entities FTS",
298 ),
299 (vec!["id"], Index::BTree(Default::default()), "entities id"),
300 (
301 vec!["key"],
302 Index::BTree(Default::default()),
303 "entities key",
304 ),
305 (
306 vec!["entity_type"],
307 Index::Bitmap(Default::default()),
308 "entities entity_type",
309 ),
310 ];
311 for (cols, index, name) in idx {
312 if let Err(e) = self.entities.create_index(&cols, index).execute().await {
313 tracing::warn!("{name} index creation skipped: {e}");
314 }
315 }
316 }
317
318 pub async fn insert_journal(&self, agent: &str, summary: &str, vector: Vec<f32>) -> Result<()> {
320 let batch = make_journal_batch(agent, summary, vector)?;
321 let schema = batch.schema();
322 let batches = RecordBatchIterator::new(std::iter::once(Ok(batch)), schema);
323 self.journals.add(Box::new(batches)).execute().await?;
324 Ok(())
325 }
326
327 pub async fn list_journals(
329 &self,
330 agent: Option<&str>,
331 limit: usize,
332 ) -> Result<Vec<JournalResult>> {
333 let mut q = self.journals.query();
334 if let Some(a) = agent {
335 q = q.only_if(format!("agent = '{}'", escape_sql(a)));
336 }
337 let batches: Vec<RecordBatch> = q.limit(limit).execute().await?.try_collect().await?;
338 let mut results = batches_to_journals(&batches)?;
339 results.sort_by(|a, b| b.created_at.cmp(&a.created_at));
340 Ok(results)
341 }
342
343 pub async fn recent_journals(&self, agent: &str, limit: usize) -> Result<Vec<JournalResult>> {
345 let filter = format!("agent = '{}'", escape_sql(agent));
346 let batches: Vec<RecordBatch> = self
347 .journals
348 .query()
349 .only_if(filter)
350 .limit(limit)
351 .execute()
352 .await?
353 .try_collect()
354 .await?;
355 let mut results = batches_to_journals(&batches)?;
356 results.sort_by(|a, b| b.created_at.cmp(&a.created_at));
357 Ok(results)
358 }
359
360 async fn ensure_journal_indices(&self) {
362 let idx = [
363 (
364 vec!["agent"],
365 Index::Bitmap(Default::default()),
366 "journals agent",
367 ),
368 (vec!["id"], Index::BTree(Default::default()), "journals id"),
369 ];
370 for (cols, index, name) in idx {
371 if let Err(e) = self.journals.create_index(&cols, index).execute().await {
372 tracing::warn!("{name} index creation skipped: {e}");
373 }
374 }
375 }
376
377 async fn ensure_relation_indices(&self) {
379 let idx = [
380 (
381 vec!["source"],
382 Index::BTree(Default::default()),
383 "relations source",
384 ),
385 (
386 vec!["target"],
387 Index::BTree(Default::default()),
388 "relations target",
389 ),
390 (
391 vec!["relation"],
392 Index::Bitmap(Default::default()),
393 "relations relation",
394 ),
395 ];
396 for (cols, index, name) in idx {
397 if let Err(e) = self.relations.create_index(&cols, index).execute().await {
398 tracing::warn!("{name} index creation skipped: {e}");
399 }
400 }
401 }
402}
403
404pub enum Direction {
406 Outgoing,
407 Incoming,
408 Both,
409}
410
411async fn open_or_create(db: &Connection, name: &str, schema: Arc<Schema>) -> Result<LanceTable> {
414 match db.open_table(name).execute().await {
415 Ok(t) => Ok(t),
416 Err(_) => {
417 let batches = RecordBatchIterator::new(std::iter::empty(), Arc::clone(&schema));
418 Ok(db.create_table(name, Box::new(batches)).execute().await?)
419 }
420 }
421}
422
423async fn backfill_entity_vectors<F>(
428 db: &Connection,
429 entities: &LanceTable,
430 embed_fn: &F,
431) -> Result<LanceTable>
432where
433 F: Fn(&str) -> Result<Vec<f32>>,
434{
435 let batches: Vec<RecordBatch> = entities.query().execute().await?.try_collect().await?;
436 #[allow(clippy::type_complexity)]
438 let mut rows: Vec<(String, String, String, String, Vec<f32>, u64, u64)> = Vec::new();
439 for batch in &batches {
440 let ids = migration_col(batch, "id")?.as_string::<i32>();
441 let types = migration_col(batch, "entity_type")?.as_string::<i32>();
442 let keys = migration_col(batch, "key")?.as_string::<i32>();
443 let values = migration_col(batch, "value")?.as_string::<i32>();
444 let created =
445 migration_col(batch, "created_at")?.as_primitive::<arrow_array::types::UInt64Type>();
446 let updated =
447 migration_col(batch, "updated_at")?.as_primitive::<arrow_array::types::UInt64Type>();
448 for i in 0..batch.num_rows() {
449 let key = keys.value(i);
450 let value = values.value(i);
451 let text = format!("{key} {value}");
452 let vector = embed_fn(&text)?;
453 rows.push((
454 ids.value(i).to_string(),
455 types.value(i).to_string(),
456 key.to_string(),
457 value.to_string(),
458 vector,
459 created.value(i),
460 updated.value(i),
461 ));
462 }
463 }
464
465 let count = rows.len();
466 tracing::info!("backfilling {count} entities with embeddings");
467
468 db.drop_table(ENTITIES_TABLE, &[]).await?;
469 let schema = entity_schema();
470 if rows.is_empty() {
471 let batches = RecordBatchIterator::new(std::iter::empty(), Arc::clone(&schema));
472 return Ok(db
473 .create_table(ENTITIES_TABLE, Box::new(batches))
474 .execute()
475 .await?);
476 }
477
478 let mut ids = Vec::with_capacity(count);
479 let mut types = Vec::with_capacity(count);
480 let mut keys_vec = Vec::with_capacity(count);
481 let mut values = Vec::with_capacity(count);
482 let mut all_vectors: Vec<f32> = Vec::with_capacity(count * EMBED_DIM as usize);
483 let mut created_ats = Vec::with_capacity(count);
484 let mut updated_ats = Vec::with_capacity(count);
485 for (id, et, key, value, vector, crt, upd) in rows {
486 ids.push(id);
487 types.push(et);
488 keys_vec.push(key);
489 values.push(value);
490 all_vectors.extend(vector);
491 created_ats.push(crt);
492 updated_ats.push(upd);
493 }
494
495 let float_array = Float32Array::from(all_vectors);
496 let field = Arc::new(Field::new("item", DataType::Float32, true));
497 let vector_array = FixedSizeListArray::new(field, EMBED_DIM, Arc::new(float_array), None);
498
499 let batch = RecordBatch::try_new(
500 Arc::clone(&schema),
501 vec![
502 Arc::new(StringArray::from(ids)) as Arc<dyn Array>,
503 Arc::new(StringArray::from(types)) as Arc<dyn Array>,
504 Arc::new(StringArray::from(keys_vec)) as Arc<dyn Array>,
505 Arc::new(StringArray::from(values)) as Arc<dyn Array>,
506 Arc::new(vector_array) as Arc<dyn Array>,
507 Arc::new(UInt64Array::from(created_ats)) as Arc<dyn Array>,
508 Arc::new(UInt64Array::from(updated_ats)) as Arc<dyn Array>,
509 ],
510 )?;
511 let batches = RecordBatchIterator::new(std::iter::once(Ok(batch)), schema);
512 Ok(db
513 .create_table(ENTITIES_TABLE, Box::new(batches))
514 .execute()
515 .await?)
516}
517
518async fn migrate_v1_to_v2<F>(
525 db: &Connection,
526 entities: &LanceTable,
527 relations: &LanceTable,
528 embed_fn: &F,
529) -> Result<(LanceTable, LanceTable)>
530where
531 F: Fn(&str) -> Result<Vec<f32>>,
532{
533 let entity_batches: Vec<RecordBatch> = entities.query().execute().await?.try_collect().await?;
535
536 let mut id_remap: HashMap<String, String> = HashMap::new();
538 let mut deduped: HashMap<(String, String), (String, u64, u64)> = HashMap::new();
539 for batch in &entity_batches {
540 let ids = migration_col(batch, "id")?.as_string::<i32>();
541 let types = migration_col(batch, "entity_type")?.as_string::<i32>();
542 let keys = migration_col(batch, "key")?.as_string::<i32>();
543 let values = migration_col(batch, "value")?.as_string::<i32>();
544 let updated =
545 migration_col(batch, "updated_at")?.as_primitive::<arrow_array::types::UInt64Type>();
546 let created =
547 migration_col(batch, "created_at")?.as_primitive::<arrow_array::types::UInt64Type>();
548
549 for i in 0..batch.num_rows() {
550 let old_id = ids.value(i).to_string();
551 let et = types.value(i).to_string();
552 let key = keys.value(i).to_string();
553 let new_id = format!("{et}:{key}");
554 id_remap.insert(old_id, new_id);
555
556 let value = values.value(i).to_string();
557 let upd = updated.value(i);
558 let crt = created.value(i);
559 let map_key = (et, key);
560 let entry = deduped.entry(map_key).or_insert((value.clone(), crt, upd));
561 if upd > entry.2 {
562 *entry = (value, crt, upd);
563 }
564 }
565 }
566
567 let entity_count = deduped.len();
568 tracing::info!("migrating {entity_count} deduplicated entities");
569
570 db.drop_table(ENTITIES_TABLE, &[]).await?;
571 let schema = entity_schema();
572 let new_entities = if deduped.is_empty() {
573 let batches = RecordBatchIterator::new(std::iter::empty(), Arc::clone(&schema));
574 db.create_table(ENTITIES_TABLE, Box::new(batches))
575 .execute()
576 .await?
577 } else {
578 let mut ids = Vec::with_capacity(entity_count);
579 let mut types = Vec::with_capacity(entity_count);
580 let mut keys_vec = Vec::with_capacity(entity_count);
581 let mut values = Vec::with_capacity(entity_count);
582 let mut all_vectors: Vec<f32> = Vec::with_capacity(entity_count * EMBED_DIM as usize);
583 let mut created_ats = Vec::with_capacity(entity_count);
584 let mut updated_ats = Vec::with_capacity(entity_count);
585
586 for ((et, key), (value, crt, upd)) in &deduped {
587 let text = format!("{key} {value}");
588 let vector = embed_fn(&text)?;
589 ids.push(format!("{et}:{key}"));
590 types.push(et.clone());
591 keys_vec.push(key.clone());
592 values.push(value.clone());
593 all_vectors.extend(vector);
594 created_ats.push(*crt);
595 updated_ats.push(*upd);
596 }
597
598 let float_array = Float32Array::from(all_vectors);
599 let field = Arc::new(Field::new("item", DataType::Float32, true));
600 let vector_array = FixedSizeListArray::new(field, EMBED_DIM, Arc::new(float_array), None);
601
602 let batch = RecordBatch::try_new(
603 Arc::clone(&schema),
604 vec![
605 Arc::new(StringArray::from(ids)) as Arc<dyn Array>,
606 Arc::new(StringArray::from(types)) as Arc<dyn Array>,
607 Arc::new(StringArray::from(keys_vec)) as Arc<dyn Array>,
608 Arc::new(StringArray::from(values)) as Arc<dyn Array>,
609 Arc::new(vector_array) as Arc<dyn Array>,
610 Arc::new(UInt64Array::from(created_ats)) as Arc<dyn Array>,
611 Arc::new(UInt64Array::from(updated_ats)) as Arc<dyn Array>,
612 ],
613 )?;
614 let batches = RecordBatchIterator::new(std::iter::once(Ok(batch)), schema);
615 db.create_table(ENTITIES_TABLE, Box::new(batches))
616 .execute()
617 .await?
618 };
619
620 let relation_batches: Vec<RecordBatch> =
622 relations.query().execute().await?.try_collect().await?;
623
624 let mut rel_deduped: HashMap<(String, String, String), u64> = HashMap::new();
626 for batch in &relation_batches {
627 let sources = migration_col(batch, "source")?.as_string::<i32>();
628 let rels = migration_col(batch, "relation")?.as_string::<i32>();
629 let targets = migration_col(batch, "target")?.as_string::<i32>();
630 let created =
631 migration_col(batch, "created_at")?.as_primitive::<arrow_array::types::UInt64Type>();
632
633 for i in 0..batch.num_rows() {
634 let raw_source = sources.value(i);
635 let raw_target = targets.value(i);
636 let rel = rels.value(i).to_string();
637 let crt = created.value(i);
638
639 let source = id_remap
641 .get(raw_source)
642 .cloned()
643 .unwrap_or_else(|| raw_source.to_string());
644 let target = id_remap
645 .get(raw_target)
646 .cloned()
647 .unwrap_or_else(|| raw_target.to_string());
648
649 rel_deduped.entry((source, rel, target)).or_insert(crt);
650 }
651 }
652
653 let rel_count = rel_deduped.len();
654 tracing::info!("migrating {rel_count} deduplicated relations");
655
656 db.drop_table(RELATIONS_TABLE, &[]).await?;
657 let rel_schema = relation_schema();
658 let new_relations = if rel_deduped.is_empty() {
659 let batches = RecordBatchIterator::new(std::iter::empty(), Arc::clone(&rel_schema));
660 db.create_table(RELATIONS_TABLE, Box::new(batches))
661 .execute()
662 .await?
663 } else {
664 let mut sources = Vec::with_capacity(rel_count);
665 let mut rels = Vec::with_capacity(rel_count);
666 let mut targets = Vec::with_capacity(rel_count);
667 let mut created_ats = Vec::with_capacity(rel_count);
668
669 for ((source, rel, target), crt) in &rel_deduped {
670 sources.push(source.clone());
671 rels.push(rel.clone());
672 targets.push(target.clone());
673 created_ats.push(*crt);
674 }
675
676 let batch = RecordBatch::try_new(
677 Arc::clone(&rel_schema),
678 vec![
679 Arc::new(StringArray::from(sources)) as Arc<dyn Array>,
680 Arc::new(StringArray::from(rels)) as Arc<dyn Array>,
681 Arc::new(StringArray::from(targets)) as Arc<dyn Array>,
682 Arc::new(UInt64Array::from(created_ats)) as Arc<dyn Array>,
683 ],
684 )?;
685 let batches = RecordBatchIterator::new(std::iter::once(Ok(batch)), rel_schema);
686 db.create_table(RELATIONS_TABLE, Box::new(batches))
687 .execute()
688 .await?
689 };
690
691 Ok((new_entities, new_relations))
692}
693
694fn entity_schema() -> Arc<Schema> {
695 Arc::new(Schema::new(vec![
696 Field::new("id", DataType::Utf8, false),
697 Field::new("entity_type", DataType::Utf8, false),
698 Field::new("key", DataType::Utf8, false),
699 Field::new("value", DataType::Utf8, false),
700 Field::new(
701 "vector",
702 DataType::FixedSizeList(
703 Arc::new(Field::new("item", DataType::Float32, true)),
704 EMBED_DIM,
705 ),
706 false,
707 ),
708 Field::new("created_at", DataType::UInt64, false),
709 Field::new("updated_at", DataType::UInt64, false),
710 ]))
711}
712
713fn relation_schema() -> Arc<Schema> {
714 Arc::new(Schema::new(vec![
715 Field::new("source", DataType::Utf8, false),
716 Field::new("relation", DataType::Utf8, false),
717 Field::new("target", DataType::Utf8, false),
718 Field::new("created_at", DataType::UInt64, false),
719 ]))
720}
721
722fn journal_schema() -> Arc<Schema> {
723 Arc::new(Schema::new(vec![
724 Field::new("id", DataType::Utf8, false),
725 Field::new("agent", DataType::Utf8, false),
726 Field::new("summary", DataType::Utf8, false),
727 Field::new(
728 "vector",
729 DataType::FixedSizeList(
730 Arc::new(Field::new("item", DataType::Float32, true)),
731 EMBED_DIM,
732 ),
733 false,
734 ),
735 Field::new("created_at", DataType::UInt64, false),
736 ]))
737}
738
739fn make_journal_batch(agent: &str, summary: &str, vector: Vec<f32>) -> Result<RecordBatch> {
740 let schema = journal_schema();
741 let now = now_unix();
742 let id = format!("{agent}:{now}");
743 let values = Float32Array::from(vector);
744 let field = Arc::new(Field::new("item", DataType::Float32, true));
745 let vector_array = FixedSizeListArray::new(field, EMBED_DIM, Arc::new(values), None);
746 Ok(RecordBatch::try_new(
747 schema,
748 vec![
749 Arc::new(StringArray::from(vec![id.as_str()])) as Arc<dyn Array>,
750 Arc::new(StringArray::from(vec![agent])) as Arc<dyn Array>,
751 Arc::new(StringArray::from(vec![summary])) as Arc<dyn Array>,
752 Arc::new(vector_array) as Arc<dyn Array>,
753 Arc::new(UInt64Array::from(vec![now])) as Arc<dyn Array>,
754 ],
755 )?)
756}
757
758fn batches_to_journals(batches: &[RecordBatch]) -> Result<Vec<JournalResult>> {
759 let mut results = Vec::new();
760 for batch in batches {
761 let summaries = batch
762 .column_by_name("summary")
763 .ok_or_else(|| anyhow::anyhow!("missing column: summary"))?
764 .as_string::<i32>();
765 let agents = batch
766 .column_by_name("agent")
767 .ok_or_else(|| anyhow::anyhow!("missing column: agent"))?
768 .as_string::<i32>();
769 let timestamps = batch
770 .column_by_name("created_at")
771 .ok_or_else(|| anyhow::anyhow!("missing column: created_at"))?
772 .as_primitive::<arrow_array::types::UInt64Type>();
773 for i in 0..batch.num_rows() {
774 results.push(JournalResult {
775 summary: summaries.value(i).to_string(),
776 agent: agents.value(i).to_string(),
777 created_at: timestamps.value(i),
778 });
779 }
780 }
781 Ok(results)
782}
783
784fn make_entity_batch(row: &EntityRow<'_>) -> Result<RecordBatch> {
785 let schema = entity_schema();
786 let now = now_unix();
787 let values = Float32Array::from(row.vector.clone());
788 let field = Arc::new(Field::new("item", DataType::Float32, true));
789 let vector_array = FixedSizeListArray::new(field, EMBED_DIM, Arc::new(values), None);
790 Ok(RecordBatch::try_new(
791 schema,
792 vec![
793 Arc::new(StringArray::from(vec![row.id])) as Arc<dyn Array>,
794 Arc::new(StringArray::from(vec![row.entity_type])) as Arc<dyn Array>,
795 Arc::new(StringArray::from(vec![row.key])) as Arc<dyn Array>,
796 Arc::new(StringArray::from(vec![row.value])) as Arc<dyn Array>,
797 Arc::new(vector_array) as Arc<dyn Array>,
798 Arc::new(UInt64Array::from(vec![now])) as Arc<dyn Array>,
799 Arc::new(UInt64Array::from(vec![now])) as Arc<dyn Array>,
800 ],
801 )?)
802}
803
804fn make_relation_batch(row: &RelationRow<'_>) -> Result<RecordBatch> {
805 let schema = relation_schema();
806 let now = now_unix();
807 Ok(RecordBatch::try_new(
808 schema,
809 vec![
810 Arc::new(StringArray::from(vec![row.source])) as Arc<dyn Array>,
811 Arc::new(StringArray::from(vec![row.relation])) as Arc<dyn Array>,
812 Arc::new(StringArray::from(vec![row.target])) as Arc<dyn Array>,
813 Arc::new(UInt64Array::from(vec![now])) as Arc<dyn Array>,
814 ],
815 )?)
816}
817
818fn batches_to_entities(batches: &[RecordBatch]) -> Result<Vec<EntityResult>> {
819 let mut results = Vec::new();
820 for batch in batches {
821 let ids = batch
822 .column_by_name("id")
823 .ok_or_else(|| anyhow::anyhow!("missing column: id"))?
824 .as_string::<i32>();
825 let types = batch
826 .column_by_name("entity_type")
827 .ok_or_else(|| anyhow::anyhow!("missing column: entity_type"))?
828 .as_string::<i32>();
829 let keys = batch
830 .column_by_name("key")
831 .ok_or_else(|| anyhow::anyhow!("missing column: key"))?
832 .as_string::<i32>();
833 let values = batch
834 .column_by_name("value")
835 .ok_or_else(|| anyhow::anyhow!("missing column: value"))?
836 .as_string::<i32>();
837 let timestamps = batch
838 .column_by_name("created_at")
839 .ok_or_else(|| anyhow::anyhow!("missing column: created_at"))?
840 .as_primitive::<arrow_array::types::UInt64Type>();
841 for i in 0..batch.num_rows() {
842 results.push(EntityResult {
843 id: ids.value(i).to_string(),
844 entity_type: types.value(i).to_string(),
845 key: keys.value(i).to_string(),
846 value: values.value(i).to_string(),
847 created_at: timestamps.value(i),
848 });
849 }
850 }
851 Ok(results)
852}
853
854fn batches_to_relation_results(batches: &[RecordBatch]) -> Result<Vec<RelationResult>> {
856 let mut results = Vec::new();
857 for batch in batches {
858 let sources = batch
859 .column_by_name("source")
860 .ok_or_else(|| anyhow::anyhow!("missing column: source"))?
861 .as_string::<i32>();
862 let relations = batch
863 .column_by_name("relation")
864 .ok_or_else(|| anyhow::anyhow!("missing column: relation"))?
865 .as_string::<i32>();
866 let targets = batch
867 .column_by_name("target")
868 .ok_or_else(|| anyhow::anyhow!("missing column: target"))?
869 .as_string::<i32>();
870 let timestamps = batch
871 .column_by_name("created_at")
872 .ok_or_else(|| anyhow::anyhow!("missing column: created_at"))?
873 .as_primitive::<arrow_array::types::UInt64Type>();
874 for i in 0..batch.num_rows() {
875 results.push(RelationResult {
876 source: sources.value(i).to_string(),
877 relation: relations.value(i).to_string(),
878 target: targets.value(i).to_string(),
879 created_at: timestamps.value(i),
880 });
881 }
882 }
883 Ok(results)
884}
885
886fn batch_to_relations(batch: &RecordBatch) -> Result<Vec<RelationResult>> {
887 if batch.num_rows() == 0 {
888 return Ok(Vec::new());
889 }
890 let sources = batch
893 .column_by_name("r__source")
894 .ok_or_else(|| anyhow::anyhow!("missing column: r__source"))?
895 .as_string::<i32>();
896 let relations = batch
897 .column_by_name("r__relation")
898 .ok_or_else(|| anyhow::anyhow!("missing column: r__relation"))?
899 .as_string::<i32>();
900 let targets = batch
901 .column_by_name("r__target")
902 .ok_or_else(|| anyhow::anyhow!("missing column: r__target"))?
903 .as_string::<i32>();
904 Ok((0..batch.num_rows())
906 .map(|i| RelationResult {
907 source: sources.value(i).to_string(),
908 relation: relations.value(i).to_string(),
909 target: targets.value(i).to_string(),
910 created_at: 0,
911 })
912 .collect())
913}
914
915fn build_connections_cypher(
917 entity_id: &str,
918 relation: Option<&str>,
919 direction: Direction,
920 limit: usize,
921) -> String {
922 let eid = escape_cypher(entity_id);
923
924 let rel_type = relation
925 .map(|r| format!(":`{}`", escape_cypher_ident(r)))
926 .unwrap_or_default();
927
928 let pattern = match direction {
929 Direction::Outgoing => {
930 format!("(e:entities {{id: '{eid}'}})-[r:relations{rel_type}]->(t:entities)")
931 }
932 Direction::Incoming => {
933 format!("(e:entities)<-[r:relations{rel_type}]-(s:entities {{id: '{eid}'}})")
934 }
935 Direction::Both => {
936 format!("(e:entities)-[r:relations{rel_type}]-(o:entities {{id: '{eid}'}})")
937 }
938 };
939
940 format!("MATCH {pattern} RETURN r.source, r.relation, r.target LIMIT {limit}")
941}
942
943fn migration_col<'a>(batch: &'a RecordBatch, name: &str) -> Result<&'a Arc<dyn Array>> {
945 batch
946 .column_by_name(name)
947 .ok_or_else(|| anyhow::anyhow!("migration: missing column '{name}'"))
948}
949
950fn escape_sql(s: &str) -> String {
951 s.replace('\'', "''")
952}
953
954fn escape_cypher(s: &str) -> String {
955 s.replace('\\', "\\\\").replace('\'', "\\'")
956}
957
958fn escape_cypher_ident(s: &str) -> String {
960 s.replace('`', "``")
961}
962
963fn now_unix() -> u64 {
964 std::time::SystemTime::now()
965 .duration_since(std::time::UNIX_EPOCH)
966 .expect("system clock before unix epoch")
967 .as_secs()
968}