tfrecord_codegen/
lib.rs

1use anyhow::{bail, format_err, Context, Error, Result};
2use flate2::read::GzDecoder;
3use itertools::{chain, Itertools};
4use once_cell::sync::Lazy;
5use std::{
6    env::{self, VarError},
7    fs::{self, File},
8    io::{self, BufReader, BufWriter},
9    path::{Path, PathBuf},
10};
11use tar::Archive;
12
13const BUILD_METHOD_ENV: &str = "TFRECORD_BUILD_METHOD";
14
15static OUT_DIR: Lazy<PathBuf> = Lazy::new(|| PathBuf::from(env::var("OUT_DIR").unwrap()));
16static GENERATED_PROTOBUF_FILE: Lazy<PathBuf> = Lazy::new(|| (*OUT_DIR).join("tensorflow.rs"));
17const TENSORFLOW_VERSION: &str =
18    include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/tensorflow_version"));
19const DEFAULT_TENSORFLOW_URL: &str = concat!(
20    "https://github.com/tensorflow/tensorflow/archive/v",
21    include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/tensorflow_version")),
22    ".tar.gz"
23);
24
25#[derive(Debug, Clone)]
26pub enum BuildMethod {
27    Url(String),
28    SrcDir(PathBuf),
29    SrcFile(PathBuf),
30    InstallPrefix(PathBuf),
31}
32
33pub fn guess_build_method() -> Result<Option<BuildMethod>> {
34    let build_method = match env::var(BUILD_METHOD_ENV) {
35        Ok(text) => {
36            const URL_PREFIX: &str = "url://";
37            const SRC_DIR_PREFIX: &str = "src_dir://";
38            const SRC_FILE_PREFIX: &str = "src_file://";
39            const INSTALL_PREFIX_PREFIX: &str = "install_prefix://";
40
41            let method = if let Some(url) = text.strip_prefix(URL_PREFIX) {
42                match url {
43                    "" => BuildMethod::Url(DEFAULT_TENSORFLOW_URL.to_string()),
44                    _ => BuildMethod::Url(url.to_string()),
45                }
46            } else if let Some(dir) = text.strip_prefix(SRC_DIR_PREFIX) {
47                BuildMethod::SrcDir(dir.into())
48            } else if let Some(path) = text.strip_prefix(SRC_FILE_PREFIX) {
49                BuildMethod::SrcFile(path.into())
50            } else if let Some(prefix) = text.strip_prefix(INSTALL_PREFIX_PREFIX) {
51                BuildMethod::InstallPrefix(prefix.into())
52            } else {
53                return Err(build_method_error());
54            };
55
56            method
57        }
58        Err(VarError::NotPresent) => return Err(build_method_error()),
59        Err(VarError::NotUnicode(_)) => {
60            bail!(
61                r#"the value of environment variable "{}" is not Unicode"#,
62                BUILD_METHOD_ENV
63            );
64        }
65    };
66    Ok(Some(build_method))
67}
68
69pub fn build_method_error() -> Error {
70    format_err!(
71        r#"By enabling the "generate_protobuf_src" feature,
72the environment variable "{BUILD_METHOD_ENV}" must be set with the following format.
73
74- "url://"
75  Download the source from default URL "{DEFAULT_TENSORFLOW_URL}".
76
77- "url://https://github.com/tensorflow/tensorflow/archive/vX.Y.Z.tar.gz"
78  Download the source from specified URL.
79
80- "src_dir:///path/to/tensorflow/dir"
81  Specify unpacked TensorFlow source directory.
82
83- "src_file:///path/to/tensorflow/file.tar.gz"
84  Specify TensorFlow source package file.
85
86- "install_prefix:///path/to/tensorflow/prefix"
87  Specify the installed TensorFlow by install prefix.
88"#,
89    )
90}
91
92pub fn build_by_url<P>(url: &str, out_dir: P) -> Result<()>
93where
94    P: AsRef<Path>,
95{
96    eprintln!("download file {}", url);
97    let src_file = download_tensorflow(url).with_context(|| format!("unable to download {url}"))?;
98    build_by_src_file(&src_file, out_dir)
99        .with_context(|| format!("remove {} and try again", src_file.display()))?;
100    Ok(())
101}
102
103pub fn build_by_src_dir<P, P2>(src_dir: P, out_dir: P2) -> Result<()>
104where
105    P: AsRef<Path>,
106    P2: AsRef<Path>,
107{
108    let src_dir = src_dir.as_ref();
109
110    // re-run if the dir changes
111    println!("cargo:rerun-if-changed={}", src_dir.display());
112
113    compile_protobuf(src_dir, out_dir)?;
114    Ok(())
115}
116
117pub fn build_by_src_file<P, P2>(src_file: P, out_dir: P2) -> Result<()>
118where
119    P: AsRef<Path>,
120    P2: AsRef<Path>,
121{
122    let src_file = src_file.as_ref();
123
124    // re-run if the dir changes
125    println!("cargo:rerun-if-changed={}", src_file.display());
126
127    let src_dir = extract_src_file(src_file)?;
128    compile_protobuf(src_dir, out_dir)?;
129    Ok(())
130}
131
132pub fn build_by_install_prefix<P, P2>(prefix: P, out_dir: P2) -> Result<()>
133where
134    P: AsRef<Path>,
135    P2: AsRef<Path>,
136{
137    let dir = prefix.as_ref().join("include").join("tensorflow");
138    compile_protobuf(dir, out_dir)?;
139    Ok(())
140}
141
142pub fn extract_src_file<P>(src_file: P) -> Result<PathBuf>
143where
144    P: AsRef<Path>,
145{
146    let working_dir = OUT_DIR.join("tensorflow");
147    let src_file = src_file.as_ref();
148    let src_dirname = format!("tensorflow-{TENSORFLOW_VERSION}");
149    let src_dir = working_dir.join(&src_dirname);
150
151    // remove previously extracted dir
152    if src_dir.is_dir() {
153        fs::remove_dir_all(&src_dir)?;
154    }
155
156    // extract package
157    {
158        let file = BufReader::new(
159            File::open(src_file)
160                .with_context(|| format!("unable to open {}", src_file.display()))?,
161        );
162        let tar = GzDecoder::new(file);
163        let mut archive = Archive::new(tar);
164        archive
165            .unpack(&working_dir)
166            .with_context(|| format!("unable to unpack {}", working_dir.display()))?;
167
168        if !src_dir.is_dir() {
169            bail!(
170                r#"expect "{}" directory in source package. Did you download the correct version?"#,
171                src_dirname
172            );
173        }
174    }
175
176    Ok(src_dir)
177}
178
179pub fn compile_protobuf<P1, P2>(src_dir: P1, out_dir: P2) -> Result<()>
180where
181    P1: AsRef<Path>,
182    P2: AsRef<Path>,
183{
184    let dir = src_dir.as_ref();
185    let include_dir = dir;
186    let proto_paths = {
187        let example_pattern = dir
188            .join("tensorflow")
189            .join("core")
190            .join("example")
191            .join("*.proto");
192        let framework_pattern = dir
193            .join("tensorflow")
194            .join("core")
195            .join("framework")
196            .join("*.proto");
197        let event_proto = dir
198            .join("tensorflow")
199            .join("core")
200            .join("util")
201            .join("event.proto");
202
203        let example_iter = glob::glob(example_pattern.to_str().unwrap())
204            .with_context(|| format!("unable to find {}", example_pattern.display()))?;
205        let framework_iter = glob::glob(framework_pattern.to_str().unwrap())
206            .with_context(|| format!("unable to find {}", framework_pattern.display()))?;
207        let paths: Vec<_> =
208            chain!(example_iter, framework_iter, [Ok(event_proto)]).try_collect()?;
209        paths
210    };
211
212    let out_dir = out_dir.as_ref();
213    let prebuild_src_dir = out_dir.join("prebuild_src");
214    let w_serde_path = prebuild_src_dir.join("tensorflow_with_serde.rs");
215    let wo_serde_path = prebuild_src_dir.join("tensorflow_without_serde.rs");
216
217    fs::create_dir_all(prebuild_src_dir)?;
218
219    // without serde
220    {
221        prost_build::compile_protos(&proto_paths, &[PathBuf::from(include_dir)])?;
222        fs::copy(&*GENERATED_PROTOBUF_FILE, wo_serde_path)?;
223    }
224
225    // with serde
226    {
227        prost_build::Config::new()
228            .type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]")
229            .compile_protos(&proto_paths, &[PathBuf::from(include_dir)])?;
230        fs::copy(&*GENERATED_PROTOBUF_FILE, w_serde_path)?;
231    }
232
233    Ok(())
234}
235
236pub fn download_tensorflow(url: &str) -> Result<PathBuf> {
237    let working_dir = OUT_DIR.join("tensorflow");
238    let tar_path = working_dir.join(format!("v{}.tar.gz", TENSORFLOW_VERSION));
239
240    // createw working dir
241    fs::create_dir_all(&working_dir)?;
242
243    // return if downloaded package exists
244    if tar_path.is_file() {
245        return Ok(tar_path);
246    }
247
248    // download file
249    io::copy(
250        &mut ureq::get(url).call()?.into_reader(),
251        &mut BufWriter::new(File::create(&tar_path)?),
252    )?;
253
254    Ok(tar_path)
255}