【阅读论文】TimesNet短期预测的基本流程梳理

目录


前言

果然是初入机器学习的新手,对pycharm、pytorch的套路了解得太少,在学习之路上犯了不少错误,走了不少弯路,虽然现在依旧是个新人,但也还是来做个阶段性的总结,也算是成长的证明。还是以TimesNet为例,下面用基于m4数据集(quarterly类别)的short term forecasting程序来做说明。

一、run.py

主程序,主要是两个部分:args和train, test。

1.args

作用:导入基本参数。

主要代码:

import argparse
parser = argparse.ArgumentParser(description='TimesNet')
parser.add_argument('--属性名', type=类型, default=默认值, help='说明')
args = parser.parse_args()

# parser.add_argument中还有一些用的少的参数required, action, nargs

2.train, test

作用:开始训练、测试模型。

主要代码:

setting = '{}_{}_..._{}_{}'.format(args.属性, ..., args.属性)
exp = Exp(args)       # 把args传递给exp_short_term_forecasting
exp.train(setting)    # 用于给checkpoints命名
exp.test(setting)     # 用于给test_results的子文件夹命名
torch.cuda.empty_cache()    # 清空显存缓冲区

二、TimesNet_M4.sh

作用:便于预设参数,批量执行程序。

主要代码:

export CUDA_VISIBLE_DEVICES=0
# 使用的显卡序号,个人电脑的主卡多为“0”,服务器可以按需选择

model_name=TimesNet   # 模型的名字

python -u run.py \
  --参数名 参数值 \    # 提前设置各种所需参数
  --model $model_name \  # model的名字已经在上面写了
  ...            \
  --参数名 参数值

三、exp_short_term_forecasting.py

作用:短期预测的主要函数。

主要代码:

class Exp_Short_Term_Forecast(Exp_Basic):
# 基于Exp_Basic而新建的类

    def __init__(self, args):
    # 初始化

    def _build_model(self):
    # 选择TimesNet模型,基于pytorch的nn.Module写的

    def _get_data(self, flag):
    # 读取m4数据,基于pytorch的DataLoader写的

    def _select_optimizer(self):
    # 选择优化器,直接用pytorch的

    def _select_criterion(self, loss_name='MSE'):
    # 选择评价标准/结束标准,MSE是直接用pytorch的

    def train(self, setting):
    # 训练模型

    def vali(self, train_loader, vali_loader, criterion):
    # 验证模型

    def test(self, setting, test=0):
    # 测试模型

由于vali()仅用了一次,test()和train()相似度比较高,故下文只解释train()。

四、train()

主要代码:

train_data, train_loader = self._get_data(flag='train')
# 读取数据
# 这里不是很理解,在Dataset_M4类中可以发现,他是先把10w条数据全部读进去,再根据seasonal_patterns进行筛选的,反正程序都是按照季节依次执行的,为啥不直接读取季节的csv文件呢?

early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
# 早停,为了避免过拟合

for epoch in range(self.args.train_epochs):
    for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
    # batch_x, batch_y是从打乱顺序的m4数据中截取的16个长度为24的序列,其中batch_x是前16个,batch_y是后16个,因此batch_x和batch_y有8个是相同的,但截取序列时,只保证batch_x中的数据有意义(非空),batch_y随意,如Q19657中截取的24个数据
    # [[ 1380.0000],[ 1350.0000],[ 1330.0000],[ 1320.0000],[ 1300.0000],[ 1300.0000],
    #  [ 1280.0000],[ 1280.0000],[ 1280.0000],[ 1280.0000],[ 1260.0000],[ 1260.0000],
    #  [ 1260.0000],[ 1250.0000],[ 1240.0000],[ 1230.0000],[ 1230.0000],[    0.0000],
    #  [    0.0000],[    0.0000],[    0.0000],[    0.0000],[    0.0000],[    0.0000]]

        model_optim.zero_grad()  # 将网络中的梯度置零

        # decoder input
        dec_inp = torch.xxxxxx
        # dec_inp为后8个变为0的batch_y

        outputs = self.model(batch_x, None, dec_inp, None)  # 运行TimesNet模型
        # size为16x16,实际上只有batch_x有用

        outputs = outputs[:, -self.args.pred_len:, f_dim:]  # outputs的最后8个是预测值
        loss_value, loss_sharpness, loss
        train_loss.append(loss.item())
        # 计算loss

        loss.backward()
        model_optim.step()
        # pytorch的一套组合拳
        # optimizer.zero_grad() 清空过往梯度;
        # loss.backward() 反向传播,计算当前梯度;
        # optimizer.step() 根据梯度更新网络参数

best_model_path = path + '/' + 'checkpoint.pth'
self.model.load_state_dict(torch.load(best_model_path))
# 读取最佳节点来继续运行,免得发生意外,程序重跑

return self.model

五、TimesNet.py

主要是三个部分:Model, TimesBlock和FFT_for_Period。

1.Model

主要代码:

class Model(nn.Module):
    def __init__(self, configs):

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        # Normalization from Non-stationary Transformer
        # Quarterly中,一个train的batch_x为[16,16,1],一个test的batch_x为[1,16,1]

        means, stdev
        x_enc /= stdev
        # 对x_enc做Z-Score标准化

        # embedding
        enc_out = self.enc_embedding(x_enc, x_mark_enc)
        # 升维,增加参数,输入[16,16,1],输出[16,16,64]

        enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(0, 2, 1)
        # 每个序列都预测了8个值,输出[16,24,64]

        # TimesNet
        for i in range(self.layer):
        # layer=e_layers,Quarterly中e_layers=2,2层TimesBlock处理

            enc_out = self.layer_norm(self.model[i](enc_out))
            # 跳转至TimesBlock
            # 返回结果[16,24,64]

        # porject back
        dec_out = self.projection(enc_out)
        # 将每行64维的数据投影至1维,输出[16,24,1]

        # De-Normalization from Non-stationary Transformer
        dec_out, stdev, means
        # 将结果进行还原,输出[16,24,1]

        return dec_out

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
    # 由于模型是基于nn.Module写的,运行模型时默认先执行forward
    # 所以用forward选择对应的任务来执行。

2.TimesBlock

主要代码:

class TimesBlock(nn.Module):
    def __init__(self, configs):

    def forward(self, x):
        B, T, N = x.size()  # enc_out为[16,24,64]
        period_list, period_weight = FFT_for_Period(x, self.k)
        # 跳转至FFT_for_Period
        # 返回结果5个周期和对应频率的振幅

        for i in range(self.k):
            # padding
            # 根据seq_len+pred_len能否整除period,决定是否补零
            # 此处能整除,输出[16,24,64]

            # reshape
            out = out.reshape
            # 2D conv: from 1d Variation to 2d Variation
            out = self.conv(out)
            # 根据周期,将1维序列变为2维变量

            # reshape back
            out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
            # 将2维变量变回1维序列

            res.append(out[:, :(self.seq_len + self.pred_len), :])
        res = torch.stack(res, dim=-1)
        # 记录残差

        # adaptive aggregation
        period_weight, res
        # 结合振幅计算新的残差

        # residual connection
        res = res + x
        return res

3.FFT_for_Period

主要代码:

def FFT_for_Period(x, k=2):
    xf = torch.fft.rfft(x, dim=1)   # 傅里叶变换

    # find period by amplitudes
    frequency_list = abs(xf).mean(0).mean(-1)   # 频域求平均振幅
    frequency_list[0] = 0   # 第一个频率很高,但没用

    _, top_list = torch.topk(frequency_list, k)  # 获取最大的k个频率
    top_list = top_list.detach().cpu().numpy()
    # detach()阻断反传,但数据仍在现存里,cpu无法获取
    # cpu()将数据移至cpu,返回值是cpu上的Tensor
    # numpy()将cpu上的tensor转为numpy数据,为ndarray类型,返回值为numpy.array()

    period = x.shape[1] // top_list     # 周期=T/f
    return period, abs(xf).mean(-1)[:, top_list]    # 返回周期、k个频率的振幅

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2023年11月14日
下一篇 2023年11月14日

相关推荐

此站出售,如需请站内私信或者邮箱!