TransWikia.com

Проверка нейросети на своих данных tensorflow

Stack Overflow на русском Asked by Ylanaish on September 4, 2020

Есть обученная на данных mnist нейросеть. Когда загружаю свое изображение цифры, получаю следующую ошибку:

TypeError                                 Traceback (most recent call last)

<ipython-input-169-3810d7660ceb> in <module>()
      1 i = 0
      2 plt.figure()
----> 3 plot_images(i, predictions, test_labels, x)
      4 plt.show()

6 frames

/usr/local/lib/python3.6/dist-packages/matplotlib/image.py in set_data(self, A)
    697                 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
    698             raise TypeError("Invalid shape {} for image data"
--> 699                             .format(self._A.shape))
    700 
    701         if self._A.ndim == 3:

TypeError: Invalid shape (28, 28, 1) for image data

На тестовых данных ошибки нет. Код пишу в google colab. Вот полный код:

import tensorflow as tf
from tensorflow import keras

import matplotlib.pyplot as plt
import numpy as np
from google.colab import files

(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

model = keras.Sequential([
                          keras.layers.Flatten(input_shape = (28, 28)),
                          keras.layers.Dropout(0.2),
                          keras.layers.Dense(128, activation = 'relu'),
                          keras.layers.Dropout(0.2),
                          keras.layers.Dense(10, activation = 'softmax')
])
model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])

model.fit(train_images, train_labels, epochs = 10)

files.upload()
images = keras.preprocessing.image.load_img("three.png", target_size=(28, 28))    
x = keras.preprocessing.image.img_to_array(images)
x = tf.image.rgb_to_grayscale(x)
x = np.expand_dims(x, axis=0)
x = x/255.0
img = []
img.append(x)

predictions = model.predict(img)

def plot_images(i, predictions_array, true_label, img):
  predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])

  plt.imshow(img, cmap = plt.cm.binary)
  predictions_label = np.argmax(predictions_array)

  plt.xlabel('{}'.format(class_names[predictions_label]))

i = 0
plt.figure()
plot_images(i, predictions, test_labels, x)
plt.show()

Только начал изучать tensorflow, прошу помочь. Спасибо.

введите сюда описание изображения

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP