Skip to main content

secure_exec_kernel/
pipe_manager.rs

1use crate::fd_table::{
2    FdResult, FileDescription, ProcessFdTable, SharedFileDescription, FILETYPE_PIPE, O_RDONLY,
3    O_WRONLY,
4};
5use crate::poll::{PollEvents, PollNotifier, POLLERR, POLLHUP, POLLIN, POLLOUT};
6use std::collections::{BTreeMap, VecDeque};
7use std::error::Error;
8use std::fmt;
9use std::sync::{Arc, Condvar, Mutex, MutexGuard};
10use std::time::{Duration, Instant};
11
12pub const MAX_PIPE_BUFFER_BYTES: usize = 65_536;
13pub const PIPE_BUF_BYTES: usize = 4_096;
14
15pub type PipeResult<T> = Result<T, PipeError>;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct PipeError {
19    code: &'static str,
20    message: String,
21}
22
23impl PipeError {
24    pub fn code(&self) -> &'static str {
25        self.code
26    }
27
28    fn bad_file_descriptor(message: impl Into<String>) -> Self {
29        Self {
30            code: "EBADF",
31            message: message.into(),
32        }
33    }
34
35    fn broken_pipe(message: impl Into<String>) -> Self {
36        Self {
37            code: "EPIPE",
38            message: message.into(),
39        }
40    }
41
42    fn would_block(message: impl Into<String>) -> Self {
43        Self {
44            code: "EAGAIN",
45            message: message.into(),
46        }
47    }
48}
49
50impl fmt::Display for PipeError {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        write!(f, "{}: {}", self.code, self.message)
53    }
54}
55
56impl Error for PipeError {}
57
58#[derive(Debug, Clone)]
59pub struct PipeEnd {
60    pub description: SharedFileDescription,
61    pub filetype: u8,
62}
63
64#[derive(Debug, Clone)]
65pub struct PipePair {
66    pub read: PipeEnd,
67    pub write: PipeEnd,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71struct PipeRef {
72    pipe_id: u64,
73    end: PipeSide,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77enum PipeSide {
78    Read,
79    Write,
80}
81
82#[derive(Debug, Default)]
83struct PendingRead {
84    length: usize,
85    result: Option<Option<Vec<u8>>>,
86}
87
88#[derive(Debug, Default)]
89struct PipeState {
90    buffer: VecDeque<Vec<u8>>,
91    closed_read: bool,
92    closed_write: bool,
93    waiting_reads: VecDeque<u64>,
94}
95
96#[derive(Debug)]
97struct PipeManagerState {
98    pipes: BTreeMap<u64, PipeState>,
99    desc_to_pipe: BTreeMap<u64, PipeRef>,
100    waiters: BTreeMap<u64, PendingRead>,
101    next_pipe_id: u64,
102    next_desc_id: u64,
103    next_waiter_id: u64,
104}
105
106impl Default for PipeManagerState {
107    fn default() -> Self {
108        Self {
109            pipes: BTreeMap::new(),
110            desc_to_pipe: BTreeMap::new(),
111            waiters: BTreeMap::new(),
112            next_pipe_id: 1,
113            next_desc_id: 100_000,
114            next_waiter_id: 1,
115        }
116    }
117}
118
119#[derive(Debug)]
120struct PipeManagerInner {
121    state: Mutex<PipeManagerState>,
122    waiters: Condvar,
123}
124
125#[derive(Debug, Clone)]
126pub struct PipeManager {
127    inner: Arc<PipeManagerInner>,
128    notifier: Option<PollNotifier>,
129}
130
131impl Default for PipeManager {
132    fn default() -> Self {
133        Self {
134            inner: Arc::new(PipeManagerInner {
135                state: Mutex::new(PipeManagerState::default()),
136                waiters: Condvar::new(),
137            }),
138            notifier: None,
139        }
140    }
141}
142
143impl PipeManager {
144    pub fn new() -> Self {
145        Self::default()
146    }
147
148    pub(crate) fn with_notifier(notifier: PollNotifier) -> Self {
149        Self {
150            notifier: Some(notifier),
151            ..Self::default()
152        }
153    }
154
155    pub fn create_pipe(&self) -> PipePair {
156        let mut state = lock_or_recover(&self.inner.state);
157        let pipe_id = state.next_pipe_id;
158        state.next_pipe_id += 1;
159
160        let read_id = state.next_desc_id;
161        state.next_desc_id += 1;
162        let write_id = state.next_desc_id;
163        state.next_desc_id += 1;
164
165        state.pipes.insert(pipe_id, PipeState::default());
166        state.desc_to_pipe.insert(
167            read_id,
168            PipeRef {
169                pipe_id,
170                end: PipeSide::Read,
171            },
172        );
173        state.desc_to_pipe.insert(
174            write_id,
175            PipeRef {
176                pipe_id,
177                end: PipeSide::Write,
178            },
179        );
180        drop(state);
181
182        PipePair {
183            read: PipeEnd {
184                description: Arc::new(FileDescription::with_ref_count(
185                    read_id,
186                    format!("pipe:{pipe_id}:read"),
187                    O_RDONLY,
188                    0,
189                )),
190                filetype: FILETYPE_PIPE,
191            },
192            write: PipeEnd {
193                description: Arc::new(FileDescription::with_ref_count(
194                    write_id,
195                    format!("pipe:{pipe_id}:write"),
196                    O_WRONLY,
197                    0,
198                )),
199                filetype: FILETYPE_PIPE,
200            },
201        }
202    }
203
204    pub fn poll(&self, description_id: u64, requested: PollEvents) -> PipeResult<PollEvents> {
205        let state = lock_or_recover(&self.inner.state);
206        let pipe_ref = state
207            .desc_to_pipe
208            .get(&description_id)
209            .copied()
210            .ok_or_else(|| PipeError::bad_file_descriptor("not a pipe end"))?;
211        let pipe = state
212            .pipes
213            .get(&pipe_ref.pipe_id)
214            .ok_or_else(|| PipeError::bad_file_descriptor("pipe not found"))?;
215
216        let mut events = PollEvents::empty();
217        match pipe_ref.end {
218            PipeSide::Read => {
219                if requested.intersects(POLLIN) && !pipe.buffer.is_empty() {
220                    events |= POLLIN;
221                }
222                if pipe.closed_write {
223                    events |= POLLHUP;
224                }
225            }
226            PipeSide::Write => {
227                if pipe.closed_read {
228                    events |= POLLERR;
229                } else if requested.intersects(POLLOUT)
230                    && (available_capacity(pipe) > 0 || !pipe.waiting_reads.is_empty())
231                {
232                    events |= POLLOUT;
233                }
234            }
235        }
236
237        Ok(events)
238    }
239
240    pub fn write(&self, description_id: u64, data: impl AsRef<[u8]>) -> PipeResult<usize> {
241        self.write_with_mode(description_id, data, true)
242    }
243
244    pub fn write_blocking(&self, description_id: u64, data: impl AsRef<[u8]>) -> PipeResult<usize> {
245        self.write_with_mode(description_id, data, false)
246    }
247
248    pub fn write_with_mode(
249        &self,
250        description_id: u64,
251        data: impl AsRef<[u8]>,
252        nonblocking: bool,
253    ) -> PipeResult<usize> {
254        let payload = data.as_ref();
255        let mut state = lock_or_recover(&self.inner.state);
256        let pipe_ref = state
257            .desc_to_pipe
258            .get(&description_id)
259            .copied()
260            .ok_or_else(|| PipeError::bad_file_descriptor("not a pipe write end"))?;
261        if pipe_ref.end != PipeSide::Write {
262            return Err(PipeError::bad_file_descriptor("not a pipe write end"));
263        }
264
265        loop {
266            let waiter_id = {
267                let pipe = state
268                    .pipes
269                    .get_mut(&pipe_ref.pipe_id)
270                    .ok_or_else(|| PipeError::bad_file_descriptor("pipe not found"))?;
271                if pipe.closed_write {
272                    return Err(PipeError::broken_pipe("write end closed"));
273                }
274                if pipe.closed_read {
275                    return Err(PipeError::broken_pipe("read end closed"));
276                }
277                pipe.waiting_reads.pop_front()
278            };
279
280            if let Some(waiter_id) = waiter_id {
281                let waiter_length = match state.waiters.get(&waiter_id) {
282                    Some(waiter) => waiter.length,
283                    None => continue,
284                };
285                let delivered_len = waiter_length.min(payload.len());
286                let delivered = payload[..delivered_len].to_vec();
287                let remainder = &payload[delivered_len..];
288
289                if !remainder.is_empty() {
290                    let pipe = state
291                        .pipes
292                        .get_mut(&pipe_ref.pipe_id)
293                        .ok_or_else(|| PipeError::bad_file_descriptor("pipe not found"))?;
294                    pipe.buffer.push_back(remainder.to_vec());
295                }
296
297                if let Some(waiter) = state.waiters.get_mut(&waiter_id) {
298                    waiter.result = Some(Some(delivered));
299                    self.notify_waiters_and_pollers();
300                    return Ok(payload.len());
301                }
302                continue;
303            }
304
305            let current_buffer_size = {
306                let pipe = state
307                    .pipes
308                    .get(&pipe_ref.pipe_id)
309                    .ok_or_else(|| PipeError::bad_file_descriptor("pipe not found"))?;
310                buffer_size(&pipe.buffer)
311            };
312            let available = MAX_PIPE_BUFFER_BYTES.saturating_sub(current_buffer_size);
313
314            if payload.len() <= PIPE_BUF_BYTES {
315                if available >= payload.len() {
316                    let pipe = state
317                        .pipes
318                        .get_mut(&pipe_ref.pipe_id)
319                        .ok_or_else(|| PipeError::bad_file_descriptor("pipe not found"))?;
320                    pipe.buffer.push_back(payload.to_vec());
321                    self.notify_waiters_and_pollers();
322                    return Ok(payload.len());
323                }
324            } else if available > 0 {
325                let chunk_len = available.min(payload.len());
326                let pipe = state
327                    .pipes
328                    .get_mut(&pipe_ref.pipe_id)
329                    .ok_or_else(|| PipeError::bad_file_descriptor("pipe not found"))?;
330                pipe.buffer.push_back(payload[..chunk_len].to_vec());
331                self.notify_waiters_and_pollers();
332                return Ok(chunk_len);
333            }
334
335            if nonblocking {
336                return Err(PipeError::would_block("pipe buffer full"));
337            }
338
339            state = wait_or_recover(&self.inner.waiters, state);
340        }
341    }
342
343    pub fn read(&self, description_id: u64, length: usize) -> PipeResult<Option<Vec<u8>>> {
344        self.read_with_timeout(description_id, length, None)
345    }
346
347    pub fn read_with_timeout(
348        &self,
349        description_id: u64,
350        length: usize,
351        timeout: Option<Duration>,
352    ) -> PipeResult<Option<Vec<u8>>> {
353        let mut state = lock_or_recover(&self.inner.state);
354        let pipe_ref = state
355            .desc_to_pipe
356            .get(&description_id)
357            .copied()
358            .ok_or_else(|| PipeError::bad_file_descriptor("not a pipe read end"))?;
359        if pipe_ref.end != PipeSide::Read {
360            return Err(PipeError::bad_file_descriptor("not a pipe read end"));
361        }
362
363        let mut waiter_id = None;
364        let deadline = timeout.map(|duration| Instant::now() + duration);
365
366        loop {
367            if let Some(id) = waiter_id {
368                if let Some(waiter) = state.waiters.get_mut(&id) {
369                    if let Some(result) = waiter.result.take() {
370                        state.waiters.remove(&id);
371                        return Ok(result);
372                    }
373                }
374            }
375
376            {
377                let pipe = state
378                    .pipes
379                    .get_mut(&pipe_ref.pipe_id)
380                    .ok_or_else(|| PipeError::bad_file_descriptor("pipe not found"))?;
381
382                if !pipe.buffer.is_empty() {
383                    let result = drain_buffer(&mut pipe.buffer, length);
384                    self.notify_waiters_and_pollers();
385                    return Ok(Some(result));
386                }
387
388                if pipe.closed_write {
389                    if let Some(id) = waiter_id {
390                        state.waiters.remove(&id);
391                    }
392                    return Ok(None);
393                }
394            }
395
396            let id = if let Some(id) = waiter_id {
397                id
398            } else {
399                let next = state.next_waiter_id;
400                state.next_waiter_id += 1;
401                state.waiters.insert(
402                    next,
403                    PendingRead {
404                        length,
405                        result: None,
406                    },
407                );
408                let Some(pipe) = state.pipes.get_mut(&pipe_ref.pipe_id) else {
409                    state.waiters.remove(&next);
410                    return Err(PipeError::bad_file_descriptor("pipe not found"));
411                };
412                pipe.waiting_reads.push_back(next);
413                self.notify_waiters_and_pollers();
414                waiter_id = Some(next);
415                next
416            };
417
418            let Some(deadline) = deadline else {
419                state = wait_or_recover(&self.inner.waiters, state);
420                if !state.waiters.contains_key(&id) {
421                    waiter_id = None;
422                }
423                continue;
424            };
425
426            let now = Instant::now();
427            if now >= deadline {
428                if let Some(id) = waiter_id.take() {
429                    state.waiters.remove(&id);
430                    if let Some(pipe) = state.pipes.get_mut(&pipe_ref.pipe_id) {
431                        pipe.waiting_reads.retain(|queued| *queued != id);
432                    }
433                    self.notify_waiters_and_pollers();
434                }
435                return Err(PipeError::would_block("pipe read timed out"));
436            }
437
438            let remaining = deadline.saturating_duration_since(now);
439            let (next_state, wait_result) =
440                wait_timeout_or_recover(&self.inner.waiters, state, remaining);
441            state = next_state;
442            if !state.waiters.contains_key(&id) {
443                waiter_id = None;
444            }
445            if wait_result.timed_out() {
446                if let Some(id) = waiter_id.take() {
447                    state.waiters.remove(&id);
448                    if let Some(pipe) = state.pipes.get_mut(&pipe_ref.pipe_id) {
449                        pipe.waiting_reads.retain(|queued| *queued != id);
450                    }
451                    self.notify_waiters_and_pollers();
452                }
453                return Err(PipeError::would_block("pipe read timed out"));
454            }
455        }
456    }
457
458    pub fn close(&self, description_id: u64) {
459        let mut state = lock_or_recover(&self.inner.state);
460        let Some(pipe_ref) = state.desc_to_pipe.remove(&description_id) else {
461            return;
462        };
463
464        let (waiter_ids, remove_pipe, should_notify) =
465            if let Some(pipe) = state.pipes.get_mut(&pipe_ref.pipe_id) {
466                match pipe_ref.end {
467                    PipeSide::Read => {
468                        pipe.closed_read = true;
469                        (Vec::new(), pipe.closed_read && pipe.closed_write, true)
470                    }
471                    PipeSide::Write => {
472                        pipe.closed_write = true;
473                        let waiter_ids = pipe.waiting_reads.drain(..).collect::<Vec<_>>();
474                        (waiter_ids, pipe.closed_read && pipe.closed_write, true)
475                    }
476                }
477            } else {
478                (Vec::new(), false, false)
479            };
480
481        for waiter_id in waiter_ids {
482            if let Some(waiter) = state.waiters.get_mut(&waiter_id) {
483                waiter.result = Some(None);
484            }
485        }
486
487        if remove_pipe {
488            state.pipes.remove(&pipe_ref.pipe_id);
489        }
490        if should_notify {
491            self.notify_waiters_and_pollers();
492        }
493    }
494
495    pub fn is_pipe(&self, description_id: u64) -> bool {
496        lock_or_recover(&self.inner.state)
497            .desc_to_pipe
498            .contains_key(&description_id)
499    }
500
501    pub fn pipe_id_for(&self, description_id: u64) -> Option<u64> {
502        lock_or_recover(&self.inner.state)
503            .desc_to_pipe
504            .get(&description_id)
505            .map(|pipe_ref| pipe_ref.pipe_id)
506    }
507
508    pub fn pipe_count(&self) -> usize {
509        lock_or_recover(&self.inner.state).pipes.len()
510    }
511
512    pub fn buffered_bytes(&self) -> usize {
513        lock_or_recover(&self.inner.state)
514            .pipes
515            .values()
516            .map(|pipe| buffer_size(&pipe.buffer))
517            .sum()
518    }
519
520    pub fn waiting_reader_count(&self, description_id: u64) -> PipeResult<usize> {
521        let state = lock_or_recover(&self.inner.state);
522        let pipe_ref = state
523            .desc_to_pipe
524            .get(&description_id)
525            .copied()
526            .ok_or_else(|| PipeError::bad_file_descriptor("not a pipe end"))?;
527        let pipe = state
528            .pipes
529            .get(&pipe_ref.pipe_id)
530            .ok_or_else(|| PipeError::bad_file_descriptor("pipe not found"))?;
531        Ok(pipe.waiting_reads.len())
532    }
533
534    pub fn pending_read_waiter_count(&self) -> usize {
535        lock_or_recover(&self.inner.state).waiters.len()
536    }
537
538    pub fn create_pipe_fds(&self, fd_table: &mut ProcessFdTable) -> FdResult<(u32, u32)> {
539        let pipe = self.create_pipe();
540        let read_fd =
541            fd_table.open_with(Arc::clone(&pipe.read.description), FILETYPE_PIPE, None)?;
542        match fd_table.open_with(Arc::clone(&pipe.write.description), FILETYPE_PIPE, None) {
543            Ok(write_fd) => Ok((read_fd, write_fd)),
544            Err(error) => {
545                fd_table.close(read_fd);
546                self.close(pipe.read.description.id());
547                self.close(pipe.write.description.id());
548                Err(error)
549            }
550        }
551    }
552
553    fn notify_waiters_and_pollers(&self) {
554        self.inner.waiters.notify_all();
555        if let Some(notifier) = &self.notifier {
556            notifier.notify();
557        }
558    }
559}
560
561fn buffer_size(buffer: &VecDeque<Vec<u8>>) -> usize {
562    buffer.iter().map(Vec::len).sum()
563}
564
565fn available_capacity(pipe: &PipeState) -> usize {
566    MAX_PIPE_BUFFER_BYTES.saturating_sub(buffer_size(&pipe.buffer))
567}
568
569fn drain_buffer(buffer: &mut VecDeque<Vec<u8>>, length: usize) -> Vec<u8> {
570    let mut chunks = Vec::new();
571    let mut remaining = length;
572
573    while remaining > 0 {
574        let Some(chunk) = buffer.pop_front() else {
575            break;
576        };
577        if chunk.len() <= remaining {
578            remaining -= chunk.len();
579            chunks.push(chunk);
580        } else {
581            let (head, tail) = chunk.split_at(remaining);
582            chunks.push(head.to_vec());
583            buffer.push_front(tail.to_vec());
584            remaining = 0;
585        }
586    }
587
588    if chunks.len() == 1 {
589        return chunks.pop().expect("single chunk should exist");
590    }
591
592    let total = chunks.iter().map(Vec::len).sum();
593    let mut result = Vec::with_capacity(total);
594    for chunk in chunks {
595        result.extend_from_slice(&chunk);
596    }
597    result
598}
599
600fn lock_or_recover<'a, T>(mutex: &'a Mutex<T>) -> MutexGuard<'a, T> {
601    match mutex.lock() {
602        Ok(guard) => guard,
603        Err(poisoned) => poisoned.into_inner(),
604    }
605}
606
607fn wait_or_recover<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
608    match condvar.wait(guard) {
609        Ok(guard) => guard,
610        Err(poisoned) => poisoned.into_inner(),
611    }
612}
613
614fn wait_timeout_or_recover<'a, T>(
615    condvar: &Condvar,
616    guard: MutexGuard<'a, T>,
617    timeout: Duration,
618) -> (MutexGuard<'a, T>, std::sync::WaitTimeoutResult) {
619    match condvar.wait_timeout(guard, timeout) {
620        Ok(result) => result,
621        Err(poisoned) => poisoned.into_inner(),
622    }
623}