1use crate::ops;
2use crate::AnyTensor;
3use crate::DataType;
4use crate::Operation;
5use crate::Output;
6use crate::Result;
7use crate::Scope;
8use crate::Shape;
9use crate::Tensor;
10use crate::TensorType;
11use std::borrow::Borrow;
12
13#[derive(Debug, Clone)]
15pub struct Variable {
16 pub(crate) name: String,
17 pub(crate) initializer: Operation,
18 pub(crate) output: Output,
19 pub(crate) dtype: DataType,
20 pub(crate) shape: Shape,
21}
22
23impl Variable {
24 pub fn builder<'a>() -> VariableBuilder<'a> {
26 VariableBuilder::default()
27 }
28
29 pub fn name(&self) -> &str {
31 &self.name
32 }
33
34 pub fn output(&self) -> &Output {
36 &self.output
37 }
38
39 pub fn initializer(&self) -> &Operation {
41 &self.initializer
42 }
43
44 pub fn data_type(&self) -> DataType {
46 self.dtype
47 }
48
49 pub fn shape(&self) -> &Shape {
51 &self.shape
52 }
53}
54
55#[derive(Debug)]
56enum VariableInitialValue<'a> {
57 Unspecified,
58 TensorBox(Box<dyn AnyTensor>),
59 TensorRef(&'a dyn AnyTensor),
60 Output(Output),
61}
62
63#[derive(Debug)]
65pub struct VariableBuilder<'a> {
66 initial_value: VariableInitialValue<'a>,
67 shape: Shape,
68 dtype: Option<DataType>,
69}
70
71impl<'a> Default for VariableBuilder<'a> {
72 fn default() -> Self {
73 Self {
74 initial_value: VariableInitialValue::Unspecified,
75 shape: Shape(None),
76 dtype: None,
77 }
78 }
79}
80
81impl<'a> VariableBuilder<'a> {
82 pub fn const_initial_value<T: TensorType, TT: Into<Tensor<T>>>(self, value: TT) -> Self {
85 let t: Tensor<T> = value.into();
86 let shape = t.shape();
87 Self {
88 initial_value: VariableInitialValue::TensorBox(Box::<Tensor<T>>::new(t)),
89 dtype: Some(T::data_type()),
90 shape,
91 }
92 }
93
94 pub fn const_initial_tensor<T: TensorType>(self, value: &'a Tensor<T>) -> Self {
97 let shape = value.shape();
98 Self {
99 initial_value: VariableInitialValue::TensorRef(value),
100 dtype: Some(T::data_type()),
101 shape,
102 }
103 }
104
105 pub fn initial_value<T: Into<Output>>(self, value: T) -> Self {
108 Self {
109 initial_value: VariableInitialValue::Output(value.into()),
110 ..self
111 }
112 }
113
114 pub fn shape<S: Into<Shape>>(self, shape: S) -> Self {
116 Self {
117 shape: shape.into(),
118 ..self
119 }
120 }
121
122 pub fn data_type(self, data_type: DataType) -> Self {
124 Self {
125 dtype: Some(data_type),
126 ..self
127 }
128 }
129
130 pub fn build(self, scope: &mut Scope) -> Result<Variable> {
132 let name = scope.get_unique_name_for_op("VariableV2");
133 let dtype = match self.dtype {
134 Some(d) => d,
135 None => return Err(invalid_arg!("data_type must be specified")),
136 };
137 let variable_op = {
138 let mut graph = scope.graph_mut();
139 let mut nd = graph.new_operation("VariableV2", &name)?;
140 nd.set_attr_type("dtype", dtype)?;
141 nd.set_attr_shape("shape", &self.shape)?;
142 nd.finish()?
143 };
144 let initial_value = match self.initial_value {
145 VariableInitialValue::Unspecified => {
146 return Err(invalid_arg!("an initial value is required"))
147 }
148 VariableInitialValue::TensorBox(t) => ops::any_constant(t.borrow(), scope)?.into(),
149 VariableInitialValue::TensorRef(t) => ops::any_constant(t, scope)?.into(),
150 VariableInitialValue::Output(o) => o,
151 };
152 let initializer = ops::assign(variable_op.clone(), initial_value, scope)?;
153 Ok(Variable {
154 name,
155 output: variable_op.into(),
156 initializer,
157 dtype,
158 shape: self.shape,
159 })
160 }
161}
162
163#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::Code;
169 use crate::Session;
170 use crate::SessionOptions;
171 use crate::SessionRunArgs;
172
173 #[test]
174 fn const_initialized_scalar() {
175 let scope = Scope::new_root_scope();
176
177 let variable = Variable::builder()
178 .const_initial_value(3.0f32)
179 .build(&mut scope.with_op_name("foo"))
180 .unwrap();
181 assert_eq!(variable.name, "foo");
182 assert_eq!(variable.shape, Shape(Some(vec![])));
183 assert_eq!(variable.dtype, DataType::Float);
184 assert_eq!(
185 variable.output.operation.get_attr_shape("shape").unwrap(),
186 Shape(Some(vec![]))
187 );
188 assert_eq!(
189 variable.output.operation.get_attr_type("dtype").unwrap(),
190 DataType::Float
191 );
192
193 let options = SessionOptions::new();
194 let session = Session::new(&options, &scope.graph()).unwrap();
195 let mut run_args = SessionRunArgs::new();
196 run_args.add_target(&variable.initializer);
197 session.run(&mut run_args).unwrap();
198
199 let mut run_args = SessionRunArgs::new();
200 let fetch = run_args.request_fetch(&variable.output.operation, 0);
201 session.run(&mut run_args).unwrap();
202 let output = run_args.fetch::<f32>(fetch).unwrap();
203 assert_eq!(&output[..], &[3.0f32]);
204 }
205
206 #[test]
207 fn const_initialized_matrix() {
208 let scope = Scope::new_root_scope();
209
210 let initial = Tensor::<i32>::new(&[2, 3])
211 .with_values(&[1, 2, 3, 4, 5, 6])
212 .unwrap();
213 let variable = Variable::builder()
214 .const_initial_tensor(&initial)
215 .build(&mut scope.with_op_name("foo"))
216 .unwrap();
217 assert_eq!(variable.name, "foo");
218 assert_eq!(variable.shape, Shape(Some(vec![Some(2), Some(3)])));
219 assert_eq!(variable.dtype, DataType::Int32);
220 assert_eq!(
221 variable.output.operation.get_attr_shape("shape").unwrap(),
222 Shape(Some(vec![Some(2), Some(3)]))
223 );
224 assert_eq!(
225 variable.output.operation.get_attr_type("dtype").unwrap(),
226 DataType::Int32
227 );
228
229 let options = SessionOptions::new();
230 let session = Session::new(&options, &scope.graph()).unwrap();
231 let mut run_args = SessionRunArgs::new();
232 run_args.add_target(&variable.initializer);
233 session.run(&mut run_args).unwrap();
234
235 let mut run_args = SessionRunArgs::new();
236 let fetch = run_args.request_fetch(&variable.output.operation, 0);
237 session.run(&mut run_args).unwrap();
238 let output = run_args.fetch::<i32>(fetch).unwrap();
239 assert_eq!(&output[..], &initial[..]);
240 }
241
242 #[test]
243 fn custom_initializer_missing_dtype() {
244 let mut scope = Scope::new_root_scope();
245 let value = Tensor::new(&[]).with_values(&[3.0f32]).unwrap();
246 let const_op = ops::constant(value, &mut scope).unwrap();
247
248 assert_eq!(
249 Variable::builder()
250 .initial_value(const_op)
251 .build(&mut scope.with_op_name("foo"))
252 .unwrap_err()
253 .code(),
254 Code::InvalidArgument
255 );
256 }
257
258 #[test]
259 fn custom_initializer() {
260 let mut scope = Scope::new_root_scope();
261 let value = Tensor::new(&[]).with_values(&[3.0f32]).unwrap();
262 let const_op = ops::constant(value, &mut scope).unwrap();
263
264 let variable = Variable::builder()
265 .initial_value(const_op)
266 .data_type(DataType::Float)
267 .build(&mut scope.with_op_name("foo"))
268 .unwrap();
269 assert_eq!(variable.name, "foo");
270 assert_eq!(variable.shape, Shape(None));
271 assert_eq!(variable.dtype, DataType::Float);
272 assert_eq!(
273 variable.output.operation.get_attr_shape("shape").unwrap(),
274 Shape(None)
275 );
276 assert_eq!(
277 variable.output.operation.get_attr_type("dtype").unwrap(),
278 DataType::Float
279 );
280
281 let options = SessionOptions::new();
282 let session = Session::new(&options, &scope.graph()).unwrap();
283 let mut run_args = SessionRunArgs::new();
284 run_args.add_target(&variable.initializer);
285 session.run(&mut run_args).unwrap();
286
287 let mut run_args = SessionRunArgs::new();
288 let fetch = run_args.request_fetch(&variable.output.operation, 0);
289 session.run(&mut run_args).unwrap();
290 let output = run_args.fetch::<f32>(fetch).unwrap();
291 assert_eq!(&output[..], &[3.0f32]);
292 }
293}