This repository was archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 42
This repository was archived by the owner on Oct 31, 2023. It is now read-only.
Failed to train the Mask-predict with larger model/hidden dimension #7
Copy link
Copy link
Open
Description
Elegant work! In addition to training a transformer_base-scale model, I am still trying to train a large model, (e.g., 1024 model dim. & 4096 hidden dim), such that I can fine-tune Mask-predict with XLM.
However, when I simply change the dimension and fix other arguments, the training is failed, that is, the ppl is even becoming bigger. Can you give me some advices?
Below is my training command:
python train.py data-bin/xlm_pretained-wmt14.en-de --arch bert_transformer_seq2seq --share-all-embeddings --criterion label_smoothed_length_cross_entropy --label-smoothing 0.1 --lr 5e-4 --warmup-init-lr 1e-7 --min-lr 1e-9 --lr-scheduler inverse_sqrt --warmup-updates 10000 --optimizer adam --adam-betas '(0.9,0.999)' --adam-eps 1e-6 --task translation_self --max-tokens 11000 --weight-decay 0.01 --dropout 0.3 --encoder-layers 6 --encoder-embed-dim 1024 --decoder-layers 6 --decoder-embed-dim 1024 --encoder-attention-heads 8 --decoder-attention-heads 8 --max-source-positions 10000 --max-target-positions 10000 --max-update 300000 --seed 0 --save-dir ${model_dir} --update-freq 3 --ddp-backend=no_c10d --fp16 --keep-last-epochs 10
and the following is the log of one training step:
| epoch 012: 74%|▋| 814/1099 [24:26<08:28, 1.79s/it, loss=12.243, nll_loss=11.121, ppl=2226.58, wps=33332, ups=1, wpb=60068.756, bsz=4060.299, num_updates=12894, lr=0.000440328, gnorm=0.341, clip=0.000, oom=0.000, loss_scale=0.250, wall=23856, train_wall=20393, length_loss=6.6472]
BTW, because I reused the XLM vocabulary list, the vocab size of larger Mask-predict is more than 60k+.
Namespace(adam_betas='(0.9,0.999)', adam_eps=1e-06, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, arch='bert_transformer_seq2seq', attention_dropout=0.0, best_checkpoint_metric='loss', bilm_add_bos=False, bilm_attention_dropout=0.0, bilm_mask_last_state=False, bilm_model_dropout=0.1, bilm_relu_dropout=0.0, bucket_cap_mb=25, clip_norm=25, cpu=False, criterion='label_smoothed_length_cross_entropy', curriculum=0, data=['data-bin/xlm_pretained-wmt14.en-de'], dataset_impl=None, ddp_backend='no_c10d', decoder_attention_heads=8, decoder_embed_dim=1024, decoder_embed_path=None, decoder_embed_scale=None, decoder_ffn_embed_dim=4096, decoder_input_dim=1024, decoder_layers=6, decoder_learned_pos=False, decoder_normalize_before=False, decoder_output_dim=1024, device_id=0, disable_validation=False, distributed_backend='nccl', distributed_init_method='tcp://localhost:10859', distributed_no_spawn=False, distributed_port=-1, distributed_rank=0, distributed_world_size=4, dropout=0.3, dynamic_length=False, embedding_only=False, encoder_attention_heads=8, encoder_embed_dim=1024, encoder_embed_path=None, encoder_embed_scale=None, encoder_ffn_embed_dim=4096, encoder_layers=6, encoder_learned_pos=False, encoder_normalize_before=False, find_unused_parameters=False, fix_batches_to_gpus=False, fp16=True, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, keep_interval_updates=-1, keep_last_epochs=10, label_smoothing=0.1, left_pad_source='True', left_pad_target='False', log_format=None, log_interval=1000, lr=[0.0005], lr_scheduler='inverse_sqrt', mask_range=False, max_epoch=0, max_sentences=None, max_sentences_valid=None, max_source_positions=10000, max_target_positions=10000, max_tokens=11000, max_tokens_valid=11000, max_update=500000, maximize_best_checkpoint_metric=False, memory_efficient_fp16=False, min_loss_scale=0.0001, min_lr=1e-09, no_dec_token_positional_embeddings=False, no_enc_token_positional_embeddings=False, no_epoch_checkpoints=False, no_last_checkpoints=False, no_progress_bar=False, no_save=False, no_save_optimizer_state=False, num_workers=0, optimizer='adam', optimizer_overrides='{}', raw_text=False, relu_dropout=0.0, required_batch_size_multiple=8, reset_dataloader=False, reset_lr_scheduler=False, reset_meters=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='./distill_model_from_scratch_1024_xlm', save_interval=1, save_interval_updates=0, seed=0, self_target=False, sentence_avg=False, share_all_embeddings=True, share_decoder_input_output_embed=False, skip_invalid_size_inputs_valid_test=False, source_lang=None, target_lang=None, task='translation_self', tbmf_wrapper=False, tensorboard_logdir='', threshold_loss_scale=None, train_subset='train', update_freq=[3], upsample_primary=1, use_bmuf=False, user_dir=None, valid_subset='valid', validate_interval=1, warmup_init_lr=1e-07, warmup_updates=10000, weight_decay=0.01)
| [en] dictionary: 60192 types
| [de] dictionary: 60192 types
| data-bin/xlm_pretained-wmt14.en-de valid 3000 examples
Transformer_nonautoregressive(
(encoder): TransformerEncoder(
(embed_tokens): Embedding(60192, 1024, padding_idx=1)
(embed_positions): LearnedPositionalEmbedding(10002, 1024, padding_idx=1)
(embed_lengths): Embedding(10000, 1024)
(layers): ModuleList(
(0): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
(layer_norms): ModuleList(
(0): BertLayerNorm()
(1): BertLayerNorm()
)
)(1)(2)...(5)
)
)
)
)
(decoder): SelfTransformerDecoder(
(embed_tokens): Embedding(60192, 1024, padding_idx=1)
(embed_positions): LearnedPositionalEmbedding(10002, 1024, padding_idx=1)
(layers): ModuleList(
(0): TransformerDecoderLayer(
(self_attn): MultiheadAttention(
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(self_attn_layer_norm): BertLayerNorm()
(encoder_attn): MultiheadAttention(
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(encoder_attn_layer_norm): BertLayerNorm()
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
(final_layer_norm): BertLayerNorm()
)(1)(2)...(5)
)
)
jungokasai
Metadata
Metadata
Assignees
Labels
No labels