diff --git a/src/py3.x/tensorflow2.x/text_bert.py b/src/py3.x/tensorflow2.x/text_bert.py index 42c63168..bec2ebee 100644 --- a/src/py3.x/tensorflow2.x/text_bert.py +++ b/src/py3.x/tensorflow2.x/text_bert.py @@ -93,6 +93,64 @@ if __name__ == "__main__": ## bert / Embedding/ + lstm + crt + #%% + # 加载数据 + class TextBert(): + def __init__(self): + self.path_config = Config.bert.path_config + self.path_checkpoint = Config.bert.path_checkpoint + + self.token_dict = {} + with codecs.open(Config.bert.dict_path, 'r', 'utf8') as reader: + for line in reader: + token = line.strip() + self.token_dict[token] = len(self.token_dict) + + + def prepare_data(self): + neg = pd.read_excel(Config.bert.path_neg, header=None) + pos = pd.read_excel(Config.bert.path_pos, header=None) + data = [] + for d in neg[0]: + data.append((d, 0)) + for d in pos[0]: + data.append((d, 1)) + # 按照9:1的比例划分训练集和验证集 + random_order = list(range(len(data))) + np.random.shuffle(random_order) + train_data = [data[j] for i, j in enumerate(random_order) if i % 10 != 0] + valid_data = [data[j] for i, j in enumerate(random_order) if i % 10 == 0] + return train_data, valid_data + + def build_model(self, m_type="bert"): + if m_type == "bert": + bert_model = load_trained_model_from_checkpoint(self.path_config, self.path_checkpoint, seq_len=None) + for l in bert_model.layers: + l.trainable = True + x1_in = Input(shape=(None,)) + x2_in = Input(shape=(None,)) + x = bert_model([x1_in, x2_in]) + x = Lambda(lambda x: x[:, 0])(x) + p = Dense(1, activation='sigmoid')(x)#根据分类种类自行调节,也可以多加一些层数 + model = Model([x1_in, x2_in], p) + model.compile( + loss='binary_crossentropy', + optimizer=Adam(1e-5), # 用足够小的学习率 + metrics=['accuracy'] + ) + else: + # 否则用 Embedding + model = Sequential() + model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding + model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True))) + crf = CRF(len(chunk_tags), sparse_target=True) + model.add(crf) + model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy]) + + model.summary() + return model + + #%% # 加载数据 from keras_bert import Tokenizer