Files
ailearning/src/py3.x/tensorflow2.x/text_NER.py
2020-03-15 14:34:02 +08:00

127 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import pickle
import numpy as np
import platform
from collections import Counter
from keras.models import Sequential
from keras.layers import Embedding, Bidirectional, LSTM
from keras_contrib.layers import CRF
from keras.preprocessing.sequence import pad_sequences
EMBED_DIM = 200
BiRNN_UNITS = 200
def load_data():
train = _parse_data(open('zh-NER/data/train_data.data', 'rb'))
test = _parse_data(open('zh-NER/data/test_data.data', 'rb'))
word_counts = Counter(row[0].lower() for sample in train for row in sample)
vocab = [w for w, f in iter(word_counts.items()) if f >= 2]
chunk_tags = ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', "B-ORG", "I-ORG"]
# save initial config data
with open('zh-NER/model/config.pkl', 'wb') as outp:
pickle.dump((vocab, chunk_tags), outp)
train = _process_data(train, vocab, chunk_tags)
test = _process_data(test, vocab, chunk_tags)
return train, test, (vocab, chunk_tags)
def _parse_data(fh):
# in windows the new line is '\r\n\r\n' the space is '\r\n' . so if you use windows system,
# you have to use recorsponding instructions
if platform.system() == 'Windows':
split_text = '\r\n'
else:
split_text = '\n'
string = fh.read().decode('utf-8')
data = [[row.split() for row in sample.split(split_text)] for
sample in
string.strip().split(split_text + split_text)]
fh.close()
return data
def _process_data(data, vocab, chunk_tags, maxlen=None, onehot=False):
if maxlen is None:
maxlen = max(len(s) for s in data)
word2idx = dict((w, i) for i, w in enumerate(vocab))
x = [[word2idx.get(w[0].lower(), 1) for w in s] for s in data] # set to <unk> (index 1) if not in vocab
y_chunk = [[chunk_tags.index(w[1]) for w in s] for s in data]
x = pad_sequences(x, maxlen) # left padding
y_chunk = pad_sequences(y_chunk, maxlen, value=-1)
if onehot:
y_chunk = np.eye(len(chunk_tags), dtype='float32')[y_chunk]
else:
y_chunk = np.expand_dims(y_chunk, 2)
return x, y_chunk
def process_data(data, vocab, maxlen=100):
word2idx = dict((w, i) for i, w in enumerate(vocab))
x = [word2idx.get(w[0].lower(), 1) for w in data]
length = len(x)
x = pad_sequences([x], maxlen) # left padding
return x, length
def create_model(train=True):
if train:
(train_x, train_y), (test_x, test_y), (vocab, chunk_tags) = load_data()
else:
with open('model/config.pkl', 'rb') as inp:
(vocab, chunk_tags) = pickle.load(inp)
model = Sequential()
model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding
model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True)))
crf = CRF(len(chunk_tags), sparse_target=True)
model.add(crf)
model.summary()
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
if train:
return model, (train_x, train_y), (test_x, test_y)
else:
return model, (vocab, chunk_tags)
def train():
EPOCHS = 10
model, (train_x, train_y), (test_x, test_y) = create_model()
# train model
model.fit(train_x, train_y,batch_size=16,epochs=EPOCHS, validation_data=[test_x, test_y])
model.save('model/crf.h5')
def test():
model, (vocab, chunk_tags) = create_model(train=False)
predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚'
str, length = process_data(predict_text, vocab)
model.load_weights('model/crf.h5')
raw = model.predict(str)[0][-length:]
result = [np.argmax(row) for row in raw]
result_tags = [chunk_tags[i] for i in result]
per, loc, org = '', '', ''
for s, t in zip(predict_text, result_tags):
if t in ('B-PER', 'I-PER'):
per += ' ' + s if (t == 'B-PER') else s
if t in ('B-ORG', 'I-ORG'):
org += ' ' + s if (t == 'B-ORG') else s
if t in ('B-LOC', 'I-LOC'):
loc += ' ' + s if (t == 'B-LOC') else s
print(['person:' + per, 'location:' + loc, 'organzation:' + org])
if __name__ == "__main__":
train()