1use std::sync::Arc;
40
41use smos_domain::chat::ToolCall;
42use smos_domain::config::{ConfidenceConfig, ExtractionConfig};
43use smos_domain::{MemoryKey, SessionId};
44
45use crate::errors::UseCaseError;
46use crate::ports::{
47 Clock, Delay, EmbeddingProvider, FactRepository, LlmExtractor, SessionRepository,
48};
49use crate::use_cases::extract_facts_from_response::ExtractFactsFromResponse;
50
51#[derive(Debug, Clone, PartialEq)]
57pub struct AssistantTurn {
58 pub message_id: String,
59 pub agent: String,
60 pub content: String,
61 pub tool_calls: Vec<ToolCall>,
62}
63
64#[derive(Debug, Clone, Default)]
71pub struct ImportStats {
72 pub session_id: String,
73 pub turns_processed: usize,
74 pub turns_skipped: usize,
75 pub facts_extracted: usize,
76}
77
78pub struct ImportOpencodeSession<FR, SR, EP, LE, C, D> {
86 pub facts: FR,
87 pub sessions: SR,
88 pub embedder: EP,
89 pub extractor: LE,
90 pub clock: C,
91 pub delay: D,
92 pub confidence_cfg: Arc<ConfidenceConfig>,
93 pub extraction_cfg: Arc<ExtractionConfig>,
97 pub enable_response_extraction: bool,
98 pub min_chars: usize,
103}
104
105impl<FR, SR, EP, LE, C, D> ImportOpencodeSession<FR, SR, EP, LE, C, D>
106where
107 FR: FactRepository,
108 SR: SessionRepository,
109 EP: EmbeddingProvider,
110 LE: LlmExtractor,
111 C: Clock,
112 D: Delay,
113{
114 pub async fn execute(
121 &self,
122 turns: Vec<AssistantTurn>,
123 memory_key: &MemoryKey,
124 session_id: &SessionId,
125 agent_filter: Option<&[String]>,
126 ) -> Result<ImportStats, UseCaseError> {
127 let mut stats = ImportStats {
128 session_id: session_id.as_str().to_string(),
129 ..Default::default()
130 };
131
132 self.sessions.get_or_create(session_id, memory_key).await?;
136
137 for turn in &turns {
138 if self.should_skip(turn, agent_filter) {
139 stats.turns_skipped += 1;
140 continue;
141 }
142
143 stats.turns_processed += 1;
144 let new_count = self.extract_turn(turn, memory_key, session_id).await?;
145 stats.facts_extracted += new_count;
146 }
147
148 tracing::info!(
149 session = %session_id,
150 memory_key = %memory_key,
151 processed = stats.turns_processed,
152 skipped = stats.turns_skipped,
153 new_facts = stats.facts_extracted,
154 "import complete"
155 );
156 Ok(stats)
157 }
158
159 fn should_skip(&self, turn: &AssistantTurn, agent_filter: Option<&[String]>) -> bool {
162 if let Some(filter) = agent_filter
163 && !filter.iter().any(|a| a == &turn.agent)
164 {
165 return true;
166 }
167 let too_short = turn.content.chars().count() < self.min_chars;
168 too_short && turn.tool_calls.is_empty()
169 }
170
171 async fn extract_turn(
175 &self,
176 turn: &AssistantTurn,
177 memory_key: &MemoryKey,
178 session_id: &SessionId,
179 ) -> Result<usize, UseCaseError> {
180 let extractor = ExtractFactsFromResponse {
181 facts: &self.facts,
182 sessions: &self.sessions,
183 embedder: &self.embedder,
184 extractor: &self.extractor,
185 clock: &self.clock,
186 delay: &self.delay,
187 confidence_cfg: &self.confidence_cfg,
188 extraction_cfg: &self.extraction_cfg,
189 enable_response_extraction: self.enable_response_extraction,
190 };
191 extractor
192 .execute(&turn.content, &turn.tool_calls, memory_key, session_id)
193 .await
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
206 use crate::testkit::{
207 ConstantEmbedder, FixedClock, InMemoryFacts, NoOpDelay, ScriptedExtractor,
208 };
209 use smos_domain::{Fact, FactId, NewPendingRequest, SessionState, Timestamp};
210 use std::sync::Mutex;
211 use std::time::Duration;
212
213 #[derive(Default, Clone)]
216 struct RecordingSessions {
217 created: std::sync::Arc<Mutex<bool>>,
218 }
219 impl SessionRepository for RecordingSessions {
220 async fn get_or_create(
221 &self,
222 id: &SessionId,
223 _m: &MemoryKey,
224 ) -> Result<SessionState, crate::errors::RepoError> {
225 *self.created.lock().unwrap() = true;
226 Ok(SessionState::new(
227 id.clone(),
228 MemoryKey::from_raw("proj").unwrap(),
229 Timestamp::from_unix_secs(1_700_000_000).unwrap(),
230 ))
231 }
232 async fn add_pending(
233 &self,
234 _i: &SessionId,
235 _ids: &[FactId],
236 ) -> Result<(), crate::errors::RepoError> {
237 Ok(())
238 }
239 async fn collect_expired(
240 &self,
241 _t: Duration,
242 ) -> Result<Vec<(SessionId, SessionState)>, crate::errors::RepoError> {
243 Ok(Vec::new())
244 }
245 async fn snapshot_all(
246 &self,
247 ) -> Result<Vec<(SessionId, SessionState)>, crate::errors::RepoError> {
248 Ok(Vec::new())
249 }
250 async fn remove_pending_owned(
251 &self,
252 _i: &SessionId,
253 _o: &[FactId],
254 ) -> Result<(), crate::errors::RepoError> {
255 Ok(())
256 }
257 async fn clear_session(&self, _i: &SessionId) -> Result<(), crate::errors::RepoError> {
258 Ok(())
259 }
260 async fn dedup_and_mark(
261 &self,
262 _i: &SessionId,
263 _m: &MemoryKey,
264 _c: &[FactId],
265 ) -> Result<Vec<FactId>, crate::errors::RepoError> {
266 Ok(Vec::new())
267 }
268 async fn save(
269 &self,
270 _i: &SessionId,
271 _s: &SessionState,
272 ) -> Result<(), crate::errors::RepoError> {
273 Ok(())
274 }
275 }
276
277 fn mk() -> MemoryKey {
278 MemoryKey::from_raw("proj").unwrap()
279 }
280 fn sid(tag: u8) -> SessionId {
281 SessionId::from_raw(&format!("sess_{:012x}", tag as u64)).unwrap()
282 }
283
284 struct Fix {
285 facts: InMemoryFacts,
286 sessions: RecordingSessions,
287 embedder: ConstantEmbedder,
288 clock: FixedClock,
289 cfg: ConfidenceConfig,
290 extraction_cfg: ExtractionConfig,
291 }
292 impl Fix {
293 fn new() -> Self {
294 Self {
295 facts: InMemoryFacts::default(),
296 sessions: RecordingSessions::default(),
297 embedder: ConstantEmbedder(vec![0.1, 0.2, 0.3]),
298 clock: FixedClock(Timestamp::from_unix_secs(1_700_000_000).unwrap()),
299 cfg: ConfidenceConfig::default(),
300 extraction_cfg: ExtractionConfig::default(),
301 }
302 }
303 fn build(
304 &self,
305 extractor: ScriptedExtractor,
306 min_chars: usize,
307 ) -> ImportOpencodeSession<
308 InMemoryFacts,
309 RecordingSessions,
310 ConstantEmbedder,
311 ScriptedExtractor,
312 FixedClock,
313 NoOpDelay,
314 > {
315 ImportOpencodeSession {
316 facts: self.facts.clone(),
317 sessions: self.sessions.clone(),
318 embedder: ConstantEmbedder(self.embedder.0.clone()),
319 extractor,
320 clock: FixedClock(self.clock.0),
321 delay: NoOpDelay,
322 confidence_cfg: Arc::new(self.cfg.clone()),
323 extraction_cfg: Arc::new(self.extraction_cfg.clone()),
324 enable_response_extraction: true,
325 min_chars,
326 }
327 }
328 }
329
330 fn turn(agent: &str, content: &str) -> AssistantTurn {
331 AssistantTurn {
332 message_id: format!("msg_{agent}"),
333 agent: agent.to_string(),
334 content: content.to_string(),
335 tool_calls: Vec::new(),
336 }
337 }
338
339 #[tokio::test]
340 async fn execute_imports_each_turn_and_counts_new_facts() {
341 let fix = Fix::new();
342 let extractor = ScriptedExtractor::new(vec![
343 Ok(vec!["fact one".to_string()]),
344 Ok(vec!["fact two".to_string()]),
345 ]);
346 let import = fix.build(extractor, 15);
347
348 let turns = vec![
349 turn("head-of-development", "TTL=10 prevents refresh loop"),
350 turn("head-of-development", "Auth uses JWT for tokens"),
351 ];
352 let stats = import.execute(turns, &mk(), &sid(1), None).await.unwrap();
353
354 assert_eq!(stats.turns_processed, 2);
355 assert_eq!(stats.turns_skipped, 0);
356 assert_eq!(stats.facts_extracted, 2);
357 }
358
359 #[tokio::test]
360 async fn execute_skips_turns_below_min_chars_without_tool_calls() {
361 let fix = Fix::new();
362 let extractor = ScriptedExtractor::new(vec![Ok(vec!["real fact".to_string()])]);
365 let import = fix.build(extractor, 15);
366
367 let turns = vec![
368 turn("a", "ok"), turn("a", "TTL=10 prevents refresh loop"),
370 ];
371 let stats = import.execute(turns, &mk(), &sid(1), None).await.unwrap();
372
373 assert_eq!(stats.turns_processed, 1);
374 assert_eq!(stats.turns_skipped, 1);
375 assert_eq!(stats.facts_extracted, 1);
376 }
377
378 #[tokio::test]
379 async fn execute_keeps_short_turn_when_it_has_tool_calls() {
380 let fix = Fix::new();
381 let extractor = ScriptedExtractor::new(vec![Ok(vec!["from tool".to_string()])]);
382 let import = fix.build(extractor, 15);
383
384 let mut short_with_tool = turn("a", "ok");
385 short_with_tool.tool_calls.push(ToolCall {
386 name: "read_file".into(),
387 arguments: smos_domain::chat::ToolArguments::from_json(r#"{"path":"auth.rs"}"#),
388 });
389 let stats = import
390 .execute(vec![short_with_tool], &mk(), &sid(1), None)
391 .await
392 .unwrap();
393
394 assert_eq!(stats.turns_processed, 1);
395 assert_eq!(stats.turns_skipped, 0);
396 assert_eq!(stats.facts_extracted, 1);
397 }
398
399 #[tokio::test]
400 async fn execute_applies_agent_filter() {
401 let fix = Fix::new();
402 let extractor = ScriptedExtractor::new(vec![
403 Ok(vec!["hod fact".to_string()]),
404 Ok(vec!["hod fact 2".to_string()]),
405 ]);
406 let import = fix.build(extractor, 15);
407
408 let turns = vec![
409 turn("head-of-development", "TTL=10 prevents refresh loop"),
410 turn("dreaming", "Internal analysis content here"),
411 turn("head-of-development", "Auth uses JWT for tokens"),
412 ];
413 let filter = vec!["head-of-development".to_string()];
414 let stats = import
415 .execute(turns, &mk(), &sid(1), Some(&filter))
416 .await
417 .unwrap();
418
419 assert_eq!(stats.turns_processed, 2);
420 assert_eq!(stats.turns_skipped, 1);
421 assert_eq!(stats.facts_extracted, 2);
422 }
423
424 #[tokio::test]
425 async fn execute_ensures_session_row_exists_before_first_turn() {
426 let fix = Fix::new();
427 let extractor = ScriptedExtractor::new(vec![]);
428 let import = fix.build(extractor, 15);
429
430 let _ = import.execute(vec![], &mk(), &sid(7), None).await.unwrap();
431
432 assert!(
433 *fix.sessions.created.lock().unwrap(),
434 "get_or_create must run even for an empty turn list"
435 );
436 }
437
438 #[tokio::test]
439 async fn execute_with_extraction_disabled_returns_zero_facts() {
440 let fix = Fix::new();
441 let extractor = ScriptedExtractor::new(vec![Ok(vec!["should not be stored".to_string()])]);
442 let mut import = fix.build(extractor, 15);
443 import.enable_response_extraction = false;
444
445 let stats = import
446 .execute(
447 vec![turn("a", "TTL=10 prevents refresh loop")],
448 &mk(),
449 &sid(1),
450 None,
451 )
452 .await
453 .unwrap();
454
455 assert_eq!(stats.turns_processed, 1);
456 assert_eq!(stats.facts_extracted, 0);
457 assert!(fix.facts.is_empty());
458 }
459
460 #[tokio::test]
461 async fn execute_confirms_existing_fact_instead_of_counting_it_new() {
462 let fix = Fix::new();
463 let seeded_content = "shared fact content here";
466 let first = Fact::new_pending(NewPendingRequest {
467 content: seeded_content,
468 memory_key: mk(),
469 session: sid(1),
470 embedding: smos_domain::Embedding::new(vec![1.0]).unwrap(),
471 extracted_at: Timestamp::from_unix_secs(1_700_000_000).unwrap(),
472 base_confidence: ConfidenceConfig::default().base,
473 })
474 .unwrap();
475 let fid = first.id().clone();
476 fix.facts.seed(first);
477
478 let extractor = ScriptedExtractor::new(vec![Ok(vec![seeded_content.to_string()])]);
479 let import = fix.build(extractor, 15);
480
481 let stats = import
482 .execute(vec![turn("a", seeded_content)], &mk(), &sid(2), None)
483 .await
484 .unwrap();
485
486 assert_eq!(stats.facts_extracted, 0, "confirmation is not a new fact");
487 let confirmed = fix.facts.get_clone(&fid).expect("fact present");
488 assert_eq!(
494 confirmed.source_sessions().distinct_count(),
495 2,
496 "provenance grew to two sessions"
497 );
498 }
499}