Custom Embedding Models¶
One of the key features of TaxoTagger is the ability to use custom embedding models. This allows users to use their own models or models from other sources to calculate the embeddings for the given FASTA file.
TaxoTagger comes with a few pre-trained models, but users can add their own models to the tool. For this, users need to follow the steps below.
Model format
Currently, TaxoTagger only supports PyTorch models. If you have a model in another framework, you could convert it to PyTorch before using it with TaxoTagger.
1. Register your custom embedding model in the PRETRAINED_MODELS
dictionary¶
The name of the models are the keys of the dictionary, and the values are the download URLs.
The model names should be unique and should not contain any spaces. The download URLs should be direct download links to the model files. Most importantly, the name of the downloaded file should be the same as the model name, with the extension .pt
.
For example, your custom model is named Example-Transformer
, and the model files should be named Example-Transformer.pt
, then you can add the following entry to the PRETRAINED_MODELS
dictionary:
PRETRAINED_MODELS = {
"MycoAI-CNN": "https://zenodo.org/records/10904344/files/MycoAI-CNN.pt",
"MycoAI-BERT": "https://zenodo.org/records/10904344/files/MycoAI-BERT.pt",
"Example-Transformer": "https://example.com/Example-Transformer.pt"
}
2. Implement a new wrapper class for the embedding model¶
Add a new wrapper class for the embedding model to the taxotagger.models
module (source code file). The wrapper class should inherit from the EmbedModelBase
abstract class and implement the embed
method to calculate the embeddings for the given FASTA file.
Here is an example of a custom embedding model wrapper class:
# For clarity, we omitted the imports and other parts of the code, e.g. docstring
def ExampleTransformerEmbedModel(EmbedModelBase):
name = "Example-Transformer" # (1)!
def __init__(self, config: ProjectConfig) -> None: # (2)!
self._config = config
self.model = load_model(self.name, config)
def embed(self, fasta_file: str) -> dict[str, list[dict[str, Any]]]: # (3)!
# Parse input FASTA file
sequences = read_fasta(fasta_file)
# Calculate embeddings
embeddings = self.model(sequences)
# Return the embeddings
return embeddings
- It's important to set the
name
attribute to the name of the model. - It's recommended to add a constructor to the class to load the model.
- The
embed
method should calculate the embeddings for the given FASTA file and return them as a dictionary. The logic for calculating the embeddings is specific to the model, and you should implement it accordingly.
3. Add the new wrapper class to the ModelFactory.get_model
method¶
After implementing the new wrapper class, you need to add it to the ModelFactory.get_model
method (source code file). This method should return the wrapper class for the given model name.
Here is an example for adding the new wrapper class ExampleTransformerEmbedModel
:
class ModelFactory:
"""Factory class to get the embedding model for the given model identifier."""
@staticmethod
def get_model(model_id: str, config: ProjectConfig) -> EmbedModelBase:
"""Get the embedding model for the given model identifier.
Args:
model_id: The identifier of the model to load.
config: The configurations for the project.
Returns:
The embedding model instance for the given model identifier.
Examples:
>>> config = ProjectConfig()
>>> model = ModelFactory.get_model("MycoAI-CNN", config)
"""
if model_id == "MycoAI-CNN":
return MycoAICNNEmbedModel(config)
elif model_id == "MycoAI-BERT":
return MycoAIBERTEmbedModel(config)
elif model_id == "Example-Transformer": # (1)!
return ExampleTransformerEmbedModel(config) # (2)!
# Add more embedding models here if needed
else:
raise ValueError(
f"Invalid model id {model_id}. Valid models are {PRETRAINED_MODELS.keys()}"
)
- Add the new model name to the
ModelFactory.get_model
method. - Return the new wrapper class for the given model name.
4. Test the custom embedding model¶
Implement a test case for the custom embedding model to ensure that it works correctly. You can add the unit tests to the file test_models.py
.
5. Submit a pull request or build your own version of TaxoTagger¶
You can submit a pull request to the TaxoTagger repository to add your custom embedding model to the tool. Alternatively, you can build your own version of TaxoTagger with the custom embedding model and use it for your projects.