KG-Baseline¶
KG baseline model on the fine-tuning classification task, assuming the model embeddings are pre-trained.
Run with: python -m src.stonkgs.models.kg_baseline_model
- class INDRAEntityDataset(embedding_dict, random_walk_dict, sources, targets, labels, max_len=254)[source]¶
Custom dataset class for INDRA data.
Initialize INDRA Dataset based on random walk embeddings for 2 nodes in each triple.
- class KGEClassificationModel(num_classes, class_weights, d_in=768, lr=0.001)[source]¶
KGE baseline model.
Initialize the components of the KGE based classification model.
- Parameters
The model consists of 1) “Max-Pooling” (embedding-dimension-wise max) 2) Dropout 3) Linear layer (d_in x num_classes) 4) Softmax (Not part of the model, but of the class: class_weights for the cross_entropy function)
- forward(x)[source]¶
Perform forward pass consisting of pooling (dimension-wise max), and a linear layer followed by softmax.
- Parameters
x – embedding sequences (random walk embeddings) for the given triples
- Returns
class probabilities for the given triples
- test_epoch_end(outputs)[source]¶
Return average and std weighted-averaged f1-score over all batches.
- get_train_test_splits(data, label_column_name='class', random_seed=42, n_splits=5, max_dataset_size=100000)[source]¶
Return deterministic train/test indices for n_splits based on the fine-tuning dataset that is passed.
- Return type
- run_kg_baseline_classification_cv(triples_path, embedding_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/stable/models/kg-hpo/embeddings_best_model.tsv', random_walks_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/stable/models/kg-hpo/random_walks_best_model.tsv', logging_uri_mlflow=None, n_splits=5, epochs=100, train_batch_size=8, test_batch_size=64, lr=0.001, label_column_name='class', log_steps=500, task_name='', max_dataset_size=100000)[source]¶
Run KG baseline classification.
Getting the node embeddings via node2vec
Node2vec model.
Run with: python -m src.stonkgs.models.node2vec
- run_link_prediction(kg, model)[source]¶
Link prediction task for a given KG and node2vec model.
- Return type