mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-03 02:14:18 +08:00
还原删除代码
This commit is contained in:
@@ -93,6 +93,64 @@ if __name__ == "__main__":
|
||||
## bert / Embedding/ + lstm + crt
|
||||
|
||||
|
||||
#%%
|
||||
# 加载数据
|
||||
class TextBert():
|
||||
def __init__(self):
|
||||
self.path_config = Config.bert.path_config
|
||||
self.path_checkpoint = Config.bert.path_checkpoint
|
||||
|
||||
self.token_dict = {}
|
||||
with codecs.open(Config.bert.dict_path, 'r', 'utf8') as reader:
|
||||
for line in reader:
|
||||
token = line.strip()
|
||||
self.token_dict[token] = len(self.token_dict)
|
||||
|
||||
|
||||
def prepare_data(self):
|
||||
neg = pd.read_excel(Config.bert.path_neg, header=None)
|
||||
pos = pd.read_excel(Config.bert.path_pos, header=None)
|
||||
data = []
|
||||
for d in neg[0]:
|
||||
data.append((d, 0))
|
||||
for d in pos[0]:
|
||||
data.append((d, 1))
|
||||
# 按照9:1的比例划分训练集和验证集
|
||||
random_order = list(range(len(data)))
|
||||
np.random.shuffle(random_order)
|
||||
train_data = [data[j] for i, j in enumerate(random_order) if i % 10 != 0]
|
||||
valid_data = [data[j] for i, j in enumerate(random_order) if i % 10 == 0]
|
||||
return train_data, valid_data
|
||||
|
||||
def build_model(self, m_type="bert"):
|
||||
if m_type == "bert":
|
||||
bert_model = load_trained_model_from_checkpoint(self.path_config, self.path_checkpoint, seq_len=None)
|
||||
for l in bert_model.layers:
|
||||
l.trainable = True
|
||||
x1_in = Input(shape=(None,))
|
||||
x2_in = Input(shape=(None,))
|
||||
x = bert_model([x1_in, x2_in])
|
||||
x = Lambda(lambda x: x[:, 0])(x)
|
||||
p = Dense(1, activation='sigmoid')(x)#根据分类种类自行调节,也可以多加一些层数
|
||||
model = Model([x1_in, x2_in], p)
|
||||
model.compile(
|
||||
loss='binary_crossentropy',
|
||||
optimizer=Adam(1e-5), # 用足够小的学习率
|
||||
metrics=['accuracy']
|
||||
)
|
||||
else:
|
||||
# 否则用 Embedding
|
||||
model = Sequential()
|
||||
model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding
|
||||
model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True)))
|
||||
crf = CRF(len(chunk_tags), sparse_target=True)
|
||||
model.add(crf)
|
||||
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
|
||||
|
||||
model.summary()
|
||||
return model
|
||||
|
||||
|
||||
#%%
|
||||
# 加载数据
|
||||
from keras_bert import Tokenizer
|
||||
|
||||
Reference in New Issue
Block a user