1use std::sync::Arc;
40
41use smos_domain::chat::ToolCall;
42use smos_domain::config::{ConfidenceConfig, ExtractionConfig};
43use smos_domain::{MemoryKey, SessionId};
44
45use crate::errors::UseCaseError;
46use crate::ports::{
47 Clock, Delay, EmbeddingProvider, FactRepository, LlmExtractor, SessionRepository,
48};
49use crate::use_cases::extract_facts_from_response::ExtractFactsFromResponse;
50
51#[derive(Debug, Clone, PartialEq)]
57pub struct AssistantTurn {
58 pub message_id: String,
59 pub agent: String,
60 pub content: String,
61 pub tool_calls: Vec<ToolCall>,
62}
63
64#[derive(Debug, Clone, Default)]
71pub struct ImportStats {
72 pub session_id: String,
73 pub turns_processed: usize,
74 pub turns_skipped: usize,
75 pub facts_extracted: usize,
76}
77
78pub struct ImportOpencodeSession<FR, SR, EP, LE, C, D> {
86 pub facts: FR,
87 pub sessions: SR,
88 pub embedder: EP,
89 pub extractor: LE,
90 pub clock: C,
91 pub delay: D,
92 pub confidence_cfg: Arc<ConfidenceConfig>,
93 pub extraction_cfg: Arc<ExtractionConfig>,
97 pub enable_response_extraction: bool,
98 pub min_chars: usize,
103}
104
105impl<FR, SR, EP, LE, C, D> ImportOpencodeSession<FR, SR, EP, LE, C, D>
106where
107 FR: FactRepository,
108 SR: SessionRepository,
109 EP: EmbeddingProvider,
110 LE: LlmExtractor,
111 C: Clock,
112 D: Delay,
113{
114 pub async fn execute(
121 &self,
122 turns: Vec<AssistantTurn>,
123 memory_key: &MemoryKey,
124 session_id: &SessionId,
125 agent_filter: Option<&[String]>,
126 ) -> Result<ImportStats, UseCaseError> {
127 let mut stats = ImportStats {
128 session_id: session_id.as_str().to_string(),
129 ..Default::default()
130 };
131
132 self.sessions.get_or_create(session_id, memory_key).await?;
136
137 for turn in &turns {
138 if self.should_skip(turn, agent_filter) {
139 stats.turns_skipped += 1;
140 continue;
141 }
142
143 stats.turns_processed += 1;
144 let new_count = self.extract_turn(turn, memory_key, session_id).await?;
145 stats.facts_extracted += new_count;
146 }
147
148 tracing::info!(
149 session = %session_id,
150 memory_key = %memory_key,
151 processed = stats.turns_processed,
152 skipped = stats.turns_skipped,
153 new_facts = stats.facts_extracted,
154 "import complete"
155 );
156 Ok(stats)
157 }
158
159 fn should_skip(&self, turn: &AssistantTurn, agent_filter: Option<&[String]>) -> bool {
162 if let Some(filter) = agent_filter
163 && !filter.iter().any(|a| a == &turn.agent)
164 {
165 return true;
166 }
167 let too_short = turn.content.chars().count() < self.min_chars;
168 too_short && turn.tool_calls.is_empty()
169 }
170
171 async fn extract_turn(
175 &self,
176 turn: &AssistantTurn,
177 memory_key: &MemoryKey,
178 session_id: &SessionId,
179 ) -> Result<usize, UseCaseError> {
180 let extractor = ExtractFactsFromResponse {
181 facts: &self.facts,
182 sessions: &self.sessions,
183 embedder: &self.embedder,
184 extractor: &self.extractor,
185 clock: &self.clock,
186 delay: &self.delay,
187 confidence_cfg: &self.confidence_cfg,
188 extraction_cfg: &self.extraction_cfg,
189 enable_response_extraction: self.enable_response_extraction,
190 };
191 extractor
192 .execute(&turn.content, &turn.tool_calls, memory_key, session_id)
193 .await
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
206 use crate::types::SearchHit;
207 use smos_domain::{Fact, FactId, Heat, SessionState, Timestamp};
208 use std::collections::HashMap;
209 use std::sync::Mutex;
210 use std::time::Duration;
211
212 #[derive(Clone)]
215 struct FixedClock(Timestamp);
216 impl Clock for FixedClock {
217 fn now(&self) -> Timestamp {
218 self.0
219 }
220 }
221
222 #[derive(Clone, Copy)]
223 struct NoOpDelay;
224 impl Delay for NoOpDelay {
225 async fn delay(&self, _duration: Duration) {}
226 }
227
228 struct ScriptedExtractor {
229 results: Mutex<Vec<Vec<String>>>,
230 }
231 impl ScriptedExtractor {
232 fn new(results: Vec<Vec<String>>) -> Self {
233 Self {
234 results: Mutex::new(results),
235 }
236 }
237 }
238 impl LlmExtractor for ScriptedExtractor {
239 async fn extract_facts(
240 &self,
241 _content: &str,
242 _tool_calls: &[ToolCall],
243 ) -> Result<Vec<String>, crate::errors::ProviderError> {
244 let mut guard = self.results.lock().unwrap();
245 if guard.is_empty() {
246 Ok(Vec::new())
247 } else {
248 Ok(guard.remove(0))
249 }
250 }
251 }
252
253 struct ConstantEmbedder(Vec<f32>);
254 impl EmbeddingProvider for ConstantEmbedder {
255 async fn embed(
256 &self,
257 _text: &str,
258 ) -> Result<Option<Vec<f32>>, crate::errors::ProviderError> {
259 Ok(Some(self.0.clone()))
260 }
261 }
262
263 #[derive(Default, Clone)]
264 struct InMemoryFacts {
265 store: std::sync::Arc<Mutex<HashMap<String, Fact>>>,
266 }
267 impl FactRepository for InMemoryFacts {
268 async fn save(&self, fact: &Fact) -> Result<(), crate::errors::RepoError> {
269 self.store
270 .lock()
271 .unwrap()
272 .insert(fact.id().as_str().to_string(), fact.clone());
273 Ok(())
274 }
275 async fn get(
276 &self,
277 id: &FactId,
278 _mk: &MemoryKey,
279 ) -> Result<Option<Fact>, crate::errors::RepoError> {
280 Ok(self.store.lock().unwrap().get(id.as_str()).cloned())
281 }
282 async fn list_accepted(
283 &self,
284 _mk: &MemoryKey,
285 ) -> Result<Vec<Fact>, crate::errors::RepoError> {
286 Ok(Vec::new())
287 }
288 async fn list_pending(
289 &self,
290 _mk: &MemoryKey,
291 ) -> Result<Vec<Fact>, crate::errors::RepoError> {
292 Ok(Vec::new())
293 }
294 async fn list_memory_keys_for_session(
295 &self,
296 _session_id: &SessionId,
297 ) -> Result<Vec<MemoryKey>, crate::errors::RepoError> {
298 Ok(Vec::new())
299 }
300 async fn list_memory_keys(&self) -> Result<Vec<MemoryKey>, crate::errors::RepoError> {
301 Ok(Vec::new())
302 }
303 async fn search_similar(
304 &self,
305 _e: Vec<f32>,
306 _mk: &MemoryKey,
307 _l: usize,
308 ) -> Result<Vec<SearchHit>, crate::errors::RepoError> {
309 Ok(Vec::new())
310 }
311 async fn update_heat_batch(
312 &self,
313 _ids: &[FactId],
314 _mk: &MemoryKey,
315 _h: Heat,
316 _t: Timestamp,
317 ) -> Result<(), crate::errors::RepoError> {
318 Ok(())
319 }
320 }
321
322 #[derive(Default, Clone)]
323 struct RecordingSessions {
324 created: std::sync::Arc<Mutex<bool>>,
325 }
326 impl SessionRepository for RecordingSessions {
327 async fn get_or_create(
328 &self,
329 id: &SessionId,
330 _m: &MemoryKey,
331 ) -> Result<SessionState, crate::errors::RepoError> {
332 *self.created.lock().unwrap() = true;
333 Ok(SessionState::new(
334 id.clone(),
335 MemoryKey::from_raw("proj").unwrap(),
336 Timestamp::from_unix_secs(1_700_000_000).unwrap(),
337 ))
338 }
339 async fn add_pending(
340 &self,
341 _i: &SessionId,
342 _ids: &[FactId],
343 ) -> Result<(), crate::errors::RepoError> {
344 Ok(())
345 }
346 async fn collect_expired(
347 &self,
348 _t: Duration,
349 ) -> Result<Vec<(SessionId, SessionState)>, crate::errors::RepoError> {
350 Ok(Vec::new())
351 }
352 async fn snapshot_all(
353 &self,
354 ) -> Result<Vec<(SessionId, SessionState)>, crate::errors::RepoError> {
355 Ok(Vec::new())
356 }
357 async fn remove_pending_owned(
358 &self,
359 _i: &SessionId,
360 _o: &[FactId],
361 ) -> Result<(), crate::errors::RepoError> {
362 Ok(())
363 }
364 async fn clear_session(&self, _i: &SessionId) -> Result<(), crate::errors::RepoError> {
365 Ok(())
366 }
367 async fn dedup_and_mark(
368 &self,
369 _i: &SessionId,
370 _m: &MemoryKey,
371 _c: &[FactId],
372 ) -> Result<Vec<FactId>, crate::errors::RepoError> {
373 Ok(Vec::new())
374 }
375 async fn save(
376 &self,
377 _i: &SessionId,
378 _s: &SessionState,
379 ) -> Result<(), crate::errors::RepoError> {
380 Ok(())
381 }
382 }
383
384 fn mk() -> MemoryKey {
385 MemoryKey::from_raw("proj").unwrap()
386 }
387 fn sid(tag: u8) -> SessionId {
388 SessionId::from_raw(&format!("sess_{:012x}", tag as u64)).unwrap()
389 }
390
391 struct Fix {
392 facts: InMemoryFacts,
393 sessions: RecordingSessions,
394 embedder: ConstantEmbedder,
395 clock: FixedClock,
396 cfg: ConfidenceConfig,
397 extraction_cfg: ExtractionConfig,
398 }
399 impl Fix {
400 fn new() -> Self {
401 Self {
402 facts: InMemoryFacts::default(),
403 sessions: RecordingSessions::default(),
404 embedder: ConstantEmbedder(vec![0.1, 0.2, 0.3]),
405 clock: FixedClock(Timestamp::from_unix_secs(1_700_000_000).unwrap()),
406 cfg: ConfidenceConfig::default(),
407 extraction_cfg: ExtractionConfig::default(),
408 }
409 }
410 fn build(
411 &self,
412 extractor: ScriptedExtractor,
413 min_chars: usize,
414 ) -> ImportOpencodeSession<
415 InMemoryFacts,
416 RecordingSessions,
417 ConstantEmbedder,
418 ScriptedExtractor,
419 FixedClock,
420 NoOpDelay,
421 > {
422 ImportOpencodeSession {
423 facts: self.facts.clone(),
424 sessions: self.sessions.clone(),
425 embedder: ConstantEmbedder(self.embedder.0.clone()),
426 extractor,
427 clock: FixedClock(self.clock.0),
428 delay: NoOpDelay,
429 confidence_cfg: Arc::new(self.cfg.clone()),
430 extraction_cfg: Arc::new(self.extraction_cfg.clone()),
431 enable_response_extraction: true,
432 min_chars,
433 }
434 }
435 }
436
437 fn turn(agent: &str, content: &str) -> AssistantTurn {
438 AssistantTurn {
439 message_id: format!("msg_{agent}"),
440 agent: agent.to_string(),
441 content: content.to_string(),
442 tool_calls: Vec::new(),
443 }
444 }
445
446 #[tokio::test]
447 async fn execute_imports_each_turn_and_counts_new_facts() {
448 let fix = Fix::new();
449 let extractor = ScriptedExtractor::new(vec![
450 vec!["fact one".to_string()],
451 vec!["fact two".to_string()],
452 ]);
453 let import = fix.build(extractor, 15);
454
455 let turns = vec![
456 turn("head-of-development", "TTL=10 prevents refresh loop"),
457 turn("head-of-development", "Auth uses JWT for tokens"),
458 ];
459 let stats = import.execute(turns, &mk(), &sid(1), None).await.unwrap();
460
461 assert_eq!(stats.turns_processed, 2);
462 assert_eq!(stats.turns_skipped, 0);
463 assert_eq!(stats.facts_extracted, 2);
464 }
465
466 #[tokio::test]
467 async fn execute_skips_turns_below_min_chars_without_tool_calls() {
468 let fix = Fix::new();
469 let extractor = ScriptedExtractor::new(vec![vec!["real fact".to_string()]]);
472 let import = fix.build(extractor, 15);
473
474 let turns = vec![
475 turn("a", "ok"), turn("a", "TTL=10 prevents refresh loop"),
477 ];
478 let stats = import.execute(turns, &mk(), &sid(1), None).await.unwrap();
479
480 assert_eq!(stats.turns_processed, 1);
481 assert_eq!(stats.turns_skipped, 1);
482 assert_eq!(stats.facts_extracted, 1);
483 }
484
485 #[tokio::test]
486 async fn execute_keeps_short_turn_when_it_has_tool_calls() {
487 let fix = Fix::new();
488 let extractor = ScriptedExtractor::new(vec![vec!["from tool".to_string()]]);
489 let import = fix.build(extractor, 15);
490
491 let mut short_with_tool = turn("a", "ok");
492 short_with_tool.tool_calls.push(ToolCall {
493 name: "read_file".into(),
494 arguments: smos_domain::chat::ToolArguments::from_json(r#"{"path":"auth.rs"}"#),
495 });
496 let stats = import
497 .execute(vec![short_with_tool], &mk(), &sid(1), None)
498 .await
499 .unwrap();
500
501 assert_eq!(stats.turns_processed, 1);
502 assert_eq!(stats.turns_skipped, 0);
503 assert_eq!(stats.facts_extracted, 1);
504 }
505
506 #[tokio::test]
507 async fn execute_applies_agent_filter() {
508 let fix = Fix::new();
509 let extractor = ScriptedExtractor::new(vec![
510 vec!["hod fact".to_string()],
511 vec!["hod fact 2".to_string()],
512 ]);
513 let import = fix.build(extractor, 15);
514
515 let turns = vec![
516 turn("head-of-development", "TTL=10 prevents refresh loop"),
517 turn("dreaming", "Internal analysis content here"),
518 turn("head-of-development", "Auth uses JWT for tokens"),
519 ];
520 let filter = vec!["head-of-development".to_string()];
521 let stats = import
522 .execute(turns, &mk(), &sid(1), Some(&filter))
523 .await
524 .unwrap();
525
526 assert_eq!(stats.turns_processed, 2);
527 assert_eq!(stats.turns_skipped, 1);
528 assert_eq!(stats.facts_extracted, 2);
529 }
530
531 #[tokio::test]
532 async fn execute_ensures_session_row_exists_before_first_turn() {
533 let fix = Fix::new();
534 let extractor = ScriptedExtractor::new(vec![]);
535 let import = fix.build(extractor, 15);
536
537 let _ = import.execute(vec![], &mk(), &sid(7), None).await.unwrap();
538
539 assert!(
540 *fix.sessions.created.lock().unwrap(),
541 "get_or_create must run even for an empty turn list"
542 );
543 }
544
545 #[tokio::test]
546 async fn execute_with_extraction_disabled_returns_zero_facts() {
547 let fix = Fix::new();
548 let extractor = ScriptedExtractor::new(vec![vec!["should not be stored".to_string()]]);
549 let mut import = fix.build(extractor, 15);
550 import.enable_response_extraction = false;
551
552 let stats = import
553 .execute(
554 vec![turn("a", "TTL=10 prevents refresh loop")],
555 &mk(),
556 &sid(1),
557 None,
558 )
559 .await
560 .unwrap();
561
562 assert_eq!(stats.turns_processed, 1);
563 assert_eq!(stats.facts_extracted, 0);
564 assert!(fix.facts.store.lock().unwrap().is_empty());
565 }
566
567 #[tokio::test]
568 async fn execute_confirms_existing_fact_instead_of_counting_it_new() {
569 let fix = Fix::new();
570 let seeded_content = "shared fact content here";
573 let first = Fact::new_pending(
574 seeded_content,
575 mk(),
576 sid(1),
577 smos_domain::Embedding::new(vec![1.0]).unwrap(),
578 Timestamp::from_unix_secs(1_700_000_000).unwrap(),
579 ConfidenceConfig::default().base,
580 )
581 .unwrap();
582 let fid = first.id().clone();
583 fix.facts
584 .store
585 .lock()
586 .unwrap()
587 .insert(fid.as_str().to_string(), first);
588
589 let extractor = ScriptedExtractor::new(vec![vec![seeded_content.to_string()]]);
590 let import = fix.build(extractor, 15);
591
592 let stats = import
593 .execute(vec![turn("a", seeded_content)], &mk(), &sid(2), None)
594 .await
595 .unwrap();
596
597 assert_eq!(stats.facts_extracted, 0, "confirmation is not a new fact");
598 let confirmed = fix
599 .facts
600 .store
601 .lock()
602 .unwrap()
603 .get(fid.as_str())
604 .cloned()
605 .expect("fact present");
606 assert_eq!(
612 confirmed.source_sessions().distinct_count(),
613 2,
614 "provenance grew to two sessions"
615 );
616 }
617}