systemprompt_agent/services/
context_provider.rs1use async_trait::async_trait;
2use systemprompt_database::DbPool;
3use systemprompt_identifiers::{ContextId, SessionId, UserId};
4use systemprompt_traits::{ContextProvider, ContextProviderError, ContextWithStats};
5
6use crate::repository::ContextRepository;
7
8#[derive(Debug, Clone)]
9pub struct ContextProviderService {
10 repo: ContextRepository,
11}
12
13impl ContextProviderService {
14 pub fn new(db_pool: &DbPool) -> anyhow::Result<Self> {
15 Ok(Self {
16 repo: ContextRepository::new(db_pool)?,
17 })
18 }
19}
20
21#[async_trait]
22impl ContextProvider for ContextProviderService {
23 async fn list_contexts_with_stats(
24 &self,
25 user_id: &UserId,
26 ) -> Result<Vec<ContextWithStats>, ContextProviderError> {
27 let contexts = self
28 .repo
29 .list_contexts_with_stats(user_id)
30 .await
31 .map_err(|e| ContextProviderError::Database(e.to_string()))?;
32
33 Ok(contexts
34 .into_iter()
35 .map(|c| ContextWithStats {
36 context_id: c.context_id,
37 user_id: c.user_id,
38 name: c.name,
39 created_at: c.created_at,
40 updated_at: c.updated_at,
41 task_count: c.task_count,
42 message_count: c.message_count,
43 last_message_at: c.last_message_at,
44 })
45 .collect())
46 }
47
48 async fn get_context(
49 &self,
50 context_id: &ContextId,
51 user_id: &UserId,
52 ) -> Result<ContextWithStats, ContextProviderError> {
53 let context = self
54 .repo
55 .get_context(context_id, user_id)
56 .await
57 .map_err(|e| match e {
58 systemprompt_traits::RepositoryError::NotFound(msg) => {
59 ContextProviderError::NotFound(msg)
60 },
61 other => ContextProviderError::Database(other.to_string()),
62 })?;
63
64 let all_contexts = self
65 .repo
66 .list_contexts_with_stats(user_id)
67 .await
68 .map_err(|e| ContextProviderError::Database(e.to_string()))?;
69
70 let context_with_stats = all_contexts
71 .into_iter()
72 .find(|c| c.context_id == context.context_id)
73 .ok_or_else(|| {
74 ContextProviderError::NotFound(format!("Context {} not found", context_id))
75 })?;
76
77 Ok(ContextWithStats {
78 context_id: context_with_stats.context_id,
79 user_id: context_with_stats.user_id,
80 name: context_with_stats.name,
81 created_at: context_with_stats.created_at,
82 updated_at: context_with_stats.updated_at,
83 task_count: context_with_stats.task_count,
84 message_count: context_with_stats.message_count,
85 last_message_at: context_with_stats.last_message_at,
86 })
87 }
88
89 async fn create_context(
90 &self,
91 user_id: &UserId,
92 session_id: Option<&SessionId>,
93 name: &str,
94 ) -> Result<ContextId, ContextProviderError> {
95 self.repo
96 .create_context(user_id, session_id, name)
97 .await
98 .map_err(|e| ContextProviderError::Database(e.to_string()))
99 }
100
101 async fn update_context_name(
102 &self,
103 context_id: &ContextId,
104 user_id: &UserId,
105 name: &str,
106 ) -> Result<(), ContextProviderError> {
107 self.repo
108 .update_context_name(context_id, user_id, name)
109 .await
110 .map_err(|e| match e {
111 systemprompt_traits::RepositoryError::NotFound(msg) => {
112 ContextProviderError::NotFound(msg)
113 },
114 other => ContextProviderError::Database(other.to_string()),
115 })
116 }
117
118 async fn delete_context(
119 &self,
120 context_id: &ContextId,
121 user_id: &UserId,
122 ) -> Result<(), ContextProviderError> {
123 self.repo
124 .delete_context(context_id, user_id)
125 .await
126 .map_err(|e| match e {
127 systemprompt_traits::RepositoryError::NotFound(msg) => {
128 ContextProviderError::NotFound(msg)
129 },
130 other => ContextProviderError::Database(other.to_string()),
131 })
132 }
133}