【PyTorch】TensorBoard基本使用

文章目录

  • 一、Tensorboard基本使用
    • 1、SummaryWriter类使用
    • 2、安装TensorBoard
    • 3、add_scalar()方法
    • 4、add_image()方法
      • 4.1 img_tensor的说明
      • 4.2 dataformats的说明
      • 4.3 滑动显示

一、Tensorboard基本使用

Tensorboard为是Google TensorFlow的可视化工具,可以用于记录训练数据、评估数据、网络结构、图像等,并且可以在web上展示,对于观察神经网络的过程非常有帮助。

PyTorch也推出了自己的可视化工具,叫做torch.utils.tensorboard

学习本节内容必须提前准备好PyTorch(推荐GPU版)环境,后续也会推出PyTorch安装(Conda环境)。

1、SummaryWriter类使用

from torch.utils.tensorboard import SummaryWriter # 导入

按下 Ctrl键,点击蓝色字体,可以查看该类所在函数描述。

还有具体方法、例子的描述,不做过多赘述!

2、安装TensorBoard

conda环境:

# 1.激活conda环境
conda activate torch # torch为自己的虚拟环境
# 2.下载并安装
conda install tensorboard

pip环境:

pip install tensorboard
# 嫌慢,可以加国内源
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple

3、add_scalar()方法

函数原型:

def add_scalar(self,
               tag: str, 
               scalar_value: Any,
               global_step: int = None,
               walltime: float = None,
               new_style: bool = False,
               double_precision: bool = False) -> None

参数说明:

  • tag:类似于图标的title
  • scalar_value:数值,即Y轴
  • global_step:多少步,即X轴

实例1:绘制 y = x

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("logs")

for i in range(100):
    writer.add_scalar("y = x", i, i)

writer.close()

打开事件文件:

成功运行后,即可打开http://localhost:6006/;当然也可以更换端口:添加--port=6007

实例2:绘制 y = 2x

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("logs")

for i in range(100):
    writer.add_scalar("y = 2x", 2*i, i)

writer.close()

实例2:绘制 y = 3x(当我们未修改title时)

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("logs")

for i in range(100):
    writer.add_scalar("y = 2x", 3*i, i) # tille未作修改

writer.close()

会出现拟合,我们可以通过删除事件文件之后,重新打开Tensorboard

4、add_image()方法

函数原型:

def add_image(self,
              tag: str,
              img_tensor: Any,
              global_step: int = None,
              walltime: float = None,
              dataformats: str = "CHW") -> None

参数说明:

  • tag:图像title
  • img_tensor:图像的数据类型(torch.Tensornumpy.array,or string/blobname
  • global_step:训练的步骤

4.1 img_tensor的说明

参数 img_tensor 为图像的数据类型,指定了三种数据类型,但在实际情况中,往往并不是理想的这三种,以下介绍如何转换:

数据集请评论或直接私信我,后续也会贴出链接!!!

利用numpy.array(),对PIL图像进行转换:

4.2 dataformats的说明

当我们准备好实例执行时,会报出如下错误:

from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image

writer = SummaryWriter("logs")
image_path = "../data/tensorboard_data/train/ants_image/0013035.jpg"
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)

writer.add_image("test", img_array, 1)

writer.close()

说明问题出在如下代码中:

writer.add_image("test", img_array, 1)

查看函数介绍发现:默认为(通道,高度,宽度),如果为 (高度,宽度,通道),需要添加参数 dataformats='HWC'

查看实例中图像的shape:

print(img_array.shape) # (512, 768, 3)

则需要添加参数:

writer.add_image("test", img_array, 1, dataformats='HWC') # 即可成功运行

4.3 滑动显示

from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image

writer = SummaryWriter("logs")
# image_path = "../data/tensorboard_data/train/ants_image/0013035.jpg" # 1
image_path = "../data/tensorboard_data/train/bees_image/16838648_415acd9e3f.jpg"  # 2
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)
print(img_array.shape)

# writer.add_image("test", img_array, 1, dataformats='HWC') #1
writer.add_image("test", img_array, 1, dataformats='HWC')  # 2

writer.close()

【PyTorch】TensorBoard基本使用

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2023年11月14日
下一篇 2023年11月14日

相关推荐