diff --git a/tutorials/keras/text_NER.py b/tutorials/keras/text_NER.py index 6ffa7c95..7f4f1cbf 100644 --- a/tutorials/keras/text_NER.py +++ b/tutorials/keras/text_NER.py @@ -3,9 +3,12 @@ import numpy as np import pandas as pd import platform from collections import Counter +import keras from keras.models import Sequential -from keras.layers import Embedding, Bidirectional, LSTM +from keras.layers import Embedding, Bidirectional, LSTM, Dropout from keras_contrib.layers import CRF +from keras_contrib.losses import crf_loss +from keras_contrib.metrics import crf_viterbi_accuracy """ # padding: pre(默认) 向前补充0 post 向后补充0 # truncating: 文本超过 pad_num, pre(默认) 删除前面 post 删除后面 @@ -31,7 +34,7 @@ def load_data(): # Counter({'的': 8, '中': 7, '致': 7, '党': 7}) word_counts = Counter(row[0].lower() for sample in train for row in sample) vocab = [w for w, f in iter(word_counts.items()) if f >= 2] - chunk_tags = ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', "B-ORG", "I-ORG"] + chunk_tags = Config.nlp_ner.chunk_tags # 存储保留的有效个数的 vovab 和 对应 chunk_tags with open(Config.nlp_ner.path_config, 'wb') as outp: @@ -57,7 +60,10 @@ def _parse_data(filename): # 主要是分句: split_text 默认每个句子都是一行,所以原来换行就需要 两个split_text texts = fn.read().decode('utf-8').strip().split(split_text + split_text) # 对于每个字需要 split_text, 而字的内部需要用空格分隔 - data = [[row.split() for row in text.split(split_text)] for text in texts] + # len(row) > 0 避免连续2个换行,导致 row 数据为空 + # row.split() 会删除空格或特殊符号,导致空格数据缺失! + data = [[[" ", "O"] if len(row.split()) != 2 else row.split() for row in text.split(split_text) if len(row) > 0] for text in texts] + # data = [[row.split() for row in text.split(split_text) if len(row.split()) == 2] for text in texts] return data @@ -96,10 +102,17 @@ def create_model(len_vocab, len_chunk_tags): model = Sequential() model.add(Embedding(len_vocab, Config.nlp_ner.EMBED_DIM, mask_zero=True)) # Random embedding model.add(Bidirectional(LSTM(Config.nlp_ner.BiLSTM_UNITS // 2, return_sequences=True))) + model.add(Dropout(0.25)) crf = CRF(len_chunk_tags, sparse_target=True) model.add(crf) model.summary() - model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy]) + model.compile('adam', loss=crf_loss, metrics=[crf_viterbi_accuracy]) + # model.compile('rmsprop', loss=crf_loss, metrics=[crf_viterbi_accuracy]) + + # from keras.optimizers import Adam + # adam_lr = 0.0001 + # adam_beta_1 = 0.5 + # model.compile(optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1), loss=crf_loss, metrics=[crf_viterbi_accuracy]) return model @@ -115,29 +128,38 @@ def test(): with open(Config.nlp_ner.path_config, 'rb') as inp: (vocab, chunk_tags) = pickle.load(inp) model = create_model(len(vocab), len(chunk_tags)) - predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下,连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚' - text_EMBED, length = process_data(predict_text, vocab) - model.load_weights(Config.nlp_ner.path_model) - raw = model.predict(text_EMBED)[0][-length:] - result = [np.argmax(row) for row in raw] - result_tags = [chunk_tags[i] for i in result] + # predict_text = '造型独特,尺码偏大,估计是钉子头圆的半径的缘故' + with open(Config.nlp_ner.path_origin, "r", encoding="utf-8") as f: + lines = f.readlines() + for predict_text in lines: + content = predict_text.strip() + text_EMBED, length = process_data(content, vocab) + model.load_weights(Config.nlp_ner.path_model) + raw = model.predict(text_EMBED)[0][-length:] + pre_result = [np.argmax(row) for row in raw] + result_tags = [chunk_tags[i] for i in pre_result] - per, loc, org = '', '', '' - - for s, t in zip(predict_text, result_tags): - if t in ('B-PER', 'I-PER'): - per += ' ' + s if (t == 'B-PER') else s - if t in ('B-ORG', 'I-ORG'): - org += ' ' + s if (t == 'B-ORG') else s - if t in ('B-LOC', 'I-LOC'): - loc += ' ' + s if (t == 'B-LOC') else s - - print(['person:' + per, 'location:' + loc, 'organzation:' + org]) + # 保存每句话的 实体和观点 + result = {} + tag_list = [i for i in chunk_tags if i not in ["O"]] + for word, t in zip(content, result_tags): + # print(word, t) + if t not in tag_list: + continue + for i in range(0, len(tag_list), 2): + if t in tag_list[i:i+2]: + # print("\n>>> %s---%s==%s" % (word, t, tag_list[i:i+2])) + tag = tag_list[i].split("-")[-1] + if tag not in result: + result[tag] = "" + result[tag] += ' '+word if t==tag_list[i] else word + print(result) def main(): # print("--") train() + test() # if __name__ == "__main__": # train()