记录若干`tf.py_function`的使用的方式,便于查阅

本文所使用的TensorFlow版本为 2.9.0-rc0
众所周知,在TensorFlow2.x中 tf.py_function可以帮助我们让原本只能在Eager Mode下才能运行的函数体顺利运行在计算图中,获得我们预期的输出。基本使用规则见tf.py_function,这里不再赘述。[0]

由于底层的实现倾向于序列,tf.py_function对映射类型并不友好,即若输入类型为 dict[str,tf.Tensor], 是无法直接输出对应的字典形式的dict[str,tf.Tensor],这无疑增加了我们的coding难度。

1.处理输入输出structure一致的映射

简单起见,我们构建两个同功能的函数体original_func和modified_func,后者支持tf.py_function,将输入的张量翻倍,不改变输入输出的映射结构。为了体现tf.py_function的特点与使用难点,我们以tf.string形式作为输入,则函数体内必须先将tf.string转化为一般的tf.in32或tf.float32。具体的代码如下:

import tensorflow as tf 
import functools
def tf_py_function_wrapper(func=None):
    # since tf.py_function can not deal with dict directly, and its using form is not easy
    # here we make this wrapper, it can trans `func`'s output 
    # all tensors that a user want to use by calling their `numpy()` functions should become the inputs of the wrapped `func`, otherwise, `numpy()` will not work
    # since tf.nest.flatten  tf.nest.pack_sequence_as will sort dict structure's `keys` automaticlly, we do not use tf.nest here to avoid unexpected behavior
    if func is None:
        return functools.partial(tf_py_function_wrapper,)
    @functools.wraps(func)
    def wrappered(inputs:dict[str,tf.Tensor],output_structure:dict[str,tf.TensorSpec])->dict[str,tf.Tensor]:
        inp = tuple(inputs.values())
        Tout = tuple(output_structure.values())
        flattened_output = tf.py_function(func,inp=inp,Tout=Tout)
        return dict(zip(output_structure.keys(),flattened_output))
    return wrappered  

inputs = {'k1':tf.convert_to_tensor(str(1)),
          'k2':tf.convert_to_tensor(str(2)),
          'k3':tf.convert_to_tensor(str(3)),
          'k4':tf.convert_to_tensor(str(4))}
#@tf.function 
def original_func(inputs:dict[str,tf.Tensor]): # 各元素*2
    # @tf.function会导致函数运行在图模式, 此时tf.Tensor不存在numpy()方法, 因此当前函数尽可能运行在Eager模式
    for key in inputs:
        inputs[key] = tf.convert_to_tensor(int(str(inputs[key].numpy(),encoding='UTF-8')))*2
    return inputs
print(original_func(inputs))

inputs = {'k1':tf.convert_to_tensor(str(1)),
          'k2':tf.convert_to_tensor(str(2)),
          'k3':tf.convert_to_tensor(str(3)),
          'k4':tf.convert_to_tensor(str(4))}
@tf_py_function_wrapper
def modified_func(*inputs): # 各元素*2
    return [tf.convert_to_tensor(int(str(item.numpy(),encoding='UTF-8')))*2 for item in inputs]
output_structure = {key:tf.TensorSpec(shape=None,dtype=tf.int32) for key in inputs}
print(modified_func(inputs,output_structure))

其中 original_func()是我们原本希望实现的函数功能,modified_func()是对original_func()的修改,不再是以字典作为输入输出结构,而是以序列(字典的值)作为输入输出,可以被tf.py_function接收。为了获得字典形式的输出,我们构建自定义的装饰器tf_py_function_wrapper,将modified_func()的输出序列依据output_structure映射回字典。

2. 处理输入输出structure不一致的映射

1中的例子比较简单,输入输出的structure是一致的,但有时候我们需要处理输入输出不一致的情形,比如对输入做多组patch切片,必然无法保证输出的structure不变。以下代码展示了基于1中样例,但输入输出的结构不一致的情况下tf.py_function的使用。

import tensorflow as tf 
import functools
def tf_py_function_wrapper(func=None):
    # since tf.py_function can not deal with dict directly, and its using form is not easy
    # here we make this wrapper, it can trans `func`'s output 
    # all tensors that a user want to use by calling their `numpy()` functions should become the inputs of the wrapped `func`, otherwise, `numpy()` will not work
    # since tf.nest.flatten  tf.nest.pack_sequence_as will sort dict structure's `keys` automaticlly, we do not use tf.nest here to avoid unexpected behavior
    if func is None:
        return functools.partial(tf_py_function_wrapper,)
    @functools.wraps(func)
    def wrappered(inputs:dict[str,tf.Tensor],output_structure:dict[str,tf.TensorSpec])->dict[str,tf.Tensor]:
        inp = tuple(inputs.values())
        Tout = tuple(output_structure.values())
        flattened_output = tf.py_function(func,inp=inp,Tout=Tout)
        return dict(zip(output_structure.keys(),flattened_output))
    return wrappered  

inputs = {'k1':tf.convert_to_tensor(str(1)),
          'k2':tf.convert_to_tensor(str(2)),
          'k3':tf.convert_to_tensor(str(3)),
          'k4':tf.convert_to_tensor(str(4))}
#@tf.function 
def original_func(inputs:dict[str,tf.Tensor]): # 各元素*2
    # @tf.function会导致函数运行在图模式, 此时tf.Tensor不存在numpy()方法, 因此当前函数尽可能运行在Eager模式
    for key in inputs:
        inputs[key] = tf.convert_to_tensor(int(str(inputs[key].numpy(),encoding='UTF-8')))*2
    inputs["k5"] = tf.random.normal(shape=[],dtype=tf.float32)
    inputs["k6"] = tf.random.normal(shape=[],dtype=tf.float16)
    return inputs
print(original_func(inputs))

inputs = {'k1':tf.convert_to_tensor(str(1)),
          'k2':tf.convert_to_tensor(str(2)),
          'k3':tf.convert_to_tensor(str(3)),
          'k4':tf.convert_to_tensor(str(4))}
@tf_py_function_wrapper
def modified_func(*inputs): # 各元素*2
    return [tf.convert_to_tensor(int(str(item.numpy(),encoding='UTF-8')))*2 for item in inputs]+[tf.random.normal(shape=[],dtype=tf.float32),tf.random.normal(shape=[],dtype=tf.float16)]
output_structure = {key:tf.TensorSpec(shape=None,dtype=tf.int32) for key in inputs}|\
                   {"k5":tf.TensorSpec(shape=None,dtype=tf.float32),"k6":tf.TensorSpec(shape=None,dtype=tf.float16)}
print(modified_func(inputs,output_structure))

其中 original_func()是我们原本希望实现的函数功能,相比于1中,此处original_func()返回了额外的内容,导致输入输出结构不一致。
modified_func()为了应对这种不一致,首先需要在返回的序列中,按顺序添加额外的元素,其次,需要修改output_structure,使其继续和modified_func()的返回值序列保存一种一致性,从而使输出序列能被tf_py_function_wrapper正确还原为字典形式。

Summarize

对比1和2中样例才可以发现,tf_py_function_wrapper的设计为tf.py_function在处理映射类型的输入张量时增添了便捷性,不论输入输出structure一致或者不一致,都可以适用。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2022年5月7日
下一篇 2022年5月7日

相关推荐