rust_ads_client/
client.rs

1use crate::reader::run_reader_thread;
2use crate::request_factory::{self, *};
3use ads_proto::error::AdsError;
4use ads_proto::proto::ads_state::AdsState;
5use ads_proto::proto::ads_transition_mode::AdsTransMode;
6use ads_proto::proto::ams_address::{AmsAddress, AmsNetId};
7use ads_proto::proto::ams_header::{AmsHeader, AmsTcpHeader};
8use ads_proto::proto::proto_traits::*;
9use ads_proto::proto::request::{
10    ReadDeviceInfoRequest, ReadRequest, ReadStateRequest, Request, WriteRequest,
11};
12use ads_proto::proto::response::Response;
13use ads_proto::proto::response::*;
14use ads_proto::proto::state_flags::StateFlags;
15use ads_proto::proto::sumup::sumup_request::{SumupReadRequest, SumupWriteRequest};
16use ads_proto::proto::sumup::sumup_response::{SumupReadResponse, SumupWriteResponse};
17use anyhow::Error;
18use anyhow::{anyhow, Result};
19use byteorder::{LittleEndian, ReadBytesExt};
20use std::collections::HashMap;
21use std::io::Write;
22use std::net::{Ipv4Addr, Shutdown, SocketAddr, TcpStream};
23use std::str::FromStr;
24use std::sync::mpsc::{channel, Receiver, Sender};
25use std::time::Duration;
26
27/// UDP ADS-Protocol port discovery
28pub const ADS_UDP_SERVER_PORT: u16 = 48899;
29/// TCP ADS-Protocol port not secured
30pub const ADS_TCP_SERVER_PORT: u16 = 48898;
31/// ADS-Protocol port secured
32pub const ADS_SECURE_TCP_SERVER_PORT: u16 = 8016;
33
34pub type ClientResult<T> = Result<T, anyhow::Error>;
35type TxGeneral = Sender<(u32, Sender<ClientResult<Response>>)>;
36type TxNotification = Sender<(u32, Sender<ClientResult<AdsNotificationStream>>)>;
37type TxStreamUpdate = Sender<TcpStream>;
38
39#[derive(Debug)]
40pub struct Client {
41    route: Option<Ipv4Addr>,
42    ams_targed_address: AmsAddress,
43    ams_source_address: AmsAddress,
44    stream: Option<TcpStream>,
45    invoke_id: u32,
46    tx_general: Option<TxGeneral>,
47    tx_notification: Option<TxNotification>,
48    tx_stream_update: Option<TxStreamUpdate>,
49    thread_started: bool,
50    handle_list: HashMap<String, u32>,
51    notification_handle_list: HashMap<String, u32>,
52}
53
54impl Drop for Client {
55    fn drop(&mut self) {
56        if let Some(s) = &self.stream {
57            let _ = s.shutdown(Shutdown::Both);
58        }
59    }
60}
61
62impl Client {
63    /// Setup a new client. This will will not yet connect to the targed.
64    /// Call connect() after creation.
65    pub fn new(ams_targed_address: AmsAddress, route: Option<Ipv4Addr>) -> Self {
66        Client {
67            route,
68            ams_targed_address,
69            ams_source_address: AmsAddress::new(AmsNetId::from([0, 0, 0, 0, 0, 0]), 0),
70            stream: None,
71            invoke_id: 0,
72            tx_general: None,
73            tx_notification: None,
74            tx_stream_update: None,
75            thread_started: false,
76            handle_list: HashMap::new(),
77            notification_handle_list: HashMap::new(),
78        }
79    }
80
81    /// Connect to host and start reader thread.
82    /// Fails if host is not reachable or if the reader thread can't be started.
83    pub fn connect(&mut self) -> ClientResult<ReadStateResponse> {
84        if self.stream.is_none() {
85            self.stream = Some(self.create_stream()?);
86            if self.route.is_none() {
87                self.open_local_port()?;
88            }
89        }
90
91        if let Some(stream) = &self.stream {
92            if self.route.is_some() {
93                self.ams_source_address
94                    .update_from_socket_addr(stream.local_addr()?)?;
95            }
96
97            if !self.thread_started {
98                let (tx, rx) = channel::<(u32, Sender<ClientResult<Response>>)>();
99                let (tx_not, rx_not) =
100                    channel::<(u32, Sender<ClientResult<AdsNotificationStream>>)>();
101                let (tx_tcp, rx_tcp) = channel::<TcpStream>();
102                self.tx_general = Some(tx);
103                self.tx_notification = Some(tx_not);
104                self.tx_stream_update = Some(tx_tcp);
105                self.thread_started = run_reader_thread(stream.try_clone()?, rx, rx_not, rx_tcp)?;
106            } else if let Some(tx) = &self.tx_stream_update {
107                tx.send(stream.try_clone()?)?;
108            }
109            //Check if host is responding
110            self.read_state()
111        } else {
112            Err(anyhow!(AdsError::ErrPortNotConnected))
113        }
114    }
115
116    /// Create the TCP stream
117    fn create_stream(&mut self) -> ClientResult<TcpStream> {
118        let mut route = Ipv4Addr::from_str("127.0.0.1")?;
119        if let Some(r) = self.route {
120            route = r;
121        }
122
123        let stream = TcpStream::connect(SocketAddr::from((route, ADS_TCP_SERVER_PORT)))?;
124        stream.set_nodelay(true)?;
125        stream.set_write_timeout(Some(Duration::from_millis(1000)))?;
126        stream.set_read_timeout(Some(Duration::from_millis(1000)))?;
127        Ok(stream)
128    }
129
130    /// open local port in case of local machine
131    fn open_local_port(&mut self) -> ClientResult<()> {
132        let request_port_msg = [0, 16, 2, 0, 0, 0, 0, 0];
133        let mut buf = [0; 14];
134
135        if let Some(s) = &mut self.stream {
136            s.write_all(&request_port_msg).unwrap();
137            use std::io::Read;
138            s.read_exact(&mut buf)?;
139            let (_, mut buf_split) = buf.split_at(6);
140            let ams_address = AmsAddress::read_from(&mut buf_split);
141            self.ams_source_address = ams_address.unwrap();
142        }
143        Ok(())
144    }
145
146    /// Sends the supplied request
147    /// Blocks until the response has been received or on error occures
148    /// Fails if no tcp stream is available.
149    pub fn request(&mut self, request: Request) -> ClientResult<Response> {
150        let rx = self.request_rx(request)?;
151        let response = rx.recv()?;
152        self.check_tcp_stream(&response);
153        response
154    }
155
156    /// Sends a request and returns imediatly a receiver object to read from (mpsc::Receiver).
157    /// Fails if no tcp stream is available.
158    pub fn request_rx(&mut self, request: Request) -> ClientResult<Receiver<Result<Response>>> {
159        let ams_header = self.new_tcp_ams_request_header(request);
160        let (tx, rx) = channel::<ClientResult<Response>>();
161        self.get_general_tx()?
162            .send((self.invoke_id, tx))
163            .expect("Failed to send request to thread by mpsc channel");
164        let mut buffer = Vec::new();
165
166        ams_header.write_to(&mut buffer)?;
167
168        if let Some(s) = &mut self.stream {
169            s.write_all(&buffer)?;
170            return Ok(rx);
171        }
172        Err(anyhow!(AdsError::AdsErrClientPortNotOpen))
173    }
174
175    /// Read a var value by it's name.
176    /// Returns ReadResponse
177    pub fn read_by_name(&mut self, var_name: &str, len: u32) -> ClientResult<ReadResponse> {
178        let handle = self.get_var_handle(var_name)?;
179        let request = Request::Read(request_factory::get_read_request(handle, len));
180        let response = self.request(request)?;
181        let read_response: ReadResponse = response.try_into()?;
182        Ok(read_response)
183    }
184
185    /// Read a list of var values by name. This will bundle all requested variables into a single request.
186    /// Returns a HashMap<String, ReadResponse>
187    pub fn sumup_read_by_name(
188        &mut self,
189        var_list: &HashMap<String, u32>,
190    ) -> ClientResult<HashMap<String, ReadResponse>> {
191        let mut requests: Vec<ReadRequest> = Vec::new();
192        let mut var_names: Vec<String> = Vec::new();
193        for name in var_list.keys() {
194            var_names.push(name.clone());
195        }
196
197        let handles = self.sumup_get_var_handle(&var_names)?;
198        for (var, length) in var_list {
199            if let Some(h) = handles.get(var) {
200                requests.push(get_read_request(*h, *length));
201            }
202        }
203
204        let mut buf = Vec::new();
205        let sumup_request = SumupReadRequest::new(requests);
206        sumup_request.write_to(&mut buf)?;
207        let request = Request::ReadWrite(get_sumup_read_request(
208            sumup_request.request_count(),
209            sumup_request.expected_response_len(),
210            buf,
211        ));
212        let response = self.request(request)?;
213        let read_write_response: ReadWriteResponse = response.try_into()?;
214        let sumup_read_response =
215            SumupReadResponse::read_from(&mut read_write_response.data.as_slice())?;
216        let mut result: HashMap<String, ReadResponse> = HashMap::new();
217
218        if read_write_response.result == AdsError::ErrNoError {
219            for (n, (name, _)) in var_list.iter().enumerate() {
220                result.insert(name.clone(), sumup_read_response.read_responses[n].clone());
221            }
222        } else {
223            return Err(anyhow![read_write_response.result]);
224        }
225        Ok(result)
226    }
227
228    /// Write by name
229    /// Returns WriteResponse
230    pub fn write_by_name(&mut self, var_name: &str, data: Vec<u8>) -> ClientResult<WriteResponse> {
231        let handle = self.get_var_handle(var_name)?;
232        let request = Request::Write(request_factory::get_write_request(handle, data));
233        let response = self.request(request)?;
234        let write_response: WriteResponse = response.try_into()?;
235        Ok(write_response)
236    }
237
238    /// Write a list of var values by name. This will bundle all the write data into a single write request.
239    /// Returns a HashMap<String, WriteResponse>
240    pub fn sumup_write_by_name(
241        &mut self,
242        var_list: HashMap<String, Vec<u8>>,
243    ) -> ClientResult<HashMap<String, WriteResponse>> {
244        let mut requests: Vec<WriteRequest> = Vec::new();
245        for (varname, data) in &var_list {
246            let handle = self.get_var_handle(varname.as_str())?;
247            requests.push(get_write_request(handle, data.clone()));
248        }
249
250        let mut buf = Vec::new();
251        let sumup_request = SumupWriteRequest::new(requests);
252        sumup_request.write_to(&mut buf)?;
253        let request = Request::ReadWrite(get_sumup_write_request(
254            sumup_request.request_count(),
255            sumup_request.expected_response_len(),
256            buf,
257        ));
258        let response = self.request(request)?;
259        let read_write_response: ReadWriteResponse = response.try_into()?;
260        let sumup_write_response =
261            SumupWriteResponse::read_from(&mut read_write_response.data.as_slice())?;
262        let mut result: HashMap<String, WriteResponse> = HashMap::new();
263
264        if read_write_response.result == AdsError::ErrNoError {
265            for (n, (name, _)) in var_list.iter().enumerate() {
266                result.insert(
267                    name.clone(),
268                    sumup_write_response.write_responses[n].clone(),
269                );
270            }
271        } else {
272            return Err(anyhow![read_write_response.result]);
273        }
274        Ok(result)
275    }
276
277    /// Read device info
278    /// Returns ReadDeviceInfoResponse
279    pub fn read_device_info(&mut self) -> ClientResult<ReadDeviceInfoResponse> {
280        let request = Request::ReadDeviceInfo(ReadDeviceInfoRequest::new());
281        let response = self.request(request)?;
282        let device_info_response: ReadDeviceInfoResponse = response.try_into()?;
283        Ok(device_info_response)
284    }
285
286    /// Read PLC state
287    /// Returns ReadStateResponse
288    pub fn read_state(&mut self) -> ClientResult<ReadStateResponse> {
289        let request = Request::ReadState(ReadStateRequest::new());
290        let response = self.request(request)?;
291        let device_state: ReadStateResponse = response.try_into()?;
292        Ok(device_state)
293    }
294
295    /// Write control
296    /// Returns WriteControlResponse
297    pub fn write_control(
298        &mut self,
299        ads_state: AdsState,
300        device_state: u16,
301    ) -> ClientResult<WriteControlResponse> {
302        let request = Request::WriteControl(request_factory::get_write_control_request(
303            ads_state,
304            device_state,
305        ));
306        let response = self.request(request)?;
307        let write_control_response: WriteControlResponse = response.try_into()?;
308        Ok(write_control_response)
309    }
310
311    /// Read and write data
312    /// Returns ReadWriteResponse
313    pub fn read_write(
314        &mut self,
315        index_offset: u32,
316        read_len: u32,
317        write_data: Vec<u8>,
318    ) -> ClientResult<ReadWriteResponse> {
319        let request = Request::ReadWrite(request_factory::get_read_write_request(
320            index_offset,
321            read_len,
322            write_data,
323        ));
324        let response = self.request(request)?;
325        let read_write_response: ReadWriteResponse = response.try_into()?;
326        Ok(read_write_response)
327    }
328
329    /// Add device notification to receive updated values at value change or at a certain time interfall
330    /// Returns mpsc::receiver which can be polled
331    pub fn add_device_notification(
332        &mut self,
333        var_name: &str,
334        length: u32,
335        transmission_mode: AdsTransMode,
336        max_delay: u32,
337        cycle_time: u32,
338    ) -> ClientResult<Receiver<Result<AdsNotificationStream, Error>>> {
339        let handle = self.get_var_handle(var_name)?;
340        let request = Request::AddDeviceNotification(request_factory::get_add_device_notification(
341            handle,
342            length,
343            transmission_mode,
344            max_delay,
345            cycle_time,
346        ));
347
348        //Get notification handle
349        let response: AddDeviceNotificationResponse = self.request(request)?.try_into()?;
350        let handle = response.notification_handle;
351        //Create mpsc channel for notifications
352        let (tx, rx) = channel::<ClientResult<AdsNotificationStream>>();
353        //Send tx to reader thread
354        self.get_notification_tx()?
355            .send((handle, tx))
356            .expect("Failed to send request to thread by mpsc channel");
357
358        self.notification_handle_list
359            .insert(var_name.to_string(), handle);
360        Ok(rx)
361    }
362
363    /// Release a device notification on the host
364    /// Returns DeleteDeviceNotificationResponse
365    pub fn delete_device_notification(
366        &mut self,
367        var_name: &str,
368    ) -> ClientResult<DeleteDeviceNotificationResponse> {
369        let handle;
370        if let Some(h) = self.notification_handle_list.get(var_name) {
371            handle = *h;
372            let request = Request::DeleteDeviceNotification(
373                request_factory::get_delete_device_notification(handle),
374            );
375            let response = self.request(request)?;
376            let response: DeleteDeviceNotificationResponse = response.try_into()?;
377            self.notification_handle_list.remove(var_name);
378            return Ok(response);
379        }
380        Err(anyhow!(AdsError::AdsErrDeviceSymbolNotFound)) //??
381    }
382
383    fn get_var_handle(&mut self, var_name: &str) -> ClientResult<u32> {
384        if let Some(handle) = self.handle_list.get(var_name) {
385            Ok(*handle)
386        } else {
387            let handle = self.request_var_handle(var_name)?;
388            self.handle_list.insert(var_name.to_string(), handle);
389            Ok(handle)
390        }
391    }
392
393    fn sumup_get_var_handle(
394        &mut self,
395        var_names: &Vec<String>,
396    ) -> ClientResult<HashMap<String, u32>> {
397        let mut do_request: Vec<String> = Vec::new();
398        let mut handles: HashMap<String, u32> = HashMap::new();
399        for var in var_names {
400            if let Some(handle) = self.handle_list.get(var) {
401                handles.insert(var.clone(), *handle);
402            } else {
403                do_request.push(var.clone());
404            }
405        }
406
407        if !do_request.is_empty() {
408            let requested_handles = self.sumup_request_var_handle(&do_request)?;
409            for (name, handle) in requested_handles {
410                self.handle_list.insert(name.clone(), handle);
411                handles.insert(name, handle);
412            }
413        }
414        Ok(handles)
415    }
416
417    /// Request new var handle
418    fn request_var_handle(&mut self, var_name: &str) -> ClientResult<u32> {
419        let request = Request::ReadWrite(get_var_handle_request(var_name));
420        let response: ReadWriteResponse = self.request(request)?.try_into()?;
421
422        if response.length == 4 {
423            return Ok(response.data.as_slice().read_u32::<LittleEndian>()?);
424        }
425        Err(anyhow!(
426            "Failed to get var handle! Variable {} not found!",
427            var_name
428        ))
429    }
430
431    /// Sumup a var handle request
432    /// Not really a sumup request. This methode send for each handle request a tx.
433    // To Do. Is there a way to perform a sumup for handle requests?
434    fn sumup_request_var_handle(
435        &mut self,
436        var_list: &Vec<String>,
437    ) -> ClientResult<HashMap<String, u32>> {
438        let mut result: HashMap<String, u32> = HashMap::new();
439        for var in var_list {
440            let response = self.request(Request::ReadWrite(get_var_handle_request(var)))?;
441            let handle: ReadWriteResponse = response.try_into()?;
442            let handle = handle.data.as_slice().read_u32::<LittleEndian>()?;
443            result.insert(var.clone(), handle);
444        }
445        Ok(result)
446    }
447
448    /// Release var handle
449    pub fn release_handle(&mut self, var_name: &str) -> ClientResult<WriteResponse> {
450        if let Some(handle) = self.handle_list.get(var_name) {
451            let request = Request::Write(request_factory::get_release_handle_request(*handle));
452            let response = self.request(request)?;
453            let response: WriteResponse = response.try_into()?;
454            self.handle_list.remove(var_name);
455            return Ok(response);
456        }
457        Err(anyhow!("Handle not available"))
458    }
459
460    ///Create new tcp_ams_header with supplied request data.
461    fn new_tcp_ams_request_header(&mut self, request: Request) -> AmsTcpHeader {
462        self.invoke_id += 1;
463        AmsTcpHeader::from(AmsHeader::new(
464            self.ams_targed_address.clone(),
465            self.ams_source_address.clone(),
466            StateFlags::req_default(),
467            self.invoke_id,
468            request,
469        ))
470    }
471
472    ///Check if stream disconnected
473    fn check_tcp_stream(&mut self, response: &ClientResult<Response>) {
474        if let Err(e) = response {
475            if e.is::<AdsError>() {
476                let e = e.downcast_ref::<AdsError>();
477                if let Some(e) = e {
478                    if e == &AdsError::ErrPortNotConnected {
479                        if let Some(stream) = &self.stream {
480                            let _ = stream.shutdown(Shutdown::Both);
481                        }
482                        self.handle_list.clear();
483                        self.notification_handle_list.clear();
484                        self.stream = None;
485                    }
486                }
487            }
488        }
489    }
490
491    /// Gets the tx (mpsc::sender) to notify the reader thread about a new handle
492    fn get_general_tx(&self) -> ClientResult<&TxGeneral> {
493        if let Some(tx) = &self.tx_general {
494            return Ok(tx);
495        }
496        Err(anyhow!(AdsError::AdsErrClientError)) //ToDo create better error
497    }
498
499    /// Gets the tx (mpsc::sender) to notify the reader thread about a new notification handle
500    fn get_notification_tx(&self) -> ClientResult<&TxNotification> {
501        if let Some(tx) = &self.tx_notification {
502            return Ok(tx);
503        }
504        Err(anyhow!(AdsError::AdsErrClientError)) //ToDo create better error
505    }
506}