socio4health.utils.harmonizer_utils.classify_rows#

socio4health.utils.harmonizer_utils.classify_rows(data: DataFrame, col1: str, col2: str, col3: str, new_column_name: str = 'category', MODEL_PATH: str = './bert_finetuned_classifier') DataFrame[source]#

Classify each row using a fine-tuned multiclass classification BERT model.

Parameters:
  • data (pd.DataFrame) – The DataFrame with text columns.

  • col1 (str) – Name of the first column containing survey-related text.

  • col2 (str) – Name of the second column containing survey-related text.

  • col3 (str) – Name of the third column containing survey-related text.

  • new_column_name (str, optional) – Name of the new column to store the predicted categories (default is category).

  • MODEL_PATH (str) – Path to the model weights (default is ./bert_finetuned_classifier)

Returns:

pd.DataFrame with a new prediction column.

Return type:

pd.DataFrame