在张量流中求解一组线性系统
tensorflow 188
原文标题 :Solving a set of linear systems in tensorflow
我在理解 tensorflow 函数的工作机制时遇到问题:tf.linalg.solve。我想求解一组线性系统(AX = Y),其中线性系数(A)是共享的,但有多个批次Y,它们是不同的。使用 numpy,我可以简单地通过以下方式完成:
np.random.seed(0)
mtx = np.random.uniform(size= (1,4,4))
vec = np.random.uniform(size= (100,4,1))
solution = np.linalg.solve(mtx,vec)
print(abs(np.matmul(mtx,solution) - vec).max())
# 5.551115123125783e-16
这给了我一个非常一致的解决方案。但是当我切换到 tensorflow 时,它给了我结果:
mtx = tf.random.uniform(shape = (1,4,4))
vec = tf.random.uniform(shape = (100,4,1))
solution = tf.linalg.solve(mtx,vec)
print(tf.math.reduce_max(abs(tf.matmul(mtx,solution) - vec)))
# tf.Tensor(1.3136615, shape=(), dtype=float32)
根据文档,我假设该解决方案应该根据相应的vec解决。但是它似乎没有给我在tensorflow中的预期结果。由于我是新用户,我可能会搞砸一些事情。它会是如果可以提供任何信息,我们将不胜感激。
回复
我来回复-
Alexey Tochin 评论
第一个维度是 Tesorflow 中的批处理维度,请参阅文档。
mtx = tf.random.uniform(shape = (5, 4, 4)) vec = tf.random.uniform(shape = (5, 4, 1)) solution = tf.linalg.solve(mtx, vec) print(solution.shape) # TensorShape([5, 4, 1]) print(tf.math.reduce_max(abs(tf.matmul(mtx, solution) - vec))) # tf.Tensor(5.1259995e-06, shape=(), dtype=float32)
2年前