STonKGs
STonKGs model architecture components.
- class BertForPreTrainingOutputWithPooling(loss=None, prediction_logits=None, seq_relationship_logits=None, hidden_states=None, attentions=None, pooler_output=None)[source]
Overriding the BertForPreTrainingOutput class to further include the pooled output.
- class STonKGsELMPredictionHead(config)[source]
Custom masked entity and language modeling (ELM) head used to predict both entities and text tokens.
Initialize the ELM head based on the (hyper)parameters in the provided BertConfig.
- class STonKGsForPreTraining(config, nlp_model_type='dmis-lab/biobert-v1.1', kg_embedding_dict_path='/home/docs/checkouts/readthedocs.org/user_builds/stonkgs/checkouts/latest/models/kg-hpo/embeddings_best_model.tsv')[source]
Create the pre-training part of the STonKGs model based on both text and entity embeddings.
Initialize the model architecture components of STonKGs.
- Parameters:
config – Required for automated methods such as .from_pretrained in classes that inherit from this one
nlp_model_type (
str
) – Model type used to initialize the LM backbone for the text embeddingskg_embedding_dict_path (
str
) – Path specification for the stored node2vec embeddings used for the KG backbone
- forward(input_ids=None, attention_mask=None, token_type_ids=None, masked_lm_labels=None, ent_masked_lm_labels=None, next_sentence_labels=None, return_dict=None, head_mask=None)[source]
Perform one forward pass for a given sequence of text_input_ids + ent_input_ids.
- Parameters:
input_ids – Concatenation of text + KG (random walk) embeddings
attention_mask – Attention mask of the combined input sequence
token_type_ids – Token type IDs of the combined input sequence
masked_lm_labels – Masked LM labels for only the text part
ent_masked_lm_labels – Masked ELM labels for only the KG part
next_sentence_labels – NSP labels (per sequence)
return_dict – Whether the output should be returned as a dict or not
head_mask – Used to cancel out certain heads in the Transformer
- Returns:
Loss, prediction_logits in a BertForPreTrainingOutputWithPooling format