Permalink
Please
sign in to comment.
Showing
with
1,420 additions
and 23 deletions.
- +3 −0 .gitignore
- +3 −0 assets/models/lstm-wos-ver-1-batch/checkpoint.pt
- +3 −0 assets/models/lstm-wos-ver-1-batch/model.pt
- +3 −2 constants/transformers.py
- +1,005 −0 docs/profiler_with_transform.txt
- +237 −0 docs/profiler_without_transform.txt
- +6 −4 models/wos_classifier.py
- +64 −8 models/wos_train.py
- +2 −1 requirements.txt
- +46 −8 streams/stream_data.py
- +48 −0 streams/transform_data.py
Git LFS file not shown
Git LFS file not shown
@@ -0,0 +1,48 @@ | ||
""" | ||
Transforms a dataset using a transformer and saves it | ||
""" | ||
import os | ||
import torch | ||
from pathlib import Path | ||
from tqdm import tqdm | ||
from streams.loaders import load_wos | ||
from constants.transformers import TransformerModel, Transformer | ||
|
||
|
||
SAVE_DIR = os.path.join(Path(__file__).parents[1], "assets/datasets") | ||
|
||
|
||
def transform_wos(version=1, transformer_model=TransformerModel.BERT, hidden_state=0): | ||
""" Transformers the given version of the Web of Science | ||
dataset using the given transformer model and saves it | ||
as a tuple (x, y, no_classes). | ||
Args: | ||
version (int): the WOS version | ||
transformer_model (TransformerModel): the transformer model to use | ||
hidden_state (int): which hidden state to get from the transformer | ||
""" | ||
print("Transforming dataset...") | ||
transformer = Transformer(transformer_model) | ||
x, y, no_classes = load_wos(version) | ||
transformed_x = [] | ||
|
||
for i in tqdm(range(len(x))): | ||
transformed_x.append(transformer.transform(x[i], hidden_state)) | ||
|
||
print("Saving dataset...") | ||
f = open( | ||
os.path.join( | ||
SAVE_DIR, | ||
"wos_v_{}_transformed_{}_hidden_{}.pt".format( | ||
version, transformer_model.name, hidden_state | ||
), | ||
), | ||
"wb", | ||
) | ||
torch.save((transformed_x, y, no_classes), f) | ||
f.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
transform_wos() |
0 comments on commit
21e5c9f