pytorch - torch 视觉 MNIST HTTPError : HTTP Error 403: Forbidden

标签 pytorch dataset mnist torchvision

我正在尝试复制此网页中的实验 https://adversarial-ml-tutorial.org/adversarial_examples/

我得到了 jupyter notebook 并加载到我的本地主机并使用 Jupiter notebook 打开它。当我运行以下代码以使用以下代码获取数据集时:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

我收到以下错误:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz
0/? [00:00<?, ?it/s]
---------------------------------------------------------------------------
HTTPError                                 Traceback (most recent call last)
<ipython-input-15-e6f62798f426> in <module>
      2 from torch.utils.data import DataLoader
      3 
----> 4 mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
      5 mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
      6 train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)

~\Anaconda3\lib\site-packages\torchvision\datasets\mnist.py in __init__(self, root, train, transform, target_transform, download)
     77 
     78         if download:
---> 79             self.download()
     80 
     81         if not self._check_exists():

~\Anaconda3\lib\site-packages\torchvision\datasets\mnist.py in download(self)
    144         for url, md5 in self.resources:
    145             filename = url.rpartition('/')[2]
--> 146             download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
    147 
    148         # process and save as torch files

~\Anaconda3\lib\site-packages\torchvision\datasets\utils.py in download_and_extract_archive(url, download_root, extract_root, filename, md5, remove_finished)
    254         filename = os.path.basename(url)
    255 
--> 256     download_url(url, download_root, filename, md5)
    257 
    258     archive = os.path.join(download_root, filename)

~\Anaconda3\lib\site-packages\torchvision\datasets\utils.py in download_url(url, root, filename, md5)
     82                 )
     83             else:
---> 84                 raise e
     85         # check integrity of downloaded file
     86         if not check_integrity(fpath, md5):

~\Anaconda3\lib\site-packages\torchvision\datasets\utils.py in download_url(url, root, filename, md5)
     70             urllib.request.urlretrieve(
     71                 url, fpath,
---> 72                 reporthook=gen_bar_updater()
     73             )
     74         except (urllib.error.URLError, IOError) as e:  # type: ignore[attr-defined]

~\Anaconda3\lib\urllib\request.py in urlretrieve(url, filename, reporthook, data)
    245     url_type, path = splittype(url)
    246 
--> 247     with contextlib.closing(urlopen(url, data)) as fp:
    248         headers = fp.info()
    249 

~\Anaconda3\lib\urllib\request.py in urlopen(url, data, timeout, cafile, capath, cadefault, context)
    220     else:
    221         opener = _opener
--> 222     return opener.open(url, data, timeout)
    223 
    224 def install_opener(opener):

~\Anaconda3\lib\urllib\request.py in open(self, fullurl, data, timeout)
    529         for processor in self.process_response.get(protocol, []):
    530             meth = getattr(processor, meth_name)
--> 531             response = meth(req, response)
    532 
    533         return response

~\Anaconda3\lib\urllib\request.py in http_response(self, request, response)
    639         if not (200 <= code < 300):
    640             response = self.parent.error(
--> 641                 'http', request, response, code, msg, hdrs)
    642 
    643         return response

~\Anaconda3\lib\urllib\request.py in error(self, proto, *args)
    567         if http_err:
    568             args = (dict, 'default', 'http_error_default') + orig_args
--> 569             return self._call_chain(*args)
    570 
    571 # XXX probably also want an abstract factory that knows when it makes

~\Anaconda3\lib\urllib\request.py in _call_chain(self, chain, kind, meth_name, *args)
    501         for handler in handlers:
    502             func = getattr(handler, meth_name)
--> 503             result = func(*args)
    504             if result is not None:
    505                 return result

~\Anaconda3\lib\urllib\request.py in http_error_default(self, req, fp, code, msg, hdrs)
    647 class HTTPDefaultErrorHandler(BaseHandler):
    648     def http_error_default(self, req, fp, code, msg, hdrs):
--> 649         raise HTTPError(req.full_url, code, msg, hdrs, fp)
    650 
    651 class HTTPRedirectHandler(BaseHandler):

HTTPError: HTTP Error 403: Forbidden

非常感谢任何解决此问题的帮助。 我也可以直接从链接下载数据集,但我不知道如何使用它!

最佳答案

是的,这是一个已知错误:https://github.com/pytorch/vision/issues/3500

可能的解决方案是修补 MNIST download 方法。

但它需要安装wget

对于 Linux:

sudo apt install wget

对于 Windows:

choco install wget
import os
import subprocess as sp
from torchvision.datasets.mnist import MNIST, read_image_file, read_label_file
from torchvision.datasets.utils import extract_archive


def patched_download(self):
    """wget patched download method.
    """
    if self._check_exists():
        return

    os.makedirs(self.raw_folder, exist_ok=True)
    os.makedirs(self.processed_folder, exist_ok=True)

    # download files
    for url, md5 in self.resources:
        filename = url.rpartition('/')[2]
        download_root = os.path.expanduser(self.raw_folder)
        extract_root = None
        remove_finished = False

        if extract_root is None:
            extract_root = download_root
        if not filename:
            filename = os.path.basename(url)
        
        # Use wget to download archives
        sp.run(["wget", url, "-P", download_root])

        archive = os.path.join(download_root, filename)
        print("Extracting {} to {}".format(archive, extract_root))
        extract_archive(archive, extract_root, remove_finished)

    # process and save as torch files
    print('Processing...')

    training_set = (
        read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
        read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
    )
    test_set = (
        read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
        read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
    )
    with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
        torch.save(training_set, f)
    with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
        torch.save(test_set, f)

    print('Done!')


MNIST.download = patched_download

mnist_train = MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size=1, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=1, shuffle=False)

关于pytorch - torch 视觉 MNIST HTTPError : HTTP Error 403: Forbidden,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66467005/

相关文章:

python - 如何在pytorch模型中加载检查点文件?

python - 调用 super 的 forward() 方法

dictionary - 如何将数据集转换为存储库内的字典。我在 Foundry 中使用 pyspark

c# - 如何在ListView/GridView中显示DataSet

theano - Caffe 与 Theano MNIST 示例

python - 如何将 Hermite 多项式与随机梯度下降 (SGD) 结合使用?

python - 由于内存问题,如何保存仅与预训练bert模型的分类器层相关的参数?

java - 在 BIRT Designer 中禁用与数据源的连接

python - MNIST 图像纠偏

python-3.x - PyTorch 时尚-MNIST (ETL)