Skip to main content

turn_server/
api.rs

1use std::time::{Duration, Instant};
2
3use crate::{
4    Service,
5    codec::{crypto::Password, message::attributes::PasswordAlgorithm},
6    config::Config,
7    service::session::{Identifier, Session},
8    statistics::Statistics,
9};
10
11use anyhow::{Result, anyhow};
12use tokio::sync::{
13    Mutex,
14    mpsc::{Sender, channel},
15};
16
17use tonic::{
18    Request, Response, Status,
19    transport::{Channel, Server},
20};
21
22#[cfg(feature = "ssl")]
23use tonic::transport::{Certificate, ClientTlsConfig, Identity, ServerTlsConfig};
24
25use sdk::protos::{
26    BindAddress, GetTurnPasswordRequest, TurnAllocatedEvent, TurnChannelBindEvent,
27    TurnCreatePermissionEvent, TurnDestroyEvent, TurnRefreshEvent, TurnServerInfo, TurnSession,
28    TurnSessionStatistics,
29    turn_hooks_service_client::TurnHooksServiceClient,
30    turn_service_server::{TurnService, TurnServiceServer},
31};
32
33impl Into<sdk::protos::Transport> for crate::service::Transport {
34    fn into(self) -> sdk::protos::Transport {
35        use sdk::protos::Transport;
36
37        match self {
38            Self::Udp => Transport::Udp,
39            Self::Tcp => Transport::Tcp,
40        }
41    }
42}
43
44impl TryFrom<sdk::protos::Transport> for crate::service::Transport {
45    type Error = anyhow::Error;
46
47    fn try_from(value: sdk::protos::Transport) -> Result<Self, Self::Error> {
48        use sdk::protos::Transport;
49
50        match value {
51            Transport::Udp => Ok(Self::Udp),
52            Transport::Tcp => Ok(Self::Tcp),
53            Transport::Unspecified => Err(anyhow!("transport is unspecified")),
54        }
55    }
56}
57
58impl Into<sdk::protos::PasswordAlgorithm> for crate::codec::message::attributes::PasswordAlgorithm {
59    fn into(self) -> sdk::protos::PasswordAlgorithm {
60        use sdk::protos::PasswordAlgorithm;
61
62        match self {
63            Self::Md5 => PasswordAlgorithm::Md5,
64            Self::Sha256 => PasswordAlgorithm::Sha256,
65        }
66    }
67}
68
69impl Into<sdk::protos::Identifier> for Identifier {
70    fn into(self) -> sdk::protos::Identifier {
71        sdk::protos::Identifier {
72            source: self.source.to_string(),
73            external: self.external.to_string(),
74            interface: self.interface.to_string(),
75            transport: Into::<sdk::protos::Transport>::into(self.transport) as i32,
76        }
77    }
78}
79
80impl TryFrom<sdk::protos::Identifier> for crate::service::session::Identifier {
81    type Error = anyhow::Error;
82
83    fn try_from(value: sdk::protos::Identifier) -> Result<Self, Self::Error> {
84        use crate::service::{Transport, session::Identifier};
85
86        Ok(Identifier {
87            source: value.source.parse()?,
88            external: value.external.parse()?,
89            interface: value.interface.parse()?,
90            transport: Transport::try_from(sdk::protos::Transport::try_from(value.transport)?)?,
91        })
92    }
93}
94
95impl Into<sdk::protos::Interface> for &crate::service::InterfaceAddr {
96    fn into(self) -> sdk::protos::Interface {
97        sdk::protos::Interface {
98            address: self.addr.to_string(),
99            external: self.external.to_string(),
100            transport: Into::<sdk::protos::Transport>::into(self.transport) as i32,
101        }
102    }
103}
104
105struct RpcService {
106    config: Config,
107    service: Service,
108    statistics: Statistics,
109    uptime: Instant,
110}
111
112#[tonic::async_trait]
113impl TurnService for RpcService {
114    async fn get_info(&self, _: Request<()>) -> Result<Response<TurnServerInfo>, Status> {
115        Ok(Response::new(TurnServerInfo {
116            software: crate::SOFTWARE.to_string(),
117            uptime: self.uptime.elapsed().as_secs(),
118            interfaces: self
119                .config
120                .server
121                .get_interface_addrs()
122                .iter()
123                .map(|it| it.into())
124                .collect(),
125            port_capacity: self.config.server.port_range.size() as u32,
126            port_allocated: self.service.get_session_manager().allocated() as u32,
127        }))
128    }
129
130    async fn get_session(
131        &self,
132        request: Request<sdk::protos::Identifier>,
133    ) -> Result<Response<TurnSession>, Status> {
134        if let Some(Session::Authenticated {
135            username,
136            allocated_port,
137            channel_relay_table,
138            port_relay_table,
139            expires,
140            ..
141        }) = self
142            .service
143            .get_session_manager()
144            .get_session(
145                &Identifier::try_from(request.into_inner())
146                    .map_err(|e| Status::internal(e.to_string()))?,
147            )
148            .get_ref()
149        {
150            Ok(Response::new(TurnSession {
151                username: username.to_string(),
152                allocated_port: allocated_port.map(|p| p as i32),
153                expires: *expires as i64,
154                permissions: port_relay_table
155                    .iter()
156                    .map(|(k, v)| BindAddress {
157                        key: *k as i32,
158                        value: Some(v.clone().into()),
159                    })
160                    .collect(),
161                channels: channel_relay_table
162                    .iter()
163                    .map(|(k, v)| BindAddress {
164                        key: *k as i32,
165                        value: Some(v.clone().into()),
166                    })
167                    .collect(),
168            }))
169        } else {
170            Err(Status::not_found("Session not found"))
171        }
172    }
173
174    async fn get_session_statistics(
175        &self,
176        request: Request<sdk::protos::Identifier>,
177    ) -> Result<Response<TurnSessionStatistics>, Status> {
178        if let Some(counts) = self.statistics.get(
179            &Identifier::try_from(request.into_inner())
180                .map_err(|e| Status::internal(e.to_string()))?,
181        ) {
182            Ok(Response::new(TurnSessionStatistics {
183                received_bytes: counts.received_bytes as u64,
184                send_bytes: counts.send_bytes as u64,
185                received_pkts: counts.received_pkts as u64,
186                send_pkts: counts.send_pkts as u64,
187                error_pkts: counts.error_pkts as u64,
188            }))
189        } else {
190            Err(Status::not_found("Session not found"))
191        }
192    }
193
194    async fn destroy_session(
195        &self,
196        request: Request<sdk::protos::Identifier>,
197    ) -> Result<Response<()>, Status> {
198        if self.service.get_session_manager().refresh(
199            &Identifier::try_from(request.into_inner())
200                .map_err(|e| Status::internal(e.to_string()))?,
201            0,
202        ) {
203            Ok(Response::new(()))
204        } else {
205            Err(Status::failed_precondition("Session not found"))
206        }
207    }
208}
209
210pub enum HooksEvent {
211    Allocated(TurnAllocatedEvent),
212    ChannelBind(TurnChannelBindEvent),
213    CreatePermission(TurnCreatePermissionEvent),
214    Refresh(TurnRefreshEvent),
215    Destroy(TurnDestroyEvent),
216}
217
218struct RpcHooksServiceInner {
219    event_channel: Sender<HooksEvent>,
220    client: Mutex<TurnHooksServiceClient<Channel>>,
221}
222
223pub struct RpcHooksService(Option<RpcHooksServiceInner>);
224
225impl RpcHooksService {
226    pub async fn new(config: &Config) -> Result<Self> {
227        if let Some(hooks) = &config.hooks {
228            let (event_channel, mut rx) = channel(hooks.max_channel_size);
229            let client = {
230                let mut builder = Channel::builder(hooks.endpoint.as_str().try_into()?);
231
232                builder = builder.timeout(Duration::from_secs(hooks.timeout as u64));
233
234                #[cfg(feature = "ssl")]
235                if let Some(ssl) = &hooks.ssl {
236                    builder = builder.tls_config(
237                        ClientTlsConfig::new()
238                            .ca_certificate(Certificate::from_pem(ssl.certificate_chain.clone()))
239                            .domain_name(
240                                url::Url::parse(&hooks.endpoint)?.domain().ok_or_else(|| {
241                                    anyhow::anyhow!("Invalid hooks server domain")
242                                })?,
243                            ),
244                    )?;
245                }
246
247                TurnHooksServiceClient::new(
248                    builder
249                        .connect_timeout(Duration::from_secs(5))
250                        .timeout(Duration::from_secs(1))
251                        .connect_lazy(),
252                )
253            };
254
255            {
256                let mut client = client.clone();
257
258                tokio::spawn(async move {
259                    while let Some(event) = rx.recv().await {
260                        if match event {
261                            HooksEvent::Allocated(event) => {
262                                client.on_allocated_event(Request::new(event)).await
263                            }
264                            HooksEvent::ChannelBind(event) => {
265                                client.on_channel_bind_event(Request::new(event)).await
266                            }
267                            HooksEvent::CreatePermission(event) => {
268                                client.on_create_permission_event(Request::new(event)).await
269                            }
270                            HooksEvent::Refresh(event) => {
271                                client.on_refresh_event(Request::new(event)).await
272                            }
273                            HooksEvent::Destroy(event) => {
274                                client.on_destroy_event(Request::new(event)).await
275                            }
276                        }
277                        .is_err()
278                        {
279                            break;
280                        }
281                    }
282                });
283            }
284
285            log::info!("create hooks client, endpoint={}", hooks.endpoint);
286
287            Ok(Self(Some(RpcHooksServiceInner {
288                client: Mutex::new(client),
289                event_channel,
290            })))
291        } else {
292            Ok(Self(None))
293        }
294    }
295
296    pub fn send_event(&self, event: HooksEvent) {
297        if let Some(inner) = &self.0
298            && !inner.event_channel.is_closed()
299            && let Err(e) = inner.event_channel.try_send(event)
300        {
301            log::error!("Failed to send event to hooks server: {}", e);
302        }
303    }
304
305    pub async fn get_password(
306        &self,
307        id: &Identifier,
308        realm: &str,
309        username: &str,
310        algorithm: PasswordAlgorithm,
311    ) -> Option<Password> {
312        if let Some(inner) = &self.0 {
313            use sdk::protos::PasswordAlgorithm;
314
315            let algorithm: PasswordAlgorithm = algorithm.into();
316
317            let password = inner
318                .client
319                .lock()
320                .await
321                .get_password(Request::new(GetTurnPasswordRequest {
322                    id: Some(id.into()),
323                    realm: realm.to_string(),
324                    username: username.to_string(),
325                    algorithm: algorithm as i32,
326                }))
327                .await
328                .ok()?
329                .into_inner()
330                .password;
331
332            return Some(match algorithm {
333                PasswordAlgorithm::Md5 => Password::Md5(password.try_into().ok()?),
334                PasswordAlgorithm::Sha256 => Password::Sha256(password.try_into().ok()?),
335                PasswordAlgorithm::Unspecified => unreachable!(),
336            });
337        }
338
339        None
340    }
341}
342
343pub async fn start_server(config: Config, service: Service, statistics: Statistics) -> Result<()> {
344    if let Some(api) = &config.api {
345        let mut builder = Server::builder();
346
347        builder = builder
348            .timeout(Duration::from_secs(api.timeout as u64))
349            .accept_http1(false);
350
351        #[cfg(feature = "ssl")]
352        if let Some(ssl) = &api.ssl {
353            builder = builder.tls_config(ServerTlsConfig::new().identity(Identity::from_pem(
354                ssl.certificate_chain.clone(),
355                ssl.private_key.clone(),
356            )))?;
357        }
358
359        log::info!("api server listening: listen={}", api.listen);
360
361        builder
362            .add_service(TurnServiceServer::new(RpcService {
363                config: config.clone(),
364                uptime: Instant::now(),
365                statistics,
366                service,
367            }))
368            .serve(api.listen)
369            .await?;
370    } else {
371        std::future::pending().await
372    }
373
374    Ok(())
375}