在 pytorch “torchmetrics” 中使用 Dice 度量: dice_score() 缺少 2 个必需的位置参数:’preds’ 和 ‘target’

青葱年少 pytorch 289

原文标题Using Dice metric in pytorch “torchmetrics” : dice_score() missing 2 required positional arguments: ‘preds’ and ‘target’

我正在尝试使用 pytorch“torchmetrics”中的 Dice 指标。我找到了一个使用准确度指标的例子。如下所示:

from torchmetrics.classification import Accuracy

train_accuracy = Accuracy()
valid_accuracy = Accuracy()

for epoch in range(epochs):
    for x, y in train_data:
        y_hat = model(x)

        # training step accuracy
        batch_acc = train_accuracy(y_hat, y)
        print(f"Accuracy of batch{i} is {batch_acc}")

    for x, y in valid_data:
        y_hat = model(x)
        valid_accuracy.update(y_hat, y)

    # total accuracy over all training batches
    total_train_accuracy = train_accuracy.compute()

    # total accuracy over all validation batches
    total_valid_accuracy = valid_accuracy.compute()

    print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
    print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")

    # Reset metric states after each epoch
    train_accuracy.reset()
    valid_accuracy.reset() 

但是,当我用“Dice_score()”替换“Accuracy()”时。如下所示:

from torchmetrics.functional import dice_score

train_accuracy =dice_score()
valid_accuracy =dice_score()

我收到以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-43-726045592283> in <module>
      3 from torchmetrics.functional import dice_score
      4 
----> 5 train_accuracy_2 =dice_score()# Accuracy()
      6 valid_accuracy_2 =dice_score()# Accuracy()
      7 

TypeError: dice_score() missing 2 required positional arguments: 'preds' and 'target' 

是否有使用“torchmetrics”中的“Dice”度量的示例

原文链接:https://stackoverflow.com//questions/71931060/using-dice-metric-in-pytorch-torchmetrics-dice-score-missing-2-required-po

回复

我来回复
  • jakub的头像
    jakub 评论

    torchmetrics.classification.dice_score是骰子分数的功能接口。这意味着它是一个无状态函数,期望得到基本事实和预测。骰子得分似乎没有模块接口,就像准确度一样。

    torchmetrics.classification.Accuracy是一个维护状态的类。在引擎盖下,它使用功能接口,即torchmetrics.functional.accuracy

    这不是以任何方式强制执行的,但通常类以 CamelCase 命名,函数以 snake_case 命名。

    2年前 0条评论