表格数据:在不诉诸迭代的情况下实现自定义张量层
原文标题 :Tabular data: Implementing a custom tensor layer without resorting to iteration
我有一个张量操作的想法,它通过迭代并不难实现,批量大小为 1。但是我想尽可能地并行化它。
我有两个形状为 (n, 5) 的张量,称为 X 和 Y。X 实际上应该表示 5 个形状为 (n, 1): (x_1, …, x_n) 的一维张量。 Y 同上。
我想计算一个形状为 (n, 25) 的张量,其中每一列代表张量操作 f(x_i, y_j) 的输出,其中 f 对于所有 1 <= i, j <= 5 都是固定的。操作 f具有输出形状 (n, 1),就像 x_i 和 y_i。
我觉得有必要澄清 f 本质上是一个完全连接的层,从形状为 (1, 10) 的串联 […x_i, …y_i] 张量到形状为 (1,5) 的输出层.
同样,很容易看到如何通过迭代和切片手动执行此操作。但是,这可能非常缓慢。分批执行此操作,其中张量 X、Y 现在具有形状 (n, 5, batch_size) 也是可取的,特别是对于小批量梯度下降。
在这里很难说清楚我为什么想要创建这个网络。我觉得它适合我的“逐项表格数据”领域,并且与完全连接的网络相比,显着减少了每次操作的权重数量。
这可以使用张量流吗?当然不只使用 keras。下面是每个 AloneTogether 请求的 numpy 示例
import numpy as np
features = 16
batch_size = 256
X_batch = np.random.random((features, 5, batch_size))
Y_batch = np.random.random((features, 5, batch_size))
# one tensor operation to reduce weights in this custom 'layer'
f = np.random.random((features, 2 * features))
for b in range(batch_size):
X = X_batch[:, :, b]
Y = Y_batch[:, :, b]
for i in range(5):
x_i = X[:, i:i+1]
for j in range(5):
y_j = Y[:, j:j+1]
x_i_y_j = np.concatenate([x_i, y_j], axis=0)
# f(x_i, y_j)
# implemented by a fully-connected layer
f_i_j = np.matmul(f, x_i_y_j)
回复
我来回复-
gergelybat 评论
该回答已被采纳!
您需要的所有操作(连接和矩阵乘法)都可以批处理。这里的困难部分是,您想将 X 中所有项目的特征与 Y 中所有项目的特征(所有组合)连接起来。我推荐的解决方案是扩展维度X 到
[batch, features, 5, 1]
,将 Y 的维度扩展到[batch, features, 1, 5]
比tf.repeat()
两个张量,使它们的形状变为[batch, features, 5, 5]
。现在你可以连接 X 和 Y。你将有一个形状为[batch, 2*features, 5, 5]
的张量。观察到这样所有组合都建立了。下一步是矩阵乘法。tf.matmul()
也可以做批量矩阵乘法,但我在这里使用tf.einsum()
因为我想更多地控制哪些维度被视为批量。完整代码:import tensorflow as tf import numpy as np batch_size=3 features=6 items=5 x = np.random.uniform(size=[batch_size,features,items]) y = np.random.uniform(size=[batch_size,features,items]) f = np.random.uniform(size=[2*features,features]) x_reps= tf.repeat(x[:,:,:,tf.newaxis], items, axis=3) y_reps= tf.repeat(y[:,:,tf.newaxis,:], items, axis=2) xy_conc = tf.concat([x_reps,y_reps], axis=1) f_i_j = tf.einsum("bfij, fg->bgij", xy_conc,f) f_i_j = tf.reshape(f_i_j , [batch_size,features,items*items])
2年前