Skip to main content

simple_ldap/sort/
adapter.rs

1//! This module implements Server Side Sort (SSS) search extension
2//! as described in [RFC 2891](https://datatracker.ietf.org/doc/rfc2891/).
3
4use async_trait::async_trait;
5use itertools::Itertools;
6use ldap3::{
7    LdapError, LdapResult, ResultEntry, Scope, SearchStream,
8    adapters::{Adapter, SoloMarker},
9    controls::{Control, MakeCritical, RawControl},
10};
11use std::{fmt::Debug, mem};
12use thiserror::Error;
13use tracing::debug;
14
15use crate::sort::{
16    SERVER_SIDE_SORT_REQUEST_OID, SERVER_SIDE_SORT_RESPONSE_OID,
17    control::{self, ServerSideSortResponse, SortResult},
18};
19
20/// Search adapter for sorting the results serverside.
21#[derive(Debug, Clone)]
22pub(crate) struct ServerSideSort {
23    /// The server evaluates these sort criteria in order.
24    /// If there's a tie, the next criteria will be consulted.
25    ///
26    ///
27    /// Invariant: This vec doesn't contain elements with the same `attribute` field.
28    //
29    //  (It shouldn't be empty either but that's not enforced at this level.)
30    sorts: Vec<SortBy>,
31}
32
33#[derive(Debug, Error)]
34#[error("Attributes {} occur more than once in the sort list.",
35    attributes.iter().format(", ")
36)]
37pub struct DuplicateSortAttributes {
38    attributes: Vec<String>,
39}
40
41impl ServerSideSort {
42    /// Create new adapter instance.
43    ///
44    /// Duplicate attributes aren't allowed.
45    ///
46    /// Servers are allowed to limit the amount of attributes to sort by.
47    /// In this case the search should just err.
48    pub fn new(sorts: Vec<SortBy>) -> Result<Self, DuplicateSortAttributes> {
49        // First validate the inputs.
50        let duplicates = sorts
51            .iter()
52            .map(|SortBy { attribute, .. }| attribute)
53            .duplicates()
54            .collect_vec();
55
56        if !duplicates.is_empty() {
57            let attributes = duplicates.into_iter().map(ToOwned::to_owned).collect();
58            Err(DuplicateSortAttributes { attributes })
59        }
60        // Everything is good in this branch.
61        else {
62            Ok(ServerSideSort { sorts })
63        }
64    }
65}
66
67/// A sort directive
68///
69// Not exposing the `orderingRule` as I don't know how it's supposed to work.
70#[derive(Debug, Clone)]
71pub struct SortBy {
72    /// Name of the attribute to sort by.
73    pub attribute: String,
74    /// Should the ordering be reversed?
75    pub reverse: bool,
76}
77
78/// Can be used by itself.
79impl SoloMarker for ServerSideSort {}
80
81#[async_trait]
82impl<'a, S, A> Adapter<'a, S, A> for ServerSideSort
83where
84    S: AsRef<str> + Clone + Debug + Send + Sync + 'a,
85    A: AsRef<[S]> + Clone + Debug + Send + Sync + 'a,
86{
87    async fn start(
88        &mut self,
89        stream: &mut SearchStream<'a, S, A>,
90        base: &str,
91        scope: Scope,
92        filter: &str,
93        attrs: A,
94    ) -> ldap3::result::Result<()> {
95        let stream_ldap = stream.ldap_handle();
96
97        // Check that SSS isn't defined already.
98        let sort_control_already_defined = stream_ldap.controls.as_ref().is_some_and(|vec| {
99            vec.iter()
100                .any(|control| control.ctype == SERVER_SIDE_SORT_REQUEST_OID)
101        });
102        if sort_control_already_defined {
103            return Err(LdapError::AdapterInit(String::from(
104                "found Server Side Sort control in op set already",
105            )));
106        }
107
108        // No need to keep these around in the adapter.
109        let sorts = mem::take(&mut self.sorts);
110        let new_control = control::ServerSideSortRequest {
111            // Convert the sort args to control parts.
112            sort_key_list: sorts.into_iter().map_into().collect(),
113        } // We want the search to fail if sorting isn't supported.
114        .critical();
115
116        // Adding the control to the search.
117        stream_ldap
118            .controls
119            .get_or_insert_default()
120            .push(new_control.into());
121
122        // Continue the chain.
123        stream.start(base, scope, filter, attrs).await
124    }
125
126    async fn next(
127        &mut self,
128        stream: &mut SearchStream<'a, S, A>,
129    ) -> ldap3::result::Result<Option<ResultEntry>> {
130        match stream.next().await? {
131            Some(result_entry) => {
132                // It's a little unclear to me whether I should be looking at this res in `stream`
133                // or the result_entry directly? Are the controls just the same?
134                let sss_control = stream.res.as_ref().and_then(
135                    |LdapResult {
136                         ctrls: controls, ..
137                     }| get_response_control(controls.as_slice()),
138                );
139
140                match sss_control {
141                    Some(ServerSideSortResponse {
142                        sort_result: SortResult::Success,
143                        ..
144                    }) => {
145                        // All good, passing on the result.
146                        Ok(Some(result_entry))
147                    }
148                    Some(ServerSideSortResponse { sort_result, .. }) => {
149                        panic!(
150                            "Server side sort result was {sort_result:?}. This should never be the case in this branch as the control was set to critical and so should have caused an error earlier."
151                        )
152                    }
153                    None => {
154                        debug!("No server side sort response control.");
155                        Ok(Some(result_entry))
156                    }
157                }
158            }
159            // I suppose we could check for the control here too, but my understanding is that it's only
160            // used when there are actually results.
161            None => Ok(None),
162        }
163    }
164
165    async fn finish(&mut self, stream: &mut SearchStream<'a, S, A>) -> LdapResult {
166        // Just logging here
167
168        let result = stream.finish().await;
169
170        let sss_control = get_response_control(result.ctrls.as_slice());
171
172        match sss_control {
173            None => debug!("No Server Side Sort control in the final result"),
174            Some(control) => debug!("The final Server Side Sort control: {control:?}"),
175        };
176
177        result
178    }
179}
180
181// Get and parse the SSS response control if there is one.
182//
183// My understanding from RFC 2981 section 2 is that whenever there is at least one search result,
184// there should also be the SSS response control.
185fn get_response_control(controls: &[Control]) -> Option<ServerSideSortResponse> {
186    controls
187        .iter()
188        // Control type isn't parsed since this control is implemented outside ldap3
189        // so we're just working with the raw values.
190        .map(|Control(_, raw)| raw)
191        .find(|raw| raw.ctype == SERVER_SIDE_SORT_RESPONSE_OID)
192        .map(RawControl::parse)
193}