1use crate::RuntimeError;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use tetcore_primitives::{Address, Hash32};
5
6#[derive(Clone, Debug, Serialize, Deserialize)]
7pub struct RevenueRoute {
8 pub recipient: Address,
9 pub basis_points: u16,
10 pub recipient_type: RecipientType,
11}
12
13#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
14pub enum RecipientType {
15 ModelOwner,
16 Operator,
17 ShardProvider,
18 Validator,
19 Vault,
20 Treasury,
21}
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24pub struct RevenueDistribution {
25 pub distribution_id: Hash32,
26 pub source: Address,
27 pub routes: Vec<RevenueRoute>,
28 pub total_amount: u128,
29 pub distributed: bool,
30 pub block_height: u64,
31}
32
33pub struct RevenueModule {
34 distributions: HashMap<Hash32, RevenueDistribution>,
35 pending_distributions: Vec<RevenueDistribution>,
36 distribution_counter: u64,
37 treasury_address: Address,
38}
39
40impl RevenueModule {
41 pub fn new() -> Self {
42 Self {
43 distributions: HashMap::new(),
44 pending_distributions: Vec::new(),
45 distribution_counter: 0,
46 treasury_address: Address::from_bytes([0u8; 32]),
47 }
48 }
49
50 pub fn set_treasury(&mut self, address: Address) {
51 self.treasury_address = address;
52 }
53
54 pub fn create_distribution(
55 &mut self,
56 source: Address,
57 routes: Vec<RevenueRoute>,
58 total_amount: u128,
59 current_height: u64,
60 ) -> Result<Hash32, RuntimeError> {
61 let total_bps: u16 = routes.iter().map(|r| r.basis_points).sum();
62
63 if total_bps != 10000 {
64 return Err(RuntimeError::InvalidState);
65 }
66
67 self.distribution_counter += 1;
68
69 let mut data = Vec::new();
70 data.extend_from_slice(&self.distribution_counter.to_le_bytes());
71 data.extend_from_slice(source.as_bytes());
72 data.extend_from_slice(&total_amount.to_le_bytes());
73
74 use sha2::{Digest, Sha256};
75 let hash = Sha256::digest(&data);
76 let mut id = [0u8; 32];
77 id.copy_from_slice(&hash[..32]);
78
79 let distribution = RevenueDistribution {
80 distribution_id: Hash32(id),
81 source,
82 routes,
83 total_amount,
84 distributed: false,
85 block_height: current_height,
86 };
87
88 self.distributions.insert(Hash32(id), distribution.clone());
89 self.pending_distributions.push(distribution);
90
91 Ok(Hash32(id))
92 }
93
94 pub fn distribute(
95 &mut self,
96 distribution_id: &Hash32,
97 balances: &mut HashMap<Address, u128>,
98 ) -> Result<Vec<(Address, u128)>, RuntimeError> {
99 let distribution = self
100 .distributions
101 .get_mut(distribution_id)
102 .ok_or(RuntimeError::InvalidState)?;
103
104 if distribution.distributed {
105 return Err(RuntimeError::InvalidState);
106 }
107
108 let mut results = Vec::new();
109
110 for route in &distribution.routes {
111 let amount = (distribution.total_amount * route.basis_points as u128) / 10000;
112
113 let balance = balances.entry(route.recipient).or_insert(0);
114 *balance += amount;
115
116 results.push((route.recipient, amount));
117 }
118
119 distribution.distributed = true;
120
121 Ok(results)
122 }
123
124 pub fn calculate_fee(
125 &self,
126 prompt_tokens: u64,
127 output_tokens: u64,
128 base_fee: u64,
129 token_fee: u64,
130 ) -> u64 {
131 let total_tokens = prompt_tokens + output_tokens;
132 base_fee + (total_tokens * token_fee)
133 }
134
135 pub fn route_revenue(
136 &self,
137 amount: u128,
138 split: &crate::model_registry::RevenueSplit,
139 model_owner: Address,
140 operator: Address,
141 validator: Address,
142 ) -> Vec<RevenueRoute> {
143 vec![
144 RevenueRoute {
145 recipient: model_owner,
146 basis_points: split.model_owner_bps,
147 recipient_type: RecipientType::ModelOwner,
148 },
149 RevenueRoute {
150 recipient: operator,
151 basis_points: split.operator_bps,
152 recipient_type: RecipientType::Operator,
153 },
154 RevenueRoute {
155 recipient: validator,
156 basis_points: split.validator_bps,
157 recipient_type: RecipientType::Validator,
158 },
159 RevenueRoute {
160 recipient: self.treasury_address,
161 basis_points: (split.shard_provider_bps + split.vault_bps) as u16,
162 recipient_type: RecipientType::Treasury,
163 },
164 ]
165 }
166
167 pub fn get_distribution(&self, distribution_id: &Hash32) -> Option<&RevenueDistribution> {
168 self.distributions.get(distribution_id)
169 }
170
171 pub fn pending_count(&self) -> usize {
172 self.pending_distributions.len()
173 }
174
175 pub fn all_distributions(&self) -> &HashMap<Hash32, RevenueDistribution> {
176 &self.distributions
177 }
178}
179
180impl Default for RevenueModule {
181 fn default() -> Self {
182 Self::new()
183 }
184}