在 pytorch “torchmetrics” 中使用 Dice 度量: dice_score() 缺少 2 个必需的位置参数:’preds’ 和 ‘target’
pytorch 289
原文标题 :Using Dice metric in pytorch “torchmetrics” : dice_score() missing 2 required positional arguments: ‘preds’ and ‘target’
我正在尝试使用 pytorch“torchmetrics”中的 Dice 指标。我找到了一个使用准确度指标的例子。如下所示:
from torchmetrics.classification import Accuracy
train_accuracy = Accuracy()
valid_accuracy = Accuracy()
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
print(f"Accuracy of batch{i} is {batch_acc}")
for x, y in valid_data:
y_hat = model(x)
valid_accuracy.update(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")
# Reset metric states after each epoch
train_accuracy.reset()
valid_accuracy.reset()
但是,当我用“Dice_score()”替换“Accuracy()”时。如下所示:
from torchmetrics.functional import dice_score
train_accuracy =dice_score()
valid_accuracy =dice_score()
我收到以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-43-726045592283> in <module>
3 from torchmetrics.functional import dice_score
4
----> 5 train_accuracy_2 =dice_score()# Accuracy()
6 valid_accuracy_2 =dice_score()# Accuracy()
7
TypeError: dice_score() missing 2 required positional arguments: 'preds' and 'target'
是否有使用“torchmetrics”中的“Dice”度量的示例