在张量流中求解一组线性系统

原文标题Solving a set of linear systems in tensorflow

我在理解 tensorflow 函数的工作机制时遇到问题:tf.linalg.solve。我想求解一组线性系统(AX = Y),其中线性系数(A)是共享的,但有多个批次Y,它们是不同的。使用 numpy,我可以简单地通过以下方式完成:

np.random.seed(0)
mtx = np.random.uniform(size= (1,4,4))
vec = np.random.uniform(size= (100,4,1))
solution = np.linalg.solve(mtx,vec)
print(abs(np.matmul(mtx,solution) - vec).max())
# 5.551115123125783e-16

这给了我一个非常一致的解决方案。但是当我切换到 tensorflow 时,它给了我结果:

mtx = tf.random.uniform(shape = (1,4,4))
vec = tf.random.uniform(shape = (100,4,1))
solution = tf.linalg.solve(mtx,vec)
print(tf.math.reduce_max(abs(tf.matmul(mtx,solution) - vec))) 
# tf.Tensor(1.3136615, shape=(), dtype=float32)

根据文档,我假设该解决方案应该根据相应的vec解决。但是它似乎没有给我在tensorflow中的预期结果。由于我是新用户,我可能会搞砸一些事情。它会是如果可以提供任何信息,我们将不胜感激。

原文链接:https://stackoverflow.com//questions/71588962/solving-a-set-of-linear-systems-in-tensorflow

回复

我来回复
  • Alexey Tochin的头像
    Alexey Tochin 评论

    第一个维度是 Tesorflow 中的批处理维度,请参阅文档。

    mtx = tf.random.uniform(shape = (5, 4, 4))
    vec = tf.random.uniform(shape = (5, 4, 1))
    solution = tf.linalg.solve(mtx, vec)
    print(solution.shape)
    # TensorShape([5, 4, 1])
    print(tf.math.reduce_max(abs(tf.matmul(mtx, solution) - vec))) 
    # tf.Tensor(5.1259995e-06, shape=(), dtype=float32)
    
    2年前 0条评论