KG-Baseline
KG baseline model on the fine-tuning classification task, assuming the model embeddings are pre-trained.
Run with: python -m src.stonkgs.models.kg_baseline_model
- class KGEClassificationModel(num_classes, class_weights, d_in=768, lr=0.001)[source]
KGE baseline model.
Initialize the components of the KGE based classification model.
- Parameters:
The model consists of 1) “Max-Pooling” (embedding-dimension-wise max) 2) Dropout 3) Linear layer (d_in x num_classes) 4) Softmax (Not part of the model, but of the class: class_weights for the cross_entropy function)
- forward(x)[source]
Perform forward pass consisting of pooling (dimension-wise max), and a linear layer followed by softmax.
- Parameters:
x – embedding sequences (random walk embeddings) for the given triples
- Returns:
class probabilities for the given triples
- test_epoch_end(outputs)[source]
Return average and std weighted-averaged f1-score over all batches.
- class Node2VecINDRAEntityDataset(embedding_dict, random_walk_dict, sources, targets, labels, max_len=254)[source]
Custom dataset class for Node2vec-based INDRA data.
Initialize INDRA Dataset based on random walk embeddings for 2 nodes in each triple.
- class TransEINDRAEntityDataset(embedding_dict, sources, relations, targets, labels)[source]
Custom dataset class for TransE-based INDRA data.
Initialize INDRA Dataset based on h,r,t embeddings in each triple.
- get_train_test_splits(data, label_column_name='class', random_seed=42, n_splits=5, max_dataset_size=100000)[source]
Return deterministic train/test indices for n_splits based on the fine-tuning dataset that is passed.
- Return type:
- run_kg_baseline_classification_cv(triples_path, embedding_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/kg-hpo/embeddings_best_model.tsv', random_walks_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/kg-hpo/random_walks_best_model.tsv', logging_uri_mlflow=None, n_splits=5, epochs=100, train_batch_size=8, test_batch_size=64, lr=0.001, label_column_name='class', log_steps=500, task_name='', max_dataset_size=100000, model_variant='node2vec')[source]
Run KG baseline classification.
Getting the node embeddings via node2vec