MNISTの使い方まとめ

概要

この記事ではMNISTデータセットの概要とそのダウンロード方法についてまとめる。

MNISTとは

「Modified National Institute of Standards and Technology database」の略称。
0~9の手書き数字文字とそのラベル(正解データ)データセットで、よく機械学習のサンプルコードの学習対象とされる。

  • 画像は学習データが60,000枚、テストデータが10,000枚で合計70,000枚
  • 各ピクセルが0~255の値を持つグレースケールの画像
  • 画像のサイズは縦横ともに28ピクセル

本家はこちら(http://yann.lecun.com/exdb/mnist/)
本家の情報では、各データセットのサイズは下記の通り。

  • train-images-idx3-ubyte.gz: training set images (9912422 bytes)
  • train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
  • t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
  • t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)

    ダウンロード方法

    主要なライブラリではとても簡単に扱えるように整備されており、
    ソースコード内でダウンロードする手段が用意されています。

Keras

1
2
3
# https://keras.io/datasets/
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

PyTorch

1
2
3
# https://pytorch.org/docs/stable/torchvision/datasets.html#mnist
from torchvision.datasets import MNIST
mnist_data = MNIST('YOUR_DIR_PATH',download=True,)

TensorFlow

1
2
3
# https://www.tensorflow.org/datasets/overview
import tensorflow_datasets as tfds
dataset = tfds.load('mnist', as_supervised=True)

scikit-learn

1
2
3
# https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_openml.html
from sklearn.datasets import fetch_openml
X, y = fetch_openml('mnist_784', return_X_y=True)

データセットの中身の確認

上述のscikit-learnでのダウンロードを選択したと仮定して、データの中身を確認してみる。

Xとyの形状を確認

1
2
3
4
>>> X.shape
(70000, 784)
>>> y.shape
(70000,)

画像はきちんと70,000枚ダウンロードできている。Xの形状をみると、2軸目が784となっており28*28の画像がflatに格納されていることがわかる。

Xの確認

1
2
3
4
5
6
7
8
>>> X
array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])

なるほど。先頭と末尾は0ばかりのようだ。ここでは割愛するが、X[0]などを確認してみると良いだろう。

yの確認

1
2
>>> y
array(['5', '0', '4', ..., '4', '5', '6'], dtype=object)

objectとして各文字が保存されている。ラベルによって画像が
ソートされているわけではなさそうだ。
one-hot vectorに変換したい場合は以下のようにpreprocessing.LabelBinarizerを使うと良いだろう。

1
2
3
4
5
6
7
8
9
10
11
# https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelBinarizer.html
>>> from sklearn import preprocessing
>>> lb = preprocessing.LabelBinarizer()
>>> lb.fit_transform(y)
array([[0, 0, 0, ..., 0, 0, 0],
[1, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]])

データの分布

collectionsを使って、各ラベルの数え上げを行う。

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> import collections
>>> c = collections.Counter(y)
>>> c
Counter({'0': 6903,
'1': 7877,
'2': 6990,
'3': 7141,
'4': 6824,
'5': 6313,
'6': 6876,
'7': 7293,
'8': 6825,
'9': 6958})

偏りがあるようなので注意が必要だ。

まとめ

MNISTデータセットの概要とそのダウンロード方法についてまとめた。