从 Tensorflow 中的张量读取文件名

原文标题Reading in file names from a tensor in Tensorflow

背景:我正在尝试制作一个 GAN 来从大型数据集中生成图像,并且在加载训练数据时遇到了 OOM 问题。为了解决这个问题,我试图传入一个文件目录列表,并仅在需要时将它们作为图像读取。

问题:我不知道如何从张量本身解析出文件名。如果有人对如何将张量转换回列表或以某种方式遍历张量有任何见解。或者,如果这是解决此问题的不好方法,请告诉我

相关代码片段:

生成数据:注意:make_file_list()返回我要读取的所有图像的文件名列表

data = make_file_list(base_dir)
train_dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train(train_dataset, EPOCHS)

训练功能:

def train(dataset, epochs):
    plot_iteration = []
    gen_loss_l = []
    disc_loss_l = []

    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            gen_loss, disc_loss = train_step(image_batch)

训练步骤:

@tf.function
def train_step(image_files):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    images = [load_img(filepath) for filepath in image_files]

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

错误:

line 250, in train_step  *
        images = [load_img(filepath) for filepath in image_files]

    OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature

原文链接:https://stackoverflow.com//questions/71991105/reading-in-file-names-from-a-tensor-in-tensorflow

回复

我来回复
  • bui的头像
    bui 评论

    在你的train_step上移除@tf.function装饰器。如果你用@tf.function 装饰你的train_step,Tensorflow 将尝试将里面的 Python 代码train_step转换为执行图,而不是在 Eager 模式下运行。执行图提供加速,但也对可以执行哪些运算符(如错误所述)。

    要保持@tf.functionontrain_step,您可以先在train函数中执行迭代和加载步骤,然后将已加载的图像作为参数传递给train_step,而不是尝试直接在train_step中加载图像

    def train(dataset, epochs):
        plot_iteration = []
        gen_loss_l = []
        disc_loss_l = []
    
        for epoch in range(epochs):
            start = time.time()
    
        for image_batch in dataset:
            images = [load_img(filepath) for filepath in image_batch ]
            gen_loss, disc_loss = train_step(images)
    
    @tf.function
    def train_step(images):
        noise = tf.random.normal([BATCH_SIZE, noise_dim])
    
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = generator(noise, training=True)
            ....
    
    1年前 0条评论