0%
July 22, 2022

Retrain Model when Nan Occurs

deep-learning

pytorch

Given a training function with the following interface:

class TrainingErrorMessage(TypedDict):
    curr_epoch: int
    message: Literal["nan_loss"]

def train(
    faster_rcnn: FasterRCNNResnet50FPN,
    lr,
    start_epoch,
    epochs,
    save_weight_interval=5
) -> None | TrainingErrorMessage:
    ...
    for epoch in range(epochs):
        for batch_id, data in enumerate(tqdm(data_loader)):
            ...
            total_loss = ...

            if torch.isnan(total_loss):
                return TrainingErrorMessage(message="nan_loss", curr_epoch=epoch)
    ...

Then we wrap the function train by

def train_with_nan(
    faster_rcnn,
    lr=1e-5,
    start_epoch=1,
    epoches=60,
    save_weight_interval=5
):
    continue_training = True
    restart_ep = start_epoch
    restart_for_eps = epoches
    curr_model = faster_rcnn

    while continue_training:
        result = train(
            curr_model,
            lr,
            restart_ep,
            restart_for_eps,
            save_weight_interval
        )
        if result is not None:
            message = result["message"]
            if message == "nan_loss":
                curr_epoch = result["curr_epoch"]
                if curr_epoch > (start_epoch + epoches):
                    print("stop training")
                    continue_training = False
                else:
                    continue_training = True
                    model_latest_epoch = (curr_epoch-1) - ((curr_epoch-1) % save_weight_interval)
                    restart_ep = model_latest_epoch + 1
                    restart_for_eps = epoches - (model_latest_epoch - start_epoch)
                    model_path = f"pths/model_epoch_{model_latest_epoch}.pth"
                    curr_model = FasterRCNNResnet50FPN().to(device)
                    curr_model.load_state_dict(torch.load(model_path))
                    curr_model.train()

                    print(f"Get nan loss, restart training at epoch {restart_ep} for additional {restart_for_eps} epochs" + "\x1B[0K")
                    print(f"Loading weight from {model_path}" + "\x1B[0K")
            else:
                continue_training = False
        else:
            continue_training = False