Scikit-learn的train_test_split与索引

9 浏览
0 Comments

Scikit-learn的train_test_split与索引

在使用train_test_split()时,我如何获得数据的原始索引?

我目前的情况是这样的:

from sklearn.cross_validation import train_test_split
import numpy as np
data = np.reshape(np.randn(20),(10,2)) # 10个训练样例
labels = np.random.randint(2, size=10) # 10个标签
x1, x2, y1, y2 = train_test_split(data, labels, size=0.2)

但是这样并没有给出原始数据的索引。

一个解决方法是将索引添加到数据中(例如data = [(i, d) for i, d in enumerate(data)]),然后将它们传递给train_test_split,然后再次展开。

是否有更简洁的解决方法?

0