]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xnn/utils/data_utils.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / utils / data_utils.py
1 #################################################################################
2 # Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
3 # All Rights Reserved.
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are met:
7 #
8 # * Redistributions of source code must retain the above copyright notice, this
9 #   list of conditions and the following disclaimer.
10 #
11 # * Redistributions in binary form must reproduce the above copyright notice,
12 #   this list of conditions and the following disclaimer in the documentation
13 #   and/or other materials provided with the distribution.
14 #
15 # * Neither the name of the copyright holder nor the names of its
16 #   contributors may be used to endorse or promote products derived from
17 #   this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 #
30 #################################################################################
31 # Some parts of the code are borrowed from: https://github.com/pytorch/vision
32 # with the following license:
33 #
34 # BSD 3-Clause License
35 #
36 # Copyright (c) Soumith Chintala 2016,
37 # All rights reserved.
38 #
39 # Redistribution and use in source and binary forms, with or without
40 # modification, are permitted provided that the following conditions are met:
41 #
42 # * Redistributions of source code must retain the above copyright notice, this
43 #   list of conditions and the following disclaimer.
44 #
45 # * Redistributions in binary form must reproduce the above copyright notice,
46 #   this list of conditions and the following disclaimer in the documentation
47 #   and/or other materials provided with the distribution.
48 #
49 # * Neither the name of the copyright holder nor the names of its
50 #   contributors may be used to endorse or promote products derived from
51 #   this software without specific prior written permission.
52 #
53 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
54 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
55 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
56 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
57 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
58 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
59 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
60 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
61 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
62 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
63 #
64 #################################################################################
67 # from torchvision.datasets
68 import os
69 import os.path
70 import hashlib
71 import gzip
72 import errno
73 import tarfile
74 import zipfile
76 import torch
77 from torch.utils.model_zoo import tqdm
80 def gen_bar_updater():
81     pbar = tqdm(total=None)
83     def bar_update(count, block_size, total_size):
84         if pbar.total is None and total_size:
85             pbar.total = total_size
86         progress_bytes = count * block_size
87         pbar.update(progress_bytes - pbar.n)
89     return bar_update
92 def calculate_md5(fpath, chunk_size=1024 * 1024):
93     md5 = hashlib.md5()
94     with open(fpath, 'rb') as f:
95         for chunk in iter(lambda: f.read(chunk_size), b''):
96             md5.update(chunk)
97     return md5.hexdigest()
100 def check_md5(fpath, md5, **kwargs):
101     return md5 == calculate_md5(fpath, **kwargs)
104 def check_integrity(fpath, md5=None):
105     if not os.path.isfile(fpath):
106         return False
107     if md5 is None:
108         return True
109     return check_md5(fpath, md5)
112 def makedir_exist_ok(dirpath):
113     """
114     Python2 support for os.makedirs(.., exist_ok=True)
115     """
116     try:
117         os.makedirs(dirpath)
118     except OSError as e:
119         if e.errno == errno.EEXIST:
120             pass
121         else:
122             raise
125 def download_url(url, root, filename=None, md5=None):
126     """Download a file from a url and place it in root.
128     Args:
129         url (str): URL to download file from
130         root (str): Directory to place downloaded file in
131         filename (str, optional): Name to save the file under. If None, use the basename of the URL
132         md5 (str, optional): MD5 checksum of the download. If None, do not check
133     """
134     from six.moves import urllib
136     root = os.path.expanduser(root)
137     if not filename:
138         filename = os.path.basename(url)
139     fpath = os.path.join(root, filename)
141     makedir_exist_ok(root)
143     # downloads file
144     if check_integrity(fpath, md5):
145         print('Using downloaded and verified file: ' + fpath)
146     else:
147         try:
148             print('Downloading ' + url + ' to ' + fpath)
149             urllib.request.urlretrieve(
150                 url, fpath,
151                 reporthook=gen_bar_updater()
152             )
153         except (urllib.error.URLError, IOError) as e:
154             if url[:5] == 'https':
155                 url = url.replace('https:', 'http:')
156                 print('Failed download. Trying https -> http instead.'
157                       ' Downloading ' + url + ' to ' + fpath)
158                 urllib.request.urlretrieve(
159                     url, fpath,
160                     reporthook=gen_bar_updater()
161                 )
162             else:
163                 raise e
165     return fpath
168 def list_dir(root, prefix=False):
169     """List all directories at a given root
171     Args:
172         root (str): Path to directory whose folders need to be listed
173         prefix (bool, optional): If true, prepends the path to each result, otherwise
174             only returns the name of the directories found
175     """
176     root = os.path.expanduser(root)
177     directories = list(
178         filter(
179             lambda p: os.path.isdir(os.path.join(root, p)),
180             os.listdir(root)
181         )
182     )
184     if prefix is True:
185         directories = [os.path.join(root, d) for d in directories]
187     return directories
190 def list_files(root, suffix, prefix=False):
191     """List all files ending with a suffix at a given root
193     Args:
194         root (str): Path to directory whose folders need to be listed
195         suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
196             It uses the Python "str.endswith" method and is passed directly
197         prefix (bool, optional): If true, prepends the path to each result, otherwise
198             only returns the name of the files found
199     """
200     root = os.path.expanduser(root)
201     files = list(
202         filter(
203             lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
204             os.listdir(root)
205         )
206     )
208     if prefix is True:
209         files = [os.path.join(root, d) for d in files]
211     return files
214 def download_file_from_google_drive(file_id, root, filename=None, md5=None):
215     """Download a Google Drive file from  and place it in root.
217     Args:
218         file_id (str): id of file to be downloaded
219         root (str): Directory to place downloaded file in
220         filename (str, optional): Name to save the file under. If None, use the id of the file.
221         md5 (str, optional): MD5 checksum of the download. If None, do not check
222     """
223     # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
224     import requests
225     url = "https://docs.google.com/uc?export=download"
227     root = os.path.expanduser(root)
228     if not filename:
229         filename = file_id
230     fpath = os.path.join(root, filename)
232     makedir_exist_ok(root)
234     if os.path.isfile(fpath) and check_integrity(fpath, md5):
235         print('Using downloaded and verified file: ' + fpath)
236     else:
237         session = requests.Session()
239         response = session.get(url, params={'id': file_id}, stream=True)
240         token = _get_confirm_token(response)
242         if token:
243             params = {'id': file_id, 'confirm': token}
244             response = session.get(url, params=params, stream=True)
246         _save_response_content(response, fpath)
249 def _get_confirm_token(response):
250     for key, value in response.cookies.items():
251         if key.startswith('download_warning'):
252             return value
254     return None
257 def _save_response_content(response, destination, chunk_size=32768):
258     with open(destination, "wb") as f:
259         pbar = tqdm(total=None)
260         progress = 0
261         for chunk in response.iter_content(chunk_size):
262             if chunk:  # filter out keep-alive new chunks
263                 f.write(chunk)
264                 progress += len(chunk)
265                 pbar.update(progress - pbar.n)
266         pbar.close()
269 def _is_tar(filename):
270     return filename.endswith(".tar")
273 def _is_targz(filename):
274     return filename.endswith(".tar.gz")
277 def _is_gzip(filename):
278     return filename.endswith(".gz") and not filename.endswith(".tar.gz")
281 def _is_zip(filename):
282     return filename.endswith(".zip")
285 def extract_archive(from_path, to_path=None, remove_finished=False):
286     if to_path is None:
287         to_path = os.path.dirname(from_path)
289     if _is_tar(from_path):
290         with tarfile.open(from_path, 'r') as tar:
291             tar.extractall(path=to_path)
292     elif _is_targz(from_path):
293         with tarfile.open(from_path, 'r:gz') as tar:
294             tar.extractall(path=to_path)
295     elif _is_gzip(from_path):
296         to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
297         with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
298             out_f.write(zip_f.read())
299     elif _is_zip(from_path):
300         with zipfile.ZipFile(from_path, 'r') as z:
301             z.extractall(to_path)
302     else:
303         raise ValueError("Extraction of {} not supported".format(from_path))
305     if remove_finished:
306         os.remove(from_path)
309 def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
310                                  md5=None, remove_finished=False):
311     download_root = os.path.expanduser(download_root)
312     if extract_root is None:
313         extract_root = download_root
314     if not filename:
315         filename = os.path.basename(url)
317     download_url(url, download_root, filename, md5)
319     archive = os.path.join(download_root, filename)
320     print("Extracting {} to {}".format(archive, extract_root))
321     extract_archive(archive, extract_root, remove_finished)
324 def iterable_to_str(iterable):
325     return "'" + "', '".join([str(item) for item in iterable]) + "'"
328 def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
329     if not isinstance(value, torch._six.string_classes):
330         if arg is None:
331             msg = "Expected type str, but got type {type}."
332         else:
333             msg = "Expected type str for argument {arg}, but got type {type}."
334         msg = msg.format(type=type(value), arg=arg)
335         raise ValueError(msg)
337     if valid_values is None:
338         return value
340     if value not in valid_values:
341         if custom_msg is not None:
342             msg = custom_msg
343         else:
344             msg = ("Unknown value '{value}' for argument {arg}. "
345                    "Valid values are {{{valid_values}}}.")
346             msg = msg.format(value=value, arg=arg,
347                              valid_values=iterable_to_str(valid_values))
348         raise ValueError(msg)
350     return value