Skip to main content

tetcore_runtime/
inference.rs

1use crate::RuntimeError;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use tetcore_primitives::{Address, Hash32};
5
6#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
7pub enum PromptState {
8    Submitted,
9    Locked,
10    Executing,
11    Completed,
12    Failed,
13}
14
15#[derive(Clone, Debug, Serialize, Deserialize)]
16pub struct Prompt {
17    pub prompt_id: Hash32,
18    pub model_id: Hash32,
19    pub version: u32,
20    pub sender: Address,
21    pub prompt_data: Vec<u8>,
22    pub state: PromptState,
23    pub escrow_amount: u128,
24    pub fee: u64,
25    pub submitted_at: u64,
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
29pub struct Receipt {
30    pub receipt_id: Hash32,
31    pub prompt_id: Hash32,
32    pub operator: Address,
33    pub inference_output: Vec<u8>,
34    pub execution_proof: Vec<u8>,
35    pub submitted_at: u64,
36    pub verified: bool,
37}
38
39#[derive(Clone, Debug, Serialize, Deserialize)]
40pub struct InferenceSettlement {
41    pub receipt_id: Hash32,
42    pub prompt_id: Hash32,
43    pub model_id: Hash32,
44    pub total_fee: u128,
45    pub model_owner_amount: u128,
46    pub operator_amount: u128,
47    pub shard_provider_amount: u128,
48    pub validator_amount: u128,
49    pub vault_amount: u128,
50    pub settled: bool,
51}
52
53pub struct InferenceModule {
54    prompts: HashMap<Hash32, Prompt>,
55    receipts: HashMap<Hash32, Receipt>,
56    settlements: HashMap<Hash32, InferenceSettlement>,
57    prompt_counter: u64,
58    receipt_counter: u64,
59}
60
61impl InferenceModule {
62    pub fn new() -> Self {
63        Self {
64            prompts: HashMap::new(),
65            receipts: HashMap::new(),
66            settlements: HashMap::new(),
67            prompt_counter: 0,
68            receipt_counter: 0,
69        }
70    }
71
72    pub fn submit_prompt(
73        &mut self,
74        model_id: Hash32,
75        version: u32,
76        sender: Address,
77        prompt_data: Vec<u8>,
78        fee: u64,
79        current_height: u64,
80    ) -> Result<Hash32, RuntimeError> {
81        self.prompt_counter += 1;
82
83        let mut data = Vec::new();
84        data.extend_from_slice(&self.prompt_counter.to_le_bytes());
85        data.extend_from_slice(model_id.as_bytes());
86        data.extend_from_slice(&prompt_data);
87
88        use sha2::{Digest, Sha256};
89        let hash = Sha256::digest(&data);
90        let mut id = [0u8; 32];
91        id.copy_from_slice(&hash[..32]);
92
93        let prompt = Prompt {
94            prompt_id: Hash32(id),
95            model_id,
96            version,
97            sender,
98            prompt_data,
99            state: PromptState::Submitted,
100            escrow_amount: 0,
101            fee,
102            submitted_at: current_height,
103        };
104
105        self.prompts.insert(Hash32(id), prompt);
106
107        Ok(Hash32(id))
108    }
109
110    pub fn lock_prompt(
111        &mut self,
112        prompt_id: &Hash32,
113        escrow_amount: u128,
114    ) -> Result<(), RuntimeError> {
115        let prompt = self
116            .prompts
117            .get_mut(prompt_id)
118            .ok_or(RuntimeError::InvalidState)?;
119
120        if prompt.state != PromptState::Submitted {
121            return Err(RuntimeError::InvalidState);
122        }
123
124        prompt.state = PromptState::Locked;
125        prompt.escrow_amount = escrow_amount;
126
127        Ok(())
128    }
129
130    pub fn execute_prompt(&mut self, prompt_id: &Hash32) -> Result<(), RuntimeError> {
131        let prompt = self
132            .prompts
133            .get_mut(prompt_id)
134            .ok_or(RuntimeError::InvalidState)?;
135
136        if prompt.state != PromptState::Locked {
137            return Err(RuntimeError::InvalidState);
138        }
139
140        prompt.state = PromptState::Executing;
141
142        Ok(())
143    }
144
145    pub fn submit_receipt(
146        &mut self,
147        prompt_id: &Hash32,
148        operator: Address,
149        inference_output: Vec<u8>,
150        execution_proof: Vec<u8>,
151        current_height: u64,
152    ) -> Result<Hash32, RuntimeError> {
153        let prompt = self
154            .prompts
155            .get(prompt_id)
156            .ok_or(RuntimeError::InvalidState)?;
157
158        if prompt.state != PromptState::Executing {
159            return Err(RuntimeError::InvalidState);
160        }
161
162        self.receipt_counter += 1;
163
164        let mut data = Vec::new();
165        data.extend_from_slice(&self.receipt_counter.to_le_bytes());
166        data.extend_from_slice(prompt_id.as_bytes());
167        data.extend_from_slice(&inference_output);
168
169        use sha2::{Digest, Sha256};
170        let hash = Sha256::digest(&data);
171        let mut id = [0u8; 32];
172        id.copy_from_slice(&hash[..32]);
173
174        let receipt = Receipt {
175            receipt_id: Hash32(id),
176            prompt_id: *prompt_id,
177            operator,
178            inference_output,
179            execution_proof,
180            submitted_at: current_height,
181            verified: false,
182        };
183
184        self.receipts.insert(Hash32(id), receipt);
185
186        let mut prompt = self.prompts.get_mut(prompt_id).unwrap();
187        prompt.state = PromptState::Completed;
188
189        Ok(Hash32(id))
190    }
191
192    pub fn verify_receipt(&mut self, receipt_id: &Hash32) -> Result<(), RuntimeError> {
193        let receipt = self
194            .receipts
195            .get_mut(receipt_id)
196            .ok_or(RuntimeError::InvalidState)?;
197
198        receipt.verified = true;
199
200        Ok(())
201    }
202
203    pub fn create_settlement(
204        &mut self,
205        receipt_id: &Hash32,
206        model_id: &Hash32,
207        revenue_split: &crate::model_registry::RevenueSplit,
208    ) -> Result<InferenceSettlement, RuntimeError> {
209        let receipt = self
210            .receipts
211            .get(receipt_id)
212            .ok_or(RuntimeError::InvalidState)?;
213        let prompt = self
214            .prompts
215            .get(&receipt.prompt_id)
216            .ok_or(RuntimeError::InvalidState)?;
217
218        let total_fee = prompt.fee as u128;
219
220        let settlement = InferenceSettlement {
221            receipt_id: *receipt_id,
222            prompt_id: receipt.prompt_id,
223            model_id: *model_id,
224            total_fee,
225            model_owner_amount: (total_fee * revenue_split.model_owner_bps as u128) / 10000,
226            operator_amount: (total_fee * revenue_split.operator_bps as u128) / 10000,
227            shard_provider_amount: (total_fee * revenue_split.shard_provider_bps as u128) / 10000,
228            validator_amount: (total_fee * revenue_split.validator_bps as u128) / 10000,
229            vault_amount: (total_fee * revenue_split.vault_bps as u128) / 10000,
230            settled: false,
231        };
232
233        self.settlements.insert(*receipt_id, settlement.clone());
234
235        Ok(settlement)
236    }
237
238    pub fn mark_settled(&mut self, receipt_id: &Hash32) -> Result<(), RuntimeError> {
239        let settlement = self
240            .settlements
241            .get_mut(receipt_id)
242            .ok_or(RuntimeError::InvalidState)?;
243        settlement.settled = true;
244        Ok(())
245    }
246
247    pub fn get_prompt(&self, prompt_id: &Hash32) -> Option<&Prompt> {
248        self.prompts.get(prompt_id)
249    }
250
251    pub fn get_receipt(&self, receipt_id: &Hash32) -> Option<&Receipt> {
252        self.receipts.get(receipt_id)
253    }
254
255    pub fn get_settlement(&self, receipt_id: &Hash32) -> Option<&InferenceSettlement> {
256        self.settlements.get(receipt_id)
257    }
258}
259
260impl Default for InferenceModule {
261    fn default() -> Self {
262        Self::new()
263    }
264}