Skip to main content

trussed_staging/
virt.rs

1// Copyright (C) Nitrokey GmbH
2// SPDX-License-Identifier: Apache-2.0 or MIT
3
4//! Wrapper around [`trussed::virt`][] that provides clients with both the core backend and the [`StagingBackend`] backend.
5
6use trussed_core::api::{reply, request, Reply, Request};
7
8#[cfg(feature = "manage")]
9use littlefs2_core::Path;
10#[cfg(feature = "manage")]
11use trussed_core::types::Location;
12
13#[cfg(feature = "chunked")]
14use trussed_chunked::ChunkedExtension;
15#[cfg(feature = "fs-info")]
16use trussed_fs_info::FsInfoExtension;
17#[cfg(feature = "hkdf")]
18use trussed_hkdf::HkdfExtension;
19#[cfg(feature = "hpke")]
20use trussed_hpke::HpkeExtension;
21#[cfg(feature = "manage")]
22use trussed_manage::ManageExtension;
23#[cfg(feature = "wrap-key-to-file")]
24use trussed_wrap_key_to_file::WrapKeyToFileExtension;
25
26use crate::{StagingBackend, StagingContext};
27
28#[derive(Default, Debug)]
29pub struct Dispatcher {
30    backend: StagingBackend,
31}
32
33#[derive(Debug)]
34pub enum BackendIds {
35    StagingBackend,
36}
37
38#[derive(Debug)]
39pub enum ExtensionIds {
40    #[cfg(feature = "chunked")]
41    Chunked,
42    #[cfg(feature = "hkdf")]
43    Hkdf,
44    #[cfg(feature = "manage")]
45    Manage,
46    #[cfg(feature = "wrap-key-to-file")]
47    WrapKeyToFile,
48    #[cfg(feature = "fs-info")]
49    FsInfo,
50    #[cfg(feature = "hpke")]
51    Hpke,
52}
53
54#[cfg(feature = "chunked")]
55impl ExtensionId<ChunkedExtension> for Dispatcher {
56    type Id = ExtensionIds;
57    const ID: ExtensionIds = ExtensionIds::Chunked;
58}
59
60#[cfg(feature = "hkdf")]
61impl ExtensionId<HkdfExtension> for Dispatcher {
62    type Id = ExtensionIds;
63    const ID: ExtensionIds = ExtensionIds::Hkdf;
64}
65
66#[cfg(feature = "manage")]
67impl ExtensionId<ManageExtension> for Dispatcher {
68    type Id = ExtensionIds;
69    const ID: ExtensionIds = ExtensionIds::Manage;
70}
71
72#[cfg(feature = "wrap-key-to-file")]
73impl ExtensionId<WrapKeyToFileExtension> for Dispatcher {
74    type Id = ExtensionIds;
75    const ID: ExtensionIds = ExtensionIds::WrapKeyToFile;
76}
77
78#[cfg(feature = "fs-info")]
79impl ExtensionId<FsInfoExtension> for Dispatcher {
80    type Id = ExtensionIds;
81    const ID: ExtensionIds = ExtensionIds::FsInfo;
82}
83
84#[cfg(feature = "hpke")]
85impl ExtensionId<HpkeExtension> for Dispatcher {
86    type Id = ExtensionIds;
87    const ID: ExtensionIds = ExtensionIds::Hpke;
88}
89
90impl From<ExtensionIds> for u8 {
91    fn from(value: ExtensionIds) -> Self {
92        match value {
93            #[cfg(feature = "chunked")]
94            ExtensionIds::Chunked => 0,
95            #[cfg(feature = "hkdf")]
96            ExtensionIds::Hkdf => 1,
97            #[cfg(feature = "manage")]
98            ExtensionIds::Manage => 2,
99            #[cfg(feature = "wrap-key-to-file")]
100            ExtensionIds::WrapKeyToFile => 3,
101            #[cfg(feature = "fs-info")]
102            ExtensionIds::FsInfo => 4,
103            #[cfg(feature = "hpke")]
104            ExtensionIds::Hpke => 5,
105        }
106    }
107}
108
109impl TryFrom<u8> for ExtensionIds {
110    type Error = Error;
111    fn try_from(value: u8) -> Result<Self, Error> {
112        match value {
113            #[cfg(feature = "chunked")]
114            0 => Ok(Self::Chunked),
115            #[cfg(feature = "hkdf")]
116            1 => Ok(Self::Hkdf),
117            #[cfg(feature = "manage")]
118            2 => Ok(Self::Manage),
119            #[cfg(feature = "wrap-key-to-file")]
120            3 => Ok(Self::WrapKeyToFile),
121            #[cfg(feature = "fs-info")]
122            4 => Ok(Self::FsInfo),
123            #[cfg(feature = "hpke")]
124            5 => Ok(Self::Hpke),
125            _ => Err(Error::FunctionNotSupported),
126        }
127    }
128}
129
130impl ExtensionDispatch for Dispatcher {
131    type BackendId = BackendIds;
132    type Context = StagingContext;
133    type ExtensionId = ExtensionIds;
134    fn core_request<P: Platform>(
135        &mut self,
136        _backend: &Self::BackendId,
137        ctx: &mut trussed::types::Context<Self::Context>,
138        request: &Request,
139        resources: &mut trussed::service::ServiceResources<P>,
140    ) -> Result<Reply, Error> {
141        self.backend
142            .request(&mut ctx.core, &mut ctx.backends, request, resources)
143    }
144
145    fn extension_request<P: Platform>(
146        &mut self,
147        _backend: &Self::BackendId,
148        extension: &Self::ExtensionId,
149        ctx: &mut trussed::types::Context<Self::Context>,
150        request: &request::SerdeExtension,
151        resources: &mut trussed::service::ServiceResources<P>,
152    ) -> Result<reply::SerdeExtension, Error> {
153        let _ = &extension;
154        let _ = &ctx;
155        let _ = &request;
156        let _ = &resources;
157        // Dereference to avoid compile issue when all features are disabled requiring a default branch
158        // See https://github.com/rust-lang/rust/issues/78123#
159        match *extension {
160            #[cfg(feature = "wrap-key-to-file")]
161            ExtensionIds::WrapKeyToFile => {
162                ExtensionImpl::<WrapKeyToFileExtension>::extension_request_serialized(
163                    &mut self.backend,
164                    &mut ctx.core,
165                    &mut ctx.backends,
166                    request,
167                    resources,
168                )
169            }
170
171            #[cfg(feature = "chunked")]
172            ExtensionIds::Chunked => {
173                ExtensionImpl::<ChunkedExtension>::extension_request_serialized(
174                    &mut self.backend,
175                    &mut ctx.core,
176                    &mut ctx.backends,
177                    request,
178                    resources,
179                )
180            }
181
182            #[cfg(feature = "hkdf")]
183            ExtensionIds::Hkdf => ExtensionImpl::<HkdfExtension>::extension_request_serialized(
184                &mut self.backend,
185                &mut ctx.core,
186                &mut ctx.backends,
187                request,
188                resources,
189            ),
190
191            #[cfg(feature = "manage")]
192            ExtensionIds::Manage => ExtensionImpl::<ManageExtension>::extension_request_serialized(
193                &mut self.backend,
194                &mut ctx.core,
195                &mut ctx.backends,
196                request,
197                resources,
198            ),
199            #[cfg(feature = "fs-info")]
200            ExtensionIds::FsInfo => ExtensionImpl::<FsInfoExtension>::extension_request_serialized(
201                &mut self.backend,
202                &mut ctx.core,
203                &mut ctx.backends,
204                request,
205                resources,
206            ),
207            #[cfg(feature = "hpke")]
208            ExtensionIds::Hpke => ExtensionImpl::<HpkeExtension>::extension_request_serialized(
209                &mut self.backend,
210                &mut ctx.core,
211                &mut ctx.backends,
212                request,
213                resources,
214            ),
215        }
216    }
217}
218
219use trussed::{
220    backend::{Backend, BackendId},
221    serde_extensions::*,
222    virt::{self, StoreConfig},
223    Platform,
224};
225use trussed_core::Error;
226
227pub type Client<'a, D = Dispatcher> = virt::Client<'a, D>;
228
229pub fn with_client<R, F>(store: StoreConfig, client_id: &str, f: F) -> R
230where
231    F: FnOnce(Client) -> R,
232{
233    virt::with_platform(store, |platform| {
234        platform.run_client_with_backends(
235            client_id,
236            Dispatcher::default(),
237            &[
238                BackendId::Custom(BackendIds::StagingBackend),
239                BackendId::Core,
240            ],
241            f,
242        )
243    })
244}
245
246#[cfg(feature = "manage")]
247pub fn with_client_and_preserve<R, F>(
248    store: StoreConfig,
249    client_id: &str,
250    f: F,
251    should_preserve_file: fn(&Path, location: Location) -> bool,
252) -> R
253where
254    F: FnOnce(Client) -> R,
255{
256    let mut dispatcher = Dispatcher::default();
257    dispatcher.backend.manage.should_preserve_file = should_preserve_file;
258
259    virt::with_platform(store, |platform| {
260        platform.run_client_with_backends(
261            client_id,
262            dispatcher,
263            &[
264                BackendId::Custom(BackendIds::StagingBackend),
265                BackendId::Core,
266            ],
267            f,
268        )
269    })
270}
271
272#[cfg(feature = "manage")]
273pub fn with_clients_and_preserve<R, F, const N: usize>(
274    store: StoreConfig,
275    client_ids: [&str; N],
276    should_preserve_file: fn(&Path, location: Location) -> bool,
277    f: F,
278) -> R
279where
280    F: FnOnce([Client; N]) -> R,
281{
282    let mut dispatcher = Dispatcher::default();
283    dispatcher.backend.manage.should_preserve_file = should_preserve_file;
284    let clients_backend = client_ids.map(|id| {
285        (
286            id,
287            [
288                BackendId::Custom(BackendIds::StagingBackend),
289                BackendId::Core,
290            ]
291            .as_slice(),
292        )
293    });
294
295    virt::with_platform(store, |platform| {
296        platform.run_clients_with_backends(clients_backend, dispatcher, f)
297    })
298}