socio4health.utils.harmonizer_utils.classify_rows#
- socio4health.utils.harmonizer_utils.classify_rows(data: DataFrame, col1: str, col2: str, col3: str, new_column_name: str = 'category', MODEL_PATH: str = './bert_finetuned_classifier') DataFrame [source]#
Classify each row using a fine-tuned multiclass classification
BERT
model.- Parameters:
data (pd.DataFrame) – The DataFrame with text columns.
col1 (str) – Name of the first column containing survey-related text.
col2 (str) – Name of the second column containing survey-related text.
col3 (str) – Name of the third column containing survey-related text.
new_column_name (str, optional) – Name of the new column to store the predicted categories (default is
category
).MODEL_PATH (str) – Path to the model weights (default is
./bert_finetuned_classifier
)
- Returns:
pd.DataFrame with a new prediction column.
- Return type: