TransWikia.com

28x28 mnist image unrolling to concatenate using numpy.vstack. Why is numpy.vstack so slow?

Data Science Asked by trendulous on March 23, 2021

I’m trying to do a simple reshape of a 60000,28,28 list of mnist digits into a 60000,784 numpy array where the digits have been unrolled.

To do this the code is this:

(xdata,xlabel),(ydata,ylabel)=tf.keras.datasets.mnist.load_data()
newxdata=np.array([])
cnt=0
for i in xdata:
 tmpx=i.ravel()
 if cnt == 0:
  newxdata=np.concatenate((newxdata,tmpx))
 else:
  newxdata=np.vstack((newxdata,tmpx))
 cnt=cnt+1

Why does this take so long to run? Is there a way to speed it up? Ultimately the data will be fed into a keras model in smaller batches. Would writing a generator that does the loop unrolling when a batchsize is asked for be more performant or would it not make a difference?

2 Answers

You can reshape a numpy array simply by:

newxdata =  xdata.reshape((60000,28*28))

for example. Or simply:

newxdata =  xdata.reshape((len(xdata),-1))

Note that reshape is a numpy function which can used also as:

import numpy as np
newxdata =  np.reshape(xdata, (60000,-1))

To speed up your loop you could alternatively use libraries like multiprocessing.Pool or CuPy or Numba.

Answered by Jojo on March 23, 2021

I went with that suggestion and had a few additions to the implementation.

Ultimately the newly transformed data would be ingested by a keras function that requires an input of ( batch_size, unrolled image), so i created a generator function .returnbatch(batch_size) that returns a ( batchsize, unrolledimage )

Answered by trendulous on March 23, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP