Skip to main content

smp_tee_runtime/tee_interface/
traits.rs

1use std::collections::HashMap;
2
3use crate::aggregation::{federated_averaging, multi_krum};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum AggregationAlgorithm {
7    FederatedAveraging,
8    MultiKrum { byzantine_tolerance: usize },
9}
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct ComputationParams {
13    pub algorithm: AggregationAlgorithm,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum TeeError {
18    NotInitialized,
19    InvalidAllocationSize,
20    InvalidPointer,
21    InvalidInput(&'static str),
22}
23
24pub trait TeeGuard {
25    fn initialize(&mut self) -> Result<(), TeeError>;
26    fn allocate_memory(&mut self, size: usize) -> Result<*mut u8, TeeError>;
27    fn write_data(&mut self, ptr: *mut u8, data: &[u8]) -> Result<(), TeeError>;
28    fn execute_computation(
29        &self,
30        input_ptrs: &[*const u8],
31        params: &ComputationParams,
32    ) -> Result<Vec<u8>, TeeError>;
33}
34
35#[derive(Debug, Default)]
36pub struct InMemoryTee {
37    initialized: bool,
38    allocations: HashMap<usize, Vec<u8>>,
39}
40
41impl InMemoryTee {
42    fn read_vector(&self, ptr: *const u8) -> Result<Vec<f32>, TeeError> {
43        let bytes = self
44            .allocations
45            .get(&(ptr as usize))
46            .ok_or(TeeError::InvalidPointer)?;
47        if bytes.len() % 4 != 0 {
48            return Err(TeeError::InvalidInput(
49                "payload length must be a multiple of 4",
50            ));
51        }
52
53        Ok(bytes
54            .chunks_exact(4)
55            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
56            .collect())
57    }
58
59    fn encode_vector(values: &[f32]) -> Vec<u8> {
60        values
61            .iter()
62            .flat_map(|value| value.to_le_bytes())
63            .collect::<Vec<u8>>()
64    }
65}
66
67impl TeeGuard for InMemoryTee {
68    fn initialize(&mut self) -> Result<(), TeeError> {
69        self.initialized = true;
70        Ok(())
71    }
72
73    fn allocate_memory(&mut self, size: usize) -> Result<*mut u8, TeeError> {
74        if !self.initialized {
75            return Err(TeeError::NotInitialized);
76        }
77        if size == 0 {
78            return Err(TeeError::InvalidAllocationSize);
79        }
80
81        let mut allocation = vec![0_u8; size];
82        let ptr = allocation.as_mut_ptr();
83        self.allocations.insert(ptr as usize, allocation);
84        Ok(ptr)
85    }
86
87    fn write_data(&mut self, ptr: *mut u8, data: &[u8]) -> Result<(), TeeError> {
88        if !self.initialized {
89            return Err(TeeError::NotInitialized);
90        }
91
92        let buffer = self
93            .allocations
94            .get_mut(&(ptr as usize))
95            .ok_or(TeeError::InvalidPointer)?;
96
97        if data.len() > buffer.len() {
98            return Err(TeeError::InvalidAllocationSize);
99        }
100
101        buffer[..data.len()].copy_from_slice(data);
102        Ok(())
103    }
104
105    fn execute_computation(
106        &self,
107        input_ptrs: &[*const u8],
108        params: &ComputationParams,
109    ) -> Result<Vec<u8>, TeeError> {
110        if !self.initialized {
111            return Err(TeeError::NotInitialized);
112        }
113
114        let vectors = input_ptrs
115            .iter()
116            .map(|ptr| self.read_vector(*ptr))
117            .collect::<Result<Vec<_>, _>>()?;
118
119        let result = match params.algorithm {
120            AggregationAlgorithm::FederatedAveraging => federated_averaging(&vectors)
121                .ok_or(TeeError::InvalidInput("invalid federated averaging input"))?,
122            AggregationAlgorithm::MultiKrum {
123                byzantine_tolerance,
124            } => multi_krum(&vectors, byzantine_tolerance)
125                .ok_or(TeeError::InvalidInput("invalid multi-krum input"))?,
126        };
127
128        Ok(Self::encode_vector(&result))
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    fn to_bytes(values: &[f32]) -> Vec<u8> {
137        values.iter().flat_map(|v| v.to_le_bytes()).collect()
138    }
139
140    fn to_f32(bytes: &[u8]) -> Vec<f32> {
141        bytes
142            .chunks_exact(4)
143            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
144            .collect()
145    }
146
147    #[test]
148    fn tee_executes_federated_averaging() {
149        let mut tee = InMemoryTee::default();
150        tee.initialize().unwrap();
151
152        let p1 = tee.allocate_memory(8).unwrap();
153        let p2 = tee.allocate_memory(8).unwrap();
154
155        tee.write_data(p1, &to_bytes(&[1.0, 3.0])).unwrap();
156        tee.write_data(p2, &to_bytes(&[3.0, 5.0])).unwrap();
157
158        let out = tee
159            .execute_computation(
160                &[p1.cast_const(), p2.cast_const()],
161                &ComputationParams {
162                    algorithm: AggregationAlgorithm::FederatedAveraging,
163                },
164            )
165            .unwrap();
166
167        assert_eq!(to_f32(&out), vec![2.0, 4.0]);
168    }
169}