如何显示比批量大小值更多的图像?
tensorflow 187
原文标题 :How to show more images than the batch size value?
我有以下代码:
train_ds = tf.keras.utils.image_dataset_from_directory(
'/media/Tesi',
validation_split=0.2,
subset="training",
seed=123,
image_size=(360, 360),
batch_size=18)
class_names = train_ds.class_names
val_ds = tf.keras.utils.image_dataset_from_directory(
'/media/Tesi',
validation_split=0.2,
subset="validation",
seed=123,
image_size=(360, 360),
batch_size=18)
num_classes = len(class_names)
然后我创建一个模型并做出一些概率。当我在 val_ds 中显示图像时,我的代码是:
plt.figure(figsize=(20, 20))
for images, _ in val_ds.take(1):
for i in range(18):
ax = plt.subplot(6, 6, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[np.argmax(predictions[i])])
plt.axis("off")
这样,我总是显示 val_ds 的前 18 个图像。例如,如何显示索引 18 到 36 的图像?谢谢
回复
我来回复-
AloneTogether 评论
该回答已被采纳!
你可以使用
tf.data.Dataset.skip
和tf.data.Dataset.take
:import tensorflow as tf import pathlib import matplotlib.pyplot as plt dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True) data_dir = pathlib.Path(data_dir) batch_size = 18 val_ds = tf.keras.utils.image_dataset_from_directory( data_dir, validation_split=0.2, subset="validation", seed=123, image_size=(360, 360), batch_size=batch_size, shuffle=False) for images, _ in val_ds.skip(1).take(1): for i in range(18): ax = plt.subplot(6, 6, i + 1) plt.imshow(images[i].numpy().astype("uint8")) plt.axis("off")
在此示例中,前 18 张图像被跳过(1 批),然后您拍摄接下来的 18 张图像(也是 1 批)。您只需要确定
shuffle=False
,以确保您在调用时不会得到相同的图像take(1)
。2年前