xmodaler.checkpoint

class xmodaler.checkpoint.PeriodicEpochCheckpointer(checkpointer: Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model')[source]

Bases: PeriodicCheckpointer

__init__(checkpointer: Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model') None
Parameters:
  • checkpointer – the checkpointer object used to save checkpoints.

  • period (int) – the period to save checkpoint.

  • max_iter (int) – maximum number of iterations. When it is reached, a checkpoint named “{file_prefix}_final” will be saved.

  • max_to_keep (int) – maximum number of most current checkpoints to keep, previous checkpoints will be deleted

  • file_prefix (str) – the prefix of checkpoint’s filename

save(name: str, **kwargs: Any) None

Same argument as Checkpointer.save(). Use this method to manually save checkpoints outside the schedule.

Parameters:
  • name (str) – file name.

  • kwargs (Any) – extra data to save, same as in Checkpointer.save().

step(iteration: int, epoch: int, **kwargs: Any) None[source]

Perform the appropriate action at the given iteration.

Parameters:
  • iteration (int) – the current iteration, ranged in [0, max_iter-1].

  • kwargs (Any) – extra data to save, same as in Checkpointer.save().

class xmodaler.checkpoint.XmodalerCheckpointer(model, save_dir='', *, save_to_disk=None, **checkpointables)[source]

Bases: Checkpointer

Same as Checkpointer, but is able to handle models in xmodaler model zoo, and apply conversions for legacy models.

__init__(model, save_dir='', *, save_to_disk=None, **checkpointables)[source]
Parameters:
  • model (nn.Module) – model.

  • save_dir (str) – a directory to save and find checkpoints.

  • save_to_disk (bool) – if True, save checkpoint to disk, otherwise disable saving for this checkpointer.

  • checkpointables (object) – any checkpointable objects, i.e., objects that have the state_dict() and load_state_dict() method. For example, it can be used like Checkpointer(model, “dir”, optimizer=optimizer).

_convert_ndarray_to_tensor(state_dict: Dict[str, Any]) None

In-place convert all numpy arrays in the state_dict to torch tensor. :param state_dict: a state-dict to be loaded to the model.

Will be modified.

_load_file(filename)[source]

Load a checkpoint file. Can be overwritten by subclasses to support different formats.

Parameters:

f (str) – a locally mounted file path.

Returns:

with keys “model” and optionally others that are saved by

the checkpointer dict[“model”] must be a dict which maps strings to torch.Tensor or numpy arrays.

Return type:

dict

_load_model(checkpoint)[source]

Load weights from a checkpoint.

Parameters:

checkpoint (Any) – checkpoint contains the weights.

Returns:

NamedTuple with missing_keys, unexpected_keys,

and incorrect_shapes fields: * missing_keys is a list of str containing the missing keys * unexpected_keys is a list of str containing the unexpected keys * incorrect_shapes is a list of (key, shape in checkpoint, shape in model)

This is just like the return value of torch.nn.Module.load_state_dict(), but with extra support for incorrect_shapes.

_log_incompatible_keys(incompatible: _IncompatibleKeys) None[source]

Log information about the incompatible keys returned by _load_model.

add_checkpointable(key: str, checkpointable: Any) None

Add checkpointable object for this checkpointer to track.

Parameters:
  • key (str) – the key used to save the object

  • checkpointable – any object with state_dict() and load_state_dict() method

get_all_checkpoint_files() List[str]
Returns:

All available checkpoint files (.pth files) in target

directory.

Return type:

list

get_checkpoint_file() str
Returns:

The latest checkpoint file in target directory.

Return type:

str

has_checkpoint() bool
Returns:

whether a checkpoint exists in the target directory.

Return type:

bool

load(path: str, checkpointables: Optional[List[str]] = None) Dict[str, Any]

Load from the given checkpoint.

Parameters:
  • path (str) – path or url to the checkpoint. If empty, will not load anything.

  • checkpointables (list) – List of checkpointable names to load. If not specified (None), will load all the possible checkpointables.

Returns:

extra data loaded from the checkpoint that has not been processed. For example, those saved with save(**extra_data)().

Return type:

dict

resume_or_load(path: str, *, resume: bool = True) Dict[str, Any]

If resume is True, this method attempts to resume from the last checkpoint, if exists. Otherwise, load checkpoint from the given path. This is useful when restarting an interrupted training job.

Parameters:
  • path (str) – path to the checkpoint.

  • resume (bool) – if True, resume from the last checkpoint if it exists and load the model together with all the checkpointables. Otherwise only load the model without loading any checkpointables.

Returns:

same as load().

save(name: str, **kwargs: Any) None

Dump model and checkpointables to a file.

Parameters:
  • name (str) – name of the file.

  • kwargs (dict) – extra arbitrary data to save.

tag_last_checkpoint(last_filename_basename: str) None

Tag the last checkpoint.

Parameters:

last_filename_basename (str) – the basename of the last filename.