从 Tensorflow 中的张量读取文件名
tensorflow 316
原文标题 :Reading in file names from a tensor in Tensorflow
背景:我正在尝试制作一个 GAN 来从大型数据集中生成图像,并且在加载训练数据时遇到了 OOM 问题。为了解决这个问题,我试图传入一个文件目录列表,并仅在需要时将它们作为图像读取。
问题:我不知道如何从张量本身解析出文件名。如果有人对如何将张量转换回列表或以某种方式遍历张量有任何见解。或者,如果这是解决此问题的不好方法,请告诉我
相关代码片段:
生成数据:注意:make_file_list()
返回我要读取的所有图像的文件名列表
data = make_file_list(base_dir)
train_dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train(train_dataset, EPOCHS)
训练功能:
def train(dataset, epochs):
plot_iteration = []
gen_loss_l = []
disc_loss_l = []
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
gen_loss, disc_loss = train_step(image_batch)
训练步骤:
@tf.function
def train_step(image_files):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
images = [load_img(filepath) for filepath in image_files]
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
错误:
line 250, in train_step *
images = [load_img(filepath) for filepath in image_files]
OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature
回复
我来回复-
bui 评论
该回答已被采纳!
在你的
train_step
上移除@tf.function
装饰器。如果你用@tf.function 装饰你的train_step
,Tensorflow 将尝试将里面的 Python 代码train_step
转换为执行图,而不是在 Eager 模式下运行。执行图提供加速,但也对可以执行哪些运算符(如错误所述)。要保持
@tf.function
ontrain_step
,您可以先在train
函数中执行迭代和加载步骤,然后将已加载的图像作为参数传递给train_step
,而不是尝试直接在train_step
中加载图像def train(dataset, epochs): plot_iteration = [] gen_loss_l = [] disc_loss_l = [] for epoch in range(epochs): start = time.time() for image_batch in dataset: images = [load_img(filepath) for filepath in image_batch ] gen_loss, disc_loss = train_step(images) @tf.function def train_step(images): noise = tf.random.normal([BATCH_SIZE, noise_dim]) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True) ....
1年前