灵活的线性卡尔曼滤波器
项目描述
林卡尔曼
linkalman是一个 python 包,用于解决具有高斯噪声的线性结构时间序列模型。与其他一些流行的用 python 编写的卡尔曼滤波器包相比,linkalman 具有以下几个优点:
- 考虑部分和完全不完整的测量
- 灵活方便的模型结构
- 稳健高效的实施
- 未知先验的正确实施
- 内置数值和 EM 算法
- 具有全面用户手册的开源
- 具有直观模型规范的模块化设计
安装
linkalman需要以下软件包才能运行:
- 麻木的
- 熊猫
- 网络x
- scipy
要安装linkalman,只需使用标准pip命令:
$ pip install linkalman
例子
在这里,我将提供一个使用linkalman. 有关更多示例,请参见此处,有关技术详细信息,请参见用户手册。
import pandas as pd
import numpy as np
from scipy.optimize import minimize
from linkalman.models import BaseConstantModel as BCM
import matplotlib.pyplot as plt
# Get data
df = pd.read_csv('https://raw.githubusercontent.com/jbrownlee/Datasets/master/daily-total-female-births.csv')
df['x'] = 1
df.set_index('Date', inplace=True)
首先,我们定义贝叶斯结构时间序列 (BSTS) 模型的系统动力学。这里我定义了一个随机线性趋势模型来从时间序列中提取趋势信息(详见用户手册的示例部分)
def my_f(theta):
sig1 = np.exp(theta[0])
sig2 = np.exp(theta[1])
sig3 = np.exp(theta[2])
F = np.array([[1, 1], [0, 1]])
Q = np.array([[sig1, 0], [0, sig2]])
R = np.array([[sig3]])
H = np.array([[1, 0]])
# Collect system matrices
M = {'F': F, 'Q': Q, 'H': H, 'R': R}
return M
接下来我们定义一个求解器或优化器,您可以选择任何您喜欢的求解器。这里我只使用scipy.optimize.minimize.
def my_solver(param, obj_func, verbose=False, **kwargs):
obj_ = lambda x: -obj_func(x)
res = minimize(obj_, param, **kwargs)
theta_opt = np.array(res.x)
fval_opt = res.fun
return theta_opt, fval_opt
现在我们可以拟合数据了。首先,我们初始化模型并输入系统动力学(my_f)和求解器(my_solver)。您也可以将关键字参数传递给 formy_f和my_solver。
model = BCM()
model.set_f(my_f)
model.set_solver(my_solver, method='nelder-mead',
options={'xatol': 1e-8, 'disp': True, 'maxiter': 10000})
theta_init = np.random.rand(3)
model.fit(df, theta_init, y_col=['Births'], x_col=['x'],
method='LLY')
df_LLY = model.predict(df)
这就对了!如果您想做额外的工作,您可以执行以下操作以围绕您的预测绘制置信区间。
df_LLY['kf_ub'] = df_LLY.Births_filtered + 1.96 * np.sqrt(df_LLY.Births_fvar)
df_LLY['kf_lb'] = df_LLY.Births_filtered - 1.96 * np.sqrt(df_LLY.Births_fvar)
df_LLY = df_LLY[df_LLY.index > '1959-01-01']
df_LLY.index = pd.to_datetime(df_LLY.index)
# Define plot function
def simple_plot(df, col_est, col_actual, col_ub, col_lb, label_est,
label_actual, title, figsize=(12, 8)):
ax = plt.figure(figsize=figsize)
plt.plot(df.index, df[col_est], 'r', label=label_est)
plt.scatter(df_LLY.index, df[col_actual], s=20, c='b',
marker='o', label=label_actual)
plt.fill_between(df.index, df[col_ub], df[col_lb], color='g', alpha=0.2)
ax.legend(loc='right', fontsize=9)
plt.title(title, fontsize=22)
plt.show()
simple_plot(df_LLY, 'Births_filtered', 'Births', 'kf_ub', 'kf_lb',
'Prediction', 'Births', 'Filtered Births Data')
执照
3 条款 BSD