跑通GaitSet(跑不通你来揍我)

一、从github下载GaitSet

这里打包好了到百度网盘,部分访问外网慢的小伙伴可以从这里下载喔!

论文地址GaitSet: Cross-view Gait Recognition through Utilizing Gait as a Deep Set | IEEE Journals & Magazine | IEEE Xplore
github地址GitHub – AbnerHqC/GaitSet: A flexible, effective and fast cross-view gait recognition network

百度网盘(内含数据集)

链接:https://pan.baidu.com/s/1k0l-BBdMvYJdl5lCeyCQMQ 
提取码:w1mb 

二、安装环境(pytorch)

按一下链接配置好环境,就一定ok亲测有效!觉得有用记得给个三联哦!

深度学习环境配置——windows下的torch-gpu环境配置_阿良是炼丹师的博客-CSDN博客

具体的包参考这儿:

# Name                    Version                   Build  Channel
absl-py                   1.0.0                    pypi_0    pypi
argon2-cffi               20.1.0           py36h2bbff1b_1
async_generator           1.10             py36h28b3542_0
attrs                     21.4.0             pyhd3eb1b0_0
backcall                  0.2.0              pyhd3eb1b0_0
bleach                    4.1.0              pyhd3eb1b0_0
cachetools                4.2.4                    pypi_0    pypi
certifi                   2021.5.30        py36haa95532_0
cffi                      1.14.6           py36h2bbff1b_0
charset-normalizer        2.0.12                   pypi_0    pypi
colorama                  0.4.4              pyhd3eb1b0_0
cycler                    0.11.0                   pypi_0    pypi
dataclasses               0.8                      pypi_0    pypi
decorator                 4.4.2                    pypi_0    pypi
defusedxml                0.7.1              pyhd3eb1b0_0
dominate                  2.6.0                    pypi_0    pypi
entrypoints               0.3                      py36_0
google-auth               2.6.2                    pypi_0    pypi
google-auth-oauthlib      0.4.6                    pypi_0    pypi
grpcio                    1.45.0                   pypi_0    pypi
h5py                      2.10.0                   pypi_0    pypi
idna                      3.3                      pypi_0    pypi
imageio                   2.15.0                   pypi_0    pypi
importlib-metadata        4.8.3                    pypi_0    pypi
ipykernel                 5.3.4            py36h5ca1d4c_0
ipython                   7.16.1           py36h5ca1d4c_0
ipython_genutils          0.2.0              pyhd3eb1b0_1
jedi                      0.17.0                   py36_0
jinja2                    3.0.3              pyhd3eb1b0_0
joblib                    1.1.0                    pypi_0    pypi
jsonschema                3.0.2                    py36_0
jupyter_client            7.1.2              pyhd3eb1b0_0
jupyter_core              4.8.1            py36haa95532_0
jupyterlab_pygments       0.1.2                      py_0
kiwisolver                1.3.1                    pypi_0    pypi
m2w64-gcc-libgfortran     5.3.0                         6
m2w64-gcc-libs            5.3.0                         7
m2w64-gcc-libs-core       5.3.0                         7
m2w64-gmp                 6.1.0                         2
m2w64-libwinpthread-git   5.0.0.4634.697f757               2
markdown                  3.3.6                    pypi_0    pypi
markupsafe                2.0.1            py36h2bbff1b_0
matplotlib                3.1.2                    pypi_0    pypi
mistune                   0.8.4            py36he774522_0
msys2-conda-epoch         20160418                      1
nbclient                  0.5.3              pyhd3eb1b0_0
nbconvert                 6.0.7                    py36_0
nbformat                  5.1.3              pyhd3eb1b0_0
nest-asyncio              1.5.1              pyhd3eb1b0_0
networkx                  2.5.1                    pypi_0    pypi
notebook                  6.4.3            py36haa95532_0
numpy                     1.16.2                   pypi_0    pypi
oauthlib                  3.2.0                    pypi_0    pypi
opencv-python             4.1.2.30                 pypi_0    pypi
packaging                 21.3               pyhd3eb1b0_0
pandas                    1.1.5                    pypi_0    pypi
pandoc                    2.12                 haa95532_0
pandocfilters             1.5.0              pyhd3eb1b0_0
parso                     0.8.3              pyhd3eb1b0_0
pickleshare               0.7.5           pyhd3eb1b0_1003
pillow                    8.4.0                    pypi_0    pypi
pip                       21.2.2           py36haa95532_0
prettytable               2.5.0                    pypi_0    pypi
prometheus_client         0.13.1             pyhd3eb1b0_0
prompt-toolkit            3.0.20             pyhd3eb1b0_0
protobuf                  3.19.4                   pypi_0    pypi
pyasn1                    0.4.8                    pypi_0    pypi
pyasn1-modules            0.2.8                    pypi_0    pypi
pycparser                 2.21               pyhd3eb1b0_0
pyecharts                 1.9.1                    pypi_0    pypi
pygments                  2.11.2             pyhd3eb1b0_0
pyparsing                 3.0.7                    pypi_0    pypi
pyrsistent                0.17.3           py36he774522_0
pysnooper                 1.1.1                    pypi_0    pypi
python                    3.6.13               h3758d61_0
python-dateutil           2.8.2              pyhd3eb1b0_0
pytz                      2021.3                   pypi_0    pypi
pywavelets                1.1.1                    pypi_0    pypi
pywin32                   228              py36hbaba5e8_1
pywinpty                  0.5.7                    py36_0
pyzmq                     22.2.1           py36hd77b12b_1
ranger                    0.10                     pypi_0    pypi
requests                  2.27.1                   pypi_0    pypi
requests-oauthlib         1.3.1                    pypi_0    pypi
rsa                       4.8                      pypi_0    pypi
scikit-image              0.17.2                   pypi_0    pypi
scikit-learn              0.24.2                   pypi_0    pypi
scipy                     1.2.1                    pypi_0    pypi
seaborn                   0.11.2                   pypi_0    pypi
send2trash                1.8.0              pyhd3eb1b0_1
setuptools                58.0.4           py36haa95532_0
simplejson                3.17.6                   pypi_0    pypi
six                       1.16.0             pyhd3eb1b0_1
sklearn                   0.0                      pypi_0    pypi
sqlite                    3.38.0               h2bbff1b_0
tensorboard               2.8.0                    pypi_0    pypi
tensorboard-data-server   0.6.1                    pypi_0    pypi
tensorboard-plugin-wit    1.8.1                    pypi_0    pypi
terminado                 0.9.4            py36haa95532_0
testpath                  0.5.0              pyhd3eb1b0_0
threadpoolctl             3.1.0                    pypi_0    pypi
tifffile                  2020.9.3                 pypi_0    pypi
torch                     1.2.0                    pypi_0    pypi
torchsnooper              0.8                      pypi_0    pypi
torchvision               0.4.0                    pypi_0    pypi
tornado                   6.1              py36h2bbff1b_0
tqdm                      4.60.0                   pypi_0    pypi
traitlets                 4.3.3            py36haa95532_0
typing-extensions         4.1.1                    pypi_0    pypi
urllib3                   1.26.9                   pypi_0    pypi
vc                        14.2                 h21ff451_1
vs2015_runtime            14.27.29016          h5e58377_2
wcwidth                   0.2.5              pyhd3eb1b0_0
webencodings              0.5.1                    py36_1
werkzeug                  2.0.3                    pypi_0    pypi
wheel                     0.37.1             pyhd3eb1b0_0
wincertstore              0.2              py36h7fe50ca_0
winpty                    0.4.3                         4
xarray                    0.16.2                   pypi_0    pypi
zipp                      3.6.0                    pypi_0    pypi

如果按照以上环境一一配置,绝对可以运行,亲测有效!!!

三、踩坑

3.1 config.py配置

给出详细的注释:(win10 num_workers必须设置为0)不然跑不了!

conf = {
    "WORK_PATH": "./work",
    "CUDA_VISIBLE_DEVICES": "0",  # 所用GPU编号
    "data": {
        'dataset_path': r"C:\Users\3i\Desktop\GaitSet-master\output",  # 数据加载路径(预处理时的输出 绝对 路径)
        'resolution': '64',  # 输出轮廓图的分辨率,不用更改
        'dataset': 'CASIA-B',  # 数据集名称
        # In CASIA-B, data of subject #5 is incomplete.
        # Thus, we ignore it in training.
        # For more detail, please refer to
        # function: utils.data_loader.load_data
        'pid_num': 73,  # 训练集人数,73用于训练,其余用于测试
        'pid_shuffle': True,  # 是否对数据集进行随机划分,如果为False,则直接选取1-pid_num
    },
    "model": {
        'hidden_dim': 256,  # 最后一层全连接层的隐藏层数
        'lr': 1e-4,  # 学习率
        'hard_or_full_trip': 'full',  # 损失函数
        'batch_size': (8, 16),  # 批次p*k = 8*16,
        'restore_iter': 0,  # 第几步开始训练
        'total_iter': 80000,  # 训练次数
        'margin': 0.2,  # 损失函数的margin参数
        'num_workers': 0,  # 线程数
        'frame_num': 30,  # 每个批次的帧数
        'model_name': 'GaitSet',
    },
}

3.2 pretreatment.py配置

可以用原来的(百般修正后)*:

# -*- coding: utf-8 -*-
# @Author  : Abner
# @Time    : 2018/12/19

import os
from scipy import misc as scisc
import cv2
import numpy as np
from warnings import warn
from time import sleep
import argparse

from multiprocessing import Pool
from multiprocessing import TimeoutError as MP_TimeoutError

START = "START"
FINISH = "FINISH"
WARNING = "WARNING"
FAIL = "FAIL"

def boolean_string(s):
    if s.upper() not in {'FALSE', 'TRUE'}:
        raise ValueError('Not a valid boolean string')
    return s.upper() == 'TRUE'

wd = os.getcwd()
input_path = os.path.join(wd, 'GaitDatasetB-silh')
output_path = os.path.join(wd, 'output')

parser = argparse.ArgumentParser(description='Test')
parser.add_argument('--input_path', default=input_path, type=str,
                    help='Root path of raw dataset.')
parser.add_argument('--output_path', default=output_path, type=str,
                    help='Root path for output.')
parser.add_argument('--log_file', default='./pretreatment.log', type=str,
                    help='Log file path. Default: ./pretreatment.log')
parser.add_argument('--log', default=False, type=boolean_string,
                    help='If set as True, all logs will be saved. '
                         'Otherwise, only warnings and errors will be saved.'
                         'Default: False')
parser.add_argument('--worker_num', default=1, type=int,
                    help='How many subprocesses to use for data pretreatment. '
                         'Default: 1')
opt = parser.parse_args()

INPUT_PATH = opt.input_path
OUTPUT_PATH = opt.output_path
IF_LOG = opt.log
LOG_PATH = opt.log_file
WORKERS = opt.worker_num

T_H = 64
T_W = 64


def log2str(pid, comment, logs):
    str_log = ''
    if type(logs) is str:
        logs = [logs]
    for log in logs:
        str_log += "# JOB %d : --%s-- %s\n" % (
            pid, comment, log)
    return str_log


def log_print(pid, comment, logs):
    str_log = log2str(pid, comment, logs)
    if comment in [WARNING, FAIL]:
        with open(LOG_PATH, 'a') as log_f:
            log_f.write(str_log)
    if comment in [START, FINISH]:
        if pid % 500 != 0:
            return
    print(str_log, end='')


def cut_img(img, seq_info, frame_name, pid):
    # A silhouette contains too little white pixels
    # might be not valid for identification.
    if img.sum() <= 10000:
        message = 'seq:%s, frame:%s, no data, %d.' % (
            '-'.join(seq_info), frame_name, img.sum())
        warn(message)
        log_print(pid, WARNING, message)
        return None
    # Get the top and bottom point
    y = img.sum(axis=1)
    y_top = (y != 0).argmax(axis=0)
    y_btm = (y != 0).cumsum(axis=0).argmax(axis=0)
    img = img[y_top:y_btm + 1, :]
    # As the height of a person is larger than the width,
    # use the height to calculate resize ratio.
    _r = img.shape[1] / img.shape[0]
    _t_w = int(T_H * _r)
    img = cv2.resize(img, (_t_w, T_H), interpolation=cv2.INTER_CUBIC)
    # Get the median of x axis and regard it as the x center of the person.
    sum_point = img.sum()
    sum_column = img.sum(axis=0).cumsum()
    x_center = -1
    for i in range(sum_column.size):
        if sum_column[i] > sum_point / 2:
            x_center = i
            break
    if x_center < 0:
        message = 'seq:%s, frame:%s, no center.' % (
            '-'.join(seq_info), frame_name)
        warn(message)
        log_print(pid, WARNING, message)
        return None
    h_T_W = int(T_W / 2)
    left = x_center - h_T_W
    right = x_center + h_T_W
    if left <= 0 or right >= img.shape[1]:
        left += h_T_W
        right += h_T_W
        _ = np.zeros((img.shape[0], h_T_W))
        img = np.concatenate([_, img, _], axis=1)
    img = img[:, left:right]
    return img.astype('uint8')


def cut_pickle(seq_info, pid):
    seq_name = '-'.join(seq_info)
    log_print(pid, START, seq_name)
    seq_path = os.path.join(INPUT_PATH, *seq_info)
    out_dir = os.path.join(OUTPUT_PATH, *seq_info)
    frame_list = os.listdir(seq_path)
    frame_list.sort()
    count_frame = 0
    for _frame_name in frame_list:
        frame_path = os.path.join(seq_path, _frame_name)
        img = cv2.imread(frame_path)[:, :, 0]
        img = cut_img(img, seq_info, _frame_name, pid)
        if img is not None:
            # Save the cut img
            save_path = os.path.join(out_dir, _frame_name)
            cv2.imwrite(save_path, img)
            count_frame += 1
    # Warn if the sequence contains less than 5 frames
    if count_frame < 5:
        message = 'seq:%s, less than 5 valid data.' % (
            '-'.join(seq_info))
        warn(message)
        log_print(pid, WARNING, message)

    log_print(pid, FINISH,
              'Contain %d valid frames. Saved to %s.'
              % (count_frame, out_dir))

if __name__ == '__main__':
    pool = Pool(WORKERS)
    results = list()
    pid = 0

    print('Pretreatment Start.\n'
          'Input path: %s\n'
          'Output path: %s\n'
          'Log file: %s\n'
          'Worker num: %d' % (
              INPUT_PATH, OUTPUT_PATH, LOG_PATH, WORKERS))

    id_list = os.listdir(INPUT_PATH)
    id_list.sort()
    # Walk the input path
    for _id in id_list:
        seq_type = os.listdir(os.path.join(INPUT_PATH, _id))
        seq_type.sort()
        for _seq_type in seq_type:
            view = os.listdir(os.path.join(INPUT_PATH, _id, _seq_type))
            view.sort()
            for _view in view:
                seq_info = [_id, _seq_type, _view]
                out_dir = os.path.join(OUTPUT_PATH, *seq_info)
                os.makedirs(out_dir)
                results.append(
                    pool.apply_async(
                        cut_pickle,
                        args=(seq_info, pid)))
                sleep(0.02)
                pid += 1

    pool.close()
    unfinish = 1
    while unfinish > 0:
        unfinish = 0
        for i, res in enumerate(results):
            try:
                res.get(timeout=0.1)
            except Exception as e:
                if type(e) == MP_TimeoutError:
                    unfinish += 1
                    continue
                else:
                    print('\n\n\nERROR OCCUR: PID ##%d##, ERRORTYPE: %s\n\n\n',
                          i, type(e))
                    raise e
    pool.join()

也可以用网上可以替代的预处理代码,代码中有一些小错误,经修正后确认可用,可以替代原作者的pretreatment.py代码
修正后代码如下:

'''
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

def cut_image(path,cut_path,size):
    '''
    剪切图片
    :param path: 输入图片路径
    :param cut_path: 剪切图片后的输出路径
    :param size: 要剪切的图片大小
    :return:
    '''
    for (root,dirs,files) in os.walk(path):
        temp = root.replace(path,cut_path)
        if not os.path.exists(temp):
            os.makedirs(temp)
        for file in files:
            image,flag = cut(Image.open(os.path.join(root,file)))
            if not flag: Image.fromarray(image).convert('L').resize((size,size)).save(os.path.join(temp,file))

    pass

def cut(image):
    '''
    通过找到人的最小最大高度与宽度把人的轮廓分割出来,、
    因为原始轮廓图为二值图,因此头顶为将二值图像列相加后,形成一列后第一个像素值不为0的索引。
    同理脚底为形成一列后最后一个像素值不为0的索引。
    人的宽度也同理。
    :param image: 需要裁剪的图片 N*M的矩阵
    :return: temp:裁剪后的图片 size*size的矩阵。flag:是否是符合要求的图片
    '''
    image = np.array(image)

    # 找到人的最小最大高度与宽度
    height_min = (image.sum(axis=1)!=0).argmax()
    height_max = ((image.sum(axis=1)!=0).cumsum()).argmax()
    width_min = (image.sum(axis=0)!=0).argmax()
    width_max = ((image.sum(axis=0)!=0).cumsum()).argmax()
    head_top = image[height_min,:].argmax()
    # 设置切割后图片的大小,为size*size,因为人的高一般都会大于宽
    size=height_max-height_min
    temp = np.zeros((size,size))

    # 将width_max-width_min(宽)乘height_max-height_min(高,szie)的人的轮廓图,放在size*size的图片中央
    # l = (width_max-width_min)//2
    # r = width_max-width_min-l
    # 以头为中心,将将width_max-width_min(宽)乘height_max-height_min(高,szie)的人的轮廓图,放在size*size的图片中央
    l1 = head_top-width_min
    r1 = width_max-head_top
    # 若宽大于高,或头的左侧或右侧身子比要生成图片的一般要大。则此图片为不符合要求的图片
    flag = False
    if size<=width_max-width_min or size//2<r1 or size//2<l1:
        flag = True
        return temp,flag
    # centroid = np.array([(width_max+width_min)/2,(height_max+height_min)/2],dtype='int')
    temp[:,(size//2-l1):(size//2+r1)] = image[height_min:height_max,width_min:width_max ]

    return temp,flag

def GEI(cut_path,data_path,size):
    '''
    生成步态能量图
    :param cut_path: 剪切后的图片路径
    :param data_path: 生成图片的路径
    :param size: 生成能量图大小
    :return:
    '''
    for (root,dirs,files) in os.walk(cut_path):
        temp = root.replace(cut_path,data_path)
        if not os.path.exists(temp):
            os.makedirs(temp)
        GEI = np.zeros([size,size])
        if len(files)!=0:
            for file in files:
                GEI += Image.open(os.path.join(root,file)).convert('L')
            GEI /= len(files)
            Image.fromarray(GEI).convert('L').resize((size,size)).save(os.path.join(temp,'1.png'))
    pass


if __name__=='__main__':
    cut_image("C:\\Users\\China\\Desktop\\GaitDatas","C:\\Users\\China\\Desktop\\CutImage",126)
    GEI("C:\\Users\\China\\Desktop\\CutImage","C:\\Users\\China\\Desktop\\GEIData",126) 

3.3 train.py的踩坑

错误:

pytorch container.py… IndexError: index 0 is out of range

解决: 

重载了模型,也就是训练过一次原来工作目录已经有了模型,删掉原来的模型(partition文件全部删掉)或者换一个工作路径(原来的工作路径: ./work 

or 

使用了多块GPU,不知道为啥超过一个就报错。

跑通GaitSet(跑不通你来揍我)

跑通GaitSet(跑不通你来揍我)

4.test.py

不用改直接运行:(可以不训练本来就自带哦!)

前提:

跑通GaitSet(跑不通你来揍我)

config.py中:

'pid_num': 73,  #  LT划分方式 pid_num+1用于训练,其余用于测试


'pid_shuffle': False,  # 是否对数据集进行随机划分,如果为False,则直接选取
'model_name': 'GaitSet',
'dataset': 'CASIA-B',

其他Bug

错误: 

 Warning: masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead

解决:

在triplet.py文件中的第19-24行 找到带有_mask的两个变量分别加上 

hp_mask = hp_mask.bool()

hn_mask = hn_mask.bool()

两行如下:

跑通GaitSet(跑不通你来揍我)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2022年5月30日 上午11:07
下一篇 2022年5月30日

相关推荐