[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / utils.py
1 import os
2 import os.path
3 import hashlib
4 import gzip
5 import errno
6 import tarfile
7 import zipfile
9 import torch
10 from torch.utils.model_zoo import tqdm
13 def gen_bar_updater():
14 pbar = tqdm(total=None)
16 def bar_update(count, block_size, total_size):
17 if pbar.total is None and total_size:
18 pbar.total = total_size
19 progress_bytes = count * block_size
20 pbar.update(progress_bytes - pbar.n)
22 return bar_update
25 def calculate_md5(fpath, chunk_size=1024 * 1024):
26 md5 = hashlib.md5()
27 with open(fpath, 'rb') as f:
28 for chunk in iter(lambda: f.read(chunk_size), b''):
29 md5.update(chunk)
30 return md5.hexdigest()
33 def check_md5(fpath, md5, **kwargs):
34 return md5 == calculate_md5(fpath, **kwargs)
37 def check_integrity(fpath, md5=None):
38 if not os.path.isfile(fpath):
39 return False
40 if md5 is None:
41 return True
42 return check_md5(fpath, md5)
45 def makedir_exist_ok(dirpath):
46 """
47 Python2 support for os.makedirs(.., exist_ok=True)
48 """
49 try:
50 os.makedirs(dirpath)
51 except OSError as e:
52 if e.errno == errno.EEXIST:
53 pass
54 else:
55 raise
58 def download_url(url, root, filename=None, md5=None):
59 """Download a file from a url and place it in root.
61 Args:
62 url (str): URL to download file from
63 root (str): Directory to place downloaded file in
64 filename (str, optional): Name to save the file under. If None, use the basename of the URL
65 md5 (str, optional): MD5 checksum of the download. If None, do not check
66 """
67 from six.moves import urllib
69 root = os.path.expanduser(root)
70 if not filename:
71 filename = os.path.basename(url)
72 fpath = os.path.join(root, filename)
74 makedir_exist_ok(root)
76 # downloads file
77 if check_integrity(fpath, md5):
78 print('Using downloaded and verified file: ' + fpath)
79 else:
80 try:
81 print('Downloading ' + url + ' to ' + fpath)
82 urllib.request.urlretrieve(
83 url, fpath,
84 reporthook=gen_bar_updater()
85 )
86 except (urllib.error.URLError, IOError) as e:
87 if url[:5] == 'https':
88 url = url.replace('https:', 'http:')
89 print('Failed download. Trying https -> http instead.'
90 ' Downloading ' + url + ' to ' + fpath)
91 urllib.request.urlretrieve(
92 url, fpath,
93 reporthook=gen_bar_updater()
94 )
95 else:
96 raise e
98 return fpath
101 def list_dir(root, prefix=False):
102 """List all directories at a given root
104 Args:
105 root (str): Path to directory whose folders need to be listed
106 prefix (bool, optional): If true, prepends the path to each result, otherwise
107 only returns the name of the directories found
108 """
109 root = os.path.expanduser(root)
110 directories = list(
111 filter(
112 lambda p: os.path.isdir(os.path.join(root, p)),
113 os.listdir(root)
114 )
115 )
117 if prefix is True:
118 directories = [os.path.join(root, d) for d in directories]
120 return directories
123 def list_files(root, suffix, prefix=False):
124 """List all files ending with a suffix at a given root
126 Args:
127 root (str): Path to directory whose folders need to be listed
128 suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
129 It uses the Python "str.endswith" method and is passed directly
130 prefix (bool, optional): If true, prepends the path to each result, otherwise
131 only returns the name of the files found
132 """
133 root = os.path.expanduser(root)
134 files = list(
135 filter(
136 lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
137 os.listdir(root)
138 )
139 )
141 if prefix is True:
142 files = [os.path.join(root, d) for d in files]
144 return files
147 def download_file_from_google_drive(file_id, root, filename=None, md5=None):
148 """Download a Google Drive file from and place it in root.
150 Args:
151 file_id (str): id of file to be downloaded
152 root (str): Directory to place downloaded file in
153 filename (str, optional): Name to save the file under. If None, use the id of the file.
154 md5 (str, optional): MD5 checksum of the download. If None, do not check
155 """
156 # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
157 import requests
158 url = "https://docs.google.com/uc?export=download"
160 root = os.path.expanduser(root)
161 if not filename:
162 filename = file_id
163 fpath = os.path.join(root, filename)
165 makedir_exist_ok(root)
167 if os.path.isfile(fpath) and check_integrity(fpath, md5):
168 print('Using downloaded and verified file: ' + fpath)
169 else:
170 session = requests.Session()
172 response = session.get(url, params={'id': file_id}, stream=True)
173 token = _get_confirm_token(response)
175 if token:
176 params = {'id': file_id, 'confirm': token}
177 response = session.get(url, params=params, stream=True)
179 _save_response_content(response, fpath)
182 def _get_confirm_token(response):
183 for key, value in response.cookies.items():
184 if key.startswith('download_warning'):
185 return value
187 return None
190 def _save_response_content(response, destination, chunk_size=32768):
191 with open(destination, "wb") as f:
192 pbar = tqdm(total=None)
193 progress = 0
194 for chunk in response.iter_content(chunk_size):
195 if chunk: # filter out keep-alive new chunks
196 f.write(chunk)
197 progress += len(chunk)
198 pbar.update(progress - pbar.n)
199 pbar.close()
202 def _is_tar(filename):
203 return filename.endswith(".tar")
206 def _is_targz(filename):
207 return filename.endswith(".tar.gz")
210 def _is_gzip(filename):
211 return filename.endswith(".gz") and not filename.endswith(".tar.gz")
214 def _is_zip(filename):
215 return filename.endswith(".zip")
218 def extract_archive(from_path, to_path=None, remove_finished=False):
219 if to_path is None:
220 to_path = os.path.dirname(from_path)
222 if _is_tar(from_path):
223 with tarfile.open(from_path, 'r') as tar:
224 tar.extractall(path=to_path)
225 elif _is_targz(from_path):
226 with tarfile.open(from_path, 'r:gz') as tar:
227 tar.extractall(path=to_path)
228 elif _is_gzip(from_path):
229 to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
230 with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
231 out_f.write(zip_f.read())
232 elif _is_zip(from_path):
233 with zipfile.ZipFile(from_path, 'r') as z:
234 z.extractall(to_path)
235 else:
236 raise ValueError("Extraction of {} not supported".format(from_path))
238 if remove_finished:
239 os.remove(from_path)
242 def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
243 md5=None, remove_finished=False):
244 download_root = os.path.expanduser(download_root)
245 if extract_root is None:
246 extract_root = download_root
247 if not filename:
248 filename = os.path.basename(url)
250 download_url(url, download_root, filename, md5)
252 archive = os.path.join(download_root, filename)
253 print("Extracting {} to {}".format(archive, extract_root))
254 extract_archive(archive, extract_root, remove_finished)
257 def iterable_to_str(iterable):
258 return "'" + "', '".join([str(item) for item in iterable]) + "'"
261 def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
262 if not isinstance(value, torch._six.string_classes):
263 if arg is None:
264 msg = "Expected type str, but got type {type}."
265 else:
266 msg = "Expected type str for argument {arg}, but got type {type}."
267 msg = msg.format(type=type(value), arg=arg)
268 raise ValueError(msg)
270 if valid_values is None:
271 return value
273 if value not in valid_values:
274 if custom_msg is not None:
275 msg = custom_msg
276 else:
277 msg = ("Unknown value '{value}' for argument {arg}. "
278 "Valid values are {{{valid_values}}}.")
279 msg = msg.format(value=value, arg=arg,
280 valid_values=iterable_to_str(valid_values))
281 raise ValueError(msg)
283 return value