This repository contains the code for the Word Sense Disambiguation (WSD) PyTorch models that have been trained and developed by the UCREL NLP Group at Lancaster University, UK.
Requires Python 3.10 or greater, it is best that you install the version of PyTorch you would like to use, e.g. CPU/GPU version etc before installing this package else you will get the default version of PyTorch for your operating system/setup, but we do require torch>=2.2,<3.0.
pip install wsd-torch-modelsHere we list the various WSD models we have implemented and how to use them.
An inference only implementation of the Bi-Encoder Model (BEM) for Word Sense Disambiguation from the paper Moving Down the Long Tail of Word Sense Disambiguation with Gloss Informed Bi-encoders. This is a bi-encoder model whereby it encodes the word(s) to disambiguate using the word(s) text context, e.g. whole sentence or document and it will encode every sense definition given and return the most similar sense definition for the given word(s). Unlike the original BEM model we use the same model to encode both the text to disambiguate and the label definitions.
These models were trained using the code from the following GitHub repository https://github.com/UCREL/experimental-wsd and ported over to this library for inference only use with easy saving and loading from the HuggingFace hub.
We currently have 4 pre-trained BEM models that predict sense labels from the USAS sense inventory which contains 232 sense categories, which in comparison to WordNet is very coarse (WordNet has approximately 117,000 senses), more details about these models and how they were trained can be found in our forthcoming paper:
- ucrelnlp/PyMUSAS-Neural-English-Small-BEM - 17 million parameter English only model.
- ucrelnlp/PyMUSAS-Neural-English-Base-BEM - 68 million parameter English only model.
- ucrelnlp/PyMUSAS-Neural-Multilingual-Small-BEM - 140 million parameter Multilingual model.
- ucrelnlp/PyMUSAS-Neural-Multilingual-Base-BEM - 307 million parameter Multilingual model.
Of which an example of how to run them can be found below, this particular example uses the Small English model:
from transformers import AutoTokenizer
import torch
from wsd_torch_models.bem import BEM
if __name__ == "__main__":
wsd_model_name = "ucrelnlp/PyMUSAS-Neural-English-Base-BEM"
wsd_model = BEM.from_pretrained(wsd_model_name)
tokenizer = AutoTokenizer.from_pretrained(wsd_model_name, add_prefix_space=True)
wsd_model.eval()
# Change this to the device you would like to use, e.g. cpu
model_device = "cpu"
wsd_model.to(device=model_device)
sentence = "The river bank was full of fish"
sentence_tokens = sentence.split()
with torch.inference_mode(mode=True):
# sub_word_tokenizer can be None when None it will download the appropriate tokenizer
# but generally it is better to give it the tokenizer as it saves the operation
# of checking if the tokenizer is already downloaded.
predictions = wsd_model.predict(sentence_tokens, sub_word_tokenizer=tokenizer, top_n=5)
for sentence_token, semantic_tags in zip(sentence_tokens, predictions):
print(f"Token: {sentence_token}")
print("Most likely tags: ")
for tag in semantic_tags:
tag_definition = wsd_model.label_to_definition[tag]
print(f"\t{tag}: {tag_definition}")
print()Output from running the code above:
Token: The
Most likely tags:
Z5: title: Grammatical bin description: Prepositions/adverbs/conjunctions, etc
I1.1: title: Money: Affluence description: Terms relating to (level of) wealth/prosperity
Z2: title: Geographical names description: Nouns that distinguish/identify a specific place (e.g. the name of a road, a city, a country, a continent, etc.)
M6: title: Location and direction description: Terms depicting position of/point of reference for X
I1: title: Money generally description: Terms relating to money generally
Token: river
Most likely tags:
W3: title: Geographical terms description: Geographical terms
M4: title: Means of transport (Water) description: Terms depicting means of transport/ways of transporting and/or travelling (by water)
Z2: title: Geographical names description: Nouns that distinguish/identify a specific place (e.g. the name of a road, a city, a country, a continent, etc.)
O2: title: Objects generally description: Terms relating to objects generally
N5: title: Quantities description: Terms depicting quantities
Token: bank
Most likely tags:
W3: title: Geographical terms description: Geographical terms
M4: title: Means of transport (Water) description: Terms depicting means of transport/ways of transporting and/or travelling (by water)
M6: title: Location and direction description: Terms depicting position of/point of reference for X
O2: title: Objects generally description: Terms relating to objects generally
I1: title: Money generally description: Terms relating to money generally
Token: was
Most likely tags:
Z5: title: Grammatical bin description: Prepositions/adverbs/conjunctions, etc
A3: title: Being description: General/abstract terms relating to being/existing
A5.4: title: Evaluation: Authenticity description: Evaluative terms depicting authenticity
O4.2: title: Judgement of appearance (pretty etc.) description: Descriptive terms relating to the appearance/look of X
A5.1: title: Evaluation: Good/bad description: Evaluative terms depicting quality
Token: full
Most likely tags:
N5.1: title: Entirety; maximum description: Terms depicting maximal/maximum quantities
N3.2: title: Measurement: Size description: Terms of measurement relating to size
I3.2: title: Work and employment: Professionalism description: Terms relating to (level of) professionalism
A5.1: title: Evaluation: Good/bad description: Evaluative terms depicting quality
M4: title: Means of transport (Water) description: Terms depicting means of transport/ways of transporting and/or travelling (by water)
Token: of
Most likely tags:
Z5: title: Grammatical bin description: Prepositions/adverbs/conjunctions, etc
N5.1: title: Entirety; maximum description: Terms depicting maximal/maximum quantities
N3.2: title: Measurement: Size description: Terms of measurement relating to size
N5: title: Quantities description: Terms depicting quantities
X9.1: title: Ability: Ability, intelligence description: Terms depicting (level of) ability/intelligence
Token: fish
Most likely tags:
L2: title: Living creatures generally description: Terms relating to living creatures (e.g. non-human)
F1: title: Food description: Terms relating to food and food preparation
S2.1: title: People: Female description: Terms relating to females
M4: title: Means of transport (Water) description: Terms depicting means of transport/ways of transporting and/or travelling (by water)
Z2: title: Geographical names description: Nouns that distinguish/identify a specific place (e.g. the name of a road, a city, a country, a continent, etc.)NOTE: the pre-trained models we have released come with the sense definitions they have been trained to predict, USAS sense definitions, if you would like to use a different list/set of sense definitions please look at the wsd_torch_models.bem.BEM.embed_and_set_label_definitions method which will allow you to change the sense definitions the model will predict. We have not tested how well these models will perform on zero shot sense prediction, e.g. training on one sense inventory and predicting on data using a different sense inventory.
All of these models have been trained on a portion of the ucrelnlp/English-USAS-Mosaico, specifically data/wikipedia_shard_0.jsonl.gz, which contains 1,083 English Wikipedia articles, with 444,880 sentences, 6.6 million tokens, with 5.3 million silver labelled tokens generated by a English rule based semantic tagger.
| Parameter | 17M English | 68M English | 140M Multilingual | 307M Multilingual |
|---|---|---|---|---|
| Layers | 7 | 19 | 22 | 22 |
| Hidden Size | 256 | 512 | 384 | 768 |
| Intermediate Size | 384 | 768 | 1152 | 1152 |
| Attention Heads | 4 | 8 | 6 | 12 |
| Total Parameters | 17M | 68M | 140M | 307M |
| Non-embedding Parameters | 3.9M | 42.4M | 42M | 110M |
| Max Sequence Length | 8,000 | 8,000 | 8,192 | 8,192 |
| Vocabulary Size | 50,368 | 50,368 | 256,000 | 256,000 |
| Tokenizer | ModernBERT | ModernBERT | Gemma 2 | Gemma 2 |
We have evaluated the models on 5 datasets from 5 different languages, 4 of these datasets are publicly available whereas one (the Irish data) requires permission from the data owner to access it. The results for these models using top 1 and top 5 accuracy results are shown below, for a more comprehensive comparison please see the technical report.
| Dataset | 17M English | 68M English | 140M Multilingual | 307M Multilingual |
|---|---|---|---|---|
| Top 1 | ||||
| Chinese | - | - | 42.2 | 47.9 |
| English | 66.4 | 70.1 | 66.0 | 70.2 |
| Finnish | - | - | 15.8 | 25.9 |
| Irish | - | - | 28.5 | 35.6 |
| Welsh | - | - | 21.7 | 42.0 |
| Top 5 | ||||
| Chinese | - | - | 66.3 | 70.4 |
| English | 87.6 | 90.0 | 88.9 | 90.1 |
| Finnish | - | - | 32.8 | 42.4 |
| Irish | - | - | 47.6 | 51.6 |
| Welsh | - | - | 40.8 | 56.4 |
The publicly available datasets can be found on HuggingFace Hub ucrelnlp/USAS-WSD.
Note the English models have not been evaluated on the non-English datasets as they are unlikely to be able to represent non-English text well or perform well on non-English data.
You can either use the dev container with your favourite editor, e.g. VSCode. Or you can create your setup locally below we demonstrate both.
In both cases they share the same tools, of which these tools are:
- uv for Python packaging and development
- make (OPTIONAL) for automation of tasks, not strictly required but makes life easier.
A dev container uses a docker container to create the required development environment, the Dockerfile we use for this dev container can be found at ./.devcontainer/Dockerfile. To run it locally it requires docker to be installed, you can also run it in a cloud based code editor, for a list of supported editors/cloud editors see the following webpage.
To run for the first time on a local VSCode editor (a slightly more detailed and better guide on the VSCode website):
- Ensure docker is running.
- Ensure the VSCode Dev Containers extension is installed in your VSCode editor.
- Open the command pallete
CMD + SHIFT + Pand then selectDev Containers: Rebuild and Reopen in Container
You should now have everything you need to develop, uv, make, for VSCode various extensions like Pylance, etc.
If you have any trouble see the VSCode website..
To run locally first ensure you have the following tools installed locally:
- uv for Python packaging and development. (version
0.9.9) - make (OPTIONAL) for automation of tasks, not strictly required but makes life easier.
- Ubuntu:
apt-get install make - Mac: Xcode command line tools includes
makeelse you can use brew. - Windows: Various solutions proposed in this blog post on how to install on Windows, including
Cygwin, andWindows Subsystem for Linux.
- Ubuntu:
When developing on the project you will want to install the Python package locally in editable format with all the extra requirements, this can be done like so:
uv syncThis code base uses isort, flake8 and mypy to ensure that the format of the code is consistent and contain type hints. ISort and mypy settings can be found within ./pyproject.toml and the flake8 settings can be found in ./.flake8. To run these linters:
make lintTo run the tests with code coverage (NOTE these are the code coverage tests that the Continuos Integration (CI) reports at the top of this README):
make testsThe default or recommended Python version is shown in [.python-version](./.python-version, currently 3.13, this can be changed using the uv command:
uv python pin
# uv python pin 3.13Some of the WSD models were originally trained using PyTorch Lightning, this section details how we convert these models to PyTorch models with a HuggingFace Pytorch Model Hub Mixin, the mixin allows the model to easily be loaded and saved from and to the HuggingFace hub, and then uploads these converted models to HuggingFace Hub.
The scripts has various arguments of which these are detailed in the help section of the script:
usage: convert_and_upload_bem_model.py [-h] [-r] [-t] [-m] hf_repository_id hf_branch model_checkpoint readme_template_path
Converts a PyTorch Lightning model to a PyTorch HuggingFace model and uploads it to the HuggingFace Hub. The script allows you to just update the model README, model tokenizer, the model itself, or any combinationof these options.
positional arguments:
hf_repository_id The repository ID to upload the model too on the HuggingFace Hub, e.g. ucrelnlp/PyMUSAS-Neural-English-Small-BEM
hf_branch The branch to upload the model too on the HuggingFace Hub, e.g. main, a branch named after the step the model was trained on.
model_checkpoint Path to the model checkpoint that you would like to upload
readme_template_path File path to the models README template
options:
-h, --help show this help message and exit
-r, --update-readme update model README
-t, --update-tokenizer
update model tokenizer
-m, --update-model update modelTo upload the model, tokenizer and README for all 4 models to the main branch:
uv run scripts/convert_and_upload_bem_model.py ucrelnlp/PyMUSAS-Neural-English-Small-BEM main checkpoints/bem_english_small/model-step=532637-validation_accuracy=0.99394.ckpt model_readmes/pymusas_bem.md -rmt
uv run scripts/convert_and_upload_bem_model.py ucrelnlp/PyMUSAS-Neural-English-Base-BEM main checkpoints/bem_english_base/model-step=532637-validation_accuracy=0.99669.ckpt model_readmes/pymusas_bem.md -rmt
uv run scripts/convert_and_upload_bem_model.py ucrelnlp/PyMUSAS-Neural-Multilingual-Small-BEM main checkpoints/bem_multilingual_small/model-step=392261-validation_accuracy=0.99615.ckpt model_readmes/pymusas_bem.md -rmt
uv run scripts/convert_and_upload_bem_model.py ucrelnlp/PyMUSAS-Neural-Multilingual-Base-BEM main checkpoints/bem_multilingual_base/model-step=240947-validation_accuracy=0.99625.ckpt model_readmes/pymusas_bem.md -rmtTo upload only an updated/new README:
uv run scripts/convert_and_upload_bem_model.py ucrelnlp/PyMUSAS-Neural-English-Small-BEM main checkpoints/bem_english_small/model-step=532637-validation_accuracy=0.99394.ckpt model_readmes/pymusas_bem.md -rAs of Python version 3.11:
from typing_extensions import Self- thetyping_extensionspackage can be removed and this can be replaced withfrom typing import Self
Technical report is forthcoming.
- Paul Rayson (p.rayson@lancaster.ac.uk)
- Andrew Moore (a.p.moore@lancaster.ac.uk / andrew.p.moore94@gmail.com)
- UCREL Research Centre (ucrel@lancaster.ac.uk) at Lancaster University.