robert和albert,瘦身成功的ALBERT能取代BERT吗
robert和albert,瘦身成功的ALBERT能取代BERT吗2018年由谷歌提出,训练的语料库规模非常庞大,包含33亿个词语。BERT模型是大家比较所熟知的。而最近,网友Naman Bansal就提出了一个疑问:是否应该用ALBERT来代替BERT?能否替代,比比便知。
十三 发自 凹非寺
量子位 报道 | 公众号 QbitAI
参数比BERT少了80%,性能却提高了。
这就是谷歌去年提出的“瘦身成功版BERT”模型——ALBERT。
这个模型一经发布,就受到了高度关注,二者的对比也成为了热门话题。
而最近,网友Naman Bansal就提出了一个疑问:
是否应该用ALBERT来代替BERT?
能否替代,比比便知。
BERT与ALBERT
BERT模型是大家比较所熟知的。
2018年由谷歌提出,训练的语料库规模非常庞大,包含33亿个词语。
模型的创新点集中在了预训练过程,采用Masked LM和Next Sentence Prediction两种方法,分别捕捉词语和句子级别的表示。
BERT的出现,彻底改变了预训练产生词向量和下游具体NLP任务的关系。
时隔1年后,谷歌又提出ALBERT,也被称作“lite-BERT”,骨干网络和BERT相似,采用的依旧是 Transformer 编码器,激活函数也是GELU。
其最大的成功,就在于参数量比BERT少了80%,同时还取得了更好的结果。
与BERT相比的改进,主要包括嵌入向量参数化的因式分解、跨层参数共享、句间连贯性损失采用SOP,以及移除了dropout。
下图便是BERT和ALBERT,在SQuAD和RACE数据集上的性能测试比较结果。
可以看出,ALBERT性能取得了较好的结果。
如何实现自定义语料库(预训练)ALBERT?
为了进一步了解ALBERT,接下来,将在自定义语料库中实现ALBERT。
所采用的数据集是“用餐点评数据集”,目标就是通过ALBERT模型来识别菜肴的名称。
第一步:下载数据集并准备文件
1#Downladingallfilesanddata
2
3!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/dish_name_train.csv
4!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/dish_name_val.csv
5!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/restaurant_review.txt
6!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/restaurant_review_nopunct.txt
7!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/models_toy/albert_config.json
8!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/model_checkpoint/finetune_checkpoint
9!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/model_checkpoint/pretrain_checkpoint
10
11#CreatingfilesandsettingupALBERT
12
13!pipinstallsentencepiece
14!gitclonehttps://github.com/google-research/ALBERT
15!python./ALBERT/create_pretraining_data.py--input_file"restaurant_review.txt"--output_file"restaurant_review_train"--vocab_file"vocab.txt"--max_seq_length=64
16!pipinstalltransformers
17!pipinstalltfrecord
第二步:使用transformer并定义层
1#DefiningLayersforALBERT
2
3fromtransformers.modeling_albertimportAlbertModel AlbertPreTrainedModel
4fromtransformers.configuration_albertimportAlbertConfig
5importtorch.nnasnn
6classAlbertSequenceOrderHead(nn.Module):
7def__init__(self config):
8super().__init__()
9self.dense=nn.Linear(config.hidden_size 2)
10self.bias=nn.Parameter(torch.zeros(2))
11
12defforward(self hidden_states):
13hidden_states=self.dense(hidden_states)
14prediction_scores=hidden_states self.bias
15
16returnprediction_scores
17
18fromtorch.nnimportCrossEntropyLoss
19fromtransformers.modeling_bertimportACT2FN
20classAlbertForPretrain(AlbertPreTrainedModel):
21
22def__init__(self config):
23super().__init__(config)
24
25self.albert=AlbertModel(config)
26
27#ForMaskedLM
28#Theoriginalhuggingfaceimplementation creatednewoutputweightsviadenselayer
29#HowevertheoriginalAlbert
30self.predictions_dense=nn.Linear(config.hidden_size config.embedding_size)
31self.predictions_activation=ACT2FN[config.hidden_act]
32self.predictions_LayerNorm=nn.LayerNorm(config.embedding_size)
33self.predictions_bias=nn.Parameter(torch.zeros(config.vocab_size))
34self.predictions_decoder=nn.Linear(config.embedding_size config.vocab_size)
35
36self.predictions_decoder.weight=self.albert.embeddings.word_embeddings.weight
37
38#Forsequenceorderprediction
39self.seq_relationship=AlbertSequenceOrderHead(config)
40
41
42defforward(
43self
44input_ids=None
45attention_mask=None
46token_type_ids=None
47position_ids=None
48head_mask=None
49inputs_embeds=None
50masked_lm_labels=None
51seq_relationship_labels=None
52):
53
54outputs=self.albert(
55input_ids
56attention_mask=attention_mask
57token_type_ids=token_type_ids
58position_ids=position_ids
59head_mask=head_mask
60inputs_embeds=inputs_embeds
61)
62
63loss_fct=CrossEntropyLoss()
64
65sequence_output=outputs[0]
66
67sequence_output=self.predictions_dense(sequence_output)
68sequence_output=self.predictions_activation(sequence_output)
69sequence_output=self.predictions_LayerNorm(sequence_output)
70prediction_scores=self.predictions_decoder(sequence_output)
71
72
73ifmasked_lm_labelsisnotNone:
74masked_lm_loss=loss_fct(prediction_scores.view(-1 self.config.vocab_size)
75 masked_lm_labels.view(-1))
76
77pooled_output=outputs[1]
78seq_relationship_scores=self.seq_relationship(pooled_output)
79ifseq_relationship_labelsisnotNone:
80seq_relationship_loss=loss_fct(seq_relationship_scores.view(-1 2) seq_relationship_labels.view(-1))
81
82loss=masked_lm_loss seq_relationship_loss
83
84returnloss
第三步:使用LAMB优化器并微调ALBERT
1#UsingLAMBoptimizer
2#LAMB-"https://github.com/cybertronai/pytorch-lamb"
3
4importtorch
5fromtorch.optimimportOptimizer
6classLamb(Optimizer):
7r"""ImplementsLambalgorithm.
8Ithasbeenproposedin`LargeBatchOptimizationforDeepLearning:TrainingBERTin76minutes`_.
9Arguments:
10params(iterable):iterableofparameterstooptimizeordictsdefining
11parametergroups
12lr(float optional):learningrate(default:1e-3)
13betas(Tuple[float float] optional):coefficientsusedforcomputing
14runningaveragesofgradientanditssquare(default:(0.9 0.999))
15eps(float optional):termaddedtothedenominatortoimprove
16numericalstability(default:1e-8)
17weight_decay(float optional):weightdecay(L2penalty)(default:0)
18adam(bool optional):alwaysusetrustratio=1 whichturnsthisinto
19Adam.Usefulforcomparisonpurposes.
20.._LargeBatchOptimizationforDeepLearning:TrainingBERTin76minutes:
21https://arxiv.org/abs/1904.00962
22"""
23
24def__init__(self params lr=1e-3 betas=(0.9 0.999) eps=1e-6
25weight_decay=0 adam=False):
26ifnot0.0<=lr:
27raiseValueError("Invalidlearningrate:{}".format(lr))
28ifnot0.0<=eps:
29raiseValueError("Invalidepsilonvalue:{}".format(eps))
30ifnot0.0<=betas[0]<1.0:
31raiseValueError("Invalidbetaparameteratindex0:{}".format(betas[0]))
32ifnot0.0<=betas[1]<1.0:
33raiseValueError("Invalidbetaparameteratindex1:{}".format(betas[1]))
34defaults=dict(lr=lr betas=betas eps=eps
35weight_decay=weight_decay)
36self.adam=adam
37super(Lamb self).__init__(params defaults)
38
39defstep(self closure=None):
40"""Performsasingleoptimizationstep.
41Arguments:
42closure(callable optional):Aclosurethatreevaluatesthemodel
43andreturnstheloss.
44"""
45loss=None
46ifclosureisnotNone:
47loss=closure()
48
49forgroupinself.param_groups:
50forpingroup['params']:
51ifp.gradisNone:
52continue
53grad=p.grad.data
54ifgrad.is_sparse:
55raiseRuntimeError('Lambdoesnotsupportsparsegradients considerSparseAdaminstad.')
56
57state=self.state[p]
58
59#Stateinitialization
60iflen(state)==0:
61state['step']=0
62#Exponentialmovingaverageofgradientvalues
63state['exp_avg']=torch.zeros_like(p.data)
64#Exponentialmovingaverageofsquaredgradientvalues
65state['exp_avg_sq']=torch.zeros_like(p.data)
66
67exp_avg exp_avg_sq=state['exp_avg'] state['exp_avg_sq']
68beta1 beta2=group['betas']
69
70state['step'] =1
71
72#Decaythefirstandsecondmomentrunningaveragecoefficient
73#m_t
74exp_avg.mul_(beta1).add_(1-beta1 grad)
75#v_t
76exp_avg_sq.mul_(beta2).addcmul_(1-beta2 grad grad)
77
78#Paperv3doesnotusedebiasing.
79#bias_correction1=1-beta1**state['step']
80#bias_correction2=1-beta2**state['step']
81#Applybiastolrtoavoidbroadcast.
82step_size=group['lr']#*math.sqrt(bias_correction2)/bias_correction1
83
84weight_norm=p.data.pow(2).sum().sqrt().clamp(0 10)
85
86adam_step=exp_avg/exp_avg_sq.sqrt().add(group['eps'])
87ifgroup['weight_decay']!=0:
88adam_step.add_(group['weight_decay'] p.data)
89
90adam_norm=adam_step.pow(2).sum().sqrt()
91ifweight_norm==0oradam_norm==0:
92trust_ratio=1
93else:
94trust_ratio=weight_norm/adam_norm
95state['weight_norm']=weight_norm
96state['adam_norm']=adam_norm
97state['trust_ratio']=trust_ratio
98ifself.adam:
99trust_ratio=1
100
101p.data.add_(-step_size*trust_ratio adam_step)
102
103returnloss
104
105importtime
106importtorch.nnasnn
107importtorch
108fromtfrecord.torch.datasetimportTFRecordDataset
109importnumpyasnp
110importos
111
112LEARNING_RATE=0.001
113EPOCH=40
114BATCH_SIZE=2
115MAX_GRAD_NORM=1.0
116
117print(f"---Resume/Starttraining---")
118feat_map={"input_ids":"int"
119"input_mask":"int"
120"segment_ids":"int"
121"next_sentence_labels":"int"
122"masked_lm_positions":"int"
123"masked_lm_ids":"int"}
124pretrain_file='restaurant_review_train'
125
126#Createalbertpretrainmodel
127config=AlbertConfig.from_json_file("albert_config.json")
128albert_pretrain=AlbertForPretrain(config)
129#Createoptimizer
130optimizer=Lamb([{"params":[pforn pinlist(albert_pretrain.named_parameters())]}] lr=LEARNING_RATE)
131albert_pretrain.train()
132dataset=TFRecordDataset(pretrain_file index_path=None description=feat_map)
133loader=torch.utils.data.DataLoader(dataset batch_size=BATCH_SIZE)
134
135tmp_loss=0
136start_time=time.time()
137
138ifos.path.isfile('pretrain_checkpoint'):
139print(f"---Loadfromcheckpoint---")
140checkpoint=torch.load("pretrain_checkpoint")
141albert_pretrain.load_state_dict(checkpoint['model_state_dict'])
142optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
143epoch=checkpoint['epoch']
144loss=checkpoint['loss']
145losses=checkpoint['losses']
146
147else:
148epoch=-1
149losses=[]
150foreinrange(epoch 1 EPOCH):
151forbatchinloader:
152b_input_ids=batch['input_ids'].long()
153b_token_type_ids=batch['segment_ids'].long()
154b_seq_relationship_labels=batch['next_sentence_labels'].long()
155
156#Convertthedataformatfromloadeddecodedformatintoformat
157#loadedformatiscreatedbygoogle'sAlbertcreate_pretrain.pyscript
158#requiredbyhuggingfacespytorchimplementationofalbert
159mask_rows=np.nonzero(batch['masked_lm_positions'].numpy())[0]
160mask_cols=batch['masked_lm_positions'].numpy()[batch['masked_lm_positions'].numpy()!=0]
161b_attention_mask=np.zeros((BATCH_SIZE 64) dtype=np.int64)
162b_attention_mask[mask_rows mask_cols]=1
163b_masked_lm_labels=np.zeros((BATCH_SIZE 64) dtype=np.int64)-100
164b_masked_lm_labels[mask_rows mask_cols]=batch['masked_lm_ids'].numpy()[batch['masked_lm_positions'].numpy()!=0]
165b_attention_mask=torch.tensor(b_attention_mask).long()
166b_masked_lm_labels=torch.tensor(b_masked_lm_labels).long()
167
168
169loss=albert_pretrain(input_ids=b_input_ids
170 attention_mask=b_attention_mask
171 token_type_ids=b_token_type_ids
172 masked_lm_labels=b_masked_lm_labels
173 seq_relationship_labels=b_seq_relationship_labels)
174
175#clearsoldgradients
176optimizer.zero_grad()
177#backwardpass
178loss.backward()
179#gradientclipping
180torch.nn.utils.clip_grad_norm_(parameters=albert_pretrain.parameters() max_norm=MAX_GRAD_NORM)
181#updateparameters
182optimizer.step()
183
184tmp_loss =loss.detach().item()
185
186#printmetricsandsavetocheckpointeveryepoch
187print(f"Epoch:{e}")
188print(f"Trainloss:{(tmp_loss/20)}")
189print(f"TrainTime:{(time.time()-start_time)/60}mins")
190losses.append(tmp_loss/20)
191
192tmp_loss=0
193start_time=time.time()
194
195torch.save({'model_state_dict':albert_pretrain.state_dict() 'optimizer_state_dict':optimizer.state_dict()
196'epoch':e 'loss':loss 'losses':losses}
197 'pretrain_checkpoint')
198frommatplotlibimportpyplotasplot
199plot.plot(losses)
200
201#FinetuningALBERT
202
203#Atthetimeofwriting Huggingfacedidntprovidetheclassobjectfor
204#AlbertForTokenClassification hencewriteyourowndefinationbelow
205fromtransformers.modeling_albertimportAlbertModel AlbertPreTrainedModel
206fromtransformers.configuration_albertimportAlbertConfig
207fromtransformers.tokenization_bertimportBertTokenizer
208importtorch.nnasnn
209fromtorch.nnimportCrossEntropyLoss
210classAlbertForTokenClassification(AlbertPreTrainedModel):
211
212def__init__(self albert config):
213super().__init__(config)
214self.num_labels=config.num_labels
215
216self.albert=albert
217self.dropout=nn.Dropout(config.hidden_dropout_prob)
218self.classifier=nn.Linear(config.hidden_size config.num_labels)
219
220defforward(
221self
222input_ids=None
223attention_mask=None
224token_type_ids=None
225position_ids=None
226head_mask=None
227inputs_embeds=None
228labels=None
229):
230
231outputs=self.albert(
232input_ids
233attention_mask=attention_mask
234token_type_ids=token_type_ids
235position_ids=position_ids
236head_mask=head_mask
237inputs_embeds=inputs_embeds
238)
239
240sequence_output=outputs[0]
241
242sequence_output=self.dropout(sequence_output)
243logits=self.classifier(sequence_output)
244
245returnlogits
246
247importnumpyasnp
248deflabel_sent(name_tokens sent_tokens):
249label=[]
250i=0
251iflen(name_tokens)>len(sent_tokens):
252label=np.zeros(len(sent_tokens))
253else:
254whilei<len(sent_tokens):
255found_match=False
256ifname_tokens[0]==sent_tokens[i]:
257found_match=True
258forjinrange(len(name_tokens)-1):
259if((i j 1)>=len(sent_tokens)):
260returnlabel
261ifname_tokens[j 1]!=sent_tokens[i j 1]:
262found_match=False
263iffound_match:
264label.extend(list(np.ones(len(name_tokens)).astype(int)))
265i=i len(name_tokens)
266else:
267label.extend([0])
268i=i 1
269else:
270label.extend([0])
271i=i 1
272returnlabel
273
274importpandasaspd
275importglob
276importos
277
278tokenizer=BertTokenizer(vocab_file="vocab.txt")
279
280df_data_train=pd.read_csv("dish_name_train.csv")
281df_data_train['name_tokens']=df_data_train['dish_name'].apply(tokenizer.tokenize)
282df_data_train['review_tokens']=df_data_train.review.apply(tokenizer.tokenize)
283df_data_train['review_label']=df_data_train.apply(lambdarow:label_sent(row['name_tokens'] row['review_tokens']) axis=1)
284
285df_data_val=pd.read_csv("dish_name_val.csv")
286df_data_val=df_data_val.dropna().reset_index()
287df_data_val['name_tokens']=df_data_val['dish_name'].apply(tokenizer.tokenize)
288df_data_val['review_tokens']=df_data_val.review.apply(tokenizer.tokenize)
289df_data_val['review_label']=df_data_val.apply(lambdarow:label_sent(row['name_tokens'] row['review_tokens']) axis=1)
290
291MAX_LEN=64
292BATCH_SIZE=1
293fromkeras.preprocessing.sequenceimportpad_sequences
294importtorch
295fromtorch.utils.dataimportTensorDataset DataLoader RandomSampler SequentialSampler
296
297tr_inputs=pad_sequences([tokenizer.convert_tokens_to_ids(txt)fortxtindf_data_train['review_tokens']] maxlen=MAX_LEN dtype="long" truncating="post" padding="post")
298tr_tags=pad_sequences(df_data_train['review_label'] maxlen=MAX_LEN padding="post" dtype="long" truncating="post")
299#createthemasktoignorethepaddedelementsinthesequences.
300tr_masks=[[float(i>0)foriinii]foriiintr_inputs]
301tr_inputs=torch.tensor(tr_inputs)
302tr_tags=torch.tensor(tr_tags)
303tr_masks=torch.tensor(tr_masks)
304train_data=TensorDataset(tr_inputs tr_masks tr_tags)
305train_sampler=RandomSampler(train_data)
306train_dataloader=DataLoader(train_data sampler=train_sampler batch_size=BATCH_SIZE)
307
308
309val_inputs=pad_sequences([tokenizer.convert_tokens_to_ids(txt)fortxtindf_data_val['review_tokens']] maxlen=MAX_LEN dtype="long" truncating="post" padding="post")
310val_tags=pad_sequences(df_data_val['review_label'] maxlen=MAX_LEN padding="post" dtype="long" truncating="post")
311#createthemasktoignorethepaddedelementsinthesequences.
312val_masks=[[float(i>0)foriinii]foriiinval_inputs]
313val_inputs=torch.tensor(val_inputs)
314val_tags=torch.tensor(val_tags)
315val_masks=torch.tensor(val_masks)
316val_data=TensorDataset(val_inputs val_masks val_tags)
317val_sampler=RandomSampler(val_data)
318val_dataloader=DataLoader(val_data sampler=val_sampler batch_size=BATCH_SIZE)
319
320model_tokenclassification=AlbertForTokenClassification(albert_pretrain.albert config)
321fromtorch.optimimportAdam
322LEARNING_RATE=0.0000003
323FULL_FINETUNING=True
324ifFULL_FINETUNING:
325param_optimizer=list(model_tokenclassification.named_parameters())
326no_decay=['bias' 'gamma' 'beta']
327optimizer_grouped_parameters=[
328{'params':[pforn pinparam_optimizerifnotany(ndinnforndinno_decay)]
329'weight_decay_rate':0.01}
330{'params':[pforn pinparam_optimizerifany(ndinnforndinno_decay)]
331'weight_decay_rate':0.0}
332]
333else:
334param_optimizer=list(model_tokenclassification.classifier.named_parameters())
335optimizer_grouped_parameters=[{"params":[pforn pinparam_optimizer]}]
336optimizer=Adam(optimizer_grouped_parameters lr=LEARNING_RATE)
第四步:为自定义语料库训练模型
1#Trainingthemodel
2
3#fromtorch.utils.tensorboardimportSummaryWriter
4importtime
5importos.path
6importtorch.nnasnn
7importtorch
8EPOCH=800
9MAX_GRAD_NORM=1.0
10
11start_time=time.time()
12tr_loss tr_acc nb_tr_steps=0 0 0
13eval_loss eval_acc nb_eval_steps=0 0 0
14
15ifos.path.isfile('finetune_checkpoint'):
16print(f"---Loadfromcheckpoint---")
17checkpoint=torch.load("finetune_checkpoint")
18model_tokenclassification.load_state_dict(checkpoint['model_state_dict'])
19optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
20epoch=checkpoint['epoch']
21train_losses=checkpoint['train_losses']
22train_accs=checkpoint['train_accs']
23eval_losses=checkpoint['eval_losses']
24eval_accs=checkpoint['eval_accs']
25
26else:
27epoch=-1
28train_losses train_accs eval_losses eval_accs=[] [] [] []
29
30print(f"---Resume/Starttraining---")
31foreinrange(epoch 1 EPOCH):
32
33#TRAINloop
34model_tokenclassification.train()
35
36forbatchintrain_dataloader:
37#addbatchtogpu
38batch=tuple(tfortinbatch)
39b_input_ids b_input_mask b_labels=batch
40#forwardpass
41b_outputs=model_tokenclassification(b_input_ids token_type_ids=None attention_mask=b_input_mask labels=b_labels)
42
43ce_loss_fct=CrossEntropyLoss()
44#Onlykeepactivepartsoftheloss
45b_active_loss=b_input_mask.view(-1)==1
46b_active_logits=b_outputs.view(-1 config.num_labels)[b_active_loss]
47b_active_labels=b_labels.view(-1)[b_active_loss]
48
49loss=ce_loss_fct(b_active_logits b_active_labels)
50acc=torch.mean((torch.max(b_active_logits.detach() 1)[1]==b_active_labels.detach()).float())
51
52model_tokenclassification.zero_grad()
53#backwardpass
54loss.backward()
55#tracktrainloss
56tr_loss =loss.item()
57tr_acc =acc
58nb_tr_steps =1
59#gradientclipping
60torch.nn.utils.clip_grad_norm_(parameters=model_tokenclassification.parameters() max_norm=MAX_GRAD_NORM)
61#updateparameters
62optimizer.step()
63
64
65#VALIDATIONonvalidationset
66model_tokenclassification.eval()
67forbatchinval_dataloader:
68batch=tuple(tfortinbatch)
69b_input_ids b_input_mask b_labels=batch
70
71withtorch.no_grad():
72
73b_outputs=model_tokenclassification(b_input_ids token_type_ids=None
74attention_mask=b_input_mask labels=b_labels)
75
76loss_fct=CrossEntropyLoss()
77#Onlykeepactivepartsoftheloss
78b_active_loss=b_input_mask.view(-1)==1
79b_active_logits=b_outputs.view(-1 config.num_labels)[b_active_loss]
80b_active_labels=b_labels.view(-1)[b_active_loss]
81loss=loss_fct(b_active_logits b_active_labels)
82acc=np.mean(np.argmax(b_active_logits.detach().cpu().numpy() axis=1).flatten()==b_active_labels.detach().cpu().numpy().flatten())
83
84eval_loss =loss.mean().item()
85eval_acc =acc
86nb_eval_steps =1
87
88ife%10==0:
89
90print(f"Epoch:{e}")
91print(f"Trainloss:{(tr_loss/nb_tr_steps)}")
92print(f"Trainacc:{(tr_acc/nb_tr_steps)}")
93print(f"TrainTime:{(time.time()-start_time)/60}mins")
94
95print(f"Validationloss:{eval_loss/nb_eval_steps}")
96print(f"ValidationAccuracy:{(eval_acc/nb_eval_steps)}")
97
98train_losses.append(tr_loss/nb_tr_steps)
99train_accs.append(tr_acc/nb_tr_steps)
100eval_losses.append(eval_loss/nb_eval_steps)
101eval_accs.append(eval_acc/nb_eval_steps)
102
103
104tr_loss tr_acc nb_tr_steps=0 0 0
105eval_loss eval_acc nb_eval_steps=0 0 0
106start_time=time.time()
107
108torch.save({'model_state_dict':model_tokenclassification.state_dict() 'optimizer_state_dict':optimizer.state_dict()
109'epoch':e 'train_losses':train_losses 'train_accs':train_accs 'eval_losses':eval_losses 'eval_accs':eval_accs}
110 'finetune_checkpoint')
111
112plot.plot(train_losses)
113plot.plot(train_accs)
114plot.plot(eval_losses)
115plot.plot(eval_accs)
116plot.legend(labels=['train_loss' 'train_accuracy' 'validation_loss' 'validation_accuracy'])
第五步:预测
1#Prediction
2
3defpredict(texts):
4tokenized_texts=[tokenizer.tokenize(txt)fortxtintexts]
5input_ids=pad_sequences([tokenizer.convert_tokens_to_ids(txt)fortxtintokenized_texts]
6maxlen=MAX_LEN dtype="long" truncating="post" padding="post")
7attention_mask=[[float(i>0)foriinii]foriiininput_ids]
8
9input_ids=torch.tensor(input_ids)
10attention_mask=torch.tensor(attention_mask)
11
12dataset=TensorDataset(input_ids attention_mask)
13datasampler=SequentialSampler(dataset)
14dataloader=DataLoader(dataset sampler=datasampler batch_size=BATCH_SIZE)
15
16predicted_labels=[]
17
18forbatchindataloader:
19batch=tuple(tfortinbatch)
20b_input_ids b_input_mask=batch
21
22withtorch.no_grad():
23logits=model_tokenclassification(b_input_ids token_type_ids=None
24attention_mask=b_input_mask)
25
26predicted_labels.append(np.multiply(np.argmax(logits.detach().cpu().numpy() axis=2) b_input_mask.detach().cpu().numpy()))
27#np.concatenate(predicted_labels) toflattenlistofarraysofbatch_size*max_lenintolistofarraysofmax_len
28returnnp.concatenate(predicted_labels).astype(int) tokenized_texts
29
30defget_dish_candidate_names(predicted_label tokenized_text):
31name_lists=[]
32iflen(np.where(predicted_label>0)[0])>0:
33name_idx_combined=np.where(predicted_label>0)[0]
34name_idxs=np.split(name_idx_combined np.where(np.diff(name_idx_combined)!=1)[0] 1)
35name_lists.append(["".join(np.take(tokenized_text name_idx))forname_idxinname_idxs])
36#Ifthereduplicatenamesinthename_lists
37name_lists=np.unique(name_lists)
38returnname_lists
39else:
40returnNone
41
42texts=df_data_val.review.values
43predicted_labels _=predict(texts)
44df_data_val['predicted_review_label']=list(predicted_labels)
45df_data_val['predicted_name']=df_data_val.apply(lambdarow:get_dish_candidate_names(row.predicted_review_label row.review_tokens)
46 axis=1)
47
48texts=df_data_train.review.values
49predicted_labels _=predict(texts)
50df_data_train['predicted_review_label']=list(predicted_labels)
51df_data_train['predicted_name']=df_data_train.apply(lambdarow:get_dish_candidate_names(row.predicted_review_label row.review_tokens)
52 axis=1)
53
54(df_data_val)
实验结果
可以看到,模型成功地从用餐评论中,提取出了菜名。
模型比拼
从上面的实战应用中可以看到,ALBERT虽然很lite,结果也可以说相当不错。
那么,参数少、结果好,是否就可以替代BERT呢?
我们可以仔细看下二者实验性能的比较,这里的Speedup是指训练时间。
因为数据数据少了,分布式训练时吞吐上去了,所以ALBERT训练更快。但推理时间还是需要和BERT一样的transformer计算。
所以可以总结为:
- 在相同的训练时间下,ALBERT效果要比BERT好。
- 在相同的推理时间下,ALBERT base和large的效果都是没有BERT好。
此外,Naman Bansal认为,由于ALBERT的结构,实现ALBERT的计算代价比BERT要高一些。
所以,还是“鱼和熊掌不可兼得”的关系,要想让ALBERT完全超越、替代BERT,还需要做更进一步的研究和改良。
传送门
博客地址:
https://medium.com/@namanbansal9909/should-we-shift-from-bert-to-albert-e6fbb7779d3e
— 完 —
量子位 QbitAI · 头条号签约
关注我们,第一时间获知前沿科技动态