Stack Overflow на русском Asked by Ylanaish on September 4, 2020
Есть обученная на данных mnist нейросеть. Когда загружаю свое изображение цифры, получаю следующую ошибку:
TypeError Traceback (most recent call last)
<ipython-input-169-3810d7660ceb> in <module>()
1 i = 0
2 plt.figure()
----> 3 plot_images(i, predictions, test_labels, x)
4 plt.show()
6 frames
/usr/local/lib/python3.6/dist-packages/matplotlib/image.py in set_data(self, A)
697 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
698 raise TypeError("Invalid shape {} for image data"
--> 699 .format(self._A.shape))
700
701 if self._A.ndim == 3:
TypeError: Invalid shape (28, 28, 1) for image data
На тестовых данных ошибки нет. Код пишу в google colab. Вот полный код:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
from google.colab import files
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
model = keras.Sequential([
keras.layers.Flatten(input_shape = (28, 28)),
keras.layers.Dropout(0.2),
keras.layers.Dense(128, activation = 'relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation = 'softmax')
])
model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
model.fit(train_images, train_labels, epochs = 10)
files.upload()
images = keras.preprocessing.image.load_img("three.png", target_size=(28, 28))
x = keras.preprocessing.image.img_to_array(images)
x = tf.image.rgb_to_grayscale(x)
x = np.expand_dims(x, axis=0)
x = x/255.0
img = []
img.append(x)
predictions = model.predict(img)
def plot_images(i, predictions_array, true_label, img):
predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap = plt.cm.binary)
predictions_label = np.argmax(predictions_array)
plt.xlabel('{}'.format(class_names[predictions_label]))
i = 0
plt.figure()
plot_images(i, predictions, test_labels, x)
plt.show()
Только начал изучать tensorflow, прошу помочь. Спасибо.
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP