用python实现经验模态分解+小波软阈值去噪

PyEmd模块安装

试过很多博主说的pip insyall PyEmd都失败了,偶然间运气好发现正确的安装方式是pip install PyEmd-signal。如果找不到相关的库或者模块,直接去github上去搜索,上面有很详细的安装教程,不要被误导

pywt模块安装

pywt可以实现小波分解与重构,小波阈值降噪,小波包分解等功能,同样安装也是用相应的pip instal pywt来进行安装,如果找不到还是去github上寻找。

特别说明

关于EMD类方法和小波阈值降噪的相关理论知识可直接百度或者在知网找几篇硕博论文来看,里面有详细的推导过程。

不要期望能够推导出和理解相关的公式。这些公式非常复杂,人们无法理解。个人建议了解相关算法流程。相关的优化也是不同方法的排列组合。还是有必要多尝试的。

以上只是我个人的看法。如果你不喜欢它,你可以把它拿走。不要在这里显示你的优越感。我写这篇文章的目的只是为了记录! ! ! !

代码

由于写代码的时候,个人的理论了解程度仅仅停留在入门阶段,IMF分量(EMD类方法分解得到的分量)的选择是凭借个人感觉来选择的,正确的做法是计算多尺度排列熵(github上可以找到相关的模块,个人正在研究相关代码)、相关系数等来进行选择。

个人玩具代码,有很多不精确的地方。

EMD类方法实现

import numpy as np
from PyEMD import EMD,EEMD,CEEMDAN,Visualisation
from threshold import Threshold
from matplotlib import pyplot as plt
import pywt
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei']

def read_txt_file(input_file_path):
    """该函数主要用来从txt文件中读取所需数据,并转换数据类型
    输入为待处理文件的路径
    输出为一个存放txt文件数据的列表"""
    file_list = []
    file = open(input_file_path)
    file_lines = file.readline()
    file_lines = list(file_lines)
    file_lines.pop(0)
    file_lines.pop(-1)
    file_lines = ''.join(file_lines)
    cur = file_lines.strip().split(",")
    for i in range(0,len(cur)):
        file_list.append(float(cur[i]))
    #print(file_list)
    return file_list

class EmdFunction:
    """
    emd方法的调用
    """
    def __init__(self,data,function_name,sym,level,imfs_start_step,thr_select,thr_way):

        X = np.array(data)
        self.signal = (X - np.mean(X)) / np.std(X)
        self.function_name = function_name
        self.soft_threshold = Threshold(data,thr_select,sym,level,thr_way)
        self.imfs_start_step = imfs_start_step


    def emd_completed(self):

        if self.function_name == 'EMD':
            emd = EMD()
            emd.emd(self.signal)
            ims,res = emd.get_imfs_and_residue()
            return ims,res
        elif self.function_name == 'EEMD':
            eemd = EEMD()
            eemd.eemd(self.signal)
            ims,res = eemd.get_imfs_and_residue()
            return ims,res
        elif self.function_name == 'CEEMDAN':
            ceemdan = CEEMDAN()
            ceemdan.ceemdan(self.signal)
            ims,res = ceemdan.get_imfs_and_residue()
            return ims,res

    def plot_imfs_and_res(self,imfs,res):
        t = np.arange(0,len(self.signal),1)
        vis = Visualisation()
        vis.plot_imfs(imfs=imfs,residue=res,t=t,include_residue=True)
        vis.show()

    def wavelet_and_emd(self):

        useful_imfs_add = np.zeros(len(self.signal)).tolist()
        imfs,res = self.emd_completed()
        for i in range(self.imfs_start_step,len(imfs)):
            useful_imfs_add += imfs[i]

        for j in range(0,self.imfs_start_step):
            data = imfs[j]
            #mid_param = self.soft_threshold.wavelet_dec_rec(data)
            mid_param = self.soft_threshold.wavelet_dec_rec(data)
            useful_imfs_add  += mid_param
        return useful_imfs_add

    def plot_org_sotfthreshold(self):

        end_signal = self.wavelet_and_emd()
        snr = self.soft_threshold.compute_snr(self.signal,end_signal)
        rmse = self.soft_threshold.compute_mse(self.signal,end_signal)
        print('snr:{} , rmse:{}'.format(snr, rmse))
        figure,(ax1,ax2) = plt.subplots(nrows=2,ncols=1)
        ax1.plot(self.signal, label='org signal')
        ax1.set_title('降噪前的信号')
        ax1.legend()

        ax2.plot(end_signal, 'g',label='after wavele')
        ax2.set_title('降噪后的信号')
        ax2.legend()
        plt.show()



if __name__ == "__main__":
    path = 'D:\桌面文件夹\数据文件/0001001228.txt'
    #path = 'D:\桌面文件夹\新建文件夹 (3)/1101000158.txt'

    data = read_txt_file(path)

    function_name = 'CEEMDAN'

    sym = 'sym8'
    level = 3
    imfs_start_step = 4
    thr_select = 'sqtwolog'
    thr_way = 'soft'

    emdfunction = EmdFunction(data,function_name,sym,level,imfs_start_step,thr_select,thr_way)
    # ims,res=emdfunction.emd_completed()
    # emdfunction.plot_imfs_and_res(ims,res)
    emdfunction.plot_org_sotfthreshold()


小波软阈值代码实现

import numpy as np
import os
from matplotlib import pyplot as plt
import pywt
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei']


def read_txt_file(input_file_path):
    """该函数主要用来从txt文件中读取所需数据,并转换数据类型
    输入为待处理文件的路径
    输出为一个存放txt文件数据的列表"""
    file_list = []
    file = open(input_file_path)
    file_lines = file.readline()
    file_lines = list(file_lines)
    file_lines.pop(0)
    file_lines.pop(-1)
    file_lines = ''.join(file_lines)
    cur = file_lines.strip().split(",")
    for i in range(0,len(cur)):
        file_list.append(float(cur[i]))
    #print(file_list)
    return file_list

class Threshold:
    """小波阈值降噪"""
    def __init__(self,data,thr_select,wave_bais='sym8',level=3,thr_way='soft'):

        if type(data) == list:
            X = np.array(data)
            self.data = (X - np.mean(X)) / np.std(X)
        else:
            self.data = (data - np.mean(data)) / np.std(data)

        self.data = data

        self.wave_bais = wave_bais
        self.level = level
        self.thr_way = thr_way
        if thr_select in ['rigrsure','heursure','sqtwolog','minimaxi']:
            self.thr_select = thr_select
        else:
            raise print('取值计算函数名称错误,请重新输入')

    def thrselect(self,data):
        """
        阈值lambda的计算方式选择
        :return: 返回阈值
        """
        N = len(data)
        if self.thr_select == 'sqtwolog':
            #固定阈值
            thr = round(np.sqrt(2.0 * np.log(N)),4)
            return thr

        elif self.thr_select == 'minimaxi':
            #极大极小阈值
            if N<32:
                thr =0
            else:
                thr = 0.3936 + 0.1829*(np.log(N)/np.log(2))
            return thr

        elif self.thr_select == 'rigrsure':
            # #风险阈值
            # sx = np.sort(abs(self.data))
            # sx2 = np.square(sx)
            # N1 = np.repeat(N-2*[i for i in range(0,N)],1)
            pass
            return -1

        elif self.thr_select == 'heursure':
            pass
            return -1

    def wavelet_dec_rec(self,data):
        """小波分解"""
        coffe = pywt.wavedec(data,self.wave_bais,level=self.level)
        #低频分量分量
        ca = coffe[0]
        #高频分量
        cd_out_list = []
        cd_out_list.append(ca)
        #阈值
        thr = self.thrselect(data)
        for i in range(1,len(coffe)):
            cd = coffe[i]
            ysotf = pywt.threshold(cd,thr,self.thr_way)
            cd_out_list.append(ysotf)

        Y = pywt.waverec(cd_out_list,self.wave_bais)
        return Y

    def plot_signal(self,data):
        #获得降噪后的信号
        Y = self.wavelet_dec_rec(data)

        #绘制原始图像
        figure, axes = plt.subplots(2, 1)
        ax1 = axes[0]
        ax1.set_title('降噪前的信号')
        ax1.plot(self.data)
        #绘制降噪的图像
        ax2 = axes[1]
        ax2.set_title('降噪后的信号')
        ax2.plot(Y,color='g')
        plt.show()

    @staticmethod
    def compute_snr(org_signal, final_signal):
        """
        信噪比:信噪比越大越好
        均方根误差:均方根误差越小越好,越小去噪效果越好
        :param org_signal:原始信号
        :param final_signal:降噪后的信号
        :return: 信噪比,均方根误差
        """

        clean = np.array(final_signal)
        org_signal = np.array(org_signal)
        #est_noise = org_signal - clean
        # power_data = np.mean(np.square(data))
        # power_noise = np.mean(np.square(data - final_signal))
        #snr = 10 * np.log10((np.sum(clean ** 2)) / (np.sum(est_noise ** 2)))
        # snr = (math.log((power_data/power_noise),10) )* 10
        sigPower = sum(abs(clean) ** 2) / len(clean)  # 求出信号功率
        noisePower = sum(abs(org_signal - clean) ** 2) / len(org_signal - clean)  # 求出噪声功率
        SNR_10 = 10 * np.log10(sigPower / noisePower)
        #SNR_10 = (sigPower / noisePower)
        return SNR_10

    @staticmethod
    def compute_mse(org_signal, final_signal):
        """
        计算均方根误差:均方根误差越小越好,越小去噪效果越好
        :param org_signal:原始信号
        :param final_signal:降噪后的信号
        :return: 均方根误差
        """
        data = np.array(org_signal)
        final_signal = np.array(final_signal)
        rmse = np.sqrt(np.mean(np.square(data - final_signal)))
        return rmse




if __name__ == "__main__":
    #path = 'D:\桌面文件夹\新建文件夹 (3)/1101000158.txt'
    path = 'D:\桌面文件夹\数据文件/0001001228.txt'

    data = read_txt_file(path)
    data = (data - np.mean(data)) / np.std(data)

    thr_select = 'sqtwolog'
    wave_bais = 'sym8'
    level = 3
    thr_way = 'soft'

    wave = Threshold(data,thr_select,wave_bais,level,thr_way)
    Y = wave.wavelet_dec_rec(data)

    snr = wave.compute_snr(wave.data,Y)
    rmse = wave.compute_mse(wave.data,Y)
    print('snr:{},rmse:{}'.format(snr,rmse))

    wave.plot_signal(data)






文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年4月2日 下午3:51
下一篇 2022年4月2日 下午3:58

相关推荐