TransWikia.com

Is it possible to have stratified train-test split of a set based on two columns?

Data Science Asked on January 9, 2021

Consider a dataframe that contains two columns, text and label. I can very easily create a stratified train-test split using sklearn.model_selection.train_test_split. The only thing I have to do is to set the column I want to use for the stratification (in this case label).

Now, consider a dataframe that contains three columns, text, subreddit, and label. I would like to make a stratified train-test split using the label column, but I also want to make sure that there is no bias in terms of the subreddit column. E.g., it’s possible that the test set has way more comments coming from subreddit X while the train set does not.

How can I do this in Python?

One Answer

One option would be to feed an array of both variables to the stratify parameter which accepts multidimensional arrays too. Here's the description from the scikit documentation:

stratify array-like, default=None

If not None, data is split in a stratified fashion, using this as the class labels.


Here is an example:

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

# create dummy data with unbalanced feature value distribution
X = pd.DataFrame(np.concatenate((np.random.randint(0, 3, 500), np.random.randint(0, 10, 500)),axis=0).reshape((500, 2)), columns=["text", "subreddit"])
y = pd.DataFrame(np.random.randint(0,2, 500).reshape((500, 1)), columns=["label"])

# split stratified to target variable and subreddit col
X_train, X_test, y_train, y_test = train_test_split(
    X, pd.concat([X["subreddit"], y], axis=1), stratify=pd.concat([X["subreddit"], y], axis=1))

# remove subreddit cols from target variable arrays
y_train = y_train.drop(["subreddit"], axis=1)
y_test = y_test.drop(["subreddit"], axis=1)

As you can see the split is stratified to subreddit too:

Train data shares for subreddits

X_train.groupby("subreddit").count()/len(X_train)

gives

text
subreddit   
0   0.232000
1   0.232000
2   0.213333
3   0.034667
4   0.037333
5   0.045333
6   0.056000
7   0.056000
8   0.048000
9   0.045333

Test data shares for subreddits

X_test.groupby("subreddit").count()/len(X_test)

gives

text
subreddit   
0   0.232
1   0.240
2   0.208
3   0.032
4   0.032
5   0.048
6   0.056
7   0.056
8   0.048
9   0.048

Naturally, this only works if you have sufficient data to stratify to subreddit and the target variable at the same time. Otherwise scikit learn will throw an exception.

Correct answer by Sammy on January 9, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP