EDIT:如果类的数量大于5,那么你可以使用新的tf.contrib.data.sample_from_datasets()
API(当前可用tf-nightly
并将在 TensorFlow 1.9 中提供)。
directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", ...]
CLASSES_PER_BATCH = 5
EXAMPLES_PER_CLASS_PER_BATCH = 5
BATCH_SIZE = CLASSES_PER_BATCH * EXAMPLES_PER_CLASS_PER_BATCH
NUM_CLASSES = len(directories)
# Build one dataset per class.
per_class_datasets = [
tf.data.TFRecordDataset(tf.data.Dataset.list_files(d)) for d in directories]
# Next, build a dataset where each element is a vector of 5 classes to be chosen
# for a particular batch.
classes_per_batch_dataset = tf.contrib.data.Counter().map(
lambda _: tf.random_shuffle(tf.range(NUM_CLASSES))[:CLASSES_PER_BATCH]))
# Transform the dataset of per-batch class vectors into a dataset with one
# one-hot element per example (i.e. 25 examples per batch).
class_dataset = classes_per_batch_dataset.flat_map(
lambda classes: tf.data.Dataset.from_tensor_slices(
tf.one_hot(classes, num_classes)).repeat(EXAMPLES_PER_CLASS_PER_BATCH))
# Use `tf.contrib.data.sample_from_datasets()` to select an example from the
# appropriate dataset in `per_class_datasets`.
example_dataset = tf.contrib.data.sample_from_datasets(per_class_datasets,
class_dataset)
# Finally, combine 25 consecutive examples into a batch.
result = example_dataset.batch(BATCH_SIZE)
如果您正好有 5 个类,则可以为每个目录定义一个嵌套数据集并使用Dataset.interleave():
# NOTE: We're assuming that the 0th directory contains elements from class 0, etc.
directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", "class_4/*"]
directories = tf.data.Dataset.from_tensor_slices(directories)
directories = directories.apply(tf.contrib.data.enumerate_dataset())
# Define a function that maps each (class, directory) pair to the (shuffled)
# records in those files.
def per_directory_dataset(class_label, directory_glob):
files = tf.data.Dataset.list_files(directory_glob, shuffle=True)
records = tf.data.TFRecordDataset(records)
# Zip the records with their class.
# NOTE: This part might not be necessary if the records contain information about
# their class that can be parsed from them.
return tf.data.Dataset.zip(
(records, tf.data.Dataset.from_tensors(class_label).repeat(None)))
# NOTE: The `cycle_length` and `block_length` here aren't strictly necessary,
# because the batch size is exactly `number of classes * images per class`.
# However, these arguments may be useful if you want to decouple these numbers.
merged_records = directories.interleave(per_directory_dataset,
cycle_length=5, block_length=5)
merged_records = merged_records.batch(25)