tent_thrift/server/
multiplexed.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use log::debug;
19
20use std::collections::HashMap;
21use std::convert::Into;
22use std::fmt;
23use std::fmt::{Debug, Formatter};
24use std::sync::{Arc, Mutex};
25
26use crate::protocol::{TInputProtocol, TMessageIdentifier, TOutputProtocol, TStoredInputProtocol};
27
28use super::{handle_process_result, TProcessor};
29
30const MISSING_SEPARATOR_AND_NO_DEFAULT: &str =
31    "missing service separator and no default processor set";
32type ThreadSafeProcessor = Box<dyn TProcessor + Send + Sync>;
33
34/// A `TProcessor` that can demux service calls to multiple underlying
35/// Thrift services.
36///
37/// Users register service-specific `TProcessor` instances with a
38/// `TMultiplexedProcessor`, and then register that processor with a server
39/// implementation. Following that, all incoming service calls are automatically
40/// routed to the service-specific `TProcessor`.
41///
42/// A `TMultiplexedProcessor` can only handle messages sent by a
43/// `TMultiplexedOutputProtocol`.
44#[derive(Default)]
45pub struct TMultiplexedProcessor {
46    stored: Mutex<StoredProcessors>,
47}
48
49#[derive(Default)]
50struct StoredProcessors {
51    processors: HashMap<String, Arc<ThreadSafeProcessor>>,
52    default_processor: Option<Arc<ThreadSafeProcessor>>,
53}
54
55impl TMultiplexedProcessor {
56    /// Create a new `TMultiplexedProcessor` with no registered service-specific
57    /// processors.
58    pub fn new() -> TMultiplexedProcessor {
59        TMultiplexedProcessor {
60            stored: Mutex::new(StoredProcessors {
61                processors: HashMap::new(),
62                default_processor: None,
63            }),
64        }
65    }
66
67    /// Register a service-specific `processor` for the service named
68    /// `service_name`. This implementation is also backwards-compatible with
69    /// non-multiplexed clients. Set `as_default` to `true` to allow
70    /// non-namespaced requests to be dispatched to a default processor.
71    ///
72    /// Returns success if a new entry was inserted. Returns an error if:
73    /// * A processor exists for `service_name`
74    /// * You attempt to register a processor as default, and an existing default exists
75    #[allow(clippy::map_entry)]
76    pub fn register<S: Into<String>>(
77        &mut self,
78        service_name: S,
79        processor: Box<dyn TProcessor + Send + Sync>,
80        as_default: bool,
81    ) -> crate::Result<()> {
82        let mut stored = self.stored.lock().unwrap();
83
84        let name = service_name.into();
85        if !stored.processors.contains_key(&name) {
86            let processor = Arc::new(processor);
87
88            if as_default {
89                if stored.default_processor.is_none() {
90                    stored.processors.insert(name, processor.clone());
91                    stored.default_processor = Some(processor.clone());
92                    Ok(())
93                } else {
94                    Err("cannot reset default processor".into())
95                }
96            } else {
97                stored.processors.insert(name, processor);
98                Ok(())
99            }
100        } else {
101            Err(format!("cannot overwrite existing processor for service {}", name).into())
102        }
103    }
104
105    fn process_message(
106        &self,
107        msg_ident: &TMessageIdentifier,
108        i_prot: &mut dyn TInputProtocol,
109        o_prot: &mut dyn TOutputProtocol,
110    ) -> crate::Result<()> {
111        let (svc_name, svc_call) = split_ident_name(&msg_ident.name);
112        debug!("routing svc_name {:?} svc_call {}", &svc_name, &svc_call);
113
114        let processor: Option<Arc<ThreadSafeProcessor>> = {
115            let stored = self.stored.lock().unwrap();
116            if let Some(name) = svc_name {
117                stored.processors.get(name).cloned()
118            } else {
119                stored.default_processor.clone()
120            }
121        };
122
123        match processor {
124            Some(arc) => {
125                let new_msg_ident = TMessageIdentifier::new(
126                    svc_call,
127                    msg_ident.message_type,
128                    msg_ident.sequence_number,
129                );
130                let mut proxy_i_prot = TStoredInputProtocol::new(i_prot, new_msg_ident);
131                (*arc).process(&mut proxy_i_prot, o_prot)
132            }
133            None => Err(missing_processor_message(svc_name).into()),
134        }
135    }
136}
137
138impl TProcessor for TMultiplexedProcessor {
139    fn process(
140        &self,
141        i_prot: &mut dyn TInputProtocol,
142        o_prot: &mut dyn TOutputProtocol,
143    ) -> crate::Result<()> {
144        let msg_ident = i_prot.read_message_begin()?;
145
146        debug!("process incoming msg id:{:?}", &msg_ident);
147        let res = self.process_message(&msg_ident, i_prot, o_prot);
148
149        handle_process_result(&msg_ident, res, o_prot)
150    }
151}
152
153impl Debug for TMultiplexedProcessor {
154    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
155        let stored = self.stored.lock().unwrap();
156        write!(
157            f,
158            "TMultiplexedProcess {{ registered_count: {:?} default: {:?} }}",
159            stored.processors.keys().len(),
160            stored.default_processor.is_some()
161        )
162    }
163}
164
165fn split_ident_name(ident_name: &str) -> (Option<&str>, &str) {
166    ident_name
167        .find(':')
168        .map(|pos| {
169            let (svc_name, svc_call) = ident_name.split_at(pos);
170            let (_, svc_call) = svc_call.split_at(1); // remove colon from service call name
171            (Some(svc_name), svc_call)
172        })
173        .unwrap_or((None, ident_name))
174}
175
176fn missing_processor_message(svc_name: Option<&str>) -> String {
177    match svc_name {
178        Some(name) => format!("no processor found for service {}", name),
179        None => MISSING_SEPARATOR_AND_NO_DEFAULT.to_owned(),
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use std::convert::Into;
186    use std::sync::atomic::{AtomicBool, Ordering};
187    use std::sync::Arc;
188
189    use crate::protocol::{
190        TBinaryInputProtocol, TBinaryOutputProtocol, TMessageIdentifier, TMessageType,
191    };
192    use crate::transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf};
193    use crate::{ApplicationError, ApplicationErrorKind};
194
195    use super::*;
196
197    #[test]
198    fn should_split_name_into_proper_separator_and_service_call() {
199        let ident_name = "foo:bar_call";
200        let (serv, call) = split_ident_name(&ident_name);
201        assert_eq!(serv, Some("foo"));
202        assert_eq!(call, "bar_call");
203    }
204
205    #[test]
206    fn should_return_full_ident_if_no_separator_exists() {
207        let ident_name = "bar_call";
208        let (serv, call) = split_ident_name(&ident_name);
209        assert_eq!(serv, None);
210        assert_eq!(call, "bar_call");
211    }
212
213    #[test]
214    fn should_write_error_if_no_separator_found_and_no_default_processor_exists() {
215        let (mut i, mut o) = build_objects();
216
217        let sent_ident = TMessageIdentifier::new("foo", TMessageType::Call, 10);
218        o.write_message_begin(&sent_ident).unwrap();
219        o.flush().unwrap();
220        o.transport.copy_write_buffer_to_read_buffer();
221        o.transport.empty_write_buffer();
222
223        let p = TMultiplexedProcessor::new();
224        p.process(&mut i, &mut o).unwrap(); // at this point an error should be written out
225
226        i.transport.set_readable_bytes(&o.transport.write_bytes());
227        let rcvd_ident = i.read_message_begin().unwrap();
228        let expected_ident = TMessageIdentifier::new("foo", TMessageType::Exception, 10);
229        assert_eq!(rcvd_ident, expected_ident);
230        let rcvd_err = crate::Error::read_application_error_from_in_protocol(&mut i).unwrap();
231        let expected_err = ApplicationError::new(
232            ApplicationErrorKind::Unknown,
233            MISSING_SEPARATOR_AND_NO_DEFAULT,
234        );
235        assert_eq!(rcvd_err, expected_err);
236    }
237
238    #[test]
239    fn should_write_error_if_separator_exists_and_no_processor_found() {
240        let (mut i, mut o) = build_objects();
241
242        let sent_ident = TMessageIdentifier::new("missing:call", TMessageType::Call, 10);
243        o.write_message_begin(&sent_ident).unwrap();
244        o.flush().unwrap();
245        o.transport.copy_write_buffer_to_read_buffer();
246        o.transport.empty_write_buffer();
247
248        let p = TMultiplexedProcessor::new();
249        p.process(&mut i, &mut o).unwrap(); // at this point an error should be written out
250
251        i.transport.set_readable_bytes(&o.transport.write_bytes());
252        let rcvd_ident = i.read_message_begin().unwrap();
253        let expected_ident = TMessageIdentifier::new("missing:call", TMessageType::Exception, 10);
254        assert_eq!(rcvd_ident, expected_ident);
255        let rcvd_err = crate::Error::read_application_error_from_in_protocol(&mut i).unwrap();
256        let expected_err = ApplicationError::new(
257            ApplicationErrorKind::Unknown,
258            missing_processor_message(Some("missing")),
259        );
260        assert_eq!(rcvd_err, expected_err);
261    }
262
263    #[derive(Default)]
264    struct Service {
265        pub invoked: Arc<AtomicBool>,
266    }
267
268    impl TProcessor for Service {
269        fn process(
270            &self,
271            _: &mut dyn TInputProtocol,
272            _: &mut dyn TOutputProtocol,
273        ) -> crate::Result<()> {
274            let res =
275                self.invoked
276                    .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed);
277            if res.is_ok() {
278                Ok(())
279            } else {
280                Err("failed swap".into())
281            }
282        }
283    }
284
285    #[test]
286    fn should_route_call_to_correct_processor() {
287        let (mut i, mut o) = build_objects();
288
289        // build the services
290        let svc_1 = Service {
291            invoked: Arc::new(AtomicBool::new(false)),
292        };
293        let atm_1 = svc_1.invoked.clone();
294        let svc_2 = Service {
295            invoked: Arc::new(AtomicBool::new(false)),
296        };
297        let atm_2 = svc_2.invoked.clone();
298
299        // register them
300        let mut p = TMultiplexedProcessor::new();
301        p.register("service_1", Box::new(svc_1), false).unwrap();
302        p.register("service_2", Box::new(svc_2), false).unwrap();
303
304        // make the service call
305        let sent_ident = TMessageIdentifier::new("service_1:call", TMessageType::Call, 10);
306        o.write_message_begin(&sent_ident).unwrap();
307        o.flush().unwrap();
308        o.transport.copy_write_buffer_to_read_buffer();
309        o.transport.empty_write_buffer();
310
311        p.process(&mut i, &mut o).unwrap();
312
313        // service 1 should have been invoked, not service 2
314        assert_eq!(atm_1.load(Ordering::Relaxed), true);
315        assert_eq!(atm_2.load(Ordering::Relaxed), false);
316    }
317
318    #[test]
319    fn should_route_call_to_correct_processor_if_no_separator_exists_and_default_processor_set() {
320        let (mut i, mut o) = build_objects();
321
322        // build the services
323        let svc_1 = Service {
324            invoked: Arc::new(AtomicBool::new(false)),
325        };
326        let atm_1 = svc_1.invoked.clone();
327        let svc_2 = Service {
328            invoked: Arc::new(AtomicBool::new(false)),
329        };
330        let atm_2 = svc_2.invoked.clone();
331
332        // register them
333        let mut p = TMultiplexedProcessor::new();
334        p.register("service_1", Box::new(svc_1), false).unwrap();
335        p.register("service_2", Box::new(svc_2), true).unwrap(); // second processor is default
336
337        // make the service call (it's an old client, so we have to be backwards compatible)
338        let sent_ident = TMessageIdentifier::new("old_call", TMessageType::Call, 10);
339        o.write_message_begin(&sent_ident).unwrap();
340        o.flush().unwrap();
341        o.transport.copy_write_buffer_to_read_buffer();
342        o.transport.empty_write_buffer();
343
344        p.process(&mut i, &mut o).unwrap();
345
346        // service 2 should have been invoked, not service 1
347        assert_eq!(atm_1.load(Ordering::Relaxed), false);
348        assert_eq!(atm_2.load(Ordering::Relaxed), true);
349    }
350
351    fn build_objects() -> (
352        TBinaryInputProtocol<ReadHalf<TBufferChannel>>,
353        TBinaryOutputProtocol<WriteHalf<TBufferChannel>>,
354    ) {
355        let c = TBufferChannel::with_capacity(128, 128);
356        let (r_c, w_c) = c.split().unwrap();
357        (
358            TBinaryInputProtocol::new(r_c, true),
359            TBinaryOutputProtocol::new(w_c, true),
360        )
361    }
362}