Training
STonKGs pre-training
Script for running the pre-training procedure of STonKGs.
STonKGs fine-tuning
Runs the STonKGs model on the fine-tuning classification task, assuming the model embeddings are pre-trained.
Run with: python -m src.stonkgs.models.stonkgs_finetuning
- class INDRADataset(encodings, labels)[source]
Custom Dataset class for INDRA data containing the combination of text and KG triple data.
Initialize INDRA Dataset based on the combined input sequence consisting of text and triple data.
- class STonKGsForSequenceClassification(config, **kwargs)[source]
Create the fine-tuning part of the STonKGs model based the pre-trained STonKGs model.
Note that this class inherits from STonKGsForPreTraining rather than PreTrainedModel, thereby it’s deviating from the typical huggingface inheritance logic of the fine-tuning classes.
Initialize the STonKGs sequence classification model based on the pre-trained STonKGs model architecture.
- get_train_test_splits(train_data, type_column_name='labels', random_seed=42, n_splits=5, max_dataset_size=100000)[source]
Return train/test indices for n_splits many splits based on the fine-tuning dataset that is passed.
- Return type:
- preprocess_fine_tuning_data(train_data_path, class_column_name='class', embedding_name_to_vector_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/kg-hpo/embeddings_best_model.tsv', embedding_name_to_random_walk_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/kg-hpo/random_walks_best_model.tsv', nlp_model_type='dmis-lab/biobert-v1.1', sep_id=102, unk_id=100)[source]
Generate input_ids, attention_mask, token_type_ids etc. based on the source, target, evidence columns.
- Return type:
DataFrame
- run_sequence_classification_cv(train_data_path, model_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/stonkgs-pretraining/pretrained-stonkgs', output_dir='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/stonkgs', logging_uri_mlflow=None, label_column_name='labels', class_column_name='class', epochs=10, log_steps=500, lr=5e-05, batch_size=8, gradient_accumulation=1, task_name='', deepspeed=True, max_dataset_size=100000, cv=5)[source]
Run cross-validation for the sequence classification task(s) using STonKGs.
- Return type: