1use crate::proxy::{LlmMessage, LlmProxy, LlmRequest, LlmResponse, LlmRole};
9use anyhow::Result;
10use chrono::Utc;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::fs;
14use std::path::{Path, PathBuf};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ConversationScope {
19 pub id: String,
20 pub messages: Vec<LlmMessage>,
21 pub last_updated: chrono::DateTime<Utc>,
22}
23
24pub struct ProxyMemory {
26 storage_path: PathBuf,
27 scopes: HashMap<String, ConversationScope>,
28 in_memory_only: bool,
30}
31
32impl ProxyMemory {
33 pub fn new() -> Result<Self> {
34 let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
35 let storage_path = Path::new(&home).join(".st").join("proxy_memory.json");
36
37 if let Some(parent) = storage_path.parent() {
38 fs::create_dir_all(parent)?;
39 }
40
41 let mut memory = Self {
42 storage_path,
43 scopes: HashMap::new(),
44 in_memory_only: false,
45 };
46
47 memory.load()?;
48 Ok(memory)
49 }
50
51 pub fn in_memory_only() -> Self {
54 Self {
55 storage_path: PathBuf::new(), scopes: HashMap::new(),
57 in_memory_only: true,
58 }
59 }
60
61 pub fn get_scope(&self, scope_id: &str) -> Option<&ConversationScope> {
62 self.scopes.get(scope_id)
63 }
64
65 pub fn update_scope(&mut self, scope_id: &str, messages: Vec<LlmMessage>) -> Result<()> {
66 let scope = self
67 .scopes
68 .entry(scope_id.to_string())
69 .or_insert_with(|| ConversationScope {
70 id: scope_id.to_string(),
71 messages: Vec::new(),
72 last_updated: Utc::now(),
73 });
74
75 scope.messages.extend(messages);
76 scope.last_updated = Utc::now();
77
78 if scope.messages.len() > 20 {
80 scope.messages = scope.messages.split_off(scope.messages.len() - 20);
81 }
82
83 self.save()?;
84 Ok(())
85 }
86
87 pub fn clear_scope(&mut self, scope_id: &str) -> Result<()> {
88 self.scopes.remove(scope_id);
89 self.save()?;
90 Ok(())
91 }
92
93 fn load(&mut self) -> Result<()> {
94 if self.in_memory_only {
96 return Ok(());
97 }
98 if self.storage_path.exists() {
99 let content = fs::read_to_string(&self.storage_path)?;
100 self.scopes = serde_json::from_str(&content).unwrap_or_default();
101 }
102 Ok(())
103 }
104
105 fn save(&self) -> Result<()> {
106 if self.in_memory_only {
108 return Ok(());
109 }
110 let content = serde_json::to_string_pretty(&self.scopes)?;
111 fs::write(&self.storage_path, content)?;
112 Ok(())
113 }
114}
115
116pub struct MemoryProxy {
118 pub inner: LlmProxy,
119 pub memory: ProxyMemory,
120}
121
122impl MemoryProxy {
123 pub fn new() -> Result<Self> {
124 Ok(Self {
125 inner: LlmProxy::default(),
126 memory: ProxyMemory::new()?,
127 })
128 }
129
130 pub async fn with_local_detection() -> Result<Self> {
132 Ok(Self {
133 inner: LlmProxy::with_local_detection().await,
134 memory: ProxyMemory::new()?,
135 })
136 }
137
138 pub async fn complete_with_memory(
139 &mut self,
140 provider_name: &str,
141 scope_id: &str,
142 mut request: LlmRequest,
143 ) -> Result<LlmResponse> {
144 if let Some(scope) = self.memory.get_scope(scope_id) {
146 let mut new_messages = Vec::new();
148
149 if let Some(system_msg) = request
151 .messages
152 .iter()
153 .find(|m| m.role == LlmRole::System)
154 .cloned()
155 {
156 new_messages.push(system_msg);
157 }
158
159 for msg in &scope.messages {
161 if msg.role != LlmRole::System {
162 new_messages.push(msg.clone());
163 }
164 }
165
166 for msg in request.messages {
168 if msg.role != LlmRole::System {
169 new_messages.push(msg);
170 }
171 }
172
173 request.messages = new_messages;
174 }
175
176 let response = self.inner.complete(provider_name, request.clone()).await?;
178
179 let mut new_history = Vec::new();
181 if let Some(last_user_msg) = request
183 .messages
184 .iter()
185 .rev()
186 .find(|m| m.role == LlmRole::User)
187 {
188 new_history.push(last_user_msg.clone());
189 }
190 new_history.push(LlmMessage {
192 role: LlmRole::Assistant,
193 content: response.content.clone(),
194 });
195
196 self.memory.update_scope(scope_id, new_history)?;
197
198 Ok(response)
199 }
200}