pyTorch 导入预训练词向量 2023

发布时间 2023-03-24 14:52:33作者: ---dgw博客

# 测试 Embedding
import torch
import gensim
import torch.nn as nn
wvmodel = gensim.models.KeyedVectors.load_word2vec_format("./data/vector.txt",binary=False,encoding='utf-8')
# 需要在字典的位置加上1是需要给UNK添加一个位置
vocab_size=len(wvmodel)+1
vector_size=wvmodel.vector_size
# 随机生成weight
weight = torch.randn(vocab_size, vector_size)

words=wvmodel.key_to_index

word_to_idx = {word: i+1 for i, word in enumerate(words)}
# 定义了一个unknown的词.
word_to_idx['<unk>'] = 0
idx_to_word = {i+1: word for i, word in enumerate(words)}
idx_to_word[0] = '<unk>'

for i in range(len(wvmodel.index_to_key)):
    try:
        index = word_to_idx[wvmodel.index_to_key[i]]
    except:
        continue
    vector=wvmodel.get_vector(idx_to_word[word_to_idx[wvmodel.index_to_key[i]]])
    weight[index, :] = torch.from_numpy(vector)

embedding = nn.Embedding.from_pretrained(weight,freeze=True)
embedding