Skip to main content

starpod_session/
lib.rs

1use chrono::{DateTime, Duration, Utc};
2use serde::{Deserialize, Serialize};
3use sqlx::{Row, SqlitePool};
4use tracing::debug;
5use uuid::Uuid;
6
7use starpod_core::{Result, StarpodError};
8
9/// A channel that sessions are scoped to.
10#[derive(Debug, Clone, PartialEq)]
11pub enum Channel {
12    /// Explicit sessions — client controls lifecycle (web, REPL, CLI).
13    Main,
14    /// Time-gap sessions — new session after inactivity threshold (6h).
15    Telegram,
16    /// Time-gap sessions via email — new session after inactivity threshold (24h).
17    Email,
18}
19
20impl Channel {
21    pub fn as_str(&self) -> &'static str {
22        match self {
23            Channel::Main => "main",
24            Channel::Telegram => "telegram",
25            Channel::Email => "email",
26        }
27    }
28
29    pub fn from_channel_str(s: &str) -> Self {
30        match s {
31            "telegram" => Channel::Telegram,
32            "email" => Channel::Email,
33            _ => Channel::Main,
34        }
35    }
36}
37
38/// Decision from session resolution on whether to continue or start a new session.
39#[derive(Debug, Clone)]
40pub enum SessionDecision {
41    /// Continue an existing session (contains session ID).
42    Continue(String),
43    /// Start a new session. If a previous session was auto-closed (e.g. time-gap),
44    /// `closed_session_id` carries its ID so callers can export it.
45    New { closed_session_id: Option<String> },
46}
47
48/// Metadata for a session.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SessionMeta {
51    pub id: String,
52    pub created_at: String,
53    pub last_message_at: String,
54    pub is_closed: bool,
55    pub summary: Option<String>,
56    pub title: Option<String>,
57    pub message_count: i64,
58    pub channel: String,
59    pub channel_session_key: Option<String>,
60    pub user_id: String,
61    pub is_read: bool,
62    /// Cron job name or `"__heartbeat__"` if this session was triggered by a scheduled job.
63    /// `None` for regular user sessions.
64    pub triggered_by: Option<String>,
65}
66
67/// A stored message in a session.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct SessionMessage {
70    pub id: i64,
71    pub session_id: String,
72    pub role: String,
73    pub content: String,
74    pub timestamp: String,
75}
76
77/// Usage record for a single turn.
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct UsageRecord {
80    pub input_tokens: u64,
81    pub output_tokens: u64,
82    pub cache_read: u64,
83    pub cache_write: u64,
84    pub cost_usd: f64,
85    pub model: String,
86    pub user_id: String,
87}
88
89/// Manages session lifecycle — creation, channel-aware resolution, closure, and usage tracking.
90pub struct SessionManager {
91    pool: SqlitePool,
92}
93
94impl SessionManager {
95    /// Create a SessionManager from a shared pool.
96    ///
97    /// The pool should already have migrations applied (via `CoreDb`).
98    pub fn from_pool(pool: SqlitePool) -> Self {
99        Self { pool }
100    }
101
102    /// Resolve session for a given channel and key.
103    ///
104    /// - **Main** (explicit): always continues the matching open session if one exists.
105    /// - **Telegram** (time-gap): continues if last message was within the gap threshold,
106    ///   otherwise auto-closes the old session and returns `New`.
107    ///
108    /// `gap_minutes` is the inactivity gap from config. Pass `None` for explicit
109    /// channels that don't use time-gap sessions.
110    pub async fn resolve_session(
111        &self,
112        channel: &Channel,
113        key: &str,
114        gap_minutes: Option<i64>,
115    ) -> Result<SessionDecision> {
116        self.resolve_session_for_user(channel, key, gap_minutes, "admin")
117            .await
118    }
119
120    /// Resolve session for a given channel, key, and user.
121    pub async fn resolve_session_for_user(
122        &self,
123        channel: &Channel,
124        key: &str,
125        gap_minutes: Option<i64>,
126        user_id: &str,
127    ) -> Result<SessionDecision> {
128        let row = sqlx::query(
129            "SELECT id, last_message_at
130             FROM session_metadata
131             WHERE channel = ?1 AND channel_session_key = ?2 AND is_closed = 0 AND user_id = ?3
132             ORDER BY last_message_at DESC
133             LIMIT 1",
134        )
135        .bind(channel.as_str())
136        .bind(key)
137        .bind(user_id)
138        .fetch_optional(&self.pool)
139        .await
140        .map_err(|e| StarpodError::Database(format!("Resolve session query failed: {}", e)))?;
141
142        let row = match row {
143            Some(r) => r,
144            None => {
145                return Ok(SessionDecision::New {
146                    closed_session_id: None,
147                })
148            }
149        };
150
151        let session_id: String = row.get("id");
152
153        // For explicit channels (no gap), always continue.
154        let gap_threshold = match gap_minutes {
155            None => {
156                debug!(session_id = %session_id, channel = %channel.as_str(), "Continuing session (explicit channel)");
157                return Ok(SessionDecision::Continue(session_id));
158            }
159            Some(gap) => gap,
160        };
161
162        // For time-gap channels, check inactivity
163        let last_msg_str: String = row.get("last_message_at");
164        let last_msg = DateTime::parse_from_rfc3339(&last_msg_str)
165            .map_err(|e| StarpodError::Session(format!("Bad timestamp: {}", e)))?
166            .with_timezone(&Utc);
167
168        let gap = Utc::now() - last_msg;
169
170        if gap < Duration::minutes(gap_threshold) {
171            debug!(session_id = %session_id, gap_mins = gap.num_minutes(), "Continuing session (within gap)");
172            Ok(SessionDecision::Continue(session_id))
173        } else {
174            debug!(session_id = %session_id, gap_mins = gap.num_minutes(), "Auto-closing session (gap exceeded)");
175            self.close_session(&session_id, "Auto-closed: inactivity")
176                .await?;
177            Ok(SessionDecision::New {
178                closed_session_id: Some(session_id),
179            })
180        }
181    }
182
183    /// Create a new session for a channel and key, returning its ID.
184    pub async fn create_session(&self, channel: &Channel, key: &str) -> Result<String> {
185        self.create_session_full(channel, key, "admin", None).await
186    }
187
188    /// Create a new session for a channel, key, and user, returning its ID.
189    pub async fn create_session_for_user(
190        &self,
191        channel: &Channel,
192        key: &str,
193        user_id: &str,
194    ) -> Result<String> {
195        self.create_session_full(channel, key, user_id, None).await
196    }
197
198    /// Create a new session with full metadata, including an optional trigger source.
199    ///
200    /// `triggered_by` records the cron job name (e.g. `"daily-digest"`) or
201    /// `"__heartbeat__"` when the session is created by the scheduler.
202    pub async fn create_session_full(
203        &self,
204        channel: &Channel,
205        key: &str,
206        user_id: &str,
207        triggered_by: Option<&str>,
208    ) -> Result<String> {
209        let id = Uuid::new_v4().to_string();
210        let now = Utc::now().to_rfc3339();
211
212        sqlx::query(
213            "INSERT INTO session_metadata (id, created_at, last_message_at, is_closed, message_count, channel, channel_session_key, user_id, triggered_by)
214             VALUES (?1, ?2, ?2, 0, 0, ?3, ?4, ?5, ?6)",
215        )
216        .bind(&id)
217        .bind(&now)
218        .bind(channel.as_str())
219        .bind(key)
220        .bind(user_id)
221        .bind(triggered_by)
222        .execute(&self.pool)
223        .await
224        .map_err(|e| StarpodError::Database(format!("Create session failed: {}", e)))?;
225
226        debug!(session_id = %id, channel = %channel.as_str(), key = %key, "Created new session");
227        Ok(id)
228    }
229
230    /// Mark a session as closed with an optional summary.
231    pub async fn close_session(&self, id: &str, summary: &str) -> Result<()> {
232        sqlx::query("UPDATE session_metadata SET is_closed = 1, summary = ?2 WHERE id = ?1")
233            .bind(id)
234            .bind(summary)
235            .execute(&self.pool)
236            .await
237            .map_err(|e| StarpodError::Database(format!("Close session failed: {}", e)))?;
238
239        debug!(session_id = %id, "Closed session");
240        Ok(())
241    }
242
243    /// Mark a session as read or unread.
244    pub async fn mark_read(&self, id: &str, is_read: bool) -> Result<()> {
245        sqlx::query("UPDATE session_metadata SET is_read = ?2 WHERE id = ?1")
246            .bind(id)
247            .bind(is_read as i64)
248            .execute(&self.pool)
249            .await
250            .map_err(|e| StarpodError::Database(format!("Mark read failed: {}", e)))?;
251        Ok(())
252    }
253
254    /// Update the last_message_at timestamp and increment message_count.
255    pub async fn touch_session(&self, id: &str) -> Result<()> {
256        let now = Utc::now().to_rfc3339();
257        sqlx::query(
258            "UPDATE session_metadata SET last_message_at = ?2, message_count = message_count + 1 WHERE id = ?1",
259        )
260        .bind(id)
261        .bind(&now)
262        .execute(&self.pool)
263        .await
264        .map_err(|e| StarpodError::Database(format!("Touch session failed: {}", e)))?;
265        Ok(())
266    }
267
268    /// Set the session title if it hasn't been set yet.
269    pub async fn set_title_if_empty(&self, id: &str, title: &str) -> Result<()> {
270        let truncated = if title.len() > 100 {
271            let mut end = 100;
272            while end > 0 && !title.is_char_boundary(end) {
273                end -= 1;
274            }
275            format!("{}...", &title[..end])
276        } else {
277            title.to_string()
278        };
279        sqlx::query("UPDATE session_metadata SET title = ?2 WHERE id = ?1 AND title IS NULL")
280            .bind(id)
281            .bind(&truncated)
282            .execute(&self.pool)
283            .await
284            .map_err(|e| StarpodError::Database(format!("Set title failed: {}", e)))?;
285        Ok(())
286    }
287
288    /// Record token usage for a turn.
289    pub async fn record_usage(
290        &self,
291        session_id: &str,
292        usage: &UsageRecord,
293        turn: u32,
294    ) -> Result<()> {
295        let now = Utc::now().to_rfc3339();
296        sqlx::query(
297            "INSERT INTO usage_stats (session_id, turn, input_tokens, output_tokens, cache_read, cache_write, cost_usd, model, user_id, timestamp)
298             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
299        )
300        .bind(session_id)
301        .bind(turn as i64)
302        .bind(usage.input_tokens as i64)
303        .bind(usage.output_tokens as i64)
304        .bind(usage.cache_read as i64)
305        .bind(usage.cache_write as i64)
306        .bind(usage.cost_usd)
307        .bind(&usage.model)
308        .bind(&usage.user_id)
309        .bind(&now)
310        .execute(&self.pool)
311        .await
312        .map_err(|e| StarpodError::Database(format!("Record usage failed: {}", e)))?;
313
314        Ok(())
315    }
316
317    /// List sessions, most recent first.
318    pub async fn list_sessions(&self, limit: usize) -> Result<Vec<SessionMeta>> {
319        let rows = sqlx::query(
320            "SELECT id, created_at, last_message_at, is_closed, summary, title, message_count, channel, channel_session_key, user_id, is_read, triggered_by
321             FROM session_metadata
322             ORDER BY last_message_at DESC
323             LIMIT ?1",
324        )
325        .bind(limit as i64)
326        .fetch_all(&self.pool)
327        .await
328        .map_err(|e| StarpodError::Database(format!("Query failed: {}", e)))?;
329
330        let sessions: Vec<SessionMeta> = rows.iter().map(session_meta_from_row).collect();
331
332        Ok(sessions)
333    }
334
335    /// Get a specific session by ID.
336    pub async fn get_session(&self, id: &str) -> Result<Option<SessionMeta>> {
337        let row = sqlx::query(
338            "SELECT id, created_at, last_message_at, is_closed, summary, title, message_count, channel, channel_session_key, user_id, is_read, triggered_by
339             FROM session_metadata WHERE id = ?1",
340        )
341        .bind(id)
342        .fetch_optional(&self.pool)
343        .await
344        .map_err(|e| StarpodError::Database(format!("Get session failed: {}", e)))?;
345
346        Ok(row.map(|r| session_meta_from_row(&r)))
347    }
348
349    /// Get total usage stats for a session.
350    ///
351    /// `total_input_tokens` includes uncached, cache-read, and cache-write
352    /// tokens so the caller gets the true context size. Cache breakdown is
353    /// available via `total_cache_read` / `total_cache_write`.
354    pub async fn session_usage(&self, session_id: &str) -> Result<UsageSummary> {
355        let row = sqlx::query(
356            "SELECT COALESCE(SUM(input_tokens + cache_read + cache_write), 0) as ti, COALESCE(SUM(output_tokens), 0) as to_,
357                    COALESCE(SUM(cache_read), 0) as cr, COALESCE(SUM(cache_write), 0) as cw,
358                    COALESCE(SUM(cost_usd), 0.0) as cost, COUNT(*) as turns
359             FROM usage_stats WHERE session_id = ?1",
360        )
361        .bind(session_id)
362        .fetch_one(&self.pool)
363        .await
364        .map_err(|e| StarpodError::Database(format!("Usage query failed: {}", e)))?;
365
366        Ok(UsageSummary {
367            total_input_tokens: row.get::<i64, _>("ti") as u64,
368            total_output_tokens: row.get::<i64, _>("to_") as u64,
369            total_cache_read: row.get::<i64, _>("cr") as u64,
370            total_cache_write: row.get::<i64, _>("cw") as u64,
371            total_cost_usd: row.get::<f64, _>("cost"),
372            total_turns: row.get::<i64, _>("turns") as u32,
373        })
374    }
375
376    /// Get a full cost overview with breakdowns by user and model.
377    ///
378    /// If `since` is provided (RFC 3339 timestamp), only usage after that time is included.
379    pub async fn cost_overview(&self, since: Option<&str>) -> Result<CostOverview> {
380        let (where_clause, bind_val) = match since {
381            Some(ts) => ("WHERE timestamp >= ?1", Some(ts)),
382            None => ("", None),
383        };
384
385        // Total
386        let total_sql = format!(
387            "SELECT COALESCE(SUM(cost_usd), 0.0) as cost,
388                    COALESCE(SUM(input_tokens + cache_read + cache_write), 0) as ti,
389                    COALESCE(SUM(output_tokens), 0) as to_,
390                    COALESCE(SUM(cache_read), 0) as cr,
391                    COALESCE(SUM(cache_write), 0) as cw,
392                    COUNT(*) as turns
393             FROM usage_stats {}",
394            where_clause
395        );
396        let mut q = sqlx::query(&total_sql);
397        if let Some(ts) = bind_val {
398            q = q.bind(ts);
399        }
400        let total_row = q
401            .fetch_one(&self.pool)
402            .await
403            .map_err(|e| StarpodError::Database(format!("Cost total query failed: {}", e)))?;
404
405        // By user
406        let user_sql = format!(
407            "SELECT user_id,
408                    COALESCE(SUM(cost_usd), 0.0) as cost,
409                    COALESCE(SUM(input_tokens + cache_read + cache_write), 0) as ti,
410                    COALESCE(SUM(output_tokens), 0) as to_,
411                    COALESCE(SUM(cache_read), 0) as cr,
412                    COALESCE(SUM(cache_write), 0) as cw,
413                    COUNT(*) as turns
414             FROM usage_stats {} GROUP BY user_id ORDER BY cost DESC",
415            where_clause
416        );
417        let mut q = sqlx::query(&user_sql);
418        if let Some(ts) = bind_val {
419            q = q.bind(ts);
420        }
421        let user_rows = q
422            .fetch_all(&self.pool)
423            .await
424            .map_err(|e| StarpodError::Database(format!("Cost by-user query failed: {}", e)))?;
425
426        // By model
427        let model_sql = format!(
428            "SELECT model,
429                    COALESCE(SUM(cost_usd), 0.0) as cost,
430                    COALESCE(SUM(input_tokens + cache_read + cache_write), 0) as ti,
431                    COALESCE(SUM(output_tokens), 0) as to_,
432                    COALESCE(SUM(cache_read), 0) as cr,
433                    COALESCE(SUM(cache_write), 0) as cw,
434                    COUNT(*) as turns
435             FROM usage_stats {} GROUP BY model ORDER BY cost DESC",
436            where_clause
437        );
438        let mut q = sqlx::query(&model_sql);
439        if let Some(ts) = bind_val {
440            q = q.bind(ts);
441        }
442        let model_rows = q
443            .fetch_all(&self.pool)
444            .await
445            .map_err(|e| StarpodError::Database(format!("Cost by-model query failed: {}", e)))?;
446
447        // By day + model
448        let day_sql = format!(
449            "SELECT DATE(timestamp) as day, COALESCE(model, 'unknown') as model,
450                    COALESCE(SUM(cost_usd), 0.0) as cost
451             FROM usage_stats {} GROUP BY day, model ORDER BY day ASC",
452            where_clause
453        );
454        let mut q = sqlx::query(&day_sql);
455        if let Some(ts) = bind_val {
456            q = q.bind(ts);
457        }
458        let day_rows = q
459            .fetch_all(&self.pool)
460            .await
461            .map_err(|e| StarpodError::Database(format!("Cost by-day query failed: {}", e)))?;
462
463        // Group day rows into DayCostSummary
464        let mut by_day: Vec<DayCostSummary> = Vec::new();
465        for row in &day_rows {
466            let date: String = row.get("day");
467            let model: String = row.get("model");
468            let cost: f64 = row.get::<f64, _>("cost");
469            if let Some(last) = by_day.last_mut().filter(|d| d.date == date) {
470                last.total_cost_usd += cost;
471                last.by_model.push(DayModelCost {
472                    model,
473                    cost_usd: cost,
474                });
475            } else {
476                by_day.push(DayCostSummary {
477                    date,
478                    total_cost_usd: cost,
479                    by_model: vec![DayModelCost {
480                        model,
481                        cost_usd: cost,
482                    }],
483                });
484            }
485        }
486
487        // By tool (from session_messages)
488        let tool_sql = format!(
489            "SELECT json_extract(sm.content, '$.name') AS tool_name,
490                    COUNT(*) AS invocations,
491                    COALESCE(SUM(
492                      CASE WHEN tr.content IS NOT NULL
493                           AND json_extract(tr.content, '$.is_error') = 1
494                      THEN 1 ELSE 0 END
495                    ), 0) AS errors
496             FROM session_messages sm
497             LEFT JOIN session_messages tr
498               ON tr.session_id = sm.session_id
499               AND tr.role = 'tool_result'
500               AND json_extract(tr.content, '$.tool_use_id') = json_extract(sm.content, '$.id')
501             WHERE sm.role = 'tool_use'
502               {}
503             GROUP BY tool_name
504             ORDER BY invocations DESC",
505            if bind_val.is_some() {
506                "AND sm.timestamp >= ?1"
507            } else {
508                ""
509            }
510        );
511        let mut q = sqlx::query(&tool_sql);
512        if let Some(ts) = bind_val {
513            q = q.bind(ts);
514        }
515        let tool_rows = q
516            .fetch_all(&self.pool)
517            .await
518            .map_err(|e| StarpodError::Database(format!("Cost by-tool query failed: {}", e)))?;
519
520        let by_tool: Vec<ToolUsageSummary> = tool_rows
521            .iter()
522            .map(|r| ToolUsageSummary {
523                tool_name: r
524                    .try_get("tool_name")
525                    .unwrap_or_else(|_| "unknown".to_string()),
526                invocations: r.get::<i64, _>("invocations") as u32,
527                errors: r.get::<i64, _>("errors") as u32,
528            })
529            .collect();
530
531        Ok(CostOverview {
532            total_cost_usd: total_row.get::<f64, _>("cost"),
533            total_input_tokens: total_row.get::<i64, _>("ti") as u64,
534            total_output_tokens: total_row.get::<i64, _>("to_") as u64,
535            total_cache_read: total_row.get::<i64, _>("cr") as u64,
536            total_cache_write: total_row.get::<i64, _>("cw") as u64,
537            total_turns: total_row.get::<i64, _>("turns") as u32,
538            by_user: user_rows
539                .iter()
540                .map(|r| UserCostSummary {
541                    user_id: r.get("user_id"),
542                    total_cost_usd: r.get::<f64, _>("cost"),
543                    total_input_tokens: r.get::<i64, _>("ti") as u64,
544                    total_output_tokens: r.get::<i64, _>("to_") as u64,
545                    total_cache_read: r.get::<i64, _>("cr") as u64,
546                    total_cache_write: r.get::<i64, _>("cw") as u64,
547                    total_turns: r.get::<i64, _>("turns") as u32,
548                })
549                .collect(),
550            by_model: model_rows
551                .iter()
552                .map(|r| ModelCostSummary {
553                    model: r.try_get("model").unwrap_or_else(|_| "unknown".to_string()),
554                    total_cost_usd: r.get::<f64, _>("cost"),
555                    total_input_tokens: r.get::<i64, _>("ti") as u64,
556                    total_output_tokens: r.get::<i64, _>("to_") as u64,
557                    total_cache_read: r.get::<i64, _>("cr") as u64,
558                    total_cache_write: r.get::<i64, _>("cw") as u64,
559                    total_turns: r.get::<i64, _>("turns") as u32,
560                })
561                .collect(),
562            by_day,
563            by_tool,
564        })
565    }
566
567    /// Record a compaction event for a session.
568    pub async fn record_compaction(
569        &self,
570        session_id: &str,
571        trigger: &str,
572        pre_tokens: u64,
573        summary: &str,
574        messages_compacted: usize,
575    ) -> Result<()> {
576        let now = Utc::now().to_rfc3339();
577        sqlx::query(
578            "INSERT INTO compaction_log (session_id, timestamp, trigger, pre_tokens, summary, messages_compacted)
579             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
580        )
581        .bind(session_id)
582        .bind(&now)
583        .bind(trigger)
584        .bind(pre_tokens as i64)
585        .bind(summary)
586        .bind(messages_compacted as i64)
587        .execute(&self.pool)
588        .await
589        .map_err(|e| StarpodError::Database(format!("Record compaction failed: {}", e)))?;
590
591        debug!(session_id = %session_id, pre_tokens, messages_compacted, "Recorded compaction event");
592        Ok(())
593    }
594
595    /// Save a message to a session.
596    ///
597    /// When the first "user" message is saved, the session title is automatically
598    /// set to the message text (truncated to 100 chars).
599    pub async fn save_message(&self, session_id: &str, role: &str, content: &str) -> Result<()> {
600        let now = Utc::now().to_rfc3339();
601        sqlx::query(
602            "INSERT INTO session_messages (session_id, role, content, timestamp)
603             VALUES (?1, ?2, ?3, ?4)",
604        )
605        .bind(session_id)
606        .bind(role)
607        .bind(content)
608        .bind(&now)
609        .execute(&self.pool)
610        .await
611        .map_err(|e| StarpodError::Database(format!("Save message failed: {}", e)))?;
612
613        // Auto-set title from first user message
614        if role == "user" {
615            let title = if content.len() > 100 {
616                let mut end = 100;
617                while end > 0 && !content.is_char_boundary(end) {
618                    end -= 1;
619                }
620                format!("{}...", &content[..end])
621            } else {
622                content.to_string()
623            };
624            // Only set if title is currently NULL (first message)
625            sqlx::query("UPDATE session_metadata SET title = ?2 WHERE id = ?1 AND title IS NULL")
626                .bind(session_id)
627                .bind(&title)
628                .execute(&self.pool)
629                .await
630                .map_err(|e| StarpodError::Database(format!("Set title failed: {}", e)))?;
631        }
632
633        Ok(())
634    }
635
636    /// Get all messages for a session, ordered by ID.
637    pub async fn get_messages(&self, session_id: &str) -> Result<Vec<SessionMessage>> {
638        let rows = sqlx::query(
639            "SELECT id, session_id, role, content, timestamp
640             FROM session_messages
641             WHERE session_id = ?1
642             ORDER BY id ASC",
643        )
644        .bind(session_id)
645        .fetch_all(&self.pool)
646        .await
647        .map_err(|e| StarpodError::Database(format!("Get messages failed: {}", e)))?;
648
649        Ok(rows
650            .iter()
651            .map(|r| SessionMessage {
652                id: r.get("id"),
653                session_id: r.get("session_id"),
654                role: r.get("role"),
655                content: r.get("content"),
656                timestamp: r.get("timestamp"),
657            })
658            .collect())
659    }
660}
661
662/// Extract a SessionMeta from a database row.
663fn session_meta_from_row(row: &sqlx::sqlite::SqliteRow) -> SessionMeta {
664    SessionMeta {
665        id: row.get("id"),
666        created_at: row.get("created_at"),
667        last_message_at: row.get("last_message_at"),
668        is_closed: row.get::<i64, _>("is_closed") != 0,
669        summary: row.get("summary"),
670        title: row.get("title"),
671        message_count: row.get("message_count"),
672        channel: row.get("channel"),
673        channel_session_key: row.get("channel_session_key"),
674        user_id: row
675            .try_get("user_id")
676            .unwrap_or_else(|_| "admin".to_string()),
677        is_read: row.try_get::<i64, _>("is_read").unwrap_or(1) != 0,
678        triggered_by: row.try_get("triggered_by").unwrap_or(None),
679    }
680}
681
682/// Aggregated usage summary for a session.
683///
684/// ## Token accounting
685///
686/// `total_input_tokens` is the **total** input context size across all turns,
687/// i.e. `SUM(input_tokens + cache_read + cache_write)` from the per-turn
688/// records. This is what the UI displays as "X in".
689///
690/// `total_cache_read` and `total_cache_write` are the cached subsets of
691/// that total — useful for showing cache efficiency (e.g. "2.1k cached").
692#[derive(Debug, Clone, Default, Serialize, Deserialize)]
693pub struct UsageSummary {
694    /// Total input tokens (uncached + cache_read + cache_write).
695    pub total_input_tokens: u64,
696    pub total_output_tokens: u64,
697    /// Tokens served from prompt cache.
698    pub total_cache_read: u64,
699    /// Tokens written to prompt cache.
700    pub total_cache_write: u64,
701    pub total_cost_usd: f64,
702    pub total_turns: u32,
703}
704
705/// Cost summary per user.
706#[derive(Debug, Clone, Serialize, Deserialize)]
707pub struct UserCostSummary {
708    pub user_id: String,
709    pub total_cost_usd: f64,
710    /// Total input tokens (uncached + cache_read + cache_write).
711    pub total_input_tokens: u64,
712    pub total_output_tokens: u64,
713    pub total_cache_read: u64,
714    pub total_cache_write: u64,
715    pub total_turns: u32,
716}
717
718/// Cost summary per model.
719#[derive(Debug, Clone, Serialize, Deserialize)]
720pub struct ModelCostSummary {
721    pub model: String,
722    pub total_cost_usd: f64,
723    /// Total input tokens (uncached + cache_read + cache_write).
724    pub total_input_tokens: u64,
725    pub total_output_tokens: u64,
726    pub total_cache_read: u64,
727    pub total_cache_write: u64,
728    pub total_turns: u32,
729}
730
731/// Cost summary for a single day, broken down by model.
732#[derive(Debug, Clone, Serialize, Deserialize)]
733pub struct DayCostSummary {
734    /// Date string (YYYY-MM-DD).
735    pub date: String,
736    /// Cost per model on this day.
737    pub by_model: Vec<DayModelCost>,
738    /// Total cost for this day.
739    pub total_cost_usd: f64,
740}
741
742/// Cost for a single model on a single day.
743#[derive(Debug, Clone, Serialize, Deserialize)]
744pub struct DayModelCost {
745    pub model: String,
746    pub cost_usd: f64,
747}
748
749/// Aggregated tool invocation statistics, grouped by tool name.
750///
751/// Extracted from `session_messages` rows with `role = "tool_use"` and
752/// `role = "tool_result"`.  The error count is derived by joining each
753/// `tool_use` to its matching `tool_result` and checking the `is_error` flag.
754#[derive(Debug, Clone, Serialize, Deserialize)]
755pub struct ToolUsageSummary {
756    /// The tool name (e.g. `"MemorySearch"`, `"VaultGet"`).
757    pub tool_name: String,
758    /// Total number of times this tool was invoked.
759    pub invocations: u32,
760    /// How many of those invocations resulted in an error.
761    pub errors: u32,
762}
763
764/// Full cost overview with breakdowns by user and model.
765///
766/// All `total_input_tokens` fields include cached tokens — see [`UsageSummary`]
767/// for the full accounting explanation.
768#[derive(Debug, Clone, Serialize, Deserialize)]
769pub struct CostOverview {
770    pub total_cost_usd: f64,
771    /// Total input tokens (uncached + cache_read + cache_write).
772    pub total_input_tokens: u64,
773    pub total_output_tokens: u64,
774    pub total_cache_read: u64,
775    pub total_cache_write: u64,
776    pub total_turns: u32,
777    pub by_user: Vec<UserCostSummary>,
778    pub by_model: Vec<ModelCostSummary>,
779    pub by_day: Vec<DayCostSummary>,
780    /// Tool invocation counts, sorted by invocations descending.
781    pub by_tool: Vec<ToolUsageSummary>,
782}
783
784#[cfg(test)]
785mod tests {
786    use super::*;
787    use tempfile::TempDir;
788
789    async fn setup() -> (TempDir, SessionManager) {
790        let tmp = TempDir::new().unwrap();
791        let db = starpod_db::CoreDb::in_memory().await.unwrap();
792        let mgr = SessionManager::from_pool(db.pool().clone());
793        (tmp, mgr)
794    }
795
796    #[tokio::test]
797    async fn test_create_and_get_session() {
798        let (_tmp, mgr) = setup().await;
799        let id = mgr
800            .create_session(&Channel::Main, "test-key")
801            .await
802            .unwrap();
803
804        let session = mgr.get_session(&id).await.unwrap().unwrap();
805        assert_eq!(session.id, id);
806        assert!(!session.is_closed);
807        assert_eq!(session.message_count, 0);
808        assert_eq!(session.channel, "main");
809        assert_eq!(session.channel_session_key.as_deref(), Some("test-key"));
810    }
811
812    #[tokio::test]
813    async fn test_close_session() {
814        let (_tmp, mgr) = setup().await;
815        let id = mgr
816            .create_session(&Channel::Main, "test-key")
817            .await
818            .unwrap();
819
820        mgr.close_session(&id, "Discussed Rust memory management")
821            .await
822            .unwrap();
823
824        let session = mgr.get_session(&id).await.unwrap().unwrap();
825        assert!(session.is_closed);
826        assert_eq!(
827            session.summary.as_deref(),
828            Some("Discussed Rust memory management")
829        );
830    }
831
832    #[tokio::test]
833    async fn test_touch_session() {
834        let (_tmp, mgr) = setup().await;
835        let id = mgr
836            .create_session(&Channel::Main, "test-key")
837            .await
838            .unwrap();
839
840        mgr.touch_session(&id).await.unwrap();
841        mgr.touch_session(&id).await.unwrap();
842
843        let session = mgr.get_session(&id).await.unwrap().unwrap();
844        assert_eq!(session.message_count, 2);
845    }
846
847    #[tokio::test]
848    async fn test_resolve_session_new_when_empty() {
849        let (_tmp, mgr) = setup().await;
850
851        match mgr
852            .resolve_session(&Channel::Main, "some-key", None)
853            .await
854            .unwrap()
855        {
856            SessionDecision::New { .. } => {} // expected
857            SessionDecision::Continue(_) => panic!("Should be New when no sessions exist"),
858        }
859    }
860
861    #[tokio::test]
862    async fn test_resolve_session_continue_recent() {
863        let (_tmp, mgr) = setup().await;
864        let id = mgr.create_session(&Channel::Main, "key-1").await.unwrap();
865        mgr.touch_session(&id).await.unwrap();
866
867        match mgr
868            .resolve_session(&Channel::Main, "key-1", None)
869            .await
870            .unwrap()
871        {
872            SessionDecision::Continue(sid) => assert_eq!(sid, id),
873            SessionDecision::New { .. } => panic!("Should continue recent session"),
874        }
875    }
876
877    #[tokio::test]
878    async fn test_resolve_session_new_when_closed() {
879        let (_tmp, mgr) = setup().await;
880        let id = mgr.create_session(&Channel::Main, "key-1").await.unwrap();
881        mgr.touch_session(&id).await.unwrap();
882        mgr.close_session(&id, "done").await.unwrap();
883
884        match mgr
885            .resolve_session(&Channel::Main, "key-1", None)
886            .await
887            .unwrap()
888        {
889            SessionDecision::New { .. } => {} // expected
890            SessionDecision::Continue(_) => panic!("Should not continue closed session"),
891        }
892    }
893
894    #[tokio::test]
895    async fn test_list_sessions() {
896        let (_tmp, mgr) = setup().await;
897        mgr.create_session(&Channel::Main, "k1").await.unwrap();
898        mgr.create_session(&Channel::Main, "k2").await.unwrap();
899        mgr.create_session(&Channel::Telegram, "chat-1")
900            .await
901            .unwrap();
902
903        let sessions = mgr.list_sessions(10).await.unwrap();
904        assert_eq!(sessions.len(), 3);
905    }
906
907    #[tokio::test]
908    async fn test_record_and_query_usage() {
909        let (_tmp, mgr) = setup().await;
910        let id = mgr
911            .create_session(&Channel::Main, "test-key")
912            .await
913            .unwrap();
914
915        mgr.record_usage(
916            &id,
917            &UsageRecord {
918                input_tokens: 1000,
919                output_tokens: 500,
920                cache_read: 200,
921                cache_write: 100,
922                cost_usd: 0.01,
923                model: "claude-sonnet".into(),
924                user_id: "admin".into(),
925            },
926            1,
927        )
928        .await
929        .unwrap();
930
931        mgr.record_usage(
932            &id,
933            &UsageRecord {
934                input_tokens: 800,
935                output_tokens: 400,
936                cache_read: 150,
937                cache_write: 50,
938                cost_usd: 0.008,
939                model: "claude-sonnet".into(),
940                user_id: "admin".into(),
941            },
942            2,
943        )
944        .await
945        .unwrap();
946
947        let summary = mgr.session_usage(&id).await.unwrap();
948        // total_input_tokens includes input_tokens + cache_read + cache_write
949        // Turn 1: 1000 + 200 + 100 = 1300, Turn 2: 800 + 150 + 50 = 1000
950        assert_eq!(summary.total_input_tokens, 2300);
951        assert_eq!(summary.total_output_tokens, 900);
952        assert_eq!(summary.total_turns, 2);
953        assert!((summary.total_cost_usd - 0.018).abs() < 0.001);
954    }
955
956    #[tokio::test]
957    async fn test_usage_cache_breakdown() {
958        let (_tmp, mgr) = setup().await;
959        let id = mgr
960            .create_session(&Channel::Main, "cache-test")
961            .await
962            .unwrap();
963
964        // Turn 1: cache miss — all tokens go to cache_write
965        mgr.record_usage(
966            &id,
967            &UsageRecord {
968                input_tokens: 500,
969                output_tokens: 200,
970                cache_read: 0,
971                cache_write: 4000,
972                cost_usd: 0.05,
973                model: "claude-sonnet".into(),
974                user_id: "admin".into(),
975            },
976            1,
977        )
978        .await
979        .unwrap();
980
981        // Turn 2: cache hit — most tokens served from cache
982        mgr.record_usage(
983            &id,
984            &UsageRecord {
985                input_tokens: 100,
986                output_tokens: 300,
987                cache_read: 4000,
988                cache_write: 0,
989                cost_usd: 0.01,
990                model: "claude-sonnet".into(),
991                user_id: "admin".into(),
992            },
993            2,
994        )
995        .await
996        .unwrap();
997
998        let summary = mgr.session_usage(&id).await.unwrap();
999
1000        // total_input_tokens = (500 + 0 + 4000) + (100 + 4000 + 0) = 8600
1001        assert_eq!(summary.total_input_tokens, 8600);
1002        assert_eq!(summary.total_output_tokens, 500);
1003        // Cache breakdown preserved separately
1004        assert_eq!(summary.total_cache_read, 4000);
1005        assert_eq!(summary.total_cache_write, 4000);
1006        assert_eq!(summary.total_turns, 2);
1007        assert!((summary.total_cost_usd - 0.06).abs() < 0.001);
1008    }
1009
1010    // --- New channel-specific tests ---
1011
1012    #[tokio::test]
1013    async fn test_main_explicit_sessions() {
1014        let (_tmp, mgr) = setup().await;
1015
1016        // Create session for key "abc"
1017        let id = mgr.create_session(&Channel::Main, "abc").await.unwrap();
1018        mgr.touch_session(&id).await.unwrap();
1019
1020        // Same key → continue
1021        match mgr
1022            .resolve_session(&Channel::Main, "abc", None)
1023            .await
1024            .unwrap()
1025        {
1026            SessionDecision::Continue(sid) => assert_eq!(sid, id),
1027            SessionDecision::New { .. } => panic!("Should continue with same key"),
1028        }
1029
1030        // Different key → new
1031        match mgr
1032            .resolve_session(&Channel::Main, "xyz", None)
1033            .await
1034            .unwrap()
1035        {
1036            SessionDecision::New { .. } => {} // expected
1037            SessionDecision::Continue(_) => panic!("Different key should get new session"),
1038        }
1039    }
1040
1041    #[tokio::test]
1042    async fn test_telegram_time_gap() {
1043        let (_tmp, mgr) = setup().await;
1044        let gap = Some(360); // 6h, as configured via [channels.telegram] gap_minutes
1045
1046        // Create a telegram session
1047        let id = mgr
1048            .create_session(&Channel::Telegram, "chat-123")
1049            .await
1050            .unwrap();
1051        mgr.touch_session(&id).await.unwrap();
1052
1053        // Within 6h → continue
1054        match mgr
1055            .resolve_session(&Channel::Telegram, "chat-123", gap)
1056            .await
1057            .unwrap()
1058        {
1059            SessionDecision::Continue(sid) => assert_eq!(sid, id),
1060            SessionDecision::New { .. } => panic!("Should continue within gap"),
1061        }
1062
1063        // Manually set last_message_at to 7h ago to simulate inactivity
1064        let old_time = (Utc::now() - Duration::hours(7)).to_rfc3339();
1065        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1066            .bind(&old_time)
1067            .bind(&id)
1068            .execute(&mgr.pool)
1069            .await
1070            .unwrap();
1071
1072        // Beyond 6h → new (old session auto-closed)
1073        match mgr
1074            .resolve_session(&Channel::Telegram, "chat-123", gap)
1075            .await
1076            .unwrap()
1077        {
1078            SessionDecision::New { .. } => {} // expected
1079            SessionDecision::Continue(_) => panic!("Should start new session after 7h gap"),
1080        }
1081
1082        // Verify old session was auto-closed
1083        let old = mgr.get_session(&id).await.unwrap().unwrap();
1084        assert!(old.is_closed);
1085        assert_eq!(old.summary.as_deref(), Some("Auto-closed: inactivity"));
1086    }
1087
1088    #[tokio::test]
1089    async fn test_record_compaction() {
1090        let (_tmp, mgr) = setup().await;
1091        let id = mgr
1092            .create_session(&Channel::Main, "test-key")
1093            .await
1094            .unwrap();
1095
1096        mgr.record_compaction(&id, "auto", 150_000, "Summary of old messages", 12)
1097            .await
1098            .unwrap();
1099
1100        // Verify via raw query
1101        let row = sqlx::query(
1102            "SELECT trigger, pre_tokens, summary, messages_compacted FROM compaction_log WHERE session_id = ?1",
1103        )
1104        .bind(&id)
1105        .fetch_one(&mgr.pool)
1106        .await
1107        .unwrap();
1108
1109        assert_eq!(row.get::<String, _>("trigger"), "auto");
1110        assert_eq!(row.get::<i64, _>("pre_tokens"), 150_000);
1111        assert_eq!(row.get::<String, _>("summary"), "Summary of old messages");
1112        assert_eq!(row.get::<i64, _>("messages_compacted"), 12);
1113    }
1114
1115    #[tokio::test]
1116    async fn test_telegram_custom_gap_override() {
1117        let (_tmp, mgr) = setup().await;
1118
1119        // Create a Telegram session
1120        let id = mgr
1121            .create_session(&Channel::Telegram, "chat-gap")
1122            .await
1123            .unwrap();
1124        mgr.touch_session(&id).await.unwrap();
1125
1126        // Set last_message_at to 2 hours ago
1127        let two_hours_ago = (Utc::now() - Duration::hours(2)).to_rfc3339();
1128        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1129            .bind(&two_hours_ago)
1130            .bind(&id)
1131            .execute(&mgr.pool)
1132            .await
1133            .unwrap();
1134
1135        // gap_minutes=60 (1h) — 2h ago exceeds 1h → should be New
1136        match mgr
1137            .resolve_session(&Channel::Telegram, "chat-gap", Some(60))
1138            .await
1139            .unwrap()
1140        {
1141            SessionDecision::New { .. } => {} // expected
1142            SessionDecision::Continue(_) => panic!("Should start new session when 2h > 1h gap"),
1143        }
1144
1145        // The old session was auto-closed, create a fresh one and backdate it again
1146        let id2 = mgr
1147            .create_session(&Channel::Telegram, "chat-gap")
1148            .await
1149            .unwrap();
1150        mgr.touch_session(&id2).await.unwrap();
1151        let two_hours_ago = (Utc::now() - Duration::hours(2)).to_rfc3339();
1152        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1153            .bind(&two_hours_ago)
1154            .bind(&id2)
1155            .execute(&mgr.pool)
1156            .await
1157            .unwrap();
1158
1159        // gap_minutes=180 (3h) — 2h ago is within 3h → should Continue
1160        match mgr
1161            .resolve_session(&Channel::Telegram, "chat-gap", Some(180))
1162            .await
1163            .unwrap()
1164        {
1165            SessionDecision::Continue(sid) => assert_eq!(sid, id2),
1166            SessionDecision::New { .. } => panic!("Should continue session when 2h < 3h gap"),
1167        }
1168    }
1169
1170    #[tokio::test]
1171    async fn test_main_channel_ignores_gap() {
1172        let (_tmp, mgr) = setup().await;
1173
1174        // Create a Main session
1175        let id = mgr
1176            .create_session(&Channel::Main, "main-gap")
1177            .await
1178            .unwrap();
1179        mgr.touch_session(&id).await.unwrap();
1180
1181        // Without a gap_minutes override, Main channel always continues (explicit)
1182        match mgr
1183            .resolve_session(&Channel::Main, "main-gap", None)
1184            .await
1185            .unwrap()
1186        {
1187            SessionDecision::Continue(sid) => assert_eq!(sid, id),
1188            SessionDecision::New { .. } => {
1189                panic!("Main channel should always continue without gap override")
1190            }
1191        }
1192
1193        // Even backdating last_message_at to 24h ago, Main without gap override still continues
1194        let old = (Utc::now() - Duration::hours(24)).to_rfc3339();
1195        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1196            .bind(&old)
1197            .bind(&id)
1198            .execute(&mgr.pool)
1199            .await
1200            .unwrap();
1201
1202        match mgr.resolve_session(&Channel::Main, "main-gap", None).await.unwrap() {
1203            SessionDecision::Continue(sid) => assert_eq!(sid, id),
1204            SessionDecision::New { .. } => panic!("Main channel should continue even with old last_message_at when gap_minutes is None"),
1205        }
1206    }
1207
1208    #[tokio::test]
1209    async fn test_channel_isolation() {
1210        let (_tmp, mgr) = setup().await;
1211
1212        // Create sessions with same key on different channels
1213        let main_id = mgr
1214            .create_session(&Channel::Main, "shared-key")
1215            .await
1216            .unwrap();
1217        let tg_id = mgr
1218            .create_session(&Channel::Telegram, "shared-key")
1219            .await
1220            .unwrap();
1221        mgr.touch_session(&main_id).await.unwrap();
1222        mgr.touch_session(&tg_id).await.unwrap();
1223
1224        // Each channel resolves to its own session
1225        match mgr
1226            .resolve_session(&Channel::Main, "shared-key", None)
1227            .await
1228            .unwrap()
1229        {
1230            SessionDecision::Continue(sid) => assert_eq!(sid, main_id),
1231            SessionDecision::New { .. } => panic!("Main should find its session"),
1232        }
1233        match mgr
1234            .resolve_session(&Channel::Telegram, "shared-key", None)
1235            .await
1236            .unwrap()
1237        {
1238            SessionDecision::Continue(sid) => assert_eq!(sid, tg_id),
1239            SessionDecision::New { .. } => panic!("Telegram should find its session"),
1240        }
1241    }
1242
1243    #[tokio::test]
1244    async fn test_auto_close_returns_closed_session_id() {
1245        let (_tmp, mgr) = setup().await;
1246        let gap = Some(60); // 1h
1247
1248        // Create and backdate a Telegram session
1249        let id = mgr
1250            .create_session(&Channel::Telegram, "export-test")
1251            .await
1252            .unwrap();
1253        mgr.touch_session(&id).await.unwrap();
1254        mgr.save_message(&id, "user", "Hello!").await.unwrap();
1255        mgr.save_message(&id, "assistant", "Hi there!")
1256            .await
1257            .unwrap();
1258
1259        let two_hours_ago = (Utc::now() - Duration::hours(2)).to_rfc3339();
1260        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1261            .bind(&two_hours_ago)
1262            .bind(&id)
1263            .execute(&mgr.pool)
1264            .await
1265            .unwrap();
1266
1267        // Resolve should return New with the closed session's ID
1268        match mgr
1269            .resolve_session(&Channel::Telegram, "export-test", gap)
1270            .await
1271            .unwrap()
1272        {
1273            SessionDecision::New { closed_session_id } => {
1274                assert_eq!(
1275                    closed_session_id,
1276                    Some(id.clone()),
1277                    "Should return the closed session ID"
1278                );
1279            }
1280            SessionDecision::Continue(_) => panic!("Should start new session after 2h > 1h gap"),
1281        }
1282
1283        // First resolve with no prior session → New without closed ID
1284        match mgr
1285            .resolve_session(&Channel::Main, "fresh-key", None)
1286            .await
1287            .unwrap()
1288        {
1289            SessionDecision::New { closed_session_id } => {
1290                assert!(
1291                    closed_session_id.is_none(),
1292                    "No prior session means no closed ID"
1293                );
1294            }
1295            SessionDecision::Continue(_) => panic!("Should be new"),
1296        }
1297    }
1298
1299    #[tokio::test]
1300    async fn test_auto_close_closed_id_is_correct_session() {
1301        let (_tmp, mgr) = setup().await;
1302        let gap = Some(60); // 1h
1303
1304        // Create two Telegram sessions for different keys
1305        let id_a = mgr
1306            .create_session(&Channel::Telegram, "chat-a")
1307            .await
1308            .unwrap();
1309        mgr.touch_session(&id_a).await.unwrap();
1310        mgr.save_message(&id_a, "user", "Message in chat A")
1311            .await
1312            .unwrap();
1313        mgr.save_message(&id_a, "assistant", "Reply in chat A")
1314            .await
1315            .unwrap();
1316
1317        let id_b = mgr
1318            .create_session(&Channel::Telegram, "chat-b")
1319            .await
1320            .unwrap();
1321        mgr.touch_session(&id_b).await.unwrap();
1322        mgr.save_message(&id_b, "user", "Message in chat B")
1323            .await
1324            .unwrap();
1325
1326        // Backdate only chat-a beyond the gap
1327        let old_time = (Utc::now() - Duration::hours(2)).to_rfc3339();
1328        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1329            .bind(&old_time)
1330            .bind(&id_a)
1331            .execute(&mgr.pool)
1332            .await
1333            .unwrap();
1334
1335        // Resolve chat-a → should auto-close and return its ID
1336        match mgr
1337            .resolve_session(&Channel::Telegram, "chat-a", gap)
1338            .await
1339            .unwrap()
1340        {
1341            SessionDecision::New { closed_session_id } => {
1342                assert_eq!(
1343                    closed_session_id,
1344                    Some(id_a.clone()),
1345                    "closed_session_id must match the session that was auto-closed"
1346                );
1347            }
1348            SessionDecision::Continue(_) => panic!("Should start new session after gap"),
1349        }
1350
1351        // Verify the closed session's messages are still accessible
1352        let messages = mgr.get_messages(&id_a).await.unwrap();
1353        assert_eq!(messages.len(), 2);
1354        assert_eq!(messages[0].content, "Message in chat A");
1355        assert_eq!(messages[1].content, "Reply in chat A");
1356
1357        // Verify chat-b is unaffected (still open, still continuable)
1358        match mgr
1359            .resolve_session(&Channel::Telegram, "chat-b", gap)
1360            .await
1361            .unwrap()
1362        {
1363            SessionDecision::Continue(sid) => assert_eq!(sid, id_b),
1364            SessionDecision::New { .. } => panic!("chat-b should still be continuable"),
1365        }
1366    }
1367
1368    #[tokio::test]
1369    async fn test_no_closed_id_for_main_channel() {
1370        let (_tmp, mgr) = setup().await;
1371
1372        // Create a Main session and backdate it far in the past
1373        let id = mgr
1374            .create_session(&Channel::Main, "main-key")
1375            .await
1376            .unwrap();
1377        mgr.touch_session(&id).await.unwrap();
1378
1379        let old_time = (Utc::now() - Duration::hours(48)).to_rfc3339();
1380        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1381            .bind(&old_time)
1382            .bind(&id)
1383            .execute(&mgr.pool)
1384            .await
1385            .unwrap();
1386
1387        // Main channel uses gap_minutes=None → never auto-closes
1388        match mgr
1389            .resolve_session(&Channel::Main, "main-key", None)
1390            .await
1391            .unwrap()
1392        {
1393            SessionDecision::Continue(sid) => assert_eq!(sid, id),
1394            SessionDecision::New { .. } => panic!("Main channel should never auto-close"),
1395        }
1396
1397        // Even with a fresh key (no session), New should have closed_session_id=None
1398        match mgr
1399            .resolve_session(&Channel::Main, "new-main-key", None)
1400            .await
1401            .unwrap()
1402        {
1403            SessionDecision::New { closed_session_id } => {
1404                assert!(
1405                    closed_session_id.is_none(),
1406                    "Main channel should never produce a closed_session_id"
1407                );
1408            }
1409            SessionDecision::Continue(_) => panic!("No session for this key, should be New"),
1410        }
1411    }
1412
1413    #[tokio::test]
1414    async fn test_no_closed_id_when_session_manually_closed() {
1415        let (_tmp, mgr) = setup().await;
1416        let gap = Some(60); // 1h
1417
1418        // Create a Telegram session and manually close it
1419        let id = mgr
1420            .create_session(&Channel::Telegram, "manual-close")
1421            .await
1422            .unwrap();
1423        mgr.touch_session(&id).await.unwrap();
1424        mgr.save_message(&id, "user", "Hello").await.unwrap();
1425        mgr.close_session(&id, "Manually closed by user")
1426            .await
1427            .unwrap();
1428
1429        // Resolve should return New with closed_session_id=None because
1430        // there's no open session to auto-close
1431        match mgr
1432            .resolve_session(&Channel::Telegram, "manual-close", gap)
1433            .await
1434            .unwrap()
1435        {
1436            SessionDecision::New { closed_session_id } => {
1437                assert!(
1438                    closed_session_id.is_none(),
1439                    "Manually closed session should not produce closed_session_id on resolve"
1440                );
1441            }
1442            SessionDecision::Continue(_) => panic!("Closed session should not be continued"),
1443        }
1444    }
1445
1446    #[tokio::test]
1447    async fn test_cost_overview_empty() {
1448        let (_tmp, mgr) = setup().await;
1449
1450        let overview = mgr.cost_overview(None).await.unwrap();
1451        assert_eq!(overview.total_cost_usd, 0.0);
1452        assert_eq!(overview.total_input_tokens, 0);
1453        assert_eq!(overview.total_output_tokens, 0);
1454        assert_eq!(overview.total_turns, 0);
1455        assert!(overview.by_user.is_empty());
1456        assert!(overview.by_model.is_empty());
1457    }
1458
1459    #[tokio::test]
1460    async fn test_cost_overview_by_user() {
1461        let (_tmp, mgr) = setup().await;
1462        let sid = mgr
1463            .create_session(&Channel::Main, "cost-test")
1464            .await
1465            .unwrap();
1466
1467        // Record usage for two different users
1468        mgr.record_usage(
1469            &sid,
1470            &UsageRecord {
1471                input_tokens: 1000,
1472                output_tokens: 500,
1473                cache_read: 0,
1474                cache_write: 0,
1475                cost_usd: 0.05,
1476                model: "claude-sonnet".into(),
1477                user_id: "alice".into(),
1478            },
1479            1,
1480        )
1481        .await
1482        .unwrap();
1483
1484        mgr.record_usage(
1485            &sid,
1486            &UsageRecord {
1487                input_tokens: 2000,
1488                output_tokens: 800,
1489                cache_read: 0,
1490                cache_write: 0,
1491                cost_usd: 0.10,
1492                model: "claude-sonnet".into(),
1493                user_id: "bob".into(),
1494            },
1495            2,
1496        )
1497        .await
1498        .unwrap();
1499
1500        mgr.record_usage(
1501            &sid,
1502            &UsageRecord {
1503                input_tokens: 500,
1504                output_tokens: 200,
1505                cache_read: 0,
1506                cache_write: 0,
1507                cost_usd: 0.02,
1508                model: "claude-haiku".into(),
1509                user_id: "alice".into(),
1510            },
1511            3,
1512        )
1513        .await
1514        .unwrap();
1515
1516        let overview = mgr.cost_overview(None).await.unwrap();
1517
1518        // Totals
1519        assert_eq!(overview.total_turns, 3);
1520        assert!((overview.total_cost_usd - 0.17).abs() < 0.001);
1521        assert_eq!(overview.total_input_tokens, 3500);
1522        assert_eq!(overview.total_output_tokens, 1500);
1523
1524        // By user (sorted by cost desc)
1525        assert_eq!(overview.by_user.len(), 2);
1526        assert_eq!(overview.by_user[0].user_id, "bob");
1527        assert!((overview.by_user[0].total_cost_usd - 0.10).abs() < 0.001);
1528        assert_eq!(overview.by_user[0].total_turns, 1);
1529        assert_eq!(overview.by_user[1].user_id, "alice");
1530        assert!((overview.by_user[1].total_cost_usd - 0.07).abs() < 0.001);
1531        assert_eq!(overview.by_user[1].total_turns, 2);
1532
1533        // By model (sorted by cost desc)
1534        assert_eq!(overview.by_model.len(), 2);
1535        assert_eq!(overview.by_model[0].model, "claude-sonnet");
1536        assert!((overview.by_model[0].total_cost_usd - 0.15).abs() < 0.001);
1537        assert_eq!(overview.by_model[1].model, "claude-haiku");
1538        assert!((overview.by_model[1].total_cost_usd - 0.02).abs() < 0.001);
1539    }
1540
1541    #[tokio::test]
1542    async fn test_cost_overview_since_filter() {
1543        let (_tmp, mgr) = setup().await;
1544        let sid = mgr
1545            .create_session(&Channel::Main, "cost-filter")
1546            .await
1547            .unwrap();
1548
1549        // Record usage now
1550        mgr.record_usage(
1551            &sid,
1552            &UsageRecord {
1553                input_tokens: 1000,
1554                output_tokens: 500,
1555                cache_read: 0,
1556                cache_write: 0,
1557                cost_usd: 0.05,
1558                model: "claude-sonnet".into(),
1559                user_id: "admin".into(),
1560            },
1561            1,
1562        )
1563        .await
1564        .unwrap();
1565
1566        // "Since" far in the future should return nothing
1567        let future = (Utc::now() + Duration::hours(1)).to_rfc3339();
1568        let overview = mgr.cost_overview(Some(&future)).await.unwrap();
1569        assert_eq!(overview.total_turns, 0);
1570        assert_eq!(overview.total_cost_usd, 0.0);
1571
1572        // "Since" far in the past should return everything
1573        let past = (Utc::now() - Duration::days(365)).to_rfc3339();
1574        let overview = mgr.cost_overview(Some(&past)).await.unwrap();
1575        assert_eq!(overview.total_turns, 1);
1576        assert!((overview.total_cost_usd - 0.05).abs() < 0.001);
1577    }
1578
1579    #[tokio::test]
1580    async fn test_cost_overview_user_id_recorded() {
1581        let (_tmp, mgr) = setup().await;
1582        let sid = mgr
1583            .create_session(&Channel::Main, "uid-test")
1584            .await
1585            .unwrap();
1586
1587        mgr.record_usage(
1588            &sid,
1589            &UsageRecord {
1590                input_tokens: 100,
1591                output_tokens: 50,
1592                cache_read: 0,
1593                cache_write: 0,
1594                cost_usd: 0.01,
1595                model: "m".into(),
1596                user_id: "user-42".into(),
1597            },
1598            1,
1599        )
1600        .await
1601        .unwrap();
1602
1603        let overview = mgr.cost_overview(None).await.unwrap();
1604        assert_eq!(overview.by_user.len(), 1);
1605        assert_eq!(overview.by_user[0].user_id, "user-42");
1606        assert_eq!(overview.by_user[0].total_input_tokens, 100);
1607        assert_eq!(overview.by_user[0].total_output_tokens, 50);
1608    }
1609
1610    #[tokio::test]
1611    async fn test_cost_overview_cache_breakdown() {
1612        let (_tmp, mgr) = setup().await;
1613        let sid = mgr
1614            .create_session(&Channel::Main, "cache-cost")
1615            .await
1616            .unwrap();
1617
1618        // Alice: cache miss (writes to cache)
1619        mgr.record_usage(
1620            &sid,
1621            &UsageRecord {
1622                input_tokens: 200,
1623                output_tokens: 100,
1624                cache_read: 0,
1625                cache_write: 3000,
1626                cost_usd: 0.04,
1627                model: "claude-sonnet".into(),
1628                user_id: "alice".into(),
1629            },
1630            1,
1631        )
1632        .await
1633        .unwrap();
1634
1635        // Alice: cache hit (reads from cache)
1636        mgr.record_usage(
1637            &sid,
1638            &UsageRecord {
1639                input_tokens: 50,
1640                output_tokens: 150,
1641                cache_read: 3000,
1642                cache_write: 0,
1643                cost_usd: 0.01,
1644                model: "claude-sonnet".into(),
1645                user_id: "alice".into(),
1646            },
1647            2,
1648        )
1649        .await
1650        .unwrap();
1651
1652        // Bob: no caching
1653        mgr.record_usage(
1654            &sid,
1655            &UsageRecord {
1656                input_tokens: 800,
1657                output_tokens: 400,
1658                cache_read: 0,
1659                cache_write: 0,
1660                cost_usd: 0.03,
1661                model: "claude-haiku".into(),
1662                user_id: "bob".into(),
1663            },
1664            3,
1665        )
1666        .await
1667        .unwrap();
1668
1669        let overview = mgr.cost_overview(None).await.unwrap();
1670
1671        // Totals: input = (200+0+3000) + (50+3000+0) + (800+0+0) = 7050
1672        assert_eq!(overview.total_input_tokens, 7050);
1673        assert_eq!(overview.total_output_tokens, 650);
1674        assert_eq!(overview.total_cache_read, 3000);
1675        assert_eq!(overview.total_cache_write, 3000);
1676
1677        // By user: alice first (higher cost)
1678        assert_eq!(overview.by_user.len(), 2);
1679        let alice = overview
1680            .by_user
1681            .iter()
1682            .find(|u| u.user_id == "alice")
1683            .unwrap();
1684        assert_eq!(alice.total_input_tokens, 6250); // (200+3000) + (50+3000)
1685        assert_eq!(alice.total_cache_read, 3000);
1686        assert_eq!(alice.total_cache_write, 3000);
1687
1688        let bob = overview
1689            .by_user
1690            .iter()
1691            .find(|u| u.user_id == "bob")
1692            .unwrap();
1693        assert_eq!(bob.total_input_tokens, 800);
1694        assert_eq!(bob.total_cache_read, 0);
1695        assert_eq!(bob.total_cache_write, 0);
1696
1697        // By model
1698        let sonnet = overview
1699            .by_model
1700            .iter()
1701            .find(|m| m.model == "claude-sonnet")
1702            .unwrap();
1703        assert_eq!(sonnet.total_cache_read, 3000);
1704        assert_eq!(sonnet.total_cache_write, 3000);
1705
1706        let haiku = overview
1707            .by_model
1708            .iter()
1709            .find(|m| m.model == "claude-haiku")
1710            .unwrap();
1711        assert_eq!(haiku.total_cache_read, 0);
1712        assert_eq!(haiku.total_cache_write, 0);
1713    }
1714
1715    // ── Read/unread state tests ────────────────────────────────────────
1716
1717    #[tokio::test]
1718    async fn test_new_session_is_read_by_default() {
1719        let (_tmp, mgr) = setup().await;
1720        let id = mgr.create_session(&Channel::Main, "key").await.unwrap();
1721
1722        let session = mgr.get_session(&id).await.unwrap().unwrap();
1723        assert!(
1724            session.is_read,
1725            "New sessions should default to is_read=true"
1726        );
1727    }
1728
1729    #[tokio::test]
1730    async fn test_mark_read_false() {
1731        let (_tmp, mgr) = setup().await;
1732        let id = mgr.create_session(&Channel::Main, "key").await.unwrap();
1733
1734        mgr.mark_read(&id, false).await.unwrap();
1735
1736        let session = mgr.get_session(&id).await.unwrap().unwrap();
1737        assert!(
1738            !session.is_read,
1739            "Session should be unread after mark_read(false)"
1740        );
1741    }
1742
1743    #[tokio::test]
1744    async fn test_mark_read_true() {
1745        let (_tmp, mgr) = setup().await;
1746        let id = mgr.create_session(&Channel::Main, "key").await.unwrap();
1747
1748        // Mark unread, then mark read again
1749        mgr.mark_read(&id, false).await.unwrap();
1750        mgr.mark_read(&id, true).await.unwrap();
1751
1752        let session = mgr.get_session(&id).await.unwrap().unwrap();
1753        assert!(
1754            session.is_read,
1755            "Session should be read after mark_read(true)"
1756        );
1757    }
1758
1759    #[tokio::test]
1760    async fn test_list_sessions_includes_is_read() {
1761        let (_tmp, mgr) = setup().await;
1762        let id1 = mgr.create_session(&Channel::Main, "key1").await.unwrap();
1763        let id2 = mgr.create_session(&Channel::Main, "key2").await.unwrap();
1764
1765        mgr.mark_read(&id1, false).await.unwrap();
1766
1767        let sessions = mgr.list_sessions(10).await.unwrap();
1768        let s1 = sessions.iter().find(|s| s.id == id1).unwrap();
1769        let s2 = sessions.iter().find(|s| s.id == id2).unwrap();
1770
1771        assert!(!s1.is_read, "Session 1 should be unread");
1772        assert!(s2.is_read, "Session 2 should still be read");
1773    }
1774
1775    #[tokio::test]
1776    async fn test_mark_read_nonexistent_session_succeeds() {
1777        let (_tmp, mgr) = setup().await;
1778        // Should not error — just a no-op UPDATE matching zero rows
1779        mgr.mark_read("nonexistent-id", true).await.unwrap();
1780    }
1781
1782    // --- Email channel tests ---
1783
1784    #[test]
1785    fn test_email_channel_as_str() {
1786        assert_eq!(Channel::Email.as_str(), "email");
1787    }
1788
1789    #[test]
1790    fn test_email_channel_from_str() {
1791        assert_eq!(Channel::from_channel_str("email"), Channel::Email);
1792    }
1793
1794    #[test]
1795    fn test_unknown_channel_defaults_to_main() {
1796        assert_eq!(Channel::from_channel_str("unknown"), Channel::Main);
1797        assert_eq!(Channel::from_channel_str(""), Channel::Main);
1798    }
1799
1800    #[tokio::test]
1801    async fn test_create_email_session() {
1802        let (_tmp, mgr) = setup().await;
1803        let id = mgr
1804            .create_session(&Channel::Email, "user@example.com")
1805            .await
1806            .unwrap();
1807
1808        let session = mgr.get_session(&id).await.unwrap().unwrap();
1809        assert_eq!(session.channel, "email");
1810        assert_eq!(
1811            session.channel_session_key.as_deref(),
1812            Some("user@example.com")
1813        );
1814    }
1815
1816    #[tokio::test]
1817    async fn test_resolve_email_session_continues_for_same_sender() {
1818        let (_tmp, mgr) = setup().await;
1819        let id = mgr
1820            .create_session(&Channel::Email, "sender@test.com")
1821            .await
1822            .unwrap();
1823        mgr.touch_session(&id).await.unwrap();
1824
1825        match mgr
1826            .resolve_session(&Channel::Email, "sender@test.com", None)
1827            .await
1828            .unwrap()
1829        {
1830            SessionDecision::Continue(sid) => assert_eq!(sid, id),
1831            SessionDecision::New { .. } => panic!("Should continue recent email session"),
1832        }
1833    }
1834
1835    #[tokio::test]
1836    async fn test_resolve_email_session_new_for_different_sender() {
1837        let (_tmp, mgr) = setup().await;
1838        let id = mgr
1839            .create_session(&Channel::Email, "sender-a@test.com")
1840            .await
1841            .unwrap();
1842        mgr.touch_session(&id).await.unwrap();
1843
1844        match mgr
1845            .resolve_session(&Channel::Email, "sender-b@test.com", None)
1846            .await
1847            .unwrap()
1848        {
1849            SessionDecision::New { .. } => {} // expected — different sender
1850            SessionDecision::Continue(_) => {
1851                panic!("Should not continue session for different sender")
1852            }
1853        }
1854    }
1855
1856    #[tokio::test]
1857    async fn test_email_and_telegram_sessions_are_separate() {
1858        let (_tmp, mgr) = setup().await;
1859        let email_id = mgr
1860            .create_session(&Channel::Email, "user@test.com")
1861            .await
1862            .unwrap();
1863        let tg_id = mgr
1864            .create_session(&Channel::Telegram, "user@test.com")
1865            .await
1866            .unwrap();
1867
1868        assert_ne!(email_id, tg_id);
1869
1870        // Each channel resolves independently
1871        mgr.touch_session(&email_id).await.unwrap();
1872        mgr.touch_session(&tg_id).await.unwrap();
1873
1874        match mgr
1875            .resolve_session(&Channel::Email, "user@test.com", None)
1876            .await
1877            .unwrap()
1878        {
1879            SessionDecision::Continue(sid) => assert_eq!(sid, email_id),
1880            SessionDecision::New { .. } => panic!("Should continue email session"),
1881        }
1882        match mgr
1883            .resolve_session(&Channel::Telegram, "user@test.com", None)
1884            .await
1885            .unwrap()
1886        {
1887            SessionDecision::Continue(sid) => assert_eq!(sid, tg_id),
1888            SessionDecision::New { .. } => panic!("Should continue telegram session"),
1889        }
1890    }
1891
1892    // ── Tool usage stats tests ────────────────────────────────────────
1893
1894    #[tokio::test]
1895    async fn test_cost_overview_by_tool_empty() {
1896        let (_tmp, mgr) = setup().await;
1897
1898        let overview = mgr.cost_overview(None).await.unwrap();
1899        assert!(
1900            overview.by_tool.is_empty(),
1901            "No tool messages → empty by_tool"
1902        );
1903    }
1904
1905    #[tokio::test]
1906    async fn test_cost_overview_by_tool_counts() {
1907        let (_tmp, mgr) = setup().await;
1908        let sid = mgr
1909            .create_session(&Channel::Main, "tool-test")
1910            .await
1911            .unwrap();
1912
1913        // Simulate 3 MemorySearch invocations (all successful)
1914        for i in 0..3 {
1915            let tool_use = serde_json::json!({
1916                "type": "tool_use",
1917                "id": format!("tu_mem_{i}"),
1918                "name": "MemorySearch",
1919                "input": {"query": "test"}
1920            });
1921            mgr.save_message(&sid, "tool_use", &tool_use.to_string())
1922                .await
1923                .unwrap();
1924
1925            let tool_result = serde_json::json!({
1926                "type": "tool_result",
1927                "tool_use_id": format!("tu_mem_{i}"),
1928                "content": "some result",
1929                "is_error": false
1930            });
1931            mgr.save_message(&sid, "tool_result", &tool_result.to_string())
1932                .await
1933                .unwrap();
1934        }
1935
1936        // Simulate 2 VaultGet invocations: 1 success, 1 error
1937        let tool_use = serde_json::json!({
1938            "type": "tool_use", "id": "tu_vault_0", "name": "VaultGet",
1939            "input": {"key": "api_key"}
1940        });
1941        mgr.save_message(&sid, "tool_use", &tool_use.to_string())
1942            .await
1943            .unwrap();
1944        let tool_result = serde_json::json!({
1945            "type": "tool_result", "tool_use_id": "tu_vault_0",
1946            "content": "secret-value", "is_error": false
1947        });
1948        mgr.save_message(&sid, "tool_result", &tool_result.to_string())
1949            .await
1950            .unwrap();
1951
1952        let tool_use = serde_json::json!({
1953            "type": "tool_use", "id": "tu_vault_1", "name": "VaultGet",
1954            "input": {"key": "missing"}
1955        });
1956        mgr.save_message(&sid, "tool_use", &tool_use.to_string())
1957            .await
1958            .unwrap();
1959        let tool_result = serde_json::json!({
1960            "type": "tool_result", "tool_use_id": "tu_vault_1",
1961            "content": "key not found", "is_error": true
1962        });
1963        mgr.save_message(&sid, "tool_result", &tool_result.to_string())
1964            .await
1965            .unwrap();
1966
1967        let overview = mgr.cost_overview(None).await.unwrap();
1968
1969        // Sorted by invocations DESC: MemorySearch(3), VaultGet(2)
1970        assert_eq!(overview.by_tool.len(), 2);
1971        assert_eq!(overview.by_tool[0].tool_name, "MemorySearch");
1972        assert_eq!(overview.by_tool[0].invocations, 3);
1973        assert_eq!(overview.by_tool[0].errors, 0);
1974        assert_eq!(overview.by_tool[1].tool_name, "VaultGet");
1975        assert_eq!(overview.by_tool[1].invocations, 2);
1976        assert_eq!(overview.by_tool[1].errors, 1);
1977    }
1978
1979    #[tokio::test]
1980    async fn test_cost_overview_by_tool_since_filter() {
1981        let (_tmp, mgr) = setup().await;
1982        let sid = mgr
1983            .create_session(&Channel::Main, "tool-filter")
1984            .await
1985            .unwrap();
1986
1987        // Save a tool_use message now
1988        let tool_use = serde_json::json!({
1989            "type": "tool_use", "id": "tu_1", "name": "CronList", "input": {}
1990        });
1991        mgr.save_message(&sid, "tool_use", &tool_use.to_string())
1992            .await
1993            .unwrap();
1994
1995        // "Since" far in the future should exclude it
1996        let future = (Utc::now() + Duration::hours(1)).to_rfc3339();
1997        let overview = mgr.cost_overview(Some(&future)).await.unwrap();
1998        assert!(overview.by_tool.is_empty());
1999
2000        // "Since" far in the past should include it
2001        let past = (Utc::now() - Duration::days(365)).to_rfc3339();
2002        let overview = mgr.cost_overview(Some(&past)).await.unwrap();
2003        assert_eq!(overview.by_tool.len(), 1);
2004        assert_eq!(overview.by_tool[0].tool_name, "CronList");
2005        assert_eq!(overview.by_tool[0].invocations, 1);
2006    }
2007
2008    #[tokio::test]
2009    async fn test_cost_overview_by_tool_without_result() {
2010        let (_tmp, mgr) = setup().await;
2011        let sid = mgr
2012            .create_session(&Channel::Main, "tool-no-result")
2013            .await
2014            .unwrap();
2015
2016        // tool_use without a matching tool_result (e.g. stream interrupted)
2017        let tool_use = serde_json::json!({
2018            "type": "tool_use", "id": "tu_orphan", "name": "SkillList", "input": {}
2019        });
2020        mgr.save_message(&sid, "tool_use", &tool_use.to_string())
2021            .await
2022            .unwrap();
2023
2024        let overview = mgr.cost_overview(None).await.unwrap();
2025        assert_eq!(overview.by_tool.len(), 1);
2026        assert_eq!(overview.by_tool[0].tool_name, "SkillList");
2027        assert_eq!(overview.by_tool[0].invocations, 1);
2028        assert_eq!(
2029            overview.by_tool[0].errors, 0,
2030            "No result means no error, not an error"
2031        );
2032    }
2033}