Permalink
Browse files
feature: DISTILBERT transform
- Loading branch information
Showing
with
8 additions
and
7 deletions.
-
+1
−0
.gitignore
-
+3
−1
models/wos_train.py
-
+3
−5
streams/stream_data.py
-
+1
−1
streams/transform_data.py
|
@@ -110,5 +110,6 @@ dmypy.json |
|
|
# Datasets |
|
|
assets/datasets/wos_v_1_transformed_BERT_hidden_0.pt |
|
|
assets/datasets/wos_v_1_transformed_SCIBERT_hidden_0.pt |
|
|
assets/datasets/wos_v_1_transformed_DISTILBERT_hidden_0.pt |
|
|
|
|
|
# End of https://www.gitignore.io/api/python |
|
@@ -45,7 +45,9 @@ def train_wos_batch( |
|
|
) |
|
|
optimizer = optim.Adam(model.parameters(), lr=lr) |
|
|
|
|
|
model_name = "lstm-wos-{}-ver-{}-batch".format(transformer_model.name, stream.version) |
|
|
model_name = "lstm-wos-{}-ver-{}-batch".format( |
|
|
transformer_model.name, stream.version |
|
|
) |
|
|
model_path = os.path.join(PATH, model_name) |
|
|
epoch = 0 |
|
|
if not os.path.exists(os.path.join(model_path, "checkpoint.pt")): |
|
|
|
@@ -14,7 +14,8 @@ |
|
|
PATH = os.path.join(Path(__file__).parents[1], "assets/datasets") |
|
|
TRANSFORMED_DATASETS = [ |
|
|
os.path.join(PATH, "wos_v_1_transformed_BERT_hidden_0.pt"), |
|
|
os.path.join(PATH, "wos_v_1_transformed_SCIBERT_hidden_0.pt") |
|
|
os.path.join(PATH, "wos_v_1_transformed_SCIBERT_hidden_0.pt"), |
|
|
os.path.join(PATH, "wos_v_1_transformed_DISTILBERT_hidden_0.pt"), |
|
|
] |
|
|
|
|
|
|
|
@@ -34,10 +35,7 @@ class WOSStream(Stream): |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
version=1, |
|
|
transformer_model=TransformerModel.BERT, |
|
|
transform=True, |
|
|
self, version=1, transformer_model=TransformerModel.BERT, transform=True, |
|
|
): |
|
|
super().__init__() |
|
|
self.version = version |
|
|
|
@@ -45,4 +45,4 @@ def transform_wos(version=1, transformer_model=TransformerModel.BERT, hidden_sta |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
transform_wos(transformer_model=TransformerModel.SCIBERT) |
|
|
transform_wos(transformer_model=TransformerModel.DISTILBERT) |
0 comments on commit
22cefbe