smp_tee_runtime/tee_interface/
traits.rs1use std::collections::HashMap;
2
3use crate::aggregation::{federated_averaging, multi_krum};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum AggregationAlgorithm {
7 FederatedAveraging,
8 MultiKrum { byzantine_tolerance: usize },
9}
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct ComputationParams {
13 pub algorithm: AggregationAlgorithm,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum TeeError {
18 NotInitialized,
19 InvalidAllocationSize,
20 InvalidPointer,
21 InvalidInput(&'static str),
22}
23
24pub trait TeeGuard {
25 fn initialize(&mut self) -> Result<(), TeeError>;
26 fn allocate_memory(&mut self, size: usize) -> Result<*mut u8, TeeError>;
27 fn write_data(&mut self, ptr: *mut u8, data: &[u8]) -> Result<(), TeeError>;
28 fn execute_computation(
29 &self,
30 input_ptrs: &[*const u8],
31 params: &ComputationParams,
32 ) -> Result<Vec<u8>, TeeError>;
33}
34
35#[derive(Debug, Default)]
36pub struct InMemoryTee {
37 initialized: bool,
38 allocations: HashMap<usize, Vec<u8>>,
39}
40
41impl InMemoryTee {
42 fn read_vector(&self, ptr: *const u8) -> Result<Vec<f32>, TeeError> {
43 let bytes = self
44 .allocations
45 .get(&(ptr as usize))
46 .ok_or(TeeError::InvalidPointer)?;
47 if bytes.len() % 4 != 0 {
48 return Err(TeeError::InvalidInput(
49 "payload length must be a multiple of 4",
50 ));
51 }
52
53 Ok(bytes
54 .chunks_exact(4)
55 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
56 .collect())
57 }
58
59 fn encode_vector(values: &[f32]) -> Vec<u8> {
60 values
61 .iter()
62 .flat_map(|value| value.to_le_bytes())
63 .collect::<Vec<u8>>()
64 }
65}
66
67impl TeeGuard for InMemoryTee {
68 fn initialize(&mut self) -> Result<(), TeeError> {
69 self.initialized = true;
70 Ok(())
71 }
72
73 fn allocate_memory(&mut self, size: usize) -> Result<*mut u8, TeeError> {
74 if !self.initialized {
75 return Err(TeeError::NotInitialized);
76 }
77 if size == 0 {
78 return Err(TeeError::InvalidAllocationSize);
79 }
80
81 let mut allocation = vec![0_u8; size];
82 let ptr = allocation.as_mut_ptr();
83 self.allocations.insert(ptr as usize, allocation);
84 Ok(ptr)
85 }
86
87 fn write_data(&mut self, ptr: *mut u8, data: &[u8]) -> Result<(), TeeError> {
88 if !self.initialized {
89 return Err(TeeError::NotInitialized);
90 }
91
92 let buffer = self
93 .allocations
94 .get_mut(&(ptr as usize))
95 .ok_or(TeeError::InvalidPointer)?;
96
97 if data.len() > buffer.len() {
98 return Err(TeeError::InvalidAllocationSize);
99 }
100
101 buffer[..data.len()].copy_from_slice(data);
102 Ok(())
103 }
104
105 fn execute_computation(
106 &self,
107 input_ptrs: &[*const u8],
108 params: &ComputationParams,
109 ) -> Result<Vec<u8>, TeeError> {
110 if !self.initialized {
111 return Err(TeeError::NotInitialized);
112 }
113
114 let vectors = input_ptrs
115 .iter()
116 .map(|ptr| self.read_vector(*ptr))
117 .collect::<Result<Vec<_>, _>>()?;
118
119 let result = match params.algorithm {
120 AggregationAlgorithm::FederatedAveraging => federated_averaging(&vectors)
121 .ok_or(TeeError::InvalidInput("invalid federated averaging input"))?,
122 AggregationAlgorithm::MultiKrum {
123 byzantine_tolerance,
124 } => multi_krum(&vectors, byzantine_tolerance)
125 .ok_or(TeeError::InvalidInput("invalid multi-krum input"))?,
126 };
127
128 Ok(Self::encode_vector(&result))
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 fn to_bytes(values: &[f32]) -> Vec<u8> {
137 values.iter().flat_map(|v| v.to_le_bytes()).collect()
138 }
139
140 fn to_f32(bytes: &[u8]) -> Vec<f32> {
141 bytes
142 .chunks_exact(4)
143 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
144 .collect()
145 }
146
147 #[test]
148 fn tee_executes_federated_averaging() {
149 let mut tee = InMemoryTee::default();
150 tee.initialize().unwrap();
151
152 let p1 = tee.allocate_memory(8).unwrap();
153 let p2 = tee.allocate_memory(8).unwrap();
154
155 tee.write_data(p1, &to_bytes(&[1.0, 3.0])).unwrap();
156 tee.write_data(p2, &to_bytes(&[3.0, 5.0])).unwrap();
157
158 let out = tee
159 .execute_computation(
160 &[p1.cast_const(), p2.cast_const()],
161 &ComputationParams {
162 algorithm: AggregationAlgorithm::FederatedAveraging,
163 },
164 )
165 .unwrap();
166
167 assert_eq!(to_f32(&out), vec![2.0, 4.0]);
168 }
169}