TransWikia.com

Issues in plotting Images using Keras

Data Science Asked by Anubhav Sachdev on July 13, 2021

I am trying to visualize Skin Cancer Images using Keras. I have imported the images in my notebook and have created batch datasets using Keras.image_dataset_from_directory. The code is as follows:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=1337,
image_size=image_size,
batch_size=batch_size)

Now, I have been trying to visualize the images. However, I want one image from each class (there are 9 classes in the dataset). I have used the below code:

plt.figure(figsize = (10,10))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3,3,i+1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

This code gets me a lot of duplicate classes. How do I get one value for each class (in this case I have 9 classes. I want one plot for each of those 9 classes)? I am not sure how to get unique classes and images from a BatchDataset?

One Answer

Given you already have the tf.data.Dataset, one way to do it would be to iterate over the dataset and each time you come across a new label, save that e.g. to a dictionary, otherwise skip an already seen label.

Here is a short example just using the MNIST dataset that comes with tensorflow:

import matplotlib.pyplot as plt
import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))

Now we have the dataset, we can shuffle it and also create a batch size of one, so each iteration will return one image and one label:

# Shuffle and make batch_size 1 to iterate over single example
dataset = dataset.shuffle(buffer_size=2048).batch(1)

For this MNIST dataset, the labels are known to be integers from 0-9, one for each single digit:

known_labels = list(range(10)) # in the case of MNIST used here

Now we can create an empty dictionary into which we will insert each class with its image as we come across it:

label_to_images = {}

And iterate over our dataset. If the label is not already in the dictionary, we add it. We also check that we have all labels and then break the cycle if we have:

for image, label in dataset:
    label_as_int = int(label.numpy())
    if label_as_int not in label_to_images.keys():
        print(f"Found label: {label_as_int}")
        label_to_images[label_as_int] = image

    # Sort before comparison as keys are sorted according to insertion order
    if sorted(list(label_to_images.keys())) == known_labels:
        print(f"Got all labels! -> {list(label_to_images.keys())}")
        break

# OUTPUT
Found label: 1
Found label: 4
Found label: 7
Found label: 9
Found label: 2
Found label: 0
Found label: 6
Found label: 3
Found label: 5
Found label: 8
Got all labels! -> [1, 4, 7, 9, 2, 0, 6, 3, 5, 8]

Now we can plot them all as follows:

fig, axs = plt.subplots(2, 5)   # 2 rows, 5 columns

plt.gray()   # Need this because the MNIST images are all grayscale

for ax, (label, image_tensor) in zip(axs.flatten(), label_to_images.items()):
    image = image_tensor.numpy()[0]   # remove channel dimension, we don't have RGB
    ax.set_title(f"Class: {label}")
    ax.imshow(image)
    ax.axis("off")    # we don't need to see the pixel numbers as axes

    
fig.tight_layout()

Which gives the following output:

MNIST digits

Because we added the shuffle() method to our dataset, if you run the code again, you'll get a different result.

Answered by n1k31t4 on July 13, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP