Training

STonKGs pre-training

Script for running the pre-training procedure of STonKGs.

STonKGs fine-tuning

Runs the STonKGs model on the fine-tuning classification task, assuming the model embeddings are pre-trained.

Run with: python -m src.stonkgs.models.stonkgs_finetuning

class INDRADataset(encodings, labels)[source]

Custom Dataset class for INDRA data containing the combination of text and KG triple data.

Initialize INDRA Dataset based on the combined input sequence consisting of text and triple data.

class STonKGsForSequenceClassification(config, **kwargs)[source]

Create the fine-tuning part of the STonKGs model based the pre-trained STonKGs model.

Note that this class inherits from STonKGsForPreTraining rather than PreTrainedModel, thereby it’s deviating from the typical huggingface inheritance logic of the fine-tuning classes.

Initialize the STonKGs sequence classification model based on the pre-trained STonKGs model architecture.

forward(input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None)[source]

Perform one forward pass for a given sequence of text_input_ids + ent_input_ids.

get_train_test_splits(train_data, type_column_name='labels', random_seed=42, n_splits=5, max_dataset_size=100000)[source]

Return train/test indices for n_splits many splits based on the fine-tuning dataset that is passed.

Return type:

List

preprocess_fine_tuning_data(train_data_path, class_column_name='class', embedding_name_to_vector_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/kg-hpo/embeddings_best_model.tsv', embedding_name_to_random_walk_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/kg-hpo/random_walks_best_model.tsv', nlp_model_type='dmis-lab/biobert-v1.1', sep_id=102, unk_id=100)[source]

Generate input_ids, attention_mask, token_type_ids etc. based on the source, target, evidence columns.

Return type:

DataFrame

run_sequence_classification_cv(train_data_path, model_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/stonkgs-pretraining/pretrained-stonkgs', output_dir='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/stonkgs', logging_uri_mlflow=None, label_column_name='labels', class_column_name='class', epochs=10, log_steps=500, lr=5e-05, batch_size=8, gradient_accumulation=1, task_name='', deepspeed=True, max_dataset_size=100000, cv=5)[source]

Run cross-validation for the sequence classification task(s) using STonKGs.

Return type:

Dict