|
@@ -7,6 +7,7 @@ |
|
|
from torch import nn, optim |
|
|
from streams.stream_data import WOSStream |
|
|
from models.wos_classifier import LSTM |
|
|
from constants.transformers import TransformerModel |
|
|
|
|
|
|
|
|
PATH = os.path.join(Path(__file__).parents[1], "assets/models") |
|
@@ -19,6 +20,7 @@ def train_wos_batch( |
|
|
lr=0.001, |
|
|
batch_size=utils.BATCH_SIZE, |
|
|
transform=True, |
|
|
transformer_model=TransformerModel.BERT, |
|
|
print_every=10, |
|
|
device="cpu", |
|
|
): |
|
@@ -29,11 +31,12 @@ def train_wos_batch( |
|
|
lr (float): learning rate of the optimizer |
|
|
batch_size (int): the batch size |
|
|
transform (bool): transform the dataset or not |
|
|
transformer_model (TransformerModel): the transformer model to use |
|
|
print_every (int): print stats parameter |
|
|
device (string): the device to run the training on (cpu or gpu) |
|
|
""" |
|
|
# Prepare stream |
|
|
stream = WOSStream(transform=transform) |
|
|
stream = WOSStream(transformer_model=transformer_model, transform=transform) |
|
|
stream.prepare_for_use() |
|
|
|
|
|
# Check for checkpoints and initialize |
|
@@ -42,7 +45,7 @@ def train_wos_batch( |
|
|
) |
|
|
optimizer = optim.Adam(model.parameters(), lr=lr) |
|
|
|
|
|
model_name = "lstm-wos-ver-{}-batch".format(stream.version) |
|
|
model_name = "lstm-wos-{}-ver-{}-batch".format(transformer_model.name, stream.version) |
|
|
model_path = os.path.join(PATH, model_name) |
|
|
epoch = 0 |
|
|
if not os.path.exists(os.path.join(model_path, "checkpoint.pt")): |
|
|
0 comments on commit
efbacef