你当前正在访问 Microsoft Azure Global Edition 技术文档网站。 如果需要访问由世纪互联运营的 Microsoft Azure 中国技术文档网站,请访问 https://docs.azure.cn。
包含手写数字的 MNIST 数据库
包含手写数字的 MNIST 数据库有一个 60,000 示例的训练集,还有一个 10,000 示例的测试集。 这些数字已在大小方面规范化,在固定大小的图像中居中。
注意
Microsoft 按“原样”提供 Azure 开放数据集。 Microsoft 对数据集的使用不提供任何担保(明示或暗示)、保证或条件。 在当地法律允许的范围内,Microsoft 对使用数据集而导致的任何损害或损失不承担任何责任,包括直接、必然、特殊、间接、偶发或惩罚性损害或损失。
此数据集是根据 Microsoft 接收源数据的原始条款提供的。 数据集可能包含来自 Microsoft 的数据。
此数据集来源于 MNIST 手写数字数据库。 它是美国国家标准与技术研究所发布的大型 NIST 手写体及字符数据库的子集。
存储位置
- Blob 帐户:azureopendatastorage
- 容器名:mnist
在容器中可直接使用以下四个文件:
- train-images-idx3-ubyte.gz:训练集图像(9912422 字节)
- train-labels-idx1-ubyte.gz:训练集标签(28881 字节)
- t10k-images-idx3-ubyte.gz:测试集图像(1648877 字节)
- t10k-labels-idx1-ubyte.gz:测试集标签(4542 字节)
数据访问
Azure Notebooks
使用 Azure 机器学习表格数据集将 MNIST 加载到数据帧中。
有关 Azure 机器学习数据集的详细信息,请参阅创建 Azure 机器学习数据集。
获取数据帧的完整数据集
from azureml.opendatasets import MNIST
mnist = MNIST.get_tabular_dataset()
mnist_df = mnist.to_pandas_dataframe()
mnist_df.info()
获取训练和测试数据帧
mnist_train = MNIST.get_tabular_dataset(dataset_filter='train')
mnist_train_df = mnist_train.to_pandas_dataframe()
X_train = mnist_train_df.drop("label", axis=1).astype(int).values/255.0
y_train = mnist_train_df.filter(items=["label"]).astype(int).values
mnist_test = MNIST.get_tabular_dataset(dataset_filter='test')
mnist_test_df = mnist_test.to_pandas_dataframe()
X_test = mnist_test_df.drop("label", axis=1).astype(int).values/255.0
y_test = mnist_test_df.filter(items=["label"]).astype(int).values
绘制一些数字图像
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
# now let's show some randomly chosen images from the traininng set.
count = 0
sample_size = 30
plt.figure(figsize=(16, 6))
for i in np.random.permutation(X_train.shape[0])[:sample_size]:
count = count + 1
plt.subplot(1, sample_size, count)
plt.axhline('')
plt.axvline('')
plt.text(x=10, y=-10, s=y_train[i], fontsize=18)
plt.imshow(X_train[i].reshape(28, 28), cmap=plt.cm.Greys)
plt.show()
下载或装载 MNIST 原始文件 Azure 机器学习文件数据集。
这仅适用于基于 Linux 的计算。 有关 Azure 机器学习数据集的详细信息,请参阅创建 Azure 机器学习数据集。
mnist_file = MNIST.get_file_dataset()
mnist_file
mnist_file.to_path()
将文件下载到本地存储
import os
import tempfile
data_folder = tempfile.mkdtemp()
data_paths = mnist_file.download(data_folder, overwrite=True)
data_paths
装载文件。 训练作业将在远程计算上运行时非常有用。
import gzip
import struct
import pandas as pd
import numpy as np
# load compressed MNIST gz files and return pandas dataframe of numpy arrays
def load_data(filename, label=False):
with gzip.open(filename) as gz:
gz.read(4)
n_items = struct.unpack('>I', gz.read(4))
if not label:
n_rows = struct.unpack('>I', gz.read(4))[0]
n_cols = struct.unpack('>I', gz.read(4))[0]
res = np.frombuffer(gz.read(n_items[0] * n_rows * n_cols), dtype=np.uint8)
res = res.reshape(n_items[0], n_rows * n_cols)
else:
res = np.frombuffer(gz.read(n_items[0]), dtype=np.uint8)
res = res.reshape(n_items[0], 1)
return pd.DataFrame(res)
import sys
mount_point = tempfile.mkdtemp()
print(mount_point)
print(os.path.exists(mount_point))
if sys.platform == 'linux':
print("start mounting....")
with mnist_file.mount(mount_point):
print("list dir...")
print(os.listdir(mount_point))
print("get the dataframe info of mounted data...")
train_images_df = load_data(next(path for path in data_paths if path.endswith("train-images-idx3-ubyte.gz")))
print(train_images_df.info())
Azure Databricks
使用 Azure 机器学习表格数据集将 MNIST 加载到数据帧中。
有关 Azure 机器学习数据集的详细信息,请参阅创建 Azure 机器学习数据集。
获取数据帧的完整数据集
# This is a package in preview.
from azureml.opendatasets import MNIST
mnist = MNIST.get_tabular_dataset()
mnist_df = mnist.to_spark_dataframe()
display(mnist_df.limit(5))
下载或装载 MNIST 原始文件 Azure 机器学习文件数据集。
这仅适用于基于 Linux 的计算。 有关 Azure 机器学习数据集的详细信息,请参阅创建 Azure 机器学习数据集。
mnist_file = MNIST.get_file_dataset()
mnist_file
mnist_file.to_path()
将文件下载到本地存储
import os
import tempfile
mount_point = tempfile.mkdtemp()
mnist_file.download(mount_point, overwrite=True)
装载文件。 训练作业将在远程计算上运行时非常有用。
import gzip
import struct
import pandas as pd
import numpy as np
# load compressed MNIST gz files and return numpy arrays
def load_data(filename, label=False):
with gzip.open(filename) as gz:
gz.read(4)
n_items = struct.unpack('>I', gz.read(4))
if not label:
n_rows = struct.unpack('>I', gz.read(4))[0]
n_cols = struct.unpack('>I', gz.read(4))[0]
res = np.frombuffer(gz.read(n_items[0] * n_rows * n_cols), dtype=np.uint8)
res = res.reshape(n_items[0], n_rows * n_cols)
else:
res = np.frombuffer(gz.read(n_items[0]), dtype=np.uint8)
res = res.reshape(n_items[0], 1)
return pd.DataFrame(res)
import sys
mount_point = tempfile.mkdtemp()
print(mount_point)
print(os.path.exists(mount_point))
print(os.listdir(mount_point))
if sys.platform == 'linux':
print("start mounting....")
with mnist_file.mount(mount_point):
print(context.mount_point )
print(os.listdir(mount_point))
train_images_df = load_data(os.path.join(mount_point, 'train-images-idx3-ubyte.gz'))
print(train_images_df.info())
后续步骤
查看开放数据集目录中的其余数据集。