Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Failed to train the Mask-predict with larger model/hidden dimension #7

@alphadl

Description

@alphadl

Elegant work! In addition to training a transformer_base-scale model, I am still trying to train a large model, (e.g., 1024 model dim. & 4096 hidden dim), such that I can fine-tune Mask-predict with XLM.

However, when I simply change the dimension and fix other arguments, the training is failed, that is, the ppl is even becoming bigger. Can you give me some advices?

Below is my training command:

python train.py data-bin/xlm_pretained-wmt14.en-de --arch bert_transformer_seq2seq --share-all-embeddings --criterion label_smoothed_length_cross_entropy --label-smoothing 0.1 --lr 5e-4 --warmup-init-lr 1e-7 --min-lr 1e-9 --lr-scheduler inverse_sqrt --warmup-updates 10000 --optimizer adam --adam-betas '(0.9,0.999)' --adam-eps 1e-6 --task translation_self --max-tokens 11000 --weight-decay 0.01 --dropout 0.3 --encoder-layers 6 --encoder-embed-dim 1024 --decoder-layers 6 --decoder-embed-dim 1024 --encoder-attention-heads 8 --decoder-attention-heads 8 --max-source-positions 10000 --max-target-positions 10000 --max-update 300000 --seed 0 --save-dir ${model_dir} --update-freq 3 --ddp-backend=no_c10d --fp16 --keep-last-epochs 10

and the following is the log of one training step:

| epoch 012:  74%|▋| 814/1099 [24:26<08:28,  1.79s/it, loss=12.243, nll_loss=11.121, ppl=2226.58, wps=33332, ups=1, wpb=60068.756, bsz=4060.299, num_updates=12894, lr=0.000440328, gnorm=0.341, clip=0.000, oom=0.000, loss_scale=0.250, wall=23856, train_wall=20393, length_loss=6.6472] 

BTW, because I reused the XLM vocabulary list, the vocab size of larger Mask-predict is more than 60k+.

Namespace(adam_betas='(0.9,0.999)', adam_eps=1e-06, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, arch='bert_transformer_seq2seq', attention_dropout=0.0, best_checkpoint_metric='loss', bilm_add_bos=False, bilm_attention_dropout=0.0, bilm_mask_last_state=False, bilm_model_dropout=0.1, bilm_relu_dropout=0.0, bucket_cap_mb=25, clip_norm=25, cpu=False, criterion='label_smoothed_length_cross_entropy', curriculum=0, data=['data-bin/xlm_pretained-wmt14.en-de'], dataset_impl=None, ddp_backend='no_c10d', decoder_attention_heads=8, decoder_embed_dim=1024, decoder_embed_path=None, decoder_embed_scale=None, decoder_ffn_embed_dim=4096, decoder_input_dim=1024, decoder_layers=6, decoder_learned_pos=False, decoder_normalize_before=False, decoder_output_dim=1024, device_id=0, disable_validation=False, distributed_backend='nccl', distributed_init_method='tcp://localhost:10859', distributed_no_spawn=False, distributed_port=-1, distributed_rank=0, distributed_world_size=4, dropout=0.3, dynamic_length=False, embedding_only=False, encoder_attention_heads=8, encoder_embed_dim=1024, encoder_embed_path=None, encoder_embed_scale=None, encoder_ffn_embed_dim=4096, encoder_layers=6, encoder_learned_pos=False, encoder_normalize_before=False, find_unused_parameters=False, fix_batches_to_gpus=False, fp16=True, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, keep_interval_updates=-1, keep_last_epochs=10, label_smoothing=0.1, left_pad_source='True', left_pad_target='False', log_format=None, log_interval=1000, lr=[0.0005], lr_scheduler='inverse_sqrt', mask_range=False, max_epoch=0, max_sentences=None, max_sentences_valid=None, max_source_positions=10000, max_target_positions=10000, max_tokens=11000, max_tokens_valid=11000, max_update=500000, maximize_best_checkpoint_metric=False, memory_efficient_fp16=False, min_loss_scale=0.0001, min_lr=1e-09, no_dec_token_positional_embeddings=False, no_enc_token_positional_embeddings=False, no_epoch_checkpoints=False, no_last_checkpoints=False, no_progress_bar=False, no_save=False, no_save_optimizer_state=False, num_workers=0, optimizer='adam', optimizer_overrides='{}', raw_text=False, relu_dropout=0.0, required_batch_size_multiple=8, reset_dataloader=False, reset_lr_scheduler=False, reset_meters=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='./distill_model_from_scratch_1024_xlm', save_interval=1, save_interval_updates=0, seed=0, self_target=False, sentence_avg=False, share_all_embeddings=True, share_decoder_input_output_embed=False, skip_invalid_size_inputs_valid_test=False, source_lang=None, target_lang=None, task='translation_self', tbmf_wrapper=False, tensorboard_logdir='', threshold_loss_scale=None, train_subset='train', update_freq=[3], upsample_primary=1, use_bmuf=False, user_dir=None, valid_subset='valid', validate_interval=1, warmup_init_lr=1e-07, warmup_updates=10000, weight_decay=0.01)
| [en] dictionary: 60192 types
| [de] dictionary: 60192 types
| data-bin/xlm_pretained-wmt14.en-de valid 3000 examples
Transformer_nonautoregressive(
  (encoder): TransformerEncoder(
    (embed_tokens): Embedding(60192, 1024, padding_idx=1)
    (embed_positions): LearnedPositionalEmbedding(10002, 1024, padding_idx=1)
    (embed_lengths): Embedding(10000, 1024)
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (layer_norms): ModuleList(
          (0): BertLayerNorm()
          (1): BertLayerNorm()
        )
      )(1)(2)...(5)
        )
      )
    )
  )
  (decoder): SelfTransformerDecoder(
    (embed_tokens): Embedding(60192, 1024, padding_idx=1)
    (embed_positions): LearnedPositionalEmbedding(10002, 1024, padding_idx=1)
    (layers): ModuleList(
      (0): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (self_attn_layer_norm): BertLayerNorm()
        (encoder_attn): MultiheadAttention(
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (encoder_attn_layer_norm): BertLayerNorm()
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (final_layer_norm): BertLayerNorm()
      )(1)(2)...(5)
   )
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions