STonKGs

STonKGs model architecture components.

class BertForPreTrainingOutputWithPooling(loss=None, prediction_logits=None, seq_relationship_logits=None, hidden_states=None, attentions=None, pooler_output=None)[source]

Overriding the BertForPreTrainingOutput class to further include the pooled output.

class STonKGsELMPredictionHead(config)[source]

Custom masked entity and language modeling (ELM) head used to predict both entities and text tokens.

Initialize the ELM head based on the (hyper)parameters in the provided BertConfig.

forward(hidden_states)[source]

Map hidden states to values for the text vocab (first half) and kg vocab (second half).

class STonKGsForPreTraining(config, nlp_model_type='dmis-lab/biobert-v1.1', kg_embedding_dict_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/kg-hpo/embeddings_best_model.tsv')[source]

Create the pre-training part of the STonKGs model based on both text and entity embeddings.

Initialize the model architecture components of STonKGs.

Parameters:
  • config – Required for automated methods such as .from_pretrained in classes that inherit from this one

  • nlp_model_type (str) – Model type used to initialize the LM backbone for the text embeddings

  • kg_embedding_dict_path (str) – Path specification for the stored node2vec embeddings used for the KG backbone

forward(input_ids=None, attention_mask=None, token_type_ids=None, masked_lm_labels=None, ent_masked_lm_labels=None, next_sentence_labels=None, return_dict=None, head_mask=None)[source]

Perform one forward pass for a given sequence of text_input_ids + ent_input_ids.

Parameters:
  • input_ids – Concatenation of text + KG (random walk) embeddings

  • attention_mask – Attention mask of the combined input sequence

  • token_type_ids – Token type IDs of the combined input sequence

  • masked_lm_labels – Masked LM labels for only the text part

  • ent_masked_lm_labels – Masked ELM labels for only the KG part

  • next_sentence_labels – NSP labels (per sequence)

  • return_dict – Whether the output should be returned as a dict or not

  • head_mask – Used to cancel out certain heads in the Transformer

Returns:

Loss, prediction_logits in a BertForPreTrainingOutputWithPooling format

classmethod from_default_pretrained(cls, **kwargs)[source]

Get the default pre-trained STonKGs model.

Return type:

STonKGsForPreTraining