1use crate::RuntimeError;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use tetcore_primitives::{Address, Hash32};
5
6#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
7pub enum VaultState {
8 Active,
9 Paused,
10 Closed,
11}
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct Vault {
15 pub vault_id: Hash32,
16 pub model_id: Hash32,
17 pub owner: Address,
18 pub state: VaultState,
19 pub total_staked: u128,
20 pub share_token_supply: u128,
21 pub reward_accumulator: u128,
22 pub created_at: u64,
23}
24
25#[derive(Clone, Debug, Serialize, Deserialize)]
26pub struct VaultShare {
27 pub holder: Address,
28 pub vault_id: Hash32,
29 pub staked_amount: u128,
30 pub share_count: u128,
31 pub last_reward_update: u64,
32}
33
34pub struct VaultModule {
35 vaults: HashMap<Hash32, Vault>,
36 shares: HashMap<Address, Vec<VaultShare>>,
37 vault_counter: u64,
38}
39
40impl VaultModule {
41 pub fn new() -> Self {
42 Self {
43 vaults: HashMap::new(),
44 shares: HashMap::new(),
45 vault_counter: 0,
46 }
47 }
48
49 pub fn create_vault(
50 &mut self,
51 model_id: Hash32,
52 owner: Address,
53 initial_stake: u128,
54 current_height: u64,
55 ) -> Result<Hash32, RuntimeError> {
56 self.vault_counter += 1;
57
58 let mut data = Vec::new();
59 data.extend_from_slice(&self.vault_counter.to_le_bytes());
60 data.extend_from_slice(model_id.as_bytes());
61 data.extend_from_slice(owner.as_bytes());
62
63 use sha2::{Digest, Sha256};
64 let hash = Sha256::digest(&data);
65 let mut id = [0u8; 32];
66 id.copy_from_slice(&hash[..32]);
67
68 let vault = Vault {
69 vault_id: Hash32(id),
70 model_id,
71 owner,
72 state: VaultState::Active,
73 total_staked: initial_stake,
74 share_token_supply: initial_stake,
75 reward_accumulator: 0,
76 created_at: current_height,
77 };
78
79 self.vaults.insert(Hash32(id), vault);
80
81 Ok(Hash32(id))
82 }
83
84 pub fn stake(
85 &mut self,
86 vault_id: &Hash32,
87 staker: Address,
88 amount: u128,
89 ) -> Result<u128, RuntimeError> {
90 let vault = self
91 .vaults
92 .get_mut(vault_id)
93 .ok_or(RuntimeError::InvalidState)?;
94
95 if vault.state != VaultState::Active {
96 return Err(RuntimeError::InvalidState);
97 }
98
99 let share_price = if vault.share_token_supply > 0 {
100 (vault.total_staked as u128 * 1_000_000) / vault.share_token_supply
101 } else {
102 1_000_000
103 };
104
105 let shares_to_mint = (amount * 1_000_000) / share_price;
106
107 vault.total_staked += amount;
108 vault.share_token_supply += shares_to_mint;
109
110 let holder_shares = self.shares.entry(staker).or_insert_with(Vec::new);
111
112 if let Some(existing) = holder_shares.iter_mut().find(|s| s.vault_id == *vault_id) {
113 existing.staked_amount += amount;
114 existing.share_count += shares_to_mint;
115 } else {
116 holder_shares.push(VaultShare {
117 holder: staker,
118 vault_id: *vault_id,
119 staked_amount: amount,
120 share_count: shares_to_mint,
121 last_reward_update: vault.created_at,
122 });
123 }
124
125 Ok(shares_to_mint)
126 }
127
128 pub fn unstake(
129 &mut self,
130 vault_id: &Hash32,
131 staker: &Address,
132 share_count: u128,
133 ) -> Result<u128, RuntimeError> {
134 let vault = self
135 .vaults
136 .get_mut(vault_id)
137 .ok_or(RuntimeError::InvalidState)?;
138
139 let holder_shares = self
140 .shares
141 .get_mut(staker)
142 .ok_or(RuntimeError::InvalidState)?;
143
144 let share_entry = holder_shares
145 .iter_mut()
146 .find(|s| s.vault_id == *vault_id)
147 .ok_or(RuntimeError::InvalidState)?;
148
149 if share_entry.share_count < share_count {
150 return Err(RuntimeError::InvalidState);
151 }
152
153 let share_ratio = (share_count as u128 * 1_000_000) / vault.share_token_supply;
154 let withdraw_amount = (vault.total_staked * share_ratio) / 1_000_000;
155
156 share_entry.share_count -= share_count;
157 share_entry.staked_amount -= withdraw_amount;
158
159 vault.total_staked -= withdraw_amount;
160 vault.share_token_supply -= share_count;
161
162 if share_entry.share_count == 0 {
163 holder_shares.retain(|s| s.vault_id != *vault_id);
164 }
165
166 Ok(withdraw_amount)
167 }
168
169 pub fn distribute_reward(
170 &mut self,
171 vault_id: &Hash32,
172 reward_amount: u128,
173 ) -> Result<(), RuntimeError> {
174 let vault = self
175 .vaults
176 .get_mut(vault_id)
177 .ok_or(RuntimeError::InvalidState)?;
178
179 if vault.state != VaultState::Active {
180 return Err(RuntimeError::InvalidState);
181 }
182
183 vault.reward_accumulator += reward_amount;
184
185 Ok(())
186 }
187
188 pub fn claim_rewards(
189 &mut self,
190 vault_id: &Hash32,
191 staker: &Address,
192 ) -> Result<u128, RuntimeError> {
193 let vault = self
194 .vaults
195 .get(vault_id)
196 .ok_or(RuntimeError::InvalidState)?;
197
198 let holder_shares = self.shares.get(staker).ok_or(RuntimeError::InvalidState)?;
199
200 let share_entry = holder_shares
201 .iter()
202 .find(|s| s.vault_id == *vault_id)
203 .ok_or(RuntimeError::InvalidState)?;
204
205 let reward =
206 (vault.reward_accumulator * share_entry.share_count) / vault.share_token_supply;
207
208 Ok(reward)
209 }
210
211 pub fn pause_vault(&mut self, vault_id: &Hash32) -> Result<(), RuntimeError> {
212 let vault = self
213 .vaults
214 .get_mut(vault_id)
215 .ok_or(RuntimeError::InvalidState)?;
216 vault.state = VaultState::Paused;
217 Ok(())
218 }
219
220 pub fn resume_vault(&mut self, vault_id: &Hash32) -> Result<(), RuntimeError> {
221 let vault = self
222 .vaults
223 .get_mut(vault_id)
224 .ok_or(RuntimeError::InvalidState)?;
225
226 if vault.state != VaultState::Paused {
227 return Err(RuntimeError::InvalidState);
228 }
229
230 vault.state = VaultState::Active;
231
232 Ok(())
233 }
234
235 pub fn get_vault(&self, vault_id: &Hash32) -> Option<&Vault> {
236 self.vaults.get(vault_id)
237 }
238
239 pub fn get_shares(&self, holder: &Address) -> Option<&Vec<VaultShare>> {
240 self.shares.get(holder)
241 }
242
243 pub fn all_vaults(&self) -> &HashMap<Hash32, Vault> {
244 &self.vaults
245 }
246}
247
248impl Default for VaultModule {
249 fn default() -> Self {
250 Self::new()
251 }
252}