Skip to content
Please note that GitHub no longer supports your web browser.

We recommend upgrading to the latest Google Chrome or Firefox.

Learn more
Permalink
Browse files

feature: data transform + first LSTM trained + profilers

  • Loading branch information
BogdanFloris committed Jan 22, 2020
1 parent e19e723 commit 21e5c9f0154be27c3f6a32ecda164e5e500ff20f
@@ -107,4 +107,7 @@ dmypy.json
# Pyre type checker
.pyre/

# Datasets
assets/datasets/wos_v_1_transformed_BERT_hidden_0.pt

# End of https://www.gitignore.io/api/python
Git LFS file not shown
Git LFS file not shown
@@ -28,11 +28,12 @@ def __init__(self, model: TransformerModel = TransformerModel.BERT):
self.tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
self.model = model_class.from_pretrained(pretrained_weights)

def transform(self, text):
def transform(self, text, hidden_state=0):
"""Transforms the given text using the initialized transformer
Args:
text (str): text to be transformed
hidden_state (int): which hidden state to return
Returns:
the transformed text
@@ -45,6 +46,6 @@ def transform(self, text):
]
)
with torch.no_grad():
last_hidden_states = self.model(input_ids)[0]
last_hidden_states = self.model(input_ids)[hidden_state]

return last_hidden_states

Large diffs are not rendered by default.

Large diffs are not rendered by default.

@@ -65,9 +65,7 @@ def __init__(
)

# Initialize the linear layer
self.fc = nn.Linear(
in_features=self.hidden_size, out_features=self.no_classes,
)
self.fc = nn.Linear(in_features=self.hidden_size, out_features=self.no_classes)

def forward(self, x, x_seq_lengths, hidden=None, cell=None):
""" Forward pass
@@ -114,7 +112,11 @@ def abs_max_pooling(t, dim=1):
# Max over absolute value in the dimension
_, abs_max_i = torch.max(t.abs(), dim=dim)
# Convert indices into one hot vectors
one_hot = f.one_hot(abs_max_i, num_classes=t.size()[dim]).transpose(dim, -1).type(torch.float)
one_hot = (
f.one_hot(abs_max_i, num_classes=t.size()[dim])
.transpose(dim, -1)
.type(torch.float)
)
# Multiply original with one hot to apply mask and then sum over the dimension
return torch.mul(t, one_hot).sum(dim=dim)

@@ -1,30 +1,63 @@
""" Training methods for different models.
"""
import os
import torch
import utils
from pathlib import Path
from torch import nn, optim
from streams.stream_data import WOSStream
from models.wos_classifier import LSTM, LSTMWrapper
from models.wos_classifier import LSTM


def train_wos_batch(epochs=1, lr=0.001, batch_size=utils.BATCH_SIZE, device="cpu"):
PATH = os.path.join(Path(__file__).parents[1], "assets/models")
if not os.path.isdir(PATH):
os.makedirs(PATH)


def train_wos_batch(
epochs=1,
lr=0.001,
batch_size=utils.BATCH_SIZE,
transform=True,
print_every=10,
device="cpu",
):
""" Trains a model using batches of data.
Args:
epochs (int): number of epochs to go over the dataset
lr (float): learning rate of the optimizer
batch_size (int): the batch size
transform (bool): transform the dataset or not
print_every (int): print stats parameter
device (string): the device to run the training on (cpu or gpu)
"""
stream = WOSStream()
# Prepare stream
stream = WOSStream(transform=transform)
stream.prepare_for_use()

# Check for checkpoints and initialize
model = LSTM(embedding_dim=utils.EMBEDDING_DIM, no_classes=stream.no_classes).to(
device
)
optimizer = optim.Adam(model.parameters(), lr=lr)

model_name = "lstm-wos-ver-{}-batch".format(stream.version)
model_path = os.path.join(PATH, model_name)
epoch = 0
if not os.path.exists(os.path.join(model_path, "checkpoint.pt")):
print("Starting training...")
os.makedirs(model_path, exist_ok=True)
else:
print("Resuming training from checkpoint...")
checkpoint = torch.load(os.path.join(model_path, "checkpoint.pt"))
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]

criterion = nn.NLLLoss()

for epoch in range(epochs):
for epoch in range(epoch, epochs):
# Initialize the loss
running_loss = 0
# Start iterating over the dataset
@@ -54,16 +87,39 @@ def train_wos_batch(epochs=1, lr=0.001, batch_size=utils.BATCH_SIZE, device="cpu

# Print statistics
running_loss += loss.item()
print(i, loss.item())
if i % 10 == 9:
if i % print_every == print_every - 1:
# Print every 10 batches
print("[{}, {}] loss: {}".format(epoch + 1, i + 1, running_loss / 10))
print(
"[{}/{} epochs, {}/{} batches] loss: {}".format(
epoch + 1,
epochs,
i + 1,
stream.n_samples // batch_size + 1,
running_loss / print_every,
)
)
running_loss = 0

# Increment i
i += 1

# Save checkpoint
print("Saving checkpoint...")
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
os.path.join(model_path, "checkpoint.pt"),
)
# Restart the stream
stream.restart()

# Save model
print("Finished training. Saving model..")
torch.save(model, os.path.join(model_path, "model.pt"))


def train_wos_stream():
""" Trains a model using a data stream.
@@ -75,4 +131,4 @@ def train_wos_stream():


if __name__ == "__main__":
train_wos_batch()
train_wos_batch(epochs=0, transform=False)
@@ -3,9 +3,10 @@ torch==1.4.0
transformers==2.3.0
numpy==1.18.1
scikit-multiflow==0.4.1
tqdm==4.41.1

# testing requirements
pytest==5.3.2
pytest==5.3.4
flake8==3.7.9
codecov==2.0.15
pytest-cov==2.8.1
@@ -2,37 +2,70 @@
This file contains streams classes that generate different drift_datasets
over time. The streams are based on the sk-multiflow framework.
"""
import os
import torch
import torch.nn as nn
from pathlib import Path
from skmultiflow.data.base_stream import Stream
from constants.transformers import Transformer, TransformerModel
from streams.loaders import load_wos


PATH = os.path.join(Path(__file__).parents[1], "assets/datasets")
TRANSFORMED_DATASETS = [
os.path.join(PATH, "wos_v_1_transformed_BERT_hidden_0.pt"),
]


class WOSStream(Stream):
""" Class that abstracts the Web of Science dataset into a streaming dataset.
There are 3 versions of the dataset (1, 2, 3), each having
increasingly more samples and targets.
When calling `next_sample` this class also transforms the text
to contextualized embeddings based on the given transformer.
Args:
version (int): the version of the dataset (1, 2, 3)
transformer_model (TransformerModel): the transformer model to use
transform (bool): whether to transform the dataset while streaming,
or use an already transformed dataset
dataset_idx (int): if transform is False, then this represents the index in the
TRANSFORMED_DATASETS list of the transformed dataset to use
"""

def __init__(self, version=1, transformer_model=TransformerModel.BERT):
def __init__(
self,
version=1,
transformer_model=TransformerModel.BERT,
transform=True,
dataset_idx=0,
):
super().__init__()
self.version = version
self.X = None
self.y = None
self.n_samples = None
self.no_classes = None
self.current_seq_lengths = None
self.transform = transform
self.dataset_idx = dataset_idx

self.transformer = Transformer(transformer_model)
if transform:
self.transformer = Transformer(transformer_model)

def prepare_for_use(self):
"""Prepares the stream for use by initializing
the X and y variables from the files.
"""
self.X, self.y, self.no_classes = load_wos(version=self.version)
if self.transform:
print("Preparing non-transformed dataset...")
self.X, self.y, self.no_classes = load_wos(version=self.version)
else:
print("Preparing transformed dataset...")
self.X, self.y, self.no_classes = torch.load(
TRANSFORMED_DATASETS[self.dataset_idx]
)
self.n_samples = len(self.y)
self.sample_idx = 0
self.current_sample_x = None
@@ -66,15 +99,20 @@ def next_sample(self, batch_size=1):
self.current_seq_lengths = []
for i in range(len(self.current_sample_x)):
# Transform to embeddings
self.current_sample_x[i] = self.transformer.transform(
self.current_sample_x[i]
).squeeze()
if self.transform:
self.current_sample_x[i] = self.transformer.transform(
self.current_sample_x[i]
)
# Squeeze tensor
self.current_sample_x[i] = self.current_sample_x[i].squeeze()

# Save the sequence length
self.current_seq_lengths.append(self.current_sample_x[i].shape[0])

# Pad and stack sequence
self.current_sample_x = nn.utils.rnn.pad_sequence(self.current_sample_x, batch_first=True)
self.current_sample_x = nn.utils.rnn.pad_sequence(
self.current_sample_x, batch_first=True
)

# Get the y target
self.current_sample_y = self.y[
@@ -106,7 +144,7 @@ def get_no_classes(self):


if __name__ == "__main__":
wos = WOSStream(transformer_model=TransformerModel.SCIBERT)
wos = WOSStream(transformer_model=TransformerModel.SCIBERT, transform=False)
wos.prepare_for_use()
x, y, _ = wos.next_sample(8)
print(x.shape)
@@ -0,0 +1,48 @@
"""
Transforms a dataset using a transformer and saves it
"""
import os
import torch
from pathlib import Path
from tqdm import tqdm
from streams.loaders import load_wos
from constants.transformers import TransformerModel, Transformer


SAVE_DIR = os.path.join(Path(__file__).parents[1], "assets/datasets")


def transform_wos(version=1, transformer_model=TransformerModel.BERT, hidden_state=0):
""" Transformers the given version of the Web of Science
dataset using the given transformer model and saves it
as a tuple (x, y, no_classes).
Args:
version (int): the WOS version
transformer_model (TransformerModel): the transformer model to use
hidden_state (int): which hidden state to get from the transformer
"""
print("Transforming dataset...")
transformer = Transformer(transformer_model)
x, y, no_classes = load_wos(version)
transformed_x = []

for i in tqdm(range(len(x))):
transformed_x.append(transformer.transform(x[i], hidden_state))

print("Saving dataset...")
f = open(
os.path.join(
SAVE_DIR,
"wos_v_{}_transformed_{}_hidden_{}.pt".format(
version, transformer_model.name, hidden_state
),
),
"wb",
)
torch.save((transformed_x, y, no_classes), f)
f.close()


if __name__ == "__main__":
transform_wos()

0 comments on commit 21e5c9f

Please sign in to comment.
You can’t perform that action at this time.