Data Science Asked by Junhan Ouyang on October 12, 2020
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, MaxPool2D, Dense, Flatten, Dropout, BatchNormalization
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from collections import Counter
Generator = ImageDataGenerator(rotation_range = 40, shear_range = 0.15, zoom_range = 0.4)
models = []
eva_list = []
for i in range(5):
X_train, X_test, Y_train, Y_test = train_test_split(train_image, label_image, test_size = 0.3)
Train_generator = Generator.flow(X_train, Y_train, batch_size = 128)
model = Sequential()
#Conv 2D layer here( not important)
model.add(Dense(10, activation="softmax"))
model.compile(optimizer='adam',loss='CategoricalCrossentropy', metrics=['accuracy'])
model.summary()
model.fit(Train_generator, batch_size= 128, epochs= 50, verbose=2)
eva_list.append(model.evaluate(X_test, Y_test, verbose=1))
models.append(model)
test = test/ 255
test = test.to_numpy().reshape((-1, 28, 28, 1))
result_group = []
for i in range(len(models)):
#loop through the models to make prediction with each model, store the result in result group
temp_result = models[i].predict(test)
result_group.append(np.argmax(temp_result, axis = 1))
result = []
#loop through each test index, create a temporary list and find the most chosen number and use that as a final result
for i in range(len(result_group[0])):
compare = []
for z in range(len(result_group)):
compare.append(result_group[z][i])
common = Counter(compare)
result.append(common.most_common(1)[0][0])
print(result)
I have spent multiple days on the MNIST dataset. I have trained a rather deep CNN model and get a good 99.1 percent result. However, after reading some discussion posts I figured it may be a good idea to do bagging to increase the accuracy. My approach is basically to create a list to store those newly trained CNN models. And in the end, using a for loop to do prediction on each of the models. However I don’t know where I made a mistake, now all my model output predict 1 for all test data. Anyone can tell me what’s going on?
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP