Question
Let's say I have defined a dataset in this way:
filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))
how can I get the number of elements that are inside the dataset (hence, the number of single elements that compose an epoch)?
I know that tf.data.Dataset
already knows the dimension of the dataset,
because the repeat()
method allows repeating the input pipeline for a
specified number of epochs. So it must be a way to get this information.
Answer
tf.data.Dataset.list_files
creates a tensor called MatchingFiles:0
(with
the appropriate prefix if applicable).
You could evaluate
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]
to get the number of files.
Of course, this would work in simple cases only, and in particular if you have only one sample (or a known number of samples) per image.
In more complex situations, e.g. when you do not know the number of samples in each file, you can only observe the number of samples as an epoch ends.
To do this, you can watch the number of epochs that is counted by your
Dataset
. repeat()
creates a member called _count
, that counts the number
of epochs. By observing it during your iterations, you can spot when it
changes and compute your dataset size from there.
This counter may be buried in the hierarchy of Dataset
s that is created when
calling member functions successively, so we have to dig it out like this.
d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround
RepeatDataset = type(tf.data.Dataset().repeat())
try:
while not isinstance(d, RepeatDataset):
d = d._input_dataset
except AttributeError:
warnings.warn('no epoch counter found')
epoch_counter = None
else:
epoch_counter = d._count
Note that with this technique, the computation of your dataset size is not
exact, because the batch during which epoch_counter
is incremented typically
mixes samples from two successive epochs. So this computation is precise up to
your batch length.