1use crate::RuntimeError;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use tetcore_primitives::{Address, Hash32};
5
6#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
7pub enum PromptState {
8 Submitted,
9 Locked,
10 Executing,
11 Completed,
12 Failed,
13}
14
15#[derive(Clone, Debug, Serialize, Deserialize)]
16pub struct Prompt {
17 pub prompt_id: Hash32,
18 pub model_id: Hash32,
19 pub version: u32,
20 pub sender: Address,
21 pub prompt_data: Vec<u8>,
22 pub state: PromptState,
23 pub escrow_amount: u128,
24 pub fee: u64,
25 pub submitted_at: u64,
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
29pub struct Receipt {
30 pub receipt_id: Hash32,
31 pub prompt_id: Hash32,
32 pub operator: Address,
33 pub inference_output: Vec<u8>,
34 pub execution_proof: Vec<u8>,
35 pub submitted_at: u64,
36 pub verified: bool,
37}
38
39#[derive(Clone, Debug, Serialize, Deserialize)]
40pub struct InferenceSettlement {
41 pub receipt_id: Hash32,
42 pub prompt_id: Hash32,
43 pub model_id: Hash32,
44 pub total_fee: u128,
45 pub model_owner_amount: u128,
46 pub operator_amount: u128,
47 pub shard_provider_amount: u128,
48 pub validator_amount: u128,
49 pub vault_amount: u128,
50 pub settled: bool,
51}
52
53pub struct InferenceModule {
54 prompts: HashMap<Hash32, Prompt>,
55 receipts: HashMap<Hash32, Receipt>,
56 settlements: HashMap<Hash32, InferenceSettlement>,
57 prompt_counter: u64,
58 receipt_counter: u64,
59}
60
61impl InferenceModule {
62 pub fn new() -> Self {
63 Self {
64 prompts: HashMap::new(),
65 receipts: HashMap::new(),
66 settlements: HashMap::new(),
67 prompt_counter: 0,
68 receipt_counter: 0,
69 }
70 }
71
72 pub fn submit_prompt(
73 &mut self,
74 model_id: Hash32,
75 version: u32,
76 sender: Address,
77 prompt_data: Vec<u8>,
78 fee: u64,
79 current_height: u64,
80 ) -> Result<Hash32, RuntimeError> {
81 self.prompt_counter += 1;
82
83 let mut data = Vec::new();
84 data.extend_from_slice(&self.prompt_counter.to_le_bytes());
85 data.extend_from_slice(model_id.as_bytes());
86 data.extend_from_slice(&prompt_data);
87
88 use sha2::{Digest, Sha256};
89 let hash = Sha256::digest(&data);
90 let mut id = [0u8; 32];
91 id.copy_from_slice(&hash[..32]);
92
93 let prompt = Prompt {
94 prompt_id: Hash32(id),
95 model_id,
96 version,
97 sender,
98 prompt_data,
99 state: PromptState::Submitted,
100 escrow_amount: 0,
101 fee,
102 submitted_at: current_height,
103 };
104
105 self.prompts.insert(Hash32(id), prompt);
106
107 Ok(Hash32(id))
108 }
109
110 pub fn lock_prompt(
111 &mut self,
112 prompt_id: &Hash32,
113 escrow_amount: u128,
114 ) -> Result<(), RuntimeError> {
115 let prompt = self
116 .prompts
117 .get_mut(prompt_id)
118 .ok_or(RuntimeError::InvalidState)?;
119
120 if prompt.state != PromptState::Submitted {
121 return Err(RuntimeError::InvalidState);
122 }
123
124 prompt.state = PromptState::Locked;
125 prompt.escrow_amount = escrow_amount;
126
127 Ok(())
128 }
129
130 pub fn execute_prompt(&mut self, prompt_id: &Hash32) -> Result<(), RuntimeError> {
131 let prompt = self
132 .prompts
133 .get_mut(prompt_id)
134 .ok_or(RuntimeError::InvalidState)?;
135
136 if prompt.state != PromptState::Locked {
137 return Err(RuntimeError::InvalidState);
138 }
139
140 prompt.state = PromptState::Executing;
141
142 Ok(())
143 }
144
145 pub fn submit_receipt(
146 &mut self,
147 prompt_id: &Hash32,
148 operator: Address,
149 inference_output: Vec<u8>,
150 execution_proof: Vec<u8>,
151 current_height: u64,
152 ) -> Result<Hash32, RuntimeError> {
153 let prompt = self
154 .prompts
155 .get(prompt_id)
156 .ok_or(RuntimeError::InvalidState)?;
157
158 if prompt.state != PromptState::Executing {
159 return Err(RuntimeError::InvalidState);
160 }
161
162 self.receipt_counter += 1;
163
164 let mut data = Vec::new();
165 data.extend_from_slice(&self.receipt_counter.to_le_bytes());
166 data.extend_from_slice(prompt_id.as_bytes());
167 data.extend_from_slice(&inference_output);
168
169 use sha2::{Digest, Sha256};
170 let hash = Sha256::digest(&data);
171 let mut id = [0u8; 32];
172 id.copy_from_slice(&hash[..32]);
173
174 let receipt = Receipt {
175 receipt_id: Hash32(id),
176 prompt_id: *prompt_id,
177 operator,
178 inference_output,
179 execution_proof,
180 submitted_at: current_height,
181 verified: false,
182 };
183
184 self.receipts.insert(Hash32(id), receipt);
185
186 let mut prompt = self.prompts.get_mut(prompt_id).unwrap();
187 prompt.state = PromptState::Completed;
188
189 Ok(Hash32(id))
190 }
191
192 pub fn verify_receipt(&mut self, receipt_id: &Hash32) -> Result<(), RuntimeError> {
193 let receipt = self
194 .receipts
195 .get_mut(receipt_id)
196 .ok_or(RuntimeError::InvalidState)?;
197
198 receipt.verified = true;
199
200 Ok(())
201 }
202
203 pub fn create_settlement(
204 &mut self,
205 receipt_id: &Hash32,
206 model_id: &Hash32,
207 revenue_split: &crate::model_registry::RevenueSplit,
208 ) -> Result<InferenceSettlement, RuntimeError> {
209 let receipt = self
210 .receipts
211 .get(receipt_id)
212 .ok_or(RuntimeError::InvalidState)?;
213 let prompt = self
214 .prompts
215 .get(&receipt.prompt_id)
216 .ok_or(RuntimeError::InvalidState)?;
217
218 let total_fee = prompt.fee as u128;
219
220 let settlement = InferenceSettlement {
221 receipt_id: *receipt_id,
222 prompt_id: receipt.prompt_id,
223 model_id: *model_id,
224 total_fee,
225 model_owner_amount: (total_fee * revenue_split.model_owner_bps as u128) / 10000,
226 operator_amount: (total_fee * revenue_split.operator_bps as u128) / 10000,
227 shard_provider_amount: (total_fee * revenue_split.shard_provider_bps as u128) / 10000,
228 validator_amount: (total_fee * revenue_split.validator_bps as u128) / 10000,
229 vault_amount: (total_fee * revenue_split.vault_bps as u128) / 10000,
230 settled: false,
231 };
232
233 self.settlements.insert(*receipt_id, settlement.clone());
234
235 Ok(settlement)
236 }
237
238 pub fn mark_settled(&mut self, receipt_id: &Hash32) -> Result<(), RuntimeError> {
239 let settlement = self
240 .settlements
241 .get_mut(receipt_id)
242 .ok_or(RuntimeError::InvalidState)?;
243 settlement.settled = true;
244 Ok(())
245 }
246
247 pub fn get_prompt(&self, prompt_id: &Hash32) -> Option<&Prompt> {
248 self.prompts.get(prompt_id)
249 }
250
251 pub fn get_receipt(&self, receipt_id: &Hash32) -> Option<&Receipt> {
252 self.receipts.get(receipt_id)
253 }
254
255 pub fn get_settlement(&self, receipt_id: &Hash32) -> Option<&InferenceSettlement> {
256 self.settlements.get(receipt_id)
257 }
258}
259
260impl Default for InferenceModule {
261 fn default() -> Self {
262 Self::new()
263 }
264}