Skip to content
Please note that GitHub no longer supports your web browser.

We recommend upgrading to the latest Google Chrome or Firefox.

Learn more
Permalink
Browse files

feature: SCIBERT transform

  • Loading branch information
BogdanFloris committed Jan 22, 2020
1 parent 21e5c9f commit efbacef8c398b4307243489c19b6fb46c4b97134
@@ -109,5 +109,6 @@ dmypy.json

# Datasets
assets/datasets/wos_v_1_transformed_BERT_hidden_0.pt
assets/datasets/wos_v_1_transformed_SCIBERT_hidden_0.pt

# End of https://www.gitignore.io/api/python
File renamed without changes.
@@ -7,6 +7,7 @@
from torch import nn, optim
from streams.stream_data import WOSStream
from models.wos_classifier import LSTM
from constants.transformers import TransformerModel


PATH = os.path.join(Path(__file__).parents[1], "assets/models")
@@ -19,6 +20,7 @@ def train_wos_batch(
lr=0.001,
batch_size=utils.BATCH_SIZE,
transform=True,
transformer_model=TransformerModel.BERT,
print_every=10,
device="cpu",
):
@@ -29,11 +31,12 @@ def train_wos_batch(
lr (float): learning rate of the optimizer
batch_size (int): the batch size
transform (bool): transform the dataset or not
transformer_model (TransformerModel): the transformer model to use
print_every (int): print stats parameter
device (string): the device to run the training on (cpu or gpu)
"""
# Prepare stream
stream = WOSStream(transform=transform)
stream = WOSStream(transformer_model=transformer_model, transform=transform)
stream.prepare_for_use()

# Check for checkpoints and initialize
@@ -42,7 +45,7 @@ def train_wos_batch(
)
optimizer = optim.Adam(model.parameters(), lr=lr)

model_name = "lstm-wos-ver-{}-batch".format(stream.version)
model_name = "lstm-wos-{}-ver-{}-batch".format(transformer_model.name, stream.version)
model_path = os.path.join(PATH, model_name)
epoch = 0
if not os.path.exists(os.path.join(model_path, "checkpoint.pt")):
@@ -14,6 +14,7 @@
PATH = os.path.join(Path(__file__).parents[1], "assets/datasets")
TRANSFORMED_DATASETS = [
os.path.join(PATH, "wos_v_1_transformed_BERT_hidden_0.pt"),
os.path.join(PATH, "wos_v_1_transformed_SCIBERT_hidden_0.pt")
]


@@ -30,16 +31,13 @@ class WOSStream(Stream):
transformer_model (TransformerModel): the transformer model to use
transform (bool): whether to transform the dataset while streaming,
or use an already transformed dataset
dataset_idx (int): if transform is False, then this represents the index in the
TRANSFORMED_DATASETS list of the transformed dataset to use
"""

def __init__(
self,
version=1,
transformer_model=TransformerModel.BERT,
transform=True,
dataset_idx=0,
):
super().__init__()
self.version = version
@@ -49,7 +47,13 @@ def __init__(
self.no_classes = None
self.current_seq_lengths = None
self.transform = transform
self.dataset_idx = dataset_idx

if transformer_model == TransformerModel.BERT:
self.dataset_idx = 0
elif transformer_model == TransformerModel.SCIBERT:
self.dataset_idx = 1
else:
self.dataset_idx = 2

if transform:
self.transformer = Transformer(transformer_model)
@@ -45,4 +45,4 @@ def transform_wos(version=1, transformer_model=TransformerModel.BERT, hidden_sta


if __name__ == "__main__":
transform_wos()
transform_wos(transformer_model=TransformerModel.SCIBERT)

0 comments on commit efbacef

Please sign in to comment.
You can’t perform that action at this time.