"""Utilities for downloading from the web.
"""
import chardet
import glob
import logging
import os
import shutil
import tarfile
import urllib
import zipfile
import warnings
from tqdm import tqdm
from smart_open import open, parse_uri
from mirdata.validate import md5
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO)
[docs]
def downloader(
save_dir,
remotes=None,
index=None,
partial_download=None,
info_message=None,
force_overwrite=False,
cleanup=False,
allow_invalid_checksum=False,
):
"""Download data to `save_dir` and optionally log a message.
Args:
save_dir (str):
The directory to download the data
remotes (dict or None):
A dictionary of RemoteFileMetadata tuples of data in zip format.
If None, there is no data to download
index (core.Index):
A mirdata Index class, which may contain a remote index to be downloaded
or a subset of remotes to download by default.
partial_download (list or None):
A list of keys to partially download the remote objects of the download dict.
If None, all data specified by the index is downloaded
info_message (str or None):
A string of info to log when this function is called.
If None, no string is logged.
force_overwrite (bool):
If True, existing files are overwritten by the downloaded files.
cleanup (bool):
Whether to delete the zip/tar file after extracting.
allow_invalid_checksum (bool):
Allow having an invalid checksum, and whenever this happens prompt a
warning instead of deleting the files.
"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if not index:
raise ValueError("Index must be specified.")
if allow_invalid_checksum:
cleanup = True
if cleanup:
logging.warning(
"Zip and tar files will be deleted after they are uncompressed. "
+ "If you download this dataset again, it will overwrite existing files, even if force_overwrite=False"
)
if index.remote:
# add index to remotes
if not remotes:
remotes = {}
remotes["index"] = index.remote
# if partial download is specified, use it. Otherwise, use the
# partial download specified by the index.
partial_download = partial_download if partial_download else index.partial_download
if remotes:
if partial_download:
# check the keys in partial_download are in the download dict
if not isinstance(partial_download, list) or any(
[k not in remotes for k in partial_download]
):
raise ValueError(
"partial_download must be a list which is a subset of {}, but got {}".format(
list(remotes.keys()), partial_download
)
)
objs_to_download = partial_download
else:
objs_to_download = list(remotes.keys())
logging.info("Downloading {} to {}".format(objs_to_download, save_dir))
for k in objs_to_download:
logging.info("[{}] downloading {}".format(k, remotes[k].filename))
extension = os.path.splitext(remotes[k].filename)[-1]
if ".zip" in extension:
download_zip_file(
remotes[k],
save_dir,
force_overwrite,
cleanup,
allow_invalid_checksum,
)
elif ".gz" in extension or ".tar" in extension or ".bz2" in extension:
download_tar_file(
remotes[k],
save_dir,
force_overwrite,
cleanup,
allow_invalid_checksum,
)
else:
download_from_remote(
remotes[k], save_dir, force_overwrite, allow_invalid_checksum
)
if remotes[k].unpack_directories:
for src_dir in remotes[k].unpack_directories:
# path to destination directory
destination_dir = (
os.path.join(save_dir, remotes[k].destination_dir)
if remotes[k].destination_dir
else save_dir
)
# path to directory to unpack
source_dir = os.path.join(destination_dir, src_dir)
if not os.path.exists(source_dir):
logging.info(
"Data not downloaded, because it probably already exists on your computer. "
+ "Run .validate() to check, or rerun with force_overwrite=True to delete any "
+ "existing files and download from scratch"
)
return
move_directory_contents(source_dir, destination_dir)
if info_message is not None:
logging.info(info_message.format(save_dir))
[docs]
class DownloadProgressBar(tqdm):
"""
Wrap `tqdm` to show download progress
"""
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
[docs]
def download_from_remote(remote, save_dir, force_overwrite, allow_invalid_checksum):
"""Download a remote dataset into path
Fetch a dataset pointed by remote's url, save into path using remote's
filename and ensure its integrity based on the MD5 Checksum of the
downloaded file.
Adapted from scikit-learn's sklearn.datasets.base._fetch_remote.
Args:
remote (RemoteFileMetadata): Named tuple containing remote dataset
meta information: url, filename and checksum
save_dir (str): Directory to save the file to. Usually `data_home`
force_overwrite (bool):
If True, overwrite existing file with the downloaded file.
If False, does not overwrite, but checks that checksum is consistent.
Returns:
str: Full path of the created file.
"""
file_uri = parse_uri(save_dir)
if file_uri.scheme != "file":
raise NotImplementedError(
"mirdata only supports downloading to a local filesystem. "
"To use mirdata with a remote filesystem, download to a local filesytem, "
"and transfer the data to your remote filesystem, setting data_home appropriately."
)
if remote.destination_dir is None:
download_dir = save_dir
else:
download_dir = os.path.join(save_dir, remote.destination_dir)
if not os.path.exists(download_dir):
os.makedirs(download_dir)
download_path = os.path.join(download_dir, remote.filename)
if not os.path.exists(download_path) or force_overwrite:
# if we got here, we want to overwrite any existing file
if os.path.exists(download_path):
os.remove(download_path)
# If file doesn't exist or we want to overwrite, download it
with DownloadProgressBar(
unit="B", unit_scale=True, unit_divisor=1024, miniters=1
) as t:
try:
urllib.request.urlretrieve(
remote.url,
filename=download_path,
reporthook=t.update_to,
data=None,
)
except Exception as exc:
error_msg = """
mirdata failed to download the dataset from {}!
Please try again in a few minutes.
If this error persists, please raise an issue at
https://github.com/mir-dataset-loaders/mirdata,
and tag it with 'broken-link'.
""".format(
remote.url
)
logging.error(error_msg)
raise exc
else:
logging.info(
"{} already exists and will not be downloaded. ".format(download_path)
+ "Rerun with force_overwrite=True to delete this file and force the download."
)
checksum = md5(download_path)
if remote.checksum != checksum:
if allow_invalid_checksum:
warnings.warn(
"{} has an MD5 checksum ({}) "
"differing from expected ({}), "
"file may be corrupted.".format(
download_path, checksum, remote.checksum
),
UserWarning,
)
else:
raise IOError(
"{} has an MD5 checksum ({}) "
"differing from expected ({}), "
"file may be corrupted.".format(
download_path, checksum, remote.checksum
)
)
return download_path
[docs]
def download_zip_file(
zip_remote, save_dir, force_overwrite, cleanup, allow_invalid_checksum
):
"""Download and unzip a zip file.
Args:
zip_remote (RemoteFileMetadata):
Object containing download information
save_dir (str):
Path to save downloaded file
force_overwrite (bool):
If True, overwrites existing files
cleanup (bool):
If True, remove zipfile after unziping
"""
zip_download_path = download_from_remote(
zip_remote, save_dir, force_overwrite, allow_invalid_checksum
)
unzip(zip_download_path, cleanup=cleanup)
[docs]
def unzip(zip_path, cleanup):
"""Unzip a zip file inside it's current directory.
Args:
zip_path (str): Path to zip file
cleanup (bool): If True, remove zipfile after unzipping
"""
zfile = zipfile.ZipFile(zip_path, "r")
extractall_unicode(zfile, os.path.dirname(zip_path))
zfile.close()
if cleanup:
os.remove(zip_path)
[docs]
def download_tar_file(
tar_remote, save_dir, force_overwrite, cleanup, allow_invalid_checksum
):
"""Download and untar a tar file.
Args:
tar_remote (RemoteFileMetadata): Object containing download information
save_dir (str): Path to save downloaded file
force_overwrite (bool): If True, overwrites existing files
cleanup (bool): If True, remove tarfile after untarring
"""
tar_download_path = download_from_remote(
tar_remote, save_dir, force_overwrite, allow_invalid_checksum
)
untar(tar_download_path, cleanup=cleanup)
[docs]
def untar(tar_path, cleanup):
"""Untar a tar file inside it's current directory.
Args:
tar_path (str): Path to tar file
cleanup (bool): If True, remove tarfile after untarring
"""
tfile = tarfile.open(tar_path, "r")
tfile.extractall(os.path.dirname(tar_path))
tfile.close()
if cleanup:
os.remove(tar_path)
[docs]
def move_directory_contents(source_dir, target_dir):
"""Move the contents of source_dir into target_dir, and delete source_dir
Args:
source_dir (str): path to source directory
target_dir (str): path to target directory
"""
directory_contents = glob.glob(os.path.join(source_dir, "*"))
for fpath in directory_contents:
target_path = os.path.join(target_dir, os.path.basename(fpath))
if os.path.exists(target_path):
logging.info(
"{} already exists. Run with force_overwrite=True to download from scratch".format(
target_path
)
)
continue
shutil.move(fpath, target_dir)
shutil.rmtree(source_dir)