xmodaler.checkpoint¶
- class xmodaler.checkpoint.PeriodicEpochCheckpointer(checkpointer: Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model')[source]¶
Bases:
PeriodicCheckpointer
- __init__(checkpointer: Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model') None ¶
- Parameters:
checkpointer – the checkpointer object used to save checkpoints.
period (int) – the period to save checkpoint.
max_iter (int) – maximum number of iterations. When it is reached, a checkpoint named “{file_prefix}_final” will be saved.
max_to_keep (int) – maximum number of most current checkpoints to keep, previous checkpoints will be deleted
file_prefix (str) – the prefix of checkpoint’s filename
- save(name: str, **kwargs: Any) None ¶
Same argument as
Checkpointer.save()
. Use this method to manually save checkpoints outside the schedule.- Parameters:
name (str) – file name.
kwargs (Any) – extra data to save, same as in
Checkpointer.save()
.
- class xmodaler.checkpoint.XmodalerCheckpointer(model, save_dir='', *, save_to_disk=None, **checkpointables)[source]¶
Bases:
Checkpointer
Same as
Checkpointer
, but is able to handle models in xmodaler model zoo, and apply conversions for legacy models.- __init__(model, save_dir='', *, save_to_disk=None, **checkpointables)[source]¶
- Parameters:
model (nn.Module) – model.
save_dir (str) – a directory to save and find checkpoints.
save_to_disk (bool) – if True, save checkpoint to disk, otherwise disable saving for this checkpointer.
checkpointables (object) – any checkpointable objects, i.e., objects that have the
state_dict()
andload_state_dict()
method. For example, it can be used like Checkpointer(model, “dir”, optimizer=optimizer).
- _convert_ndarray_to_tensor(state_dict: Dict[str, Any]) None ¶
In-place convert all numpy arrays in the state_dict to torch tensor. :param state_dict: a state-dict to be loaded to the model.
Will be modified.
- _load_file(filename)[source]¶
Load a checkpoint file. Can be overwritten by subclasses to support different formats.
- Parameters:
f (str) – a locally mounted file path.
- Returns:
- with keys “model” and optionally others that are saved by
the checkpointer dict[“model”] must be a dict which maps strings to torch.Tensor or numpy arrays.
- Return type:
dict
- _load_model(checkpoint)[source]¶
Load weights from a checkpoint.
- Parameters:
checkpoint (Any) – checkpoint contains the weights.
- Returns:
NamedTuple
withmissing_keys
,unexpected_keys
,and
incorrect_shapes
fields: * missing_keys is a list of str containing the missing keys * unexpected_keys is a list of str containing the unexpected keys * incorrect_shapes is a list of (key, shape in checkpoint, shape in model)
This is just like the return value of
torch.nn.Module.load_state_dict()
, but with extra support forincorrect_shapes
.
- _log_incompatible_keys(incompatible: _IncompatibleKeys) None [source]¶
Log information about the incompatible keys returned by
_load_model
.
- add_checkpointable(key: str, checkpointable: Any) None ¶
Add checkpointable object for this checkpointer to track.
- Parameters:
key (str) – the key used to save the object
checkpointable – any object with
state_dict()
andload_state_dict()
method
- get_all_checkpoint_files() List[str] ¶
- Returns:
- All available checkpoint files (.pth files) in target
directory.
- Return type:
list
- get_checkpoint_file() str ¶
- Returns:
The latest checkpoint file in target directory.
- Return type:
str
- has_checkpoint() bool ¶
- Returns:
whether a checkpoint exists in the target directory.
- Return type:
bool
- load(path: str, checkpointables: Optional[List[str]] = None) Dict[str, Any] ¶
Load from the given checkpoint.
- Parameters:
path (str) – path or url to the checkpoint. If empty, will not load anything.
checkpointables (list) – List of checkpointable names to load. If not specified (None), will load all the possible checkpointables.
- Returns:
extra data loaded from the checkpoint that has not been processed. For example, those saved with
save(**extra_data)()
.- Return type:
dict
- resume_or_load(path: str, *, resume: bool = True) Dict[str, Any] ¶
If resume is True, this method attempts to resume from the last checkpoint, if exists. Otherwise, load checkpoint from the given path. This is useful when restarting an interrupted training job.
- Parameters:
path (str) – path to the checkpoint.
resume (bool) – if True, resume from the last checkpoint if it exists and load the model together with all the checkpointables. Otherwise only load the model without loading any checkpointables.
- Returns:
same as
load()
.
- save(name: str, **kwargs: Any) None ¶
Dump model and checkpointables to a file.
- Parameters:
name (str) – name of the file.
kwargs (dict) – extra arbitrary data to save.
- tag_last_checkpoint(last_filename_basename: str) None ¶
Tag the last checkpoint.
- Parameters:
last_filename_basename (str) – the basename of the last filename.