Files
ailearning/tutorials/keras/text_NER.py
2020-09-01 15:49:46 +08:00

144 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import pickle
import numpy as np
import pandas as pd
import platform
from collections import Counter
from keras.models import Sequential
from keras.layers import Embedding, Bidirectional, LSTM
from keras_contrib.layers import CRF
"""
# padding: pre(默认) 向前补充0 post 向后补充0
# truncating: 文本超过 pad_num, pre(默认) 删除前面 post 删除后面
# x_train = pad_sequences(x, maxlen=pad_num, value=0, padding='post', truncating="post")
# print("--- ", x_train[0][:20])
使用keras_bert、keras_contrib的crf时bug记录
TypeError: Tensors in list passed to 'values' of 'ConcatV2' Op have types [bool, float32] that don't all match
解决方案, 修改crf.py 516行
mask2 = K.cast(K.concatenate([mask, K.zeros_like(mask[:, :1])], axis=1),
为:
mask2 = K.cast(K.concatenate([mask, K.cast(K.zeros_like(mask[:, :1]), mask.dtype)], axis=1),
"""
from keras.preprocessing.sequence import pad_sequences
from config.setting import Config
def load_data():
train = _parse_data(Config.nlp_ner.path_train)
test = _parse_data(Config.nlp_ner.path_test)
print("--- init 数据加载解析完成 ---")
# Counter({'的': 8, '中': 7, '致': 7, '党': 7})
word_counts = Counter(row[0].lower() for sample in train for row in sample)
vocab = [w for w, f in iter(word_counts.items()) if f >= 2]
chunk_tags = ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', "B-ORG", "I-ORG"]
# 存储保留的有效个数的 vovab 和 对应 chunk_tags
with open(Config.nlp_ner.path_config, 'wb') as outp:
pickle.dump((vocab, chunk_tags), outp)
print("--- init 配置文件保存成功 ---")
train = _process_data(train, vocab, chunk_tags)
test = _process_data(test , vocab, chunk_tags)
print("--- init 对数据进行编码,生成训练需要的数据格式 ---")
return train, test, (vocab, chunk_tags)
def _parse_data(filename):
"""
以单下划线开头_foo的代表不能直接访问的类属性
用于解析数据,用于模型训练
:param filename: 文件地址
:return: data: 解析数据后的结果
[[['', 'B-ORG'], ['', 'I-ORG']], [['', 'B-ORG'], ['', 'I-ORG']]]
"""
with open(filename, 'rb') as fn:
split_text = '\n'
# 主要是分句: split_text 默认每个句子都是一行,所以原来换行就需要 两个split_text
texts = fn.read().decode('utf-8').strip().split(split_text + split_text)
# 对于每个字需要 split_text, 而字的内部需要用空格分隔
data = [[row.split() for row in text.split(split_text)] for text in texts]
return data
def _process_data(data, vocab, chunk_tags, maxlen=None, onehot=False):
if maxlen is None:
maxlen = max(len(s) for s in data)
# 对每个字进行编码
word2idx = dict((w, i) for i, w in enumerate(vocab))
# 如果不在 vocab里面就给 unk 值为 1
x = [[word2idx.get(w[0].lower(), 1) for w in s] for s in data]
y_chunk = [[chunk_tags.index(w[1]) for w in s] for s in data]
x = pad_sequences(x, maxlen) # left padding
y_chunk = pad_sequences(y_chunk, maxlen, value=-1)
if onehot:
# 返回一个onehot 编码的多维数组
y_chunk = np.eye(len(chunk_tags), dtype='float32')[y_chunk]
else:
# np.expand_dims:用于扩展数组的形状
# https://blog.csdn.net/hong615771420/article/details/83448878
y_chunk = np.expand_dims(y_chunk, 2)
return x, y_chunk
def process_data(data, vocab, maxlen=100):
word2idx = dict((w, i) for i, w in enumerate(vocab))
x = [word2idx.get(w[0].lower(), 1) for w in data]
length = len(x)
x = pad_sequences([x], maxlen) # left padding
return x, length
def create_model(len_vocab, len_chunk_tags):
model = Sequential()
model.add(Embedding(len_vocab, Config.nlp_ner.EMBED_DIM, mask_zero=True)) # Random embedding
model.add(Bidirectional(LSTM(Config.nlp_ner.BiLSTM_UNITS // 2, return_sequences=True)))
crf = CRF(len_chunk_tags, sparse_target=True)
model.add(crf)
model.summary()
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
return model
def train():
(train_x, train_y), (test_x, test_y), (vocab, chunk_tags) = load_data()
model = create_model(len(vocab), len(chunk_tags))
# train model
model.fit(train_x, train_y, batch_size=16, epochs=Config.nlp_ner.EPOCHS, validation_data=[test_x, test_y])
model.save(Config.nlp_ner.path_model)
def test():
with open(Config.nlp_ner.path_config, 'rb') as inp:
(vocab, chunk_tags) = pickle.load(inp)
model = create_model(len(vocab), len(chunk_tags))
predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚'
text_EMBED, length = process_data(predict_text, vocab)
model.load_weights(Config.nlp_ner.path_model)
raw = model.predict(text_EMBED)[0][-length:]
result = [np.argmax(row) for row in raw]
result_tags = [chunk_tags[i] for i in result]
per, loc, org = '', '', ''
for s, t in zip(predict_text, result_tags):
if t in ('B-PER', 'I-PER'):
per += ' ' + s if (t == 'B-PER') else s
if t in ('B-ORG', 'I-ORG'):
org += ' ' + s if (t == 'B-ORG') else s
if t in ('B-LOC', 'I-LOC'):
loc += ' ' + s if (t == 'B-LOC') else s
print(['person:' + per, 'location:' + loc, 'organzation:' + org])
def main():
# print("--")
train()
# if __name__ == "__main__":
# train()