如何显示比批量大小值更多的图像?

原文标题How to show more images than the batch size value?

我有以下代码:

train_ds = tf.keras.utils.image_dataset_from_directory(
  '/media/Tesi',
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(360, 360),
  batch_size=18)

class_names = train_ds.class_names

val_ds = tf.keras.utils.image_dataset_from_directory(
  '/media/Tesi',
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(360, 360),
  batch_size=18)

num_classes = len(class_names)

然后我创建一个模型并做出一些概率。当我在 val_ds 中显示图像时,我的代码是:

plt.figure(figsize=(20, 20))
for images, _ in val_ds.take(1):
    for i in range(18):
        ax = plt.subplot(6, 6, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[np.argmax(predictions[i])])
        plt.axis("off")

这样,我总是显示 val_ds 的前 18 个图像。例如,如何显示索引 18 到 36 的图像?谢谢

原文链接:https://stackoverflow.com//questions/71515145/how-to-show-more-images-than-the-batch-size-value

回复

我来回复
  • AloneTogether的头像
    AloneTogether 评论

    你可以使用tf.data.Dataset.skiptf.data.Dataset.take

    import tensorflow as tf
    import pathlib
    import matplotlib.pyplot as plt
    
    dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
    data_dir = pathlib.Path(data_dir)
    
    batch_size = 18
    
    val_ds = tf.keras.utils.image_dataset_from_directory(
      data_dir,
      validation_split=0.2,
      subset="validation",
      seed=123,
      image_size=(360, 360),
      batch_size=batch_size, shuffle=False)
    
    for images, _ in val_ds.skip(1).take(1):
      for i in range(18):
        ax = plt.subplot(6, 6, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.axis("off")
    

    在此示例中,前 18 张图像被跳过(1 批),然后您拍摄接下来的 18 张图像(也是 1 批)。您只需要确定shuffle=False,以确保您在调用时不会得到相同的图像take(1)

    2年前 0条评论