分组时间序列(面板)数据的交叉验证

2024-04-24

我使用面板数据:随着时间的推移,我观察许多单位(例如人);对于每个单元,我都有相同固定时间间隔的记录。

当将数据分为训练集和测试集时,我们需要确保这两个集是不相交的并且顺序的,即训练集中的最新记录应该在测试集中最早的记录之前(参见例如此博客文章 https://robjhyndman.com/hyndsight/tscv/).

是否有面板数据交叉验证的标准 Python 实现?

我尝试过 Scikit-Learn时间序列分割 http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.TimeSeriesSplit.html,它不能解释群体,并且组随机分割 http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GroupShuffleSplit.html它无法解释数据的顺序性质,请参阅下面的代码。

import pandas as pd
import numpy as np
from sklearn.model_selection import GroupShuffleSplit, TimeSeriesSplit

# generate panel data
user = np.repeat(np.arange(10), 12)
time = np.tile(pd.date_range(start='2018-01-01', periods=12, freq='M'), 10)
data = (pd.DataFrame({'user': user, 'time': time})
        .sort_values(['time', 'user'])
        .reset_index(drop=True))

tscv = TimeSeriesSplit(n_splits=4)
for train_idx, test_idx in tscv.split(data):
    train = data.iloc[train_idx]
    test = data.iloc[test_idx]
    train_end = train.time.max().date()
    test_start = test.time.min().date()
    print('TRAIN:', train_end, '\tTEST:', test_start, '\tSequential:', train_end < test_start, sep=' ')

Output:

TRAIN: 2018-03-31   TEST: 2018-03-31    Sequential: False
TRAIN: 2018-05-31   TEST: 2018-05-31    Sequential: False
TRAIN: 2018-08-31   TEST: 2018-08-31    Sequential: False
TRAIN: 2018-10-31   TEST: 2018-10-31    Sequential: False

因此,在这个例子中,我希望训练集和测试集仍然是连续的。

有许多相关的旧帖子,但没有(令人信服的)答案,请参见例如

  • https://stackoverflow.com/questions/51861417/time-series-prediction-for-grouped-data https://stackoverflow.com/questions/51861417/time-series-prediction-for-grouped-data[现已删除]

  • 时间序列数据的分层交叉验证 https://stackoverflow.com/questions/46698792/stratified-cross-validation-of-timeseries-data


scikit-learn 上请求了此功能,我添加了一个PR https://github.com/getgaurav2/scikit-learn/blob/d4a3af5cc9da3a76f0266932644b884c99724c57/sklearn/model_selection/_split.py#L2243为了它 。 这项技术在最近的一些项目中得到了令人惊叹的结果Kaggle 笔记本 https://www.kaggle.com/search?q=getgaurav2 .

  • scikit-learn 功能请求 :https://github.com/scikit-learn/scikit-learn/issues/14257 https://github.com/scikit-learn/scikit-learn/issues/14257
  • scikit-learn 公关:https://github.com/scikit-learn/scikit-learn/pull/16236 https://github.com/scikit-learn/scikit-learn/pull/16236
  • 卡格尔笔记本 1 https://www.kaggle.com/jorijnsmit/found-the-holy-grail-grouptimeseriessplit下面的代码块
  • Kaggle 笔记本 2 https://www.kaggle.com/marketneutral/purged-time-series-cv-xgboost-optuna/(清除时间序列 CV):这是一个很好的修改gap不同组之间的参数。功能要求 https://github.com/scikit-learn/scikit-learn/issues/19072Scikit-learn 上也提出了同样的问题。
  • Kaggle 笔记本 3 https://www.kaggle.com/code/konradb/ts-10-validation-methods-for-time-series: 非常清楚地总结了所有方法。
from sklearn.model_selection._split import _BaseKFold, indexable, _num_samples
from sklearn.utils.validation import _deprecate_positional_args

# https://github.com/getgaurav2/scikit-learn/blob/d4a3af5cc9da3a76f0266932644b884c99724c57/sklearn/model_selection/_split.py#L2243
class GroupTimeSeriesSplit(_BaseKFold):
    """Time Series cross-validator variant with non-overlapping groups.
    Provides train/test indices to split time series data samples
    that are observed at fixed time intervals according to a
    third-party provided group.
    In each split, test indices must be higher than before, and thus shuffling
    in cross validator is inappropriate.
    This cross-validation object is a variation of :class:`KFold`.
    In the kth split, it returns first k folds as train set and the
    (k+1)th fold as test set.
    The same group will not appear in two different folds (the number of
    distinct groups has to be at least equal to the number of folds).
    Note that unlike standard cross-validation methods, successive
    training sets are supersets of those that come before them.
    Read more in the :ref:`User Guide <cross_validation>`.
    Parameters
    ----------
    n_splits : int, default=5
        Number of splits. Must be at least 2.
    max_train_size : int, default=None
        Maximum size for a single training set.
    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.model_selection import GroupTimeSeriesSplit
    >>> groups = np.array(['a', 'a', 'a', 'a', 'a', 'a',\
                           'b', 'b', 'b', 'b', 'b',\
                           'c', 'c', 'c', 'c',\
                           'd', 'd', 'd'])
    >>> gtss = GroupTimeSeriesSplit(n_splits=3)
    >>> for train_idx, test_idx in gtss.split(groups, groups=groups):
    ...     print("TRAIN:", train_idx, "TEST:", test_idx)
    ...     print("TRAIN GROUP:", groups[train_idx],\
                  "TEST GROUP:", groups[test_idx])
    TRAIN: [0, 1, 2, 3, 4, 5] TEST: [6, 7, 8, 9, 10]
    TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a']\
    TEST GROUP: ['b' 'b' 'b' 'b' 'b']
    TRAIN: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] TEST: [11, 12, 13, 14]
    TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a' 'b' 'b' 'b' 'b' 'b']\
    TEST GROUP: ['c' 'c' 'c' 'c']
    TRAIN: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]\
    TEST: [15, 16, 17]
    TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a' 'b' 'b' 'b' 'b' 'b' 'c' 'c' 'c' 'c']\
    TEST GROUP: ['d' 'd' 'd']
    """
    @_deprecate_positional_args
    def __init__(self,
                 n_splits=5,
                 *,
                 max_train_size=None
                 ):
        super().__init__(n_splits, shuffle=False, random_state=None)
        self.max_train_size = max_train_size

    def split(self, X, y=None, groups=None):
        """Generate indices to split data into training and test set.
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Training data, where n_samples is the number of samples
            and n_features is the number of features.
        y : array-like of shape (n_samples,)
            Always ignored, exists for compatibility.
        groups : array-like of shape (n_samples,)
            Group labels for the samples used while splitting the dataset into
            train/test set.
        Yields
        ------
        train : ndarray
            The training set indices for that split.
        test : ndarray
            The testing set indices for that split.
        """
        if groups is None:
            raise ValueError(
                "The 'groups' parameter should not be None")
        X, y, groups = indexable(X, y, groups)
        n_samples = _num_samples(X)
        n_splits = self.n_splits
        n_folds = n_splits + 1
        group_dict = {}
        u, ind = np.unique(groups, return_index=True)
        unique_groups = u[np.argsort(ind)]
        n_samples = _num_samples(X)
        n_groups = _num_samples(unique_groups)
        for idx in np.arange(n_samples):
            if (groups[idx] in group_dict):
                group_dict[groups[idx]].append(idx)
            else:
                group_dict[groups[idx]] = [idx]
        if n_folds > n_groups:
            raise ValueError(
                ("Cannot have number of folds={0} greater than"
                 " the number of groups={1}").format(n_folds,
                                                     n_groups))
        group_test_size = n_groups // n_folds
        group_test_starts = range(n_groups - n_splits * group_test_size,
                                  n_groups, group_test_size)
        for group_test_start in group_test_starts:
            train_array = []
            test_array = []
            for train_group_idx in unique_groups[:group_test_start]:
                train_array_tmp = group_dict[train_group_idx]
                train_array = np.sort(np.unique(
                                      np.concatenate((train_array,
                                                      train_array_tmp)),
                                      axis=None), axis=None)
            train_end = train_array.size
            if self.max_train_size and self.max_train_size < train_end:
                train_array = train_array[train_end -
                                          self.max_train_size:train_end]
            for test_group_idx in unique_groups[group_test_start:
                                                group_test_start +
                                                group_test_size]:
                test_array_tmp = group_dict[test_group_idx]
                test_array = np.sort(np.unique(
                                              np.concatenate((test_array,
                                                              test_array_tmp)),
                                     axis=None), axis=None)
            yield [int(i) for i in train_array], [int(i) for i in test_array]

GridSearchCV 示例。从 SO 帖子修改的代码here https://stackoverflow.com/questions/46732748/how-do-i-use-a-timeseriessplit-with-a-gridsearchcv-object-to-tune-a-model-in-sci.


import xgboost as xgb
from sklearn.model_selection import  GridSearchCV
import numpy as np
groups = np.array(['a', 'a', 'a', 'b', 'b', 'c'])

X = np.array([[4, 5, 6, 1, 0, 2], [3.1, 3.5, 1.0, 2.1, 8.3, 1.1]]).T
y = np.array([1, 6, 7, 1, 2, 3])

model = xgb.XGBRegressor()
param_search = {'max_depth' : [3, 5]}

tscv = GroupTimeSeriesSplit(n_splits=2)
gsearch = GridSearchCV(estimator=model, cv=tscv,
                        param_grid=param_search)
gsearch.fit(X, y , groups=groups)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

分组时间序列(面板)数据的交叉验证 的相关文章

随机推荐

  • 在 Apps 脚本中对同一工作表使用 Google Sheets API

    通过 SpreadsheetApp 全局 使用绑定到电子表格的 Apps 脚本来影响电子表格非常简单 但是 有一些功能 例如在工作表上获取 设置过滤器 只能通过 Google Sheets REST API 访问 我见过一个在 Apps 脚
  • 如何将 Red5 与 Asp.net 结合使用

    我想在线录制语音 我想我需要使用FMS或Red5 但我不知道如何将Red5与Asp net一起使用 实际上这是我第一次尝试处理这样的事情 目前我是一名 net开发人员 所以请有人告诉我一种处理它的方法 并告诉我如何将 Red5 与 Asp
  • 正则表达式:如何匹配包含重复模式的字符串?

    是否有一个正则表达式模式可以匹配包含重复模式的字符串 例如 a b c d y z 你有什么主意吗 也许您正在寻找这样的东西 这将匹配以逗号分隔的表单序列列表 where and 可以是任何字符
  • 使自定义 monad 转换器成为 MonadError 的实例

    我想让我的 monad 转换器成为一个实例MonadError如果转换后的单子是一个实例 基本上我希望我的变压器的行为与内置变压器一样 例如有一个MonadError实例为StateT MonadError e m gt MonadErro
  • 如何从另一个 sbt 项目引用外部 sbt 项目?

    我对 Scala 应用程序和通用核心库进行了以下设置 根 gt ApplicationA gt project gt build sbt gt CoreLibrary gt project gt build sbt 我想将 Applicat
  • 将 Yup 验证错误转换为可用对象

    Problem 我有一个 formik 表单 需要有 2 个不同的验证模式 具体取决于用户使用哪个按钮提交 我看到有些人说使用状态来决定哪个 但我想避免使用状态 因为在这种情况下感觉不对 我看过是的文档 https www npmjs co
  • 格式化整数时 printf 中的精度字段

    当我执行这两行时 printf 5d n 3 use of precision filed printf 05d n 3 use of 0 flag to prepend with 0 我得到以下输出 00003 00003 结果相同 所以
  • Google 字体无法在移动设备中加载

    我读过类似的帖子 但这个问题有点不同 我有 rest of the code 在 css 样式文件中我有 body font family Source Sans Pro sans serif rest of the code 它在浏览器中
  • 将新对象附加到 JSON 文件中的数组

    如何将附加对象添加到现有 JSON 文件 即对象数组 中 这是我的 JS 代码 const fs require fs let Human Name John age 20 Human JSON stringify Human null 2
  • 自动完成搜索字符串的多个部分,然后返回最可能的部分

    有点像这个问题 https stackoverflow com questions 824144 how do i use jquery autocomplete for multiple words 我有很多文本片段 每天都会使用很多很多
  • 使用 nokogiri 干式搜索网站的每个页面

    我想搜索网站的每个页面 我的想法是找到页面上保留在域内的所有链接 访问它们 然后重复 我也必须采取措施 避免重复努力 所以开始很容易 page http example com nf Nokogiri HTML open page link
  • Azure Functions 中 PowerShell 脚本的选项在哪里

    我想使用 PowerShell 创建 Azure Function 当我谈到 Azure 希望我选择要创建的函数类型时 唯一可用的语言是 C F 和 JavaScript 我错过了什么吗 如何使用 PowerShell 创建 Azure 函
  • 尝试使用 Comparator 按名称排序、忽略大小写以及先处理空值

    我在使用 Java 8 Comparator 类对项目列表进行排序时遇到问题 我当前的工作比较器如下 comparator Comparator comparing Person getName Comparator nullsFirst
  • Android 中从时间戳获取日期名称

    我有一个类 当它初始化时 它会使用公共 getter 在私有字段中记录初始化时间 public class TestClass private long mTimestamp public TestClass mTimestamp Syst
  • 每个 ajax 请求都会调用 preRenderView

    我正在使用 jquery waypoints 和 jsf 实现无限滚动link http kahimyang info kauswagan code blogs 1405 building a page with infinite scro
  • CSS自定义组合框问题

    我需要一个自定义组合框 所以 我实施了ul 问题是我无法通过单击在顶部打开组合框列表button 展示的同时ul 它移动button到网页底部 Code ul width 100px background color rgb 224 224
  • 在 Emacs 中定义新的工具提示

    我想向 emacs 添加自定义工具提示 更具体地说 每当我将鼠标悬停在符号 函数 变量 名称上时 用我的鼠标我想看到带有符号定义的工具提示 我知道我可以使用 cscope 这样的工具找到此类信息 但我不知道如何找到 将 cscope 的输出
  • 运行烘焙命令时出现 SQLSTATE HY000 2002

    我在运行烘焙命令时遇到问题 我认为它与 mysql 有关 但我在 Stackoverflow 上没有找到此错误的任何解决方案 这是我的app php Datasources gt default gt className gt Cake D
  • Kafka的消息键有什么特别的地方吗?

    我没有看到任何提及消息键 org apache kafka clients producer ProducerRecord key 除了它们可以用于主题分区 我可以自由地将我喜欢的任何数据放入密钥中 还是有一些我应该遵守的特殊语义 该密钥似
  • 分组时间序列(面板)数据的交叉验证

    我使用面板数据 随着时间的推移 我观察许多单位 例如人 对于每个单元 我都有相同固定时间间隔的记录 当将数据分为训练集和测试集时 我们需要确保这两个集是不相交的并且顺序的 即训练集中的最新记录应该在测试集中最早的记录之前 参见例如此博客文章