Skip to main content

turn_server/
api.rs

1use std::time::{Duration, Instant};
2
3use anyhow::Result;
4use tokio::sync::{
5    Mutex,
6    mpsc::{Sender, channel},
7};
8
9use tonic::{
10    Request, Response, Status,
11    transport::{Channel, Server},
12};
13
14#[cfg(feature = "ssl")]
15use tonic::transport::{Certificate, ClientTlsConfig, Identity, ServerTlsConfig};
16
17use protos::{
18    GetTurnPasswordRequest, PasswordAlgorithm as ProtoPasswordAlgorithm, SessionQueryParams,
19    TurnAllocatedEvent, TurnChannelBindEvent, TurnCreatePermissionEvent, TurnDestroyEvent,
20    TurnRefreshEvent, TurnServerInfo, TurnSession, TurnSessionStatistics,
21    turn_hooks_service_client::TurnHooksServiceClient,
22    turn_service_server::{TurnService, TurnServiceServer},
23};
24
25use crate::{
26    Service,
27    codec::{crypto::Password, message::attributes::PasswordAlgorithm},
28    config::Config,
29    service::session::{Identifier, Session},
30    statistics::Statistics,
31};
32
33impl From<PasswordAlgorithm> for ProtoPasswordAlgorithm {
34    fn from(val: PasswordAlgorithm) -> Self {
35        match val {
36            PasswordAlgorithm::Md5 => Self::Md5,
37            PasswordAlgorithm::Sha256 => Self::Sha256,
38        }
39    }
40}
41
42pub trait IdString {
43    type Error;
44
45    fn to_string(&self) -> String;
46    fn from_string(s: String) -> Result<Self, Self::Error>
47    where
48        Self: Sized;
49}
50
51impl IdString for Identifier {
52    type Error = Status;
53
54    fn to_string(&self) -> String {
55        format!("{}/{}", self.source(), self.interface())
56    }
57
58    fn from_string(s: String) -> Result<Self, Self::Error> {
59        let (source, interface) = s
60            .split_once('/')
61            .ok_or(Status::invalid_argument("Invalid identifier"))?;
62
63        Ok(Self::new(
64            source
65                .parse()
66                .map_err(|_| Status::invalid_argument("Invalid source address"))?,
67            interface
68                .parse()
69                .map_err(|_| Status::invalid_argument("Invalid interface address"))?,
70        ))
71    }
72}
73
74struct RpcService {
75    config: Config,
76    service: Service,
77    statistics: Statistics,
78    uptime: Instant,
79}
80
81#[tonic::async_trait]
82impl TurnService for RpcService {
83    async fn get_info(&self, _: Request<()>) -> Result<Response<TurnServerInfo>, Status> {
84        Ok(Response::new(TurnServerInfo {
85            software: crate::SOFTWARE.to_string(),
86            uptime: self.uptime.elapsed().as_secs(),
87            interfaces: self
88                .config
89                .server
90                .get_external_addresses()
91                .iter()
92                .map(|addr| addr.to_string())
93                .collect(),
94            port_capacity: self.config.server.port_range.size() as u32,
95            port_allocated: self.service.get_session_manager().allocated() as u32,
96        }))
97    }
98
99    async fn get_session(
100        &self,
101        request: Request<SessionQueryParams>,
102    ) -> Result<Response<TurnSession>, Status> {
103        if let Some(Session::Authenticated {
104            username,
105            allocate_port,
106            allocate_channels,
107            permissions,
108            expires,
109            ..
110        }) = self
111            .service
112            .get_session_manager()
113            .get_session(&Identifier::from_string(request.into_inner().id)?)
114            .get_ref()
115        {
116            Ok(Response::new(TurnSession {
117                username: username.to_string(),
118                permissions: permissions.iter().map(|p| *p as i32).collect(),
119                channels: allocate_channels.iter().map(|p| *p as i32).collect(),
120                port: allocate_port.map(|p| p as i32),
121                expires: *expires as i64,
122            }))
123        } else {
124            Err(Status::not_found("Session not found"))
125        }
126    }
127
128    async fn get_session_statistics(
129        &self,
130        request: Request<SessionQueryParams>,
131    ) -> Result<Response<TurnSessionStatistics>, Status> {
132        if let Some(counts) = self
133            .statistics
134            .get(&Identifier::from_string(request.into_inner().id)?)
135        {
136            Ok(Response::new(TurnSessionStatistics {
137                received_bytes: counts.received_bytes as u64,
138                send_bytes: counts.send_bytes as u64,
139                received_pkts: counts.received_pkts as u64,
140                send_pkts: counts.send_pkts as u64,
141                error_pkts: counts.error_pkts as u64,
142            }))
143        } else {
144            Err(Status::not_found("Session not found"))
145        }
146    }
147
148    async fn destroy_session(
149        &self,
150        request: Request<SessionQueryParams>,
151    ) -> Result<Response<()>, Status> {
152        if self
153            .service
154            .get_session_manager()
155            .refresh(&Identifier::from_string(request.into_inner().id)?, 0)
156        {
157            Ok(Response::new(()))
158        } else {
159            Err(Status::failed_precondition("Session not found"))
160        }
161    }
162}
163
164pub enum HooksEvent {
165    Allocated(TurnAllocatedEvent),
166    ChannelBind(TurnChannelBindEvent),
167    CreatePermission(TurnCreatePermissionEvent),
168    Refresh(TurnRefreshEvent),
169    Destroy(TurnDestroyEvent),
170}
171
172struct RpcHooksServiceInner {
173    event_channel: Sender<HooksEvent>,
174    client: Mutex<TurnHooksServiceClient<Channel>>,
175}
176
177pub struct RpcHooksService(Option<RpcHooksServiceInner>);
178
179impl RpcHooksService {
180    pub async fn new(config: &Config) -> Result<Self> {
181        if let Some(hooks) = &config.hooks {
182            let (event_channel, mut rx) = channel(hooks.max_channel_size);
183            let client = {
184                let mut builder = Channel::builder(hooks.endpoint.as_str().try_into()?);
185
186                builder = builder.timeout(Duration::from_secs(hooks.timeout as u64));
187
188                #[cfg(feature = "ssl")]
189                if let Some(ssl) = &hooks.ssl {
190                    builder = builder.tls_config(
191                        ClientTlsConfig::new()
192                            .ca_certificate(Certificate::from_pem(ssl.certificate_chain.clone()))
193                            .domain_name(
194                                url::Url::parse(&hooks.endpoint)?.domain().ok_or_else(|| {
195                                    anyhow::anyhow!("Invalid hooks server domain")
196                                })?,
197                            ),
198                    )?;
199                }
200
201                TurnHooksServiceClient::new(
202                    builder
203                        .connect_timeout(Duration::from_secs(5))
204                        .timeout(Duration::from_secs(1))
205                        .connect_lazy(),
206                )
207            };
208
209            {
210                let mut client = client.clone();
211
212                tokio::spawn(async move {
213                    while let Some(event) = rx.recv().await {
214                        if match event {
215                            HooksEvent::Allocated(event) => {
216                                client.on_allocated_event(Request::new(event)).await
217                            }
218                            HooksEvent::ChannelBind(event) => {
219                                client.on_channel_bind_event(Request::new(event)).await
220                            }
221                            HooksEvent::CreatePermission(event) => {
222                                client.on_create_permission_event(Request::new(event)).await
223                            }
224                            HooksEvent::Refresh(event) => {
225                                client.on_refresh_event(Request::new(event)).await
226                            }
227                            HooksEvent::Destroy(event) => {
228                                client.on_destroy_event(Request::new(event)).await
229                            }
230                        }
231                        .is_err()
232                        {
233                            break;
234                        }
235                    }
236                });
237            }
238
239            log::info!("create hooks client, endpoint={}", hooks.endpoint);
240
241            Ok(Self(Some(RpcHooksServiceInner {
242                client: Mutex::new(client),
243                event_channel,
244            })))
245        } else {
246            Ok(Self(None))
247        }
248    }
249
250    pub fn send_event(&self, event: HooksEvent) {
251        if let Some(inner) = &self.0
252            && !inner.event_channel.is_closed()
253            && let Err(e) = inner.event_channel.try_send(event)
254        {
255            log::error!("Failed to send event to hooks server: {}", e);
256        }
257    }
258
259    pub async fn get_password(
260        &self,
261        realm: &str,
262        username: &str,
263        algorithm: PasswordAlgorithm,
264    ) -> Option<Password> {
265        if let Some(inner) = &self.0 {
266            let algorithm: ProtoPasswordAlgorithm = algorithm.into();
267            let password = inner
268                .client
269                .lock()
270                .await
271                .get_password(Request::new(GetTurnPasswordRequest {
272                    realm: realm.to_string(),
273                    username: username.to_string(),
274                    algorithm: algorithm as i32,
275                }))
276                .await
277                .ok()?
278                .into_inner()
279                .password;
280
281            return Some(match algorithm {
282                ProtoPasswordAlgorithm::Md5 => Password::Md5(password.try_into().ok()?),
283                ProtoPasswordAlgorithm::Sha256 => Password::Sha256(password.try_into().ok()?),
284                ProtoPasswordAlgorithm::Unspecified => unreachable!(),
285            });
286        }
287
288        None
289    }
290}
291
292pub async fn start_server(config: Config, service: Service, statistics: Statistics) -> Result<()> {
293    if let Some(api) = &config.api {
294        let mut builder = Server::builder();
295
296        builder = builder
297            .timeout(Duration::from_secs(api.timeout as u64))
298            .accept_http1(false);
299
300        #[cfg(feature = "ssl")]
301        if let Some(ssl) = &api.ssl {
302            builder = builder.tls_config(ServerTlsConfig::new().identity(Identity::from_pem(
303                ssl.certificate_chain.clone(),
304                ssl.private_key.clone(),
305            )))?;
306        }
307
308        log::info!("api server listening: listen={}", api.listen);
309
310        builder
311            .add_service(TurnServiceServer::new(RpcService {
312                config: config.clone(),
313                uptime: Instant::now(),
314                statistics,
315                service,
316            }))
317            .serve(api.listen)
318            .await?;
319    } else {
320        std::future::pending().await
321    }
322
323    Ok(())
324}