rust_ethernet_ip/
python.rs

1#![allow(non_local_definitions)]
2
3use crate::{EipClient, PlcValue, SubscriptionOptions};
4use pyo3::prelude::*;
5use pyo3::types::{PyDict, PyTuple};
6use pyo3::IntoPyObjectExt;
7use std::collections::HashMap;
8use tokio::runtime::Runtime;
9
10/// Python module for rust_ethernet_ip
11#[pymodule]
12fn rust_ethernet_ip(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
13    m.add_class::<PyEipClient>()?;
14    m.add_class::<PyPlcValue>()?;
15    m.add_class::<PySubscriptionOptions>()?;
16    Ok(())
17}
18
19/// Python wrapper for EipClient
20#[pyclass]
21struct PyEipClient {
22    client: EipClient,
23    runtime: Runtime,
24}
25
26// Newtype for (String, PyPlcValue)
27struct TagValueArg {
28    name: String,
29    value: PyPlcValue,
30}
31
32impl<'a> FromPyObject<'a> for TagValueArg {
33    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
34        let tuple = ob.downcast::<PyTuple>()?;
35        if tuple.len() != 2 {
36            return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
37                "Expected tuple of length 2",
38            ));
39        }
40        let name = tuple.get_item(0)?.extract::<String>()?;
41        let value = tuple.get_item(1)?.extract::<PyPlcValue>()?;
42        Ok(TagValueArg { name, value })
43    }
44}
45
46// Newtype for (String, PySubscriptionOptions)
47struct TagSubOptArg {
48    name: String,
49    options: PySubscriptionOptions,
50}
51
52impl<'a> FromPyObject<'a> for TagSubOptArg {
53    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
54        let tuple = ob.downcast::<PyTuple>()?;
55        if tuple.len() != 2 {
56            return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
57                "Expected tuple of length 2",
58            ));
59        }
60        let name = tuple.get_item(0)?.extract::<String>()?;
61        let options = tuple.get_item(1)?.extract::<PySubscriptionOptions>()?;
62        Ok(TagSubOptArg { name, options })
63    }
64}
65
66#[pymethods]
67impl PyEipClient {
68    /// Create a new EipClient instance
69    #[new]
70    fn new(addr: &str) -> PyResult<Self> {
71        let runtime = Runtime::new().unwrap();
72        let client = runtime
73            .block_on(async { EipClient::connect(addr).await })
74            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
75
76        Ok(PyEipClient { client, runtime })
77    }
78
79    /// Read a tag value
80    fn read_tag(&mut self, tag_name: &str) -> PyResult<PyPlcValue> {
81        let value = self
82            .runtime
83            .block_on(async { self.client.read_tag(tag_name).await })
84            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
85
86        Ok(PyPlcValue { value })
87    }
88
89    /// Write a value to a tag
90    fn write_tag(&mut self, tag_name: &str, value: &PyPlcValue) -> PyResult<bool> {
91        let result = self
92            .runtime
93            .block_on(async { self.client.write_tag(tag_name, value.value.clone()).await });
94        match result {
95            Ok(_) => Ok(true),
96            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
97                e.to_string(),
98            )),
99        }
100    }
101
102    /// Read multiple tags in batch
103    fn read_tags_batch(&mut self, tag_names: Vec<String>) -> PyResult<Vec<(String, Py<PyAny>)>> {
104        Python::attach(|py| {
105            let runtime = tokio::runtime::Runtime::new().unwrap();
106            let results = runtime
107                .block_on(async {
108                    self.client
109                        .read_tags_batch(&tag_names.iter().map(|s| s.as_str()).collect::<Vec<_>>())
110                        .await
111                })
112                .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
113            let mut results_vec = Vec::new();
114            for (name, result) in results {
115                let obj = match result {
116                    Ok(v) => PyPlcValue { value: v }.into_bound_py_any(py)?,
117                    Err(e) => PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string())
118                        .into_bound_py_any(py)?,
119                };
120                results_vec.push((name, obj.unbind()));
121            }
122            Ok(results_vec)
123        })
124    }
125
126    /// Write multiple tags in batch
127    fn write_tags_batch(
128        &mut self,
129        tag_values: Vec<TagValueArg>,
130    ) -> PyResult<Vec<(String, Py<PyAny>)>> {
131        Python::attach(|py| {
132            let runtime = tokio::runtime::Runtime::new().unwrap();
133            let results = runtime
134                .block_on(async {
135                    self.client
136                        .write_tags_batch(
137                            &tag_values
138                                .iter()
139                                .map(|arg| (arg.name.as_str(), arg.value.value.clone()))
140                                .collect::<Vec<_>>(),
141                        )
142                        .await
143                })
144                .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
145            let mut results_vec = Vec::new();
146            for (name, result) in results {
147                let obj = match result {
148                    Ok(()) => py.None().into_bound_py_any(py)?,
149                    Err(e) => PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string())
150                        .into_bound_py_any(py)?,
151                };
152                results_vec.push((name, obj.unbind()));
153            }
154            Ok(results_vec)
155        })
156    }
157
158    /// Subscribe to a tag
159    fn subscribe_to_tag(&self, tag_path: &str, options: &PySubscriptionOptions) -> PyResult<()> {
160        self.runtime
161            .block_on(async {
162                self.client
163                    .subscribe_to_tag(tag_path, options.options.clone())
164                    .await
165            })
166            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
167
168        Ok(())
169    }
170
171    /// Subscribe to multiple tags
172    fn subscribe_to_tags(&self, tags: Vec<TagSubOptArg>) -> PyResult<()> {
173        self.runtime
174            .block_on(async {
175                self.client
176                    .subscribe_to_tags(
177                        &tags
178                            .iter()
179                            .map(|arg| (arg.name.as_str(), arg.options.options.clone()))
180                            .collect::<Vec<_>>(),
181                    )
182                    .await
183            })
184            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
185        Ok(())
186    }
187
188    /// Unregister the session
189    fn unregister_session(&mut self) -> PyResult<()> {
190        self.runtime
191            .block_on(async { self.client.unregister_session().await })
192            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
193
194        Ok(())
195    }
196}
197
198/// Python wrapper for PlcValue
199#[pyclass]
200struct PyPlcValue {
201    value: PlcValue,
202}
203
204impl FromPyObject<'_> for PyPlcValue {
205    fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
206        if let Ok(bool_val) = ob.extract::<bool>() {
207            Ok(PyPlcValue {
208                value: PlcValue::Bool(bool_val),
209            })
210        } else if let Ok(int_val) = ob.extract::<i32>() {
211            Ok(PyPlcValue {
212                value: PlcValue::Dint(int_val),
213            })
214        } else if let Ok(float_val) = ob.extract::<f64>() {
215            Ok(PyPlcValue {
216                value: PlcValue::Lreal(float_val),
217            })
218        } else if let Ok(string_val) = ob.extract::<String>() {
219            Ok(PyPlcValue {
220                value: PlcValue::String(string_val),
221            })
222        } else if let Ok(dict) = ob.downcast::<PyDict>() {
223            let mut map = HashMap::new();
224            for (key, value) in dict.iter() {
225                let key = key.extract::<String>()?;
226                let value = value.extract::<PyPlcValue>()?.value;
227                map.insert(key, value);
228            }
229            Ok(PyPlcValue {
230                value: PlcValue::Udt(map),
231            })
232        } else {
233            Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
234                "Unsupported value type",
235            ))
236        }
237    }
238}
239
240#[pymethods]
241impl PyPlcValue {
242    #[new]
243    fn new(value: Py<PyAny>) -> PyResult<Self> {
244        Python::attach(|py| {
245            let bound_value = value.bind(py);
246            if let Ok(val) = bound_value.extract::<bool>() {
247                Ok(PyPlcValue {
248                    value: PlcValue::Bool(val),
249                })
250            } else if let Ok(val) = bound_value.extract::<i32>() {
251                Ok(PyPlcValue {
252                    value: PlcValue::Dint(val),
253                })
254            } else if let Ok(val) = bound_value.extract::<f32>() {
255                Ok(PyPlcValue {
256                    value: PlcValue::Real(val),
257                })
258            } else if let Ok(val) = bound_value.extract::<f64>() {
259                Ok(PyPlcValue {
260                    value: PlcValue::Real(val as f32),
261                })
262            } else if let Ok(val) = bound_value.extract::<String>() {
263                Ok(PyPlcValue {
264                    value: PlcValue::String(val),
265                })
266            } else {
267                Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
268                    "Unsupported value type",
269                ))
270            }
271        })
272    }
273
274    #[staticmethod]
275    fn real(val: f32) -> Self {
276        PyPlcValue {
277            value: PlcValue::Real(val),
278        }
279    }
280    #[staticmethod]
281    fn lreal(val: f64) -> Self {
282        PyPlcValue {
283            value: PlcValue::Lreal(val),
284        }
285    }
286    #[staticmethod]
287    fn dint(val: i32) -> Self {
288        PyPlcValue {
289            value: PlcValue::Dint(val),
290        }
291    }
292    #[staticmethod]
293    fn lint(val: i64) -> Self {
294        PyPlcValue {
295            value: PlcValue::Lint(val),
296        }
297    }
298    #[staticmethod]
299    fn string(val: String) -> Self {
300        PyPlcValue {
301            value: PlcValue::String(val),
302        }
303    }
304
305    #[getter]
306    fn value(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
307        match &self.value {
308            PlcValue::Bool(b) => Ok(b.into_bound_py_any(py)?.unbind()),
309            PlcValue::Sint(i) => Ok(i.into_bound_py_any(py)?.unbind()),
310            PlcValue::Int(i) => Ok(i.into_bound_py_any(py)?.unbind()),
311            PlcValue::Dint(i) => Ok(i.into_bound_py_any(py)?.unbind()),
312            PlcValue::Lint(i) => Ok(i.into_bound_py_any(py)?.unbind()),
313            PlcValue::Usint(u) => Ok(u.into_bound_py_any(py)?.unbind()),
314            PlcValue::Uint(u) => Ok(u.into_bound_py_any(py)?.unbind()),
315            PlcValue::Udint(u) => Ok(u.into_bound_py_any(py)?.unbind()),
316            PlcValue::Ulint(u) => Ok(u.into_bound_py_any(py)?.unbind()),
317            PlcValue::Real(f) => Ok(f.into_bound_py_any(py)?.unbind()),
318            PlcValue::Lreal(f) => Ok(f.into_bound_py_any(py)?.unbind()),
319            PlcValue::String(s) => Ok(s.into_bound_py_any(py)?.unbind()),
320            PlcValue::Udt(map) => {
321                let dict = PyDict::new(py);
322                for (k, v) in map.iter() {
323                    let v_py = PyPlcValue { value: v.clone() }.value(py)?;
324                    dict.set_item(k, v_py)?;
325                }
326                Ok(dict.unbind().into())
327            }
328        }
329    }
330
331    fn __str__(&self) -> String {
332        format!("{:?}", self.value)
333    }
334
335    fn __repr__(&self) -> String {
336        format!("PyPlcValue({:?})", self.value)
337    }
338}
339
340/// Python wrapper for SubscriptionOptions
341#[pyclass]
342struct PySubscriptionOptions {
343    options: SubscriptionOptions,
344}
345
346impl FromPyObject<'_> for PySubscriptionOptions {
347    fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
348        let update_rate = ob.getattr("update_rate")?.extract::<u32>()?;
349        let change_threshold = ob.getattr("change_threshold")?.extract::<f32>()?;
350        let timeout = ob.getattr("timeout")?.extract::<u32>()?;
351
352        Ok(PySubscriptionOptions {
353            options: SubscriptionOptions {
354                update_rate,
355                change_threshold,
356                timeout,
357            },
358        })
359    }
360}
361
362#[pymethods]
363impl PySubscriptionOptions {
364    #[new]
365    fn new(update_rate: u32, change_threshold: f32, timeout: u32) -> PyResult<Self> {
366        let options = SubscriptionOptions {
367            update_rate,
368            change_threshold,
369            timeout,
370        };
371
372        Ok(PySubscriptionOptions { options })
373    }
374
375    /// Get the update rate in milliseconds
376    #[getter]
377    fn update_rate(&self) -> u32 {
378        self.options.update_rate
379    }
380
381    /// Get the change threshold for numeric values
382    #[getter]
383    fn change_threshold(&self) -> f32 {
384        self.options.change_threshold
385    }
386
387    /// Get the timeout in milliseconds
388    #[getter]
389    fn timeout(&self) -> u32 {
390        self.options.timeout
391    }
392}