mirror of
https://github.com/apachecn/ailearning.git
synced 2026-04-05 11:40:08 +08:00
修改命名实体识别的损失函数和评估函数
This commit is contained in:
@@ -3,9 +3,12 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import platform
|
||||
from collections import Counter
|
||||
import keras
|
||||
from keras.models import Sequential
|
||||
from keras.layers import Embedding, Bidirectional, LSTM
|
||||
from keras.layers import Embedding, Bidirectional, LSTM, Dropout
|
||||
from keras_contrib.layers import CRF
|
||||
from keras_contrib.losses import crf_loss
|
||||
from keras_contrib.metrics import crf_viterbi_accuracy
|
||||
"""
|
||||
# padding: pre(默认) 向前补充0 post 向后补充0
|
||||
# truncating: 文本超过 pad_num, pre(默认) 删除前面 post 删除后面
|
||||
@@ -31,7 +34,7 @@ def load_data():
|
||||
# Counter({'的': 8, '中': 7, '致': 7, '党': 7})
|
||||
word_counts = Counter(row[0].lower() for sample in train for row in sample)
|
||||
vocab = [w for w, f in iter(word_counts.items()) if f >= 2]
|
||||
chunk_tags = ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', "B-ORG", "I-ORG"]
|
||||
chunk_tags = Config.nlp_ner.chunk_tags
|
||||
|
||||
# 存储保留的有效个数的 vovab 和 对应 chunk_tags
|
||||
with open(Config.nlp_ner.path_config, 'wb') as outp:
|
||||
@@ -57,7 +60,10 @@ def _parse_data(filename):
|
||||
# 主要是分句: split_text 默认每个句子都是一行,所以原来换行就需要 两个split_text
|
||||
texts = fn.read().decode('utf-8').strip().split(split_text + split_text)
|
||||
# 对于每个字需要 split_text, 而字的内部需要用空格分隔
|
||||
data = [[row.split() for row in text.split(split_text)] for text in texts]
|
||||
# len(row) > 0 避免连续2个换行,导致 row 数据为空
|
||||
# row.split() 会删除空格或特殊符号,导致空格数据缺失!
|
||||
data = [[[" ", "O"] if len(row.split()) != 2 else row.split() for row in text.split(split_text) if len(row) > 0] for text in texts]
|
||||
# data = [[row.split() for row in text.split(split_text) if len(row.split()) == 2] for text in texts]
|
||||
return data
|
||||
|
||||
|
||||
@@ -96,10 +102,17 @@ def create_model(len_vocab, len_chunk_tags):
|
||||
model = Sequential()
|
||||
model.add(Embedding(len_vocab, Config.nlp_ner.EMBED_DIM, mask_zero=True)) # Random embedding
|
||||
model.add(Bidirectional(LSTM(Config.nlp_ner.BiLSTM_UNITS // 2, return_sequences=True)))
|
||||
model.add(Dropout(0.25))
|
||||
crf = CRF(len_chunk_tags, sparse_target=True)
|
||||
model.add(crf)
|
||||
model.summary()
|
||||
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
|
||||
model.compile('adam', loss=crf_loss, metrics=[crf_viterbi_accuracy])
|
||||
# model.compile('rmsprop', loss=crf_loss, metrics=[crf_viterbi_accuracy])
|
||||
|
||||
# from keras.optimizers import Adam
|
||||
# adam_lr = 0.0001
|
||||
# adam_beta_1 = 0.5
|
||||
# model.compile(optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1), loss=crf_loss, metrics=[crf_viterbi_accuracy])
|
||||
return model
|
||||
|
||||
|
||||
@@ -115,29 +128,38 @@ def test():
|
||||
with open(Config.nlp_ner.path_config, 'rb') as inp:
|
||||
(vocab, chunk_tags) = pickle.load(inp)
|
||||
model = create_model(len(vocab), len(chunk_tags))
|
||||
predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下,连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚'
|
||||
text_EMBED, length = process_data(predict_text, vocab)
|
||||
model.load_weights(Config.nlp_ner.path_model)
|
||||
raw = model.predict(text_EMBED)[0][-length:]
|
||||
result = [np.argmax(row) for row in raw]
|
||||
result_tags = [chunk_tags[i] for i in result]
|
||||
# predict_text = '造型独特,尺码偏大,估计是钉子头圆的半径的缘故'
|
||||
with open(Config.nlp_ner.path_origin, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for predict_text in lines:
|
||||
content = predict_text.strip()
|
||||
text_EMBED, length = process_data(content, vocab)
|
||||
model.load_weights(Config.nlp_ner.path_model)
|
||||
raw = model.predict(text_EMBED)[0][-length:]
|
||||
pre_result = [np.argmax(row) for row in raw]
|
||||
result_tags = [chunk_tags[i] for i in pre_result]
|
||||
|
||||
per, loc, org = '', '', ''
|
||||
|
||||
for s, t in zip(predict_text, result_tags):
|
||||
if t in ('B-PER', 'I-PER'):
|
||||
per += ' ' + s if (t == 'B-PER') else s
|
||||
if t in ('B-ORG', 'I-ORG'):
|
||||
org += ' ' + s if (t == 'B-ORG') else s
|
||||
if t in ('B-LOC', 'I-LOC'):
|
||||
loc += ' ' + s if (t == 'B-LOC') else s
|
||||
|
||||
print(['person:' + per, 'location:' + loc, 'organzation:' + org])
|
||||
# 保存每句话的 实体和观点
|
||||
result = {}
|
||||
tag_list = [i for i in chunk_tags if i not in ["O"]]
|
||||
for word, t in zip(content, result_tags):
|
||||
# print(word, t)
|
||||
if t not in tag_list:
|
||||
continue
|
||||
for i in range(0, len(tag_list), 2):
|
||||
if t in tag_list[i:i+2]:
|
||||
# print("\n>>> %s---%s==%s" % (word, t, tag_list[i:i+2]))
|
||||
tag = tag_list[i].split("-")[-1]
|
||||
if tag not in result:
|
||||
result[tag] = ""
|
||||
result[tag] += ' '+word if t==tag_list[i] else word
|
||||
print(result)
|
||||
|
||||
|
||||
def main():
|
||||
# print("--")
|
||||
train()
|
||||
test()
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# train()
|
||||
|
||||
Reference in New Issue
Block a user