Skip to main content

smos_application/use_cases/
import_opencode_session.rs

1//! `ImportOpencodeSession` — bulk import of an opencode transcript (Slice-8).
2//!
3//! Parses already-flattened assistant turns (the adapter layer's
4//! [`AssistantTurn`] is produced by `smos::opencode::transcript`) and
5//! re-runs the SAME extraction pipeline the live proxy runs after each chat
6//! completion. Concretely: every turn is fed to
7//! [`ExtractFactsFromResponse`], so dedup, embedding, cross-session
8//! confirmation, and the `MIN_INPUT_CHARS` floor are reused verbatim — the
9//! import path is DRY with the live path.
10//!
11//! # Filtering
12//!
13//! The use case applies two pre-extraction filters that mirror the POC
14//! `iter_assistant_turns`:
15//!
16//! 1. **Agent filter** — optional `&[String]` allow-list. Turns whose `agent`
17//!    is not in the list are skipped (`turns_skipped`).
18//! 2. **Min-chars floor** — turns with fewer than `min_chars` content chars AND
19//!    no tool calls are skipped. Tool-call-only turns survive because the
20//!    extraction pipeline renders tool calls into the input, so a turn with
21//!    zero prose still carries extractable signal.
22//!
23//! `min_chars` is wired from the SAME const as the live extraction pipeline
24//! ([`extract_facts_from_response::MIN_INPUT_CHARS`]) by the CLI binary, so
25//! the import path and the live response path cannot drift apart. The use
26//! case keeps the field as a runtime knob (not a const) so future callers
27//! can override it explicitly when they have a stronger reason than "match
28//! the live path".
29//!
30//! # Stats
31//!
32//! [`ImportStats`] is the wire shape surfaced by the `smos-import` binary. The
33//! `facts_extracted` counter is the sum of `ExtractFactsFromResponse::execute`
34//! return values — i.e. ONLY newly-created pending facts. Cross-session
35//! confirmations on pre-existing facts do NOT increment the counter (they
36//! update an existing fact's provenance instead), so re-importing the same
37//! session is idempotent on the new-fact axis.
38
39use std::sync::Arc;
40
41use smos_domain::chat::ToolCall;
42use smos_domain::config::{ConfidenceConfig, ExtractionConfig};
43use smos_domain::{MemoryKey, SessionId};
44
45use crate::errors::UseCaseError;
46use crate::ports::{
47    Clock, Delay, EmbeddingProvider, FactRepository, LlmExtractor, SessionRepository,
48};
49use crate::use_cases::extract_facts_from_response::ExtractFactsFromResponse;
50
51/// One assistant turn parsed from an opencode transcript.
52///
53/// Pure data — no IO concerns. Produced by
54/// `smos::opencode::transcript::parse_transcript` and consumed by
55/// [`ImportOpencodeSession::execute`].
56#[derive(Debug, Clone, PartialEq)]
57pub struct AssistantTurn {
58    pub message_id: String,
59    pub agent: String,
60    pub content: String,
61    pub tool_calls: Vec<ToolCall>,
62}
63
64/// Aggregate outcome counters for one import run.
65///
66/// Surfaced to operators by the `smos-import` CLI. `facts_extracted` is the
67/// number of NEWLY-stored pending facts (cross-session confirmations on
68/// pre-existing facts do NOT count — see the module docs for the idempotency
69/// contract).
70#[derive(Debug, Clone, Default)]
71pub struct ImportStats {
72    pub session_id: String,
73    pub turns_processed: usize,
74    pub turns_skipped: usize,
75    pub facts_extracted: usize,
76}
77
78/// Import an opencode transcript by re-running the live extraction pipeline.
79///
80/// Owns the same six port dependencies `ExtractFactsFromResponse` needs
81/// (`facts`, `sessions`, `embedder`, `extractor`, `clock`, `delay`) plus the
82/// configuration knobs the per-turn extraction relies on. The concrete
83/// `TokioDelay` adapter is wired by the CLI binary; unit tests inject a
84/// no-op delay so the retry backoff is instantaneous.
85pub struct ImportOpencodeSession<FR, SR, EP, LE, C, D> {
86    pub facts: FR,
87    pub sessions: SR,
88    pub embedder: EP,
89    pub extractor: LE,
90    pub clock: C,
91    pub delay: D,
92    pub confidence_cfg: Arc<ConfidenceConfig>,
93    /// Semantic-dedup safety net, threaded into the per-turn
94    /// [`ExtractFactsFromResponse`] bundle so the import path and the live
95    /// response path share one source of truth.
96    pub extraction_cfg: Arc<ExtractionConfig>,
97    pub enable_response_extraction: bool,
98    /// Pre-extraction content floor. Turns below this length AND without tool
99    /// calls are skipped. Wired from
100    /// [`extract_facts_from_response::MIN_INPUT_CHARS`] by the CLI binary so
101    /// the import path and the live response path share one source of truth.
102    pub min_chars: usize,
103}
104
105impl<FR, SR, EP, LE, C, D> ImportOpencodeSession<FR, SR, EP, LE, C, D>
106where
107    FR: FactRepository,
108    SR: SessionRepository,
109    EP: EmbeddingProvider,
110    LE: LlmExtractor,
111    C: Clock,
112    D: Delay,
113{
114    /// Import `turns` under `(memory_key, session_id)`.
115    ///
116    /// Reuses [`ExtractFactsFromResponse`] per turn so the extraction contract
117    /// is identical to the live response pipeline. Returns aggregate stats;
118    /// never raises on a per-turn extraction failure (the extractor's retry
119    /// loop already swallows transient failures per §12 fail-open).
120    pub async fn execute(
121        &self,
122        turns: Vec<AssistantTurn>,
123        memory_key: &MemoryKey,
124        session_id: &SessionId,
125        agent_filter: Option<&[String]>,
126    ) -> Result<ImportStats, UseCaseError> {
127        let mut stats = ImportStats {
128            session_id: session_id.as_str().to_string(),
129            ..Default::default()
130        };
131
132        // Ensure the session row exists so `add_pending` registrations land on
133        // a real row. The session also serves as the cross-session
134        // confirmation key inside `ExtractFactsFromResponse::persist_facts`.
135        self.sessions.get_or_create(session_id, memory_key).await?;
136
137        for turn in &turns {
138            if self.should_skip(turn, agent_filter) {
139                stats.turns_skipped += 1;
140                continue;
141            }
142
143            stats.turns_processed += 1;
144            let new_count = self.extract_turn(turn, memory_key, session_id).await?;
145            stats.facts_extracted += new_count;
146        }
147
148        tracing::info!(
149            session = %session_id,
150            memory_key = %memory_key,
151            processed = stats.turns_processed,
152            skipped = stats.turns_skipped,
153            new_facts = stats.facts_extracted,
154            "import complete"
155        );
156        Ok(stats)
157    }
158
159    /// Apply the agent + min-chars filters. Returns `true` when the turn must
160    /// be skipped, `false` when it should be processed.
161    fn should_skip(&self, turn: &AssistantTurn, agent_filter: Option<&[String]>) -> bool {
162        if let Some(filter) = agent_filter
163            && !filter.iter().any(|a| a == &turn.agent)
164        {
165            return true;
166        }
167        let too_short = turn.content.chars().count() < self.min_chars;
168        too_short && turn.tool_calls.is_empty()
169    }
170
171    /// Delegate one turn to `ExtractFactsFromResponse` (DRY with the live
172    /// response path). The borrow bundle is rebuilt per turn so the use case
173    /// does not hold references across awaits between turns.
174    async fn extract_turn(
175        &self,
176        turn: &AssistantTurn,
177        memory_key: &MemoryKey,
178        session_id: &SessionId,
179    ) -> Result<usize, UseCaseError> {
180        let extractor = ExtractFactsFromResponse {
181            facts: &self.facts,
182            sessions: &self.sessions,
183            embedder: &self.embedder,
184            extractor: &self.extractor,
185            clock: &self.clock,
186            delay: &self.delay,
187            confidence_cfg: &self.confidence_cfg,
188            extraction_cfg: &self.extraction_cfg,
189            enable_response_extraction: self.enable_response_extraction,
190        };
191        extractor
192            .execute(&turn.content, &turn.tool_calls, memory_key, session_id)
193            .await
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    //! Import use case unit tests.
200    //!
201    //! Classicist style: in-memory repos + scripted providers. The full
202    //! pipeline (SurrealStore + extraction) is exercised by the
203    //! `tests/e2e_import.rs` integration suite.
204
205    use super::*;
206    use crate::types::SearchHit;
207    use smos_domain::{Fact, FactId, Heat, SessionState, Timestamp};
208    use std::collections::HashMap;
209    use std::sync::Mutex;
210    use std::time::Duration;
211
212    // ---- Fakes mirroring the `extract_facts_from_response` test kit ----
213
214    #[derive(Clone)]
215    struct FixedClock(Timestamp);
216    impl Clock for FixedClock {
217        fn now(&self) -> Timestamp {
218            self.0
219        }
220    }
221
222    #[derive(Clone, Copy)]
223    struct NoOpDelay;
224    impl Delay for NoOpDelay {
225        async fn delay(&self, _duration: Duration) {}
226    }
227
228    struct ScriptedExtractor {
229        results: Mutex<Vec<Vec<String>>>,
230    }
231    impl ScriptedExtractor {
232        fn new(results: Vec<Vec<String>>) -> Self {
233            Self {
234                results: Mutex::new(results),
235            }
236        }
237    }
238    impl LlmExtractor for ScriptedExtractor {
239        async fn extract_facts(
240            &self,
241            _content: &str,
242            _tool_calls: &[ToolCall],
243        ) -> Result<Vec<String>, crate::errors::ProviderError> {
244            let mut guard = self.results.lock().unwrap();
245            if guard.is_empty() {
246                Ok(Vec::new())
247            } else {
248                Ok(guard.remove(0))
249            }
250        }
251    }
252
253    struct ConstantEmbedder(Vec<f32>);
254    impl EmbeddingProvider for ConstantEmbedder {
255        async fn embed(
256            &self,
257            _text: &str,
258        ) -> Result<Option<Vec<f32>>, crate::errors::ProviderError> {
259            Ok(Some(self.0.clone()))
260        }
261    }
262
263    #[derive(Default, Clone)]
264    struct InMemoryFacts {
265        store: std::sync::Arc<Mutex<HashMap<String, Fact>>>,
266    }
267    impl FactRepository for InMemoryFacts {
268        async fn save(&self, fact: &Fact) -> Result<(), crate::errors::RepoError> {
269            self.store
270                .lock()
271                .unwrap()
272                .insert(fact.id().as_str().to_string(), fact.clone());
273            Ok(())
274        }
275        async fn get(
276            &self,
277            id: &FactId,
278            _mk: &MemoryKey,
279        ) -> Result<Option<Fact>, crate::errors::RepoError> {
280            Ok(self.store.lock().unwrap().get(id.as_str()).cloned())
281        }
282        async fn list_accepted(
283            &self,
284            _mk: &MemoryKey,
285        ) -> Result<Vec<Fact>, crate::errors::RepoError> {
286            Ok(Vec::new())
287        }
288        async fn list_pending(
289            &self,
290            _mk: &MemoryKey,
291        ) -> Result<Vec<Fact>, crate::errors::RepoError> {
292            Ok(Vec::new())
293        }
294        async fn list_memory_keys_for_session(
295            &self,
296            _session_id: &SessionId,
297        ) -> Result<Vec<MemoryKey>, crate::errors::RepoError> {
298            Ok(Vec::new())
299        }
300        async fn list_memory_keys(&self) -> Result<Vec<MemoryKey>, crate::errors::RepoError> {
301            Ok(Vec::new())
302        }
303        async fn search_similar(
304            &self,
305            _e: Vec<f32>,
306            _mk: &MemoryKey,
307            _l: usize,
308        ) -> Result<Vec<SearchHit>, crate::errors::RepoError> {
309            Ok(Vec::new())
310        }
311        async fn update_heat_batch(
312            &self,
313            _ids: &[FactId],
314            _mk: &MemoryKey,
315            _h: Heat,
316            _t: Timestamp,
317        ) -> Result<(), crate::errors::RepoError> {
318            Ok(())
319        }
320    }
321
322    #[derive(Default, Clone)]
323    struct RecordingSessions {
324        created: std::sync::Arc<Mutex<bool>>,
325    }
326    impl SessionRepository for RecordingSessions {
327        async fn get_or_create(
328            &self,
329            id: &SessionId,
330            _m: &MemoryKey,
331        ) -> Result<SessionState, crate::errors::RepoError> {
332            *self.created.lock().unwrap() = true;
333            Ok(SessionState::new(
334                id.clone(),
335                MemoryKey::from_raw("proj").unwrap(),
336                Timestamp::from_unix_secs(1_700_000_000).unwrap(),
337            ))
338        }
339        async fn add_pending(
340            &self,
341            _i: &SessionId,
342            _ids: &[FactId],
343        ) -> Result<(), crate::errors::RepoError> {
344            Ok(())
345        }
346        async fn collect_expired(
347            &self,
348            _t: Duration,
349        ) -> Result<Vec<(SessionId, SessionState)>, crate::errors::RepoError> {
350            Ok(Vec::new())
351        }
352        async fn snapshot_all(
353            &self,
354        ) -> Result<Vec<(SessionId, SessionState)>, crate::errors::RepoError> {
355            Ok(Vec::new())
356        }
357        async fn remove_pending_owned(
358            &self,
359            _i: &SessionId,
360            _o: &[FactId],
361        ) -> Result<(), crate::errors::RepoError> {
362            Ok(())
363        }
364        async fn clear_session(&self, _i: &SessionId) -> Result<(), crate::errors::RepoError> {
365            Ok(())
366        }
367        async fn dedup_and_mark(
368            &self,
369            _i: &SessionId,
370            _m: &MemoryKey,
371            _c: &[FactId],
372        ) -> Result<Vec<FactId>, crate::errors::RepoError> {
373            Ok(Vec::new())
374        }
375        async fn save(
376            &self,
377            _i: &SessionId,
378            _s: &SessionState,
379        ) -> Result<(), crate::errors::RepoError> {
380            Ok(())
381        }
382    }
383
384    fn mk() -> MemoryKey {
385        MemoryKey::from_raw("proj").unwrap()
386    }
387    fn sid(tag: u8) -> SessionId {
388        SessionId::from_raw(&format!("sess_{:012x}", tag as u64)).unwrap()
389    }
390
391    struct Fix {
392        facts: InMemoryFacts,
393        sessions: RecordingSessions,
394        embedder: ConstantEmbedder,
395        clock: FixedClock,
396        cfg: ConfidenceConfig,
397        extraction_cfg: ExtractionConfig,
398    }
399    impl Fix {
400        fn new() -> Self {
401            Self {
402                facts: InMemoryFacts::default(),
403                sessions: RecordingSessions::default(),
404                embedder: ConstantEmbedder(vec![0.1, 0.2, 0.3]),
405                clock: FixedClock(Timestamp::from_unix_secs(1_700_000_000).unwrap()),
406                cfg: ConfidenceConfig::default(),
407                extraction_cfg: ExtractionConfig::default(),
408            }
409        }
410        fn build(
411            &self,
412            extractor: ScriptedExtractor,
413            min_chars: usize,
414        ) -> ImportOpencodeSession<
415            InMemoryFacts,
416            RecordingSessions,
417            ConstantEmbedder,
418            ScriptedExtractor,
419            FixedClock,
420            NoOpDelay,
421        > {
422            ImportOpencodeSession {
423                facts: self.facts.clone(),
424                sessions: self.sessions.clone(),
425                embedder: ConstantEmbedder(self.embedder.0.clone()),
426                extractor,
427                clock: FixedClock(self.clock.0),
428                delay: NoOpDelay,
429                confidence_cfg: Arc::new(self.cfg.clone()),
430                extraction_cfg: Arc::new(self.extraction_cfg.clone()),
431                enable_response_extraction: true,
432                min_chars,
433            }
434        }
435    }
436
437    fn turn(agent: &str, content: &str) -> AssistantTurn {
438        AssistantTurn {
439            message_id: format!("msg_{agent}"),
440            agent: agent.to_string(),
441            content: content.to_string(),
442            tool_calls: Vec::new(),
443        }
444    }
445
446    #[tokio::test]
447    async fn execute_imports_each_turn_and_counts_new_facts() {
448        let fix = Fix::new();
449        let extractor = ScriptedExtractor::new(vec![
450            vec!["fact one".to_string()],
451            vec!["fact two".to_string()],
452        ]);
453        let import = fix.build(extractor, 15);
454
455        let turns = vec![
456            turn("head-of-development", "TTL=10 prevents refresh loop"),
457            turn("head-of-development", "Auth uses JWT for tokens"),
458        ];
459        let stats = import.execute(turns, &mk(), &sid(1), None).await.unwrap();
460
461        assert_eq!(stats.turns_processed, 2);
462        assert_eq!(stats.turns_skipped, 0);
463        assert_eq!(stats.facts_extracted, 2);
464    }
465
466    #[tokio::test]
467    async fn execute_skips_turns_below_min_chars_without_tool_calls() {
468        let fix = Fix::new();
469        // Only one extraction result is scripted; the short turn must be
470        // skipped so the second turn does not consume a result.
471        let extractor = ScriptedExtractor::new(vec![vec!["real fact".to_string()]]);
472        let import = fix.build(extractor, 15);
473
474        let turns = vec![
475            turn("a", "ok"), // 2 chars < 15 → skipped
476            turn("a", "TTL=10 prevents refresh loop"),
477        ];
478        let stats = import.execute(turns, &mk(), &sid(1), None).await.unwrap();
479
480        assert_eq!(stats.turns_processed, 1);
481        assert_eq!(stats.turns_skipped, 1);
482        assert_eq!(stats.facts_extracted, 1);
483    }
484
485    #[tokio::test]
486    async fn execute_keeps_short_turn_when_it_has_tool_calls() {
487        let fix = Fix::new();
488        let extractor = ScriptedExtractor::new(vec![vec!["from tool".to_string()]]);
489        let import = fix.build(extractor, 15);
490
491        let mut short_with_tool = turn("a", "ok");
492        short_with_tool.tool_calls.push(ToolCall {
493            name: "read_file".into(),
494            arguments: smos_domain::chat::ToolArguments::from_json(r#"{"path":"auth.rs"}"#),
495        });
496        let stats = import
497            .execute(vec![short_with_tool], &mk(), &sid(1), None)
498            .await
499            .unwrap();
500
501        assert_eq!(stats.turns_processed, 1);
502        assert_eq!(stats.turns_skipped, 0);
503        assert_eq!(stats.facts_extracted, 1);
504    }
505
506    #[tokio::test]
507    async fn execute_applies_agent_filter() {
508        let fix = Fix::new();
509        let extractor = ScriptedExtractor::new(vec![
510            vec!["hod fact".to_string()],
511            vec!["hod fact 2".to_string()],
512        ]);
513        let import = fix.build(extractor, 15);
514
515        let turns = vec![
516            turn("head-of-development", "TTL=10 prevents refresh loop"),
517            turn("dreaming", "Internal analysis content here"),
518            turn("head-of-development", "Auth uses JWT for tokens"),
519        ];
520        let filter = vec!["head-of-development".to_string()];
521        let stats = import
522            .execute(turns, &mk(), &sid(1), Some(&filter))
523            .await
524            .unwrap();
525
526        assert_eq!(stats.turns_processed, 2);
527        assert_eq!(stats.turns_skipped, 1);
528        assert_eq!(stats.facts_extracted, 2);
529    }
530
531    #[tokio::test]
532    async fn execute_ensures_session_row_exists_before_first_turn() {
533        let fix = Fix::new();
534        let extractor = ScriptedExtractor::new(vec![]);
535        let import = fix.build(extractor, 15);
536
537        let _ = import.execute(vec![], &mk(), &sid(7), None).await.unwrap();
538
539        assert!(
540            *fix.sessions.created.lock().unwrap(),
541            "get_or_create must run even for an empty turn list"
542        );
543    }
544
545    #[tokio::test]
546    async fn execute_with_extraction_disabled_returns_zero_facts() {
547        let fix = Fix::new();
548        let extractor = ScriptedExtractor::new(vec![vec!["should not be stored".to_string()]]);
549        let mut import = fix.build(extractor, 15);
550        import.enable_response_extraction = false;
551
552        let stats = import
553            .execute(
554                vec![turn("a", "TTL=10 prevents refresh loop")],
555                &mk(),
556                &sid(1),
557                None,
558            )
559            .await
560            .unwrap();
561
562        assert_eq!(stats.turns_processed, 1);
563        assert_eq!(stats.facts_extracted, 0);
564        assert!(fix.facts.store.lock().unwrap().is_empty());
565    }
566
567    #[tokio::test]
568    async fn execute_confirms_existing_fact_instead_of_counting_it_new() {
569        let fix = Fix::new();
570        // First import seeds the fact; second import re-observes it from a
571        // different session → cross-session confirmation, no new count.
572        let seeded_content = "shared fact content here";
573        let first = Fact::new_pending(
574            seeded_content,
575            mk(),
576            sid(1),
577            smos_domain::Embedding::new(vec![1.0]).unwrap(),
578            Timestamp::from_unix_secs(1_700_000_000).unwrap(),
579            ConfidenceConfig::default().base,
580        )
581        .unwrap();
582        let fid = first.id().clone();
583        fix.facts
584            .store
585            .lock()
586            .unwrap()
587            .insert(fid.as_str().to_string(), first);
588
589        let extractor = ScriptedExtractor::new(vec![vec![seeded_content.to_string()]]);
590        let import = fix.build(extractor, 15);
591
592        let stats = import
593            .execute(vec![turn("a", seeded_content)], &mk(), &sid(2), None)
594            .await
595            .unwrap();
596
597        assert_eq!(stats.facts_extracted, 0, "confirmation is not a new fact");
598        let confirmed = fix
599            .facts
600            .store
601            .lock()
602            .unwrap()
603            .get(fid.as_str())
604            .cloned()
605            .expect("fact present");
606        // Cross-session confirmation promotes the fact through the validation
607        // gate (multi-source bonus + no-contradiction bonus clears accept
608        // threshold). The exact status depends on the confidence formula; we
609        // only assert provenance growth so the test is robust to formula
610        // tweaks.
611        assert_eq!(
612            confirmed.source_sessions().distinct_count(),
613            2,
614            "provenance grew to two sessions"
615        );
616    }
617}