Source code for xmodaler.checkpoint.xmodaler_checkpoint

"""
From original at https://github.com/facebookresearch/detectron2/blob/master/detectron2/checkpoint/detection_checkpoint.py
Original copyright of Facebook code below, modifications by Yehao Li, Copyright 2021.
"""
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import os
import pickle
import torch
from typing import Any
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer, _IncompatibleKeys
from fvcore.common.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from torch.nn.parallel import DistributedDataParallel

import xmodaler.utils.comm as comm
from xmodaler.utils.env import TORCH_VERSION
from xmodaler.utils.file_io import PathManager

from .c2_model_loading import align_and_update_state_dicts

[docs]class PeriodicEpochCheckpointer(PeriodicCheckpointer):
[docs] def step(self, iteration: int, epoch: int, **kwargs: Any) -> None: """ Perform the appropriate action at the given iteration. Args: iteration (int): the current iteration, ranged in [0, max_iter-1]. kwargs (Any): extra data to save, same as in :meth:`Checkpointer.save`. """ iteration = int(iteration) epoch = int(epoch) additional_state = {"iteration": iteration} additional_state.update(kwargs) if (iteration + 1) % self.period == 0: self.checkpointer.save( "{}_Epoch_{:05d}_Iter_{:07d}".format(self.file_prefix, epoch, iteration), **additional_state ) if self.max_to_keep is not None: self.recent_checkpoints.append(self.checkpointer.get_checkpoint_file()) # pyre-fixme[58]: `>` is not supported for operand types `int` and # `Optional[int]`. if len(self.recent_checkpoints) > self.max_to_keep: file_to_delete = self.recent_checkpoints.pop(0) if self.path_manager.exists( file_to_delete ) and not file_to_delete.endswith(f"{self.file_prefix}_final.pth"): self.path_manager.rm(file_to_delete) if self.max_iter is not None: # pyre-fixme[58] if iteration >= self.max_iter - 1: self.checkpointer.save(f"{self.file_prefix}_final", **additional_state)
[docs]class XmodalerCheckpointer(Checkpointer): """ Same as :class:`Checkpointer`, but is able to handle models in xmodaler model zoo, and apply conversions for legacy models. """
[docs] def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): is_main_process = comm.is_main_process() super().__init__( model, save_dir, save_to_disk=is_main_process if save_to_disk is None else save_to_disk, **checkpointables, ) self.path_manager = PathManager
[docs] def _load_file(self, filename): if filename.endswith(".pkl"): with PathManager.open(filename, "rb") as f: data = pickle.load(f, encoding="latin1") if "model" in data and "__author__" in data: # file is in Detectron2 model zoo format self.logger.info("Reading a file from '{}'".format(data["__author__"])) return data else: # assume file is from Caffe2 / Detectron1 model zoo if "blobs" in data: # Detection models have "blobs", but ImageNet models don't data = data["blobs"] data = {k: v for k, v in data.items() if not k.endswith("_momentum")} return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} loaded = super()._load_file(filename) # load native pth checkpoint if "model" not in loaded: loaded = {"model": loaded} return loaded
[docs] def _load_model(self, checkpoint): if checkpoint.get("matching_heuristics", False): self._convert_ndarray_to_tensor(checkpoint["model"]) # convert weights by name-matching heuristics model_state_dict = self.model.state_dict() align_and_update_state_dicts( model_state_dict, checkpoint["model"], c2_conversion=checkpoint.get("__author__", None) == "Caffe2", ) checkpoint["model"] = model_state_dict # for non-caffe2 models, use standard ways to load it incompatible = super()._load_model(checkpoint) if incompatible is None: # support older versions of fvcore return None model_buffers = dict(self.model.named_buffers(recurse=False)) for k in ["pixel_mean", "pixel_std"]: # Ignore missing key message about pixel_mean/std. # Though they may be missing in old checkpoints, they will be correctly # initialized from config anyway. if k in model_buffers: try: incompatible.missing_keys.remove(k) except ValueError: pass return incompatible
[docs] def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None: """ Log information about the incompatible keys returned by ``_load_model``. """ for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes: self.logger.warning( "Skip loading parameter '{}' to the model due to incompatible " "shapes: {} in the checkpoint but {} in the " "model! You might want to double check if this is expected.".format( k, shape_checkpoint, shape_model ) ) if incompatible.missing_keys: self.logger.info(get_missing_parameters_message(incompatible.missing_keys)) if incompatible.unexpected_keys: self.logger.info( get_unexpected_parameters_message(incompatible.unexpected_keys) )