前言
对于一个序列而言,求均值和方差根据定义式是不难的,其时空复杂度均为 。但有的时候,我们的样本是一个一个给的,此时新来了一个样本,我们总不可能把原来的样本都捞出来再算一次均值、方差吧,那样时空复杂度都是 了。因此,我们需要一个递推的方式,假设我们已知前 个样本的均值和方差 ,且知道了新的样本 ,以 复杂度给出 。
数学推导
我们知道样本均值、样本方差的无偏估计如下
不妨令
则均值递推式如下
方差递推式的推导要复杂一些,首先将式 稍微变形,引入式 。
然后引入式 。
然后再对样本方差定义式化简,得到式 。
式 的推导第三行到第四行的等号是因为完全平方展开的交叉项为 ,具体推导可见式 。
最终得到递推关系式(化简成这样为了保留 公共项,减少计算量)
C++代码实现
核心代码其实只有这么一点
void SeqMeanVar::AppendImpl(double new_value) {
double xn1_mun = new_value - m_mean; // x(n + 1) - mu(n)
double rev_beta_n = 1 - 1 / m_n; // 1 - beta(n)
++ m_n;
double beta_n1 = 1 / m_n; // beta(n + 1)
m_mean += beta_n1 * xn1_mun; // mu(n + 1) = mu(n) + beta(n + 1) * (x(n + 1) - mu(n))
m_var = rev_beta_n * m_var + beta_n1 * xn1_mun * xn1_mun; // var(n + 1) = (1 - beta(n)) * var(n) + beta(n + 1) * (x(n + 1) - mu(n))^2
}
完整代码(含测试样例)如下——
#include <iostream> // cout
// 有初始值的均值迭代器
class SeqMeanVar
{
public:
SeqMeanVar();
SeqMeanVar(double init_value);
const double GetN() const;
const double GetMean() const;
const double GetVar() const;
// type=0 为检查均值是否有意义, type=1 为检查方差是否有意义
bool IsValid(int type=1) const;
void Append(double new_value);
private:
void InitialConstruct(double init_value);
void AppendImpl(double new_value);
double m_n;
double m_mean;
double m_var;
friend std::ostream &operator<<(std::ostream &os, SeqMeanVar &seq);
};
SeqMeanVar::SeqMeanVar()
: m_n(0)
, m_mean(0)
, m_var(0)
{
}
SeqMeanVar::SeqMeanVar(double init_value)
{
InitialConstruct(init_value);
}
void SeqMeanVar::InitialConstruct(double init_value) {
m_n = 1;
m_mean = init_value;
m_var = 0;
}
const double SeqMeanVar::GetN() const {
return m_n;
}
const double SeqMeanVar::GetMean() const {
return m_mean;
}
const double SeqMeanVar::GetVar() const {
return m_var;
}
bool SeqMeanVar::IsValid(int type) const {
switch (type) {
case 0: return m_n >= 1; // 检查均值
case 1: return m_n >= 2; // 检查方差
default: return false;
}
}
void SeqMeanVar::Append(double new_value) {
if (m_n == 0)
InitialConstruct(new_value);
else
AppendImpl(new_value);
}
void SeqMeanVar::AppendImpl(double new_value) {
double xn1_mun = new_value - m_mean; // x(n + 1) - mu(n)
double rev_beta_n = 1 - 1 / m_n; // 1 - beta(n)
++ m_n;
double beta_n1 = 1 / m_n; // beta(n + 1)
m_mean += beta_n1 * xn1_mun; // mu(n + 1) = mu(n) + beta(n + 1) * (x(n + 1) - mu(n))
m_var = rev_beta_n * m_var + beta_n1 * xn1_mun * xn1_mun; // var(n + 1) = (1 - beta(n)) * var(n) + beta(n + 1) * (x(n + 1) - mu(n))^2
}
std::ostream &operator<<(std::ostream &os, SeqMeanVar &seq) {
os << "(n = " << seq.m_n << ", mean = " << seq.m_mean << ", var = " << seq.m_var << ")";
return os;
}
int main()
{
SeqMeanVar seqMV(2);
std::cout << seqMV << std::endl;
seqMV.Append (4);
std::cout << seqMV << std::endl;
seqMV.Append (6);
std::cout << seqMV << std::endl;
seqMV.Append (7);
std::cout << seqMV << std::endl;
seqMV.Append (8);
std::cout << seqMV << std::endl;
getchar();
return 0;
}
文章出处登录后可见!
已经登录?立即刷新