Skip to content

Training a Mention Detection Model

Load a dataset

Copy / Paste this code to define the dataset loading functions

Dataset Loading Functions
import pandas as pd

import os, requests, zipfile
from tqdm.auto import tqdm
from collections import Counter

from pathlib import Path
import json
from collections import Counter

from propp_fr import load_spacy_model
from propp_fr import load_text_file, generate_tokens_df, save_tokens_df
from propp_fr import load_tokens_df, load_entities_df, add_features_to_entities, save_entities_df
from propp_fr import mentions_detection_LOOCV_full_model_training, generate_NER_model_card_from_LOOCV_directory

def realign_tokens_offsets(tokens_df, entities_df):
    start_tokens = []
    end_tokens = []
    new_byte_onsets = []
    new_byte_offsets = []

    for mention_byte_onset, mention_byte_offset in entities_df[["byte_onset", "byte_offset"]].values:
        start_token = tokens_df[tokens_df["byte_offset"] > mention_byte_onset].index.min()
        end_token = tokens_df[tokens_df["byte_onset"] < mention_byte_offset].index.max()
        new_byte_onsets.append(tokens_df.loc[start_token, "byte_onset"])
        new_byte_offsets.append(tokens_df.loc[end_token, "byte_offset"])

        start_tokens.append(start_token)
        end_tokens.append(end_token)

    entities_df["start_token"] = start_tokens
    entities_df["end_token"] = end_tokens
    entities_df["byte_onset"] = new_byte_onsets
    entities_df["byte_offset"] = new_byte_offsets

    return entities_df

def extract_mention_text(text_content, entities_df):
    mention_texts = []
    for mention_byte_onset, mention_byte_offset in entities_df[["byte_onset", "byte_offset"]].values:
        mention_texts.append(text_content[mention_byte_onset:mention_byte_offset])
    entities_df["text"] = mention_texts
    entities_df["text"] = entities_df["text"].astype(str)
    return entities_df

def load_dataset(dataset_name, local_dataset_directory="loaded_datasets", preprocess=False, force_download=False):
    # Where datasets will be stored
    os.makedirs(local_dataset_directory, exist_ok=True)

    # Dataset URL
    dataset_URL_path = (
        "https://lattice-8094.github.io/propp/datasets/"
        f"{dataset_name}_propp_minimal_implementation.zip"
    )

    # Local paths
    archive_name = dataset_URL_path.split("/")[-1]
    archive_path = os.path.join(local_dataset_directory, archive_name)
    files_directory = os.path.join(local_dataset_directory, dataset_name.replace(".zip", ""))

    # Download if needed
    if not os.path.exists(archive_path) or force_download:
        print("Downloading dataset...")
        response = requests.get(dataset_URL_path, stream=True)
        response.raise_for_status()

        with open(archive_path, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)

        print(f"Downloaded dataset to {archive_path}")
    else:
        print("Archive already exists, skipping download.")

    # Extract if needed
    if not os.path.exists(files_directory) or force_download:
        print("Extracting dataset...")
        with zipfile.ZipFile(archive_path, "r") as zip_ref:
            zip_ref.extractall(files_directory)
        print(f"Dataset extracted to {files_directory}")
    else:
        print("Dataset already extracted.")

    if os.path.exists(os.path.join(files_directory, "split_config.json")):
        with open(os.path.join(files_directory, "split_config.json"), "r") as f:
            config = json.load(f)
            if "dataset_language" in config.keys():
                dataset_language = config["dataset_language"]
                print(dataset_language)
            else:
                print(f"No language configuration. Default to English.")
                dataset_language = "en"
    else:
        print(f"Config not found for '{dataset_name}'")
        dataset_language = "en"

    if dataset_language == "fr":
        spacy_model_name = "fr_dep_news_trf"
    elif dataset_language == "ru":
        spacy_model_name = "ru_core_news_lg"
    else:
        spacy_model_name = "en_core_web_trf"

    all_entities_files = sorted(list(set([p.stem for p in Path(files_directory).iterdir() if p.suffix == ".entities"])))
    all_tokens_files = sorted(list(set([p.stem for p in Path(files_directory).iterdir() if p.suffix == ".tokens"])))

    if len(all_entities_files) != len(all_tokens_files) or preprocess:
        spacy_model = load_spacy_model(spacy_model_name)

    for file_name in tqdm(all_entities_files, leave=False):
        if os.path.exists(os.path.join(files_directory, file_name + ".tokens")) and not preprocess:
            continue
        text_content = load_text_file(file_name, files_directory)
        tokens_df = generate_tokens_df(text_content, spacy_model, verbose=0)
        entities_df = load_entities_df(file_name, files_directory)

        entities_df = realign_tokens_offsets(tokens_df, entities_df)
        entities_df = extract_mention_text(text_content, entities_df)

        entities_df = add_features_to_entities(entities_df, tokens_df)

        if dataset_language != "fr":
            entities_df["gender"] = "Not_Assigned" # Not available for english
            entities_df["number"] = "Not_Assigned" # Not available for english
            entities_df["grammatical_person"] = "4" # Not available for english

        save_entities_df(entities_df, file_name, files_directory)
        save_tokens_df(tokens_df, file_name, files_directory)

    print(f"Dataset Directory: {files_directory}")

Choose the Dataset to load. See Available Dataset.

# Available Datasets: ["long-litbank-fr-PER-only", "litbank-fr", "litbank-ru", "litbank", "conll2003-NER", "ontonotes5_english-NER"]
dataset_name = "litbank"

local_dataset_directory="loaded_datasets"
load_dataset(dataset_name, local_dataset_directory=local_dataset_directory)

files_directory = os.path.join(local_dataset_directory, dataset_name)

if os.path.exists(os.path.join(files_directory, "split_config.json")):
    with open(os.path.join(files_directory, "split_config.json"), "r") as f:
        config = json.load(f)
    print(f"\nConfiguration:\n{list(config.keys())}")
Archive already exists, skipping download.  
Dataset already extracted.  
['en']  

Dataset Directory: loaded_datasets/litbank  

Dataset contains a config.json file with the language of the dataset, and possibly the train, valid, test splits

Configuration:  
['test_0', 'test_1', 'test_2', 'test_3', 'test_4', 'test_5', 'test_6', 'test_7', 'test_8', 'test_9', 'dataset_language']
import os, requests, zipfile
from tqdm.auto import tqdm
from collections import Counter

# Where datasets will be stored
local_dataset_directory = "loaded_datasets"
os.makedirs(local_dataset_directory, exist_ok=True)

# Available Datasets: long-litbank-fr-PER-only ; litbank-fr ; litbank ; conll2003-NER

dataset_name = "conll2003-NER" #

# Dataset URL
dataset_URL_path = (
    "https://lattice-8094.github.io/propp/datasets/"
    f"{dataset_name}_propp_minimal_implementation.zip"
)

# Local paths
archive_name = dataset_URL_path.split("/")[-1]
archive_path = os.path.join(local_dataset_directory, archive_name)
files_directory = os.path.join(local_dataset_directory, archive_name.replace(".zip", ""))

# Download if needed
if not os.path.exists(archive_path):
    print("Downloading dataset...")
    response = requests.get(dataset_URL_path, stream=True)
    response.raise_for_status()

    with open(archive_path, "wb") as f:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                f.write(chunk)

    print(f"Downloaded dataset to {archive_path}")
else:
    print("Archive already exists, skipping download.")

# Extract if needed
if not os.path.exists(files_directory):
    print("Extracting dataset...")
    with zipfile.ZipFile(archive_path, "r") as zip_ref:
        zip_ref.extractall(files_directory)
    print(f"Dataset extracted to {files_directory}")
else:
    print("Dataset already extracted.")

Preprocess the dataset

from pathlib import Path

all_files = sorted(list(set([p.stem for p in Path(files_directory).iterdir() if p.is_file()])))
len(all_files)

Load the spaCy model

Here, as the dataset is in English, and we have a GPU available, we will use the en_core_web_trf model.

from propp_fr import load_spacy_model

spacy_model = load_spacy_model("en_core_web_trf")

Define Preprocessing Functions

def realign_tokens_offsets(tokens_df, entities_df):
    start_tokens = []
    end_tokens = []
    new_byte_onsets = []
    new_byte_offsets = []

    for mention_byte_onset, mention_byte_offset in entities_df[["byte_onset", "byte_offset"]].values:
        start_token = tokens_df[tokens_df["byte_offset"] > mention_byte_onset].index.min()
        end_token = tokens_df[tokens_df["byte_onset"] < mention_byte_offset].index.max()
        new_byte_onsets.append(tokens_df.loc[start_token, "byte_onset"])
        new_byte_offsets.append(tokens_df.loc[end_token, "byte_offset"])

        start_tokens.append(start_token)
        end_tokens.append(end_token)

    entities_df["start_token"] = start_tokens
    entities_df["end_token"] = end_tokens
    entities_df["byte_onset"] = new_byte_onsets
    entities_df["byte_offset"] = new_byte_offsets

    return entities_df

def extract_mention_text(text_content, entities_df):
    mention_texts = []
    for mention_byte_onset, mention_byte_offset in entities_df[["byte_onset", "byte_offset"]].values:
        mention_texts.append(text_content[mention_byte_onset:mention_byte_offset])
    entities_df["text"] = mention_texts
    entities_df["text"] = entities_df["text"].astype(str)
    return entities_df

Preprocess the dataset to make it ready for mention spans detection training.

from propp_fr import load_text_file, generate_tokens_df, save_tokens_df
from propp_fr import load_tokens_df, load_entities_df, add_features_to_entities, save_entities_df

for file_name in tqdm(all_files):
    if os.path.exists(os.path.join(files_directory, file_name + ".tokens")):
        continue
    text_content = load_text_file(file_name, files_directory)
    tokens_df = generate_tokens_df(text_content, spacy_model, verbose=0)
    entities_df = load_entities_df(file_name, files_directory)

    entities_df = realign_tokens_offsets(tokens_df, entities_df)
    entities_df = extract_mention_text(text_content, entities_df)

    entities_df = add_features_to_entities(entities_df, tokens_df)
    entities_df["gender"] = "Not_Assigned" # Not available for english
    entities_df["number"] = "Not_Assigned" # Not available for english
    entities_df["grammatical_person"] = "4" # Not available for english

    save_entities_df(entities_df, file_name, files_directory)
    save_tokens_df(tokens_df, file_name, files_directory)
from propp_fr import mentions_detection_LOOCV_full_model_training, generate_NER_model_card_from_LOOCV_directory

from collections import Counter

NER_categories = []
for file_name in all_files:
    entities_df = load_entities_df(file_name, files_directory)
    NER_categories.extend(entities_df["cat"].tolist())
print(Counter(NER_categories))
NER_cat_list = list(set(NER_categories))
print(NER_cat_list)

Define Training Parameters

subword_pooling_strategy = "first_last"
nested_levels = [0]
tagging_scheme = "BIOES"
test_split = [file for file in all_files if file.startswith("test")]
print(f"Test split: {len(test_split):,}")

model_name = "FacebookAI/roberta-large"
embedding_model_name = model_name.split("/")[-1]
trained_model_directory = os.path.join(files_directory, f"mentions_detection_model_{embedding_model_name}")

Train and Evaluate the model And Generate the Model Card

mentions_detection_LOOCV_full_model_training(files_directory=files_directory,
                                             trained_model_directory=trained_model_directory,
                                             model_name=model_name,
                                             subword_pooling_strategy=subword_pooling_strategy,
                                             nested_levels=nested_levels,
                                             NER_cat_list=NER_cat_list,
                                             tagging_scheme=tagging_scheme,
                                             train_final_model=False,
                                             files_to_use_in_cross_validation=[test_split],
                                             verbose=0)

generate_NER_model_card_from_LOOCV_directory(trained_model_directory)
Model Card Example

language: fr tags: - NER - literary-texts - nested-entities - propp-fr license: apache-2.0 metrics: - f1 - precision - recall base_model: - FacebookAI/roberta-large pipeline_tag: token-classification


INTRODUCTION:

This model, developed as part of the propp-fr project, is a NER model built on top of roberta-large embeddings, trained to predict nested entities in french, specifically for literary texts.

The predicted entities are:

  • PER: Person names (real or fictional)
    • ORG: Organizations (companies, institutions)
    • LOC: Geographical locations (non-political: mountains, rivers, cities)
    • MISC: Miscellaneous entities (events, nationalities, products, etc.)

MODEL PERFORMANCES (TEST SET):

NER_tag precision recall f1_score support support %
LOC 93.74% 93.41% 93.57% 1,668 29.53%
ORG 92.98% 93.26% 93.12% 1,661 29.41%
PER 97.90% 98.02% 97.96% 1,617 28.63%
MISC 81.33% 81.91% 81.62% 702 12.43%
micro_avg 93.16% 93.25% 93.21% 5,648 100.00%
macro_avg 91.49% 91.65% 91.57% 5,648 100.00%

TRAINING PARAMETERS:

  • Entities types: ['MISC', 'ORG', 'PER', 'LOC']
  • Tagging scheme: BIOES
  • Nested entities levels: [0]
  • Split strategy: Leave-one-out cross-validation (1393 files)
  • Train/Validation split: 0.85 / 0.15
  • Batch size: 16
  • Initial learning rate: 0.00014

MODEL ARCHITECTURE:

Model Input: Maximum context roberta-large embeddings (1024 dimensions)

  • Locked Dropout: 0.5

  • Projection layer:

    • layer type: highway layer
    • input: 1024 dimensions
    • output: 2048 dimensions
  • BiLSTM layer:

    • input: 2048 dimensions
    • output: 256 dimensions (hidden state)
  • Linear layer:

    • input: 256 dimensions
    • output: 17 dimensions (predicted labels with BIOES tagging scheme)
  • CRF layer

Model Output: BIOES labels sequence

HOW TO USE:

Propp Documentation

PREDICTIONS CONFUSION MATRIX:

Gold Labels LOC ORG PER MISC O support
LOC 1,558 38 1 22 49 1,668
ORG 30 1,549 2 12 68 1,661
PER 6 7 1,585 2 17 1,617
MISC 25 24 10 575 68 702
O 43 48 21 96 0 208