修改命名实体识别的损失函数和评估函数

This commit is contained in:
jiangzhonglian
2020-09-04 18:51:34 +08:00
parent fcaaf66052
commit a388f310bf

View File

@@ -3,9 +3,12 @@ import numpy as np
import pandas as pd
import platform
from collections import Counter
import keras
from keras.models import Sequential
from keras.layers import Embedding, Bidirectional, LSTM
from keras.layers import Embedding, Bidirectional, LSTM, Dropout
from keras_contrib.layers import CRF
from keras_contrib.losses import crf_loss
from keras_contrib.metrics import crf_viterbi_accuracy
"""
# padding: pre(默认) 向前补充0 post 向后补充0
# truncating: 文本超过 pad_num, pre(默认) 删除前面 post 删除后面
@@ -31,7 +34,7 @@ def load_data():
# Counter({'的': 8, '中': 7, '致': 7, '党': 7})
word_counts = Counter(row[0].lower() for sample in train for row in sample)
vocab = [w for w, f in iter(word_counts.items()) if f >= 2]
chunk_tags = ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', "B-ORG", "I-ORG"]
chunk_tags = Config.nlp_ner.chunk_tags
# 存储保留的有效个数的 vovab 和 对应 chunk_tags
with open(Config.nlp_ner.path_config, 'wb') as outp:
@@ -57,7 +60,10 @@ def _parse_data(filename):
# 主要是分句: split_text 默认每个句子都是一行,所以原来换行就需要 两个split_text
texts = fn.read().decode('utf-8').strip().split(split_text + split_text)
# 对于每个字需要 split_text, 而字的内部需要用空格分隔
data = [[row.split() for row in text.split(split_text)] for text in texts]
# len(row) > 0 避免连续2个换行导致 row 数据为空
# row.split() 会删除空格或特殊符号,导致空格数据缺失!
data = [[[" ", "O"] if len(row.split()) != 2 else row.split() for row in text.split(split_text) if len(row) > 0] for text in texts]
# data = [[row.split() for row in text.split(split_text) if len(row.split()) == 2] for text in texts]
return data
@@ -96,10 +102,17 @@ def create_model(len_vocab, len_chunk_tags):
model = Sequential()
model.add(Embedding(len_vocab, Config.nlp_ner.EMBED_DIM, mask_zero=True)) # Random embedding
model.add(Bidirectional(LSTM(Config.nlp_ner.BiLSTM_UNITS // 2, return_sequences=True)))
model.add(Dropout(0.25))
crf = CRF(len_chunk_tags, sparse_target=True)
model.add(crf)
model.summary()
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
model.compile('adam', loss=crf_loss, metrics=[crf_viterbi_accuracy])
# model.compile('rmsprop', loss=crf_loss, metrics=[crf_viterbi_accuracy])
# from keras.optimizers import Adam
# adam_lr = 0.0001
# adam_beta_1 = 0.5
# model.compile(optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1), loss=crf_loss, metrics=[crf_viterbi_accuracy])
return model
@@ -115,29 +128,38 @@ def test():
with open(Config.nlp_ner.path_config, 'rb') as inp:
(vocab, chunk_tags) = pickle.load(inp)
model = create_model(len(vocab), len(chunk_tags))
predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚'
text_EMBED, length = process_data(predict_text, vocab)
model.load_weights(Config.nlp_ner.path_model)
raw = model.predict(text_EMBED)[0][-length:]
result = [np.argmax(row) for row in raw]
result_tags = [chunk_tags[i] for i in result]
# predict_text = '造型独特,尺码偏大,估计是钉子头圆的半径的缘故'
with open(Config.nlp_ner.path_origin, "r", encoding="utf-8") as f:
lines = f.readlines()
for predict_text in lines:
content = predict_text.strip()
text_EMBED, length = process_data(content, vocab)
model.load_weights(Config.nlp_ner.path_model)
raw = model.predict(text_EMBED)[0][-length:]
pre_result = [np.argmax(row) for row in raw]
result_tags = [chunk_tags[i] for i in pre_result]
per, loc, org = '', '', ''
for s, t in zip(predict_text, result_tags):
if t in ('B-PER', 'I-PER'):
per += ' ' + s if (t == 'B-PER') else s
if t in ('B-ORG', 'I-ORG'):
org += ' ' + s if (t == 'B-ORG') else s
if t in ('B-LOC', 'I-LOC'):
loc += ' ' + s if (t == 'B-LOC') else s
print(['person:' + per, 'location:' + loc, 'organzation:' + org])
# 保存每句话的 实体和观点
result = {}
tag_list = [i for i in chunk_tags if i not in ["O"]]
for word, t in zip(content, result_tags):
# print(word, t)
if t not in tag_list:
continue
for i in range(0, len(tag_list), 2):
if t in tag_list[i:i+2]:
# print("\n>>> %s---%s==%s" % (word, t, tag_list[i:i+2]))
tag = tag_list[i].split("-")[-1]
if tag not in result:
result[tag] = ""
result[tag] += ' '+word if t==tag_list[i] else word
print(result)
def main():
# print("--")
train()
test()
# if __name__ == "__main__":
# train()