Skip to main content

灵活的线性卡尔曼滤波器

项目描述

林卡尔曼

linkalman是一个 python 包,用于解决具有高斯噪声的线性结构时间序列模型。与其他一些流行的用 python 编写的卡尔曼滤波器包相比,linkalman 具有以下几个优点:

  • 考虑部分和完全不完整的测量
  • 灵活方便的模型结构
  • 稳健高效的实施
  • 未知先验的正确实施
  • 内置数值和 EM 算法
  • 具有全面用户手册的开源
  • 具有直观模型规范的模块化设计

安装

linkalman需要以下软件包才能运行:

  • 麻木的
  • 熊猫
  • 网络x
  • scipy

要安装linkalman,只需使用标准pip命令:

$ pip install linkalman

例子

在这里,我将提供一个使用linkalman. 有关更多示例,请参见此处,有关技术详细信息,请参见用户手册

import pandas as pd
import numpy as np
from scipy.optimize import minimize
from linkalman.models import BaseConstantModel as BCM
import matplotlib.pyplot as plt


# Get data
df = pd.read_csv('https://raw.githubusercontent.com/jbrownlee/Datasets/master/daily-total-female-births.csv')
df['x'] = 1
df.set_index('Date', inplace=True)

首先,我们定义贝叶斯结构时间序列 (BSTS) 模型的系统动力学。这里我定义了一个随机线性趋势模型来从时间序列中提取趋势信息(详见用户手册的示例部分)

def my_f(theta):
    sig1 = np.exp(theta[0])
    sig2 = np.exp(theta[1])
    sig3 = np.exp(theta[2])

    F = np.array([[1, 1], [0, 1]])
    Q = np.array([[sig1, 0], [0, sig2]]) 
    R = np.array([[sig3]])
    H = np.array([[1, 0]])
    # Collect system matrices
    M = {'F': F, 'Q': Q, 'H': H, 'R': R}

    return M 

接下来我们定义一个求解器或优化器,您可以选择任何您喜欢的求解器。这里我只使用scipy.optimize.minimize.

def my_solver(param, obj_func, verbose=False, **kwargs):
    obj_ = lambda x: -obj_func(x)
    res = minimize(obj_, param, **kwargs)
    theta_opt = np.array(res.x)
    fval_opt = res.fun
    return theta_opt, fval_opt

现在我们可以拟合数据了。首先,我们初始化模型并输入系统动力学(my_f)和求解器(my_solver)。您也可以将关键字参数传递给 formy_fmy_solver

model = BCM()
model.set_f(my_f)
model.set_solver(my_solver, method='nelder-mead', 
        options={'xatol': 1e-8, 'disp': True, 'maxiter': 10000})
theta_init = np.random.rand(3)
model.fit(df, theta_init, y_col=['Births'], x_col=['x'], 
              method='LLY')
df_LLY = model.predict(df)

这就对了!如果您想做额外的工作,您可以执行以下操作以围绕您的预测绘制置信区间。

df_LLY['kf_ub'] = df_LLY.Births_filtered + 1.96 * np.sqrt(df_LLY.Births_fvar)
df_LLY['kf_lb'] = df_LLY.Births_filtered - 1.96 * np.sqrt(df_LLY.Births_fvar)
df_LLY = df_LLY[df_LLY.index > '1959-01-01']
df_LLY.index = pd.to_datetime(df_LLY.index)

# Define plot function
def simple_plot(df, col_est, col_actual, col_ub, col_lb, label_est,
                label_actual, title, figsize=(12, 8)):
    ax = plt.figure(figsize=figsize)
    plt.plot(df.index, df[col_est], 'r', label=label_est)
    plt.scatter(df_LLY.index, df[col_actual], s=20, c='b', 
                marker='o', label=label_actual)
    plt.fill_between(df.index, df[col_ub], df[col_lb], color='g', alpha=0.2)
    ax.legend(loc='right', fontsize=9)
    plt.title(title, fontsize=22)
    plt.show()
simple_plot(df_LLY, 'Births_filtered', 'Births', 'kf_ub', 'kf_lb',  
           'Prediction', 'Births', 'Filtered Births Data')

执照

3 条款 BSD

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

源分布

linkalman-0.11.5.tar.gz (25.7 kB 查看哈希

已上传 source