zamba.pytorch.finetuning¶
Classes¶
BackboneFinetuning (BackboneFinetuning)
¶
Derived from PTL's built-in BackboneFinetuning
, but during the backbone freeze phase,
choose whether to freeze batch norm layers, even if train_bn
is True (i.e., even if we train them
during the backbone unfreeze phase).
Finetune a backbone model based on a learning rate user-defined scheduling.
When the backbone learning rate reaches the current model learning rate
and should_align
is set to True, it will align with it for the rest of the training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
unfreeze_backbone_at_epoch |
Epoch at which the backbone will be unfreezed. |
required | |
lambda_func |
Scheduling function for increasing backbone learning rate. |
required | |
backbone_initial_ratio_lr |
Used to scale down the backbone learning rate compared to rest of model |
required | |
backbone_initial_lr |
Optional, Inital learning rate for the backbone. By default, we will use current_learning / backbone_initial_ratio_lr |
required | |
should_align |
Wheter to align with current learning rate when backbone learning reaches it. |
required | |
initial_denom_lr |
When unfreezing the backbone, the intial learning rate will current_learning_rate / initial_denom_lr. |
required | |
train_bn |
Wheter to make Batch Normalization trainable. |
required | |
verbose |
Display current learning rate for model and backbone |
required | |
round |
Precision for displaying learning rate |
required |
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import BackboneFinetuning
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> trainer = Trainer(callbacks=[backbone_finetuning])
Source code in zamba/pytorch/finetuning.py
class BackboneFinetuning(pl.callbacks.finetuning.BackboneFinetuning):
r"""
Derived from PTL's built-in ``BackboneFinetuning``, but during the backbone freeze phase,
choose whether to freeze batch norm layers, even if ``train_bn`` is True (i.e., even if we train them
during the backbone unfreeze phase).
Finetune a backbone model based on a learning rate user-defined scheduling.
When the backbone learning rate reaches the current model learning rate
and ``should_align`` is set to True, it will align with it for the rest of the training.
Args:
unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed.
lambda_func: Scheduling function for increasing backbone learning rate.
backbone_initial_ratio_lr:
Used to scale down the backbone learning rate compared to rest of model
backbone_initial_lr: Optional, Inital learning rate for the backbone.
By default, we will use current_learning / backbone_initial_ratio_lr
should_align: Wheter to align with current learning rate when backbone learning
reaches it.
initial_denom_lr: When unfreezing the backbone, the intial learning rate will
current_learning_rate / initial_denom_lr.
train_bn: Wheter to make Batch Normalization trainable.
verbose: Display current learning rate for model and backbone
round: Precision for displaying learning rate
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import BackboneFinetuning
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> trainer = Trainer(callbacks=[backbone_finetuning])
"""
def __init__(
self, *args, multiplier: Optional[float] = 1, pre_train_bn: bool = False, **kwargs
):
if multiplier is not None:
kwargs["lambda_func"] = multiplier_factory(multiplier)
super().__init__(*args, **kwargs)
# choose whether to train batch norm layers prior to finetuning phase
self.pre_train_bn = pre_train_bn
def freeze_before_training(self, pl_module: "pl.LightningModule"):
self.freeze(pl_module.backbone, train_bn=self.pre_train_bn)
Attributes¶
state_key: str
inherited
property
readonly
¶
Identifier for the state of the callback.
Used to store and retrieve a callback's state from the checkpoint dictionary by
checkpoint["callbacks"][state_key]
. Implementations of a callback need to provide a unique state key if 1)
the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.
Methods¶
__init__(self, *args, *, multiplier: Optional[float] = 1, pre_train_bn: bool = False, **kwargs)
special
¶
Source code in zamba/pytorch/finetuning.py
def __init__(
self, *args, multiplier: Optional[float] = 1, pre_train_bn: bool = False, **kwargs
):
if multiplier is not None:
kwargs["lambda_func"] = multiplier_factory(multiplier)
super().__init__(*args, **kwargs)
# choose whether to train batch norm layers prior to finetuning phase
self.pre_train_bn = pre_train_bn
filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List
inherited
¶
This function is used to exclude any parameter which already exists in this optimizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimizer |
Optimizer |
Optimizer used for parameter exclusion |
required |
params |
Iterable |
Iterable of parameters used to check against the provided optimizer |
required |
Returns:
Type | Description |
---|---|
List |
List of parameters not contained in this optimizer param groups |
Source code in zamba/pytorch/finetuning.py
@staticmethod
def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List:
"""This function is used to exclude any parameter which already exists in this optimizer.
Args:
optimizer: Optimizer used for parameter exclusion
params: Iterable of parameters used to check against the provided optimizer
Returns:
List of parameters not contained in this optimizer param groups
"""
out_params = []
removed_params = []
for param in params:
if not any(torch.equal(p, param) for group in optimizer.param_groups for p in group["params"]):
out_params.append(param)
else:
removed_params.append(param)
if removed_params:
rank_zero_warn(
"The provided params to be frozen already exist within another group of this optimizer."
" Those parameters will be skipped.\n"
"HINT: Did you init your optimizer in `configure_optimizer` as such:\n"
f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ",
UserWarning,
)
return out_params
filter_params(modules: Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]], train_bn: bool = True, requires_grad: bool = True) -> Generator
inherited
¶
Yields the requires_grad
parameters of a given module or list of modules.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
modules |
Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]] |
A given module or an iterable of modules |
required |
train_bn |
bool |
Whether to train BatchNorm module |
True |
requires_grad |
bool |
Whether to create a generator for trainable or non-trainable parameters. |
True |
Returns:
Type | Description |
---|---|
Generator |
Generator |
Source code in zamba/pytorch/finetuning.py
@staticmethod
def filter_params(
modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True, requires_grad: bool = True
) -> Generator:
"""Yields the `requires_grad` parameters of a given module or list of modules.
Args:
modules: A given module or an iterable of modules
train_bn: Whether to train BatchNorm module
requires_grad: Whether to create a generator for trainable or non-trainable parameters.
Returns:
Generator
"""
modules = BaseFinetuning.flatten_modules(modules)
for mod in modules:
if isinstance(mod, _BatchNorm) and not train_bn:
continue
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in mod.parameters(recurse=False):
if param.requires_grad == requires_grad:
yield param
finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int) -> None
inherited
¶
Called when the epoch begins.
Source code in zamba/pytorch/finetuning.py
def finetune_function(
self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
) -> None:
"""Called when the epoch begins."""
if epoch == self.unfreeze_backbone_at_epoch:
current_lr = optimizer.param_groups[0]["lr"]
initial_backbone_lr = (
self.backbone_initial_lr
if self.backbone_initial_lr is not None
else current_lr * self.backbone_initial_ratio_lr
)
self.previous_backbone_lr = initial_backbone_lr
self.unfreeze_and_add_param_group(
pl_module.backbone,
optimizer,
initial_backbone_lr,
train_bn=self.train_bn,
initial_denom_lr=self.initial_denom_lr,
)
if self.verbose:
log.info(
f"Current lr: {round(current_lr, self.rounding)}, "
f"Backbone lr: {round(initial_backbone_lr, self.rounding)}"
)
elif epoch > self.unfreeze_backbone_at_epoch:
current_lr = optimizer.param_groups[0]["lr"]
next_current_backbone_lr = self.lambda_func(epoch + 1) * self.previous_backbone_lr
next_current_backbone_lr = (
current_lr
if (self.should_align and next_current_backbone_lr > current_lr)
else next_current_backbone_lr
)
optimizer.param_groups[-1]["lr"] = next_current_backbone_lr
self.previous_backbone_lr = next_current_backbone_lr
if self.verbose:
log.info(
f"Current lr: {round(current_lr, self.rounding)}, "
f"Backbone lr: {round(next_current_backbone_lr, self.rounding)}"
)
flatten_modules(modules: Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]]) -> List[torch.nn.modules.module.Module]
inherited
¶
This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
modules |
Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]] |
A given module or an iterable of modules |
required |
Returns:
Type | Description |
---|---|
List[torch.nn.modules.module.Module] |
List of modules |
Source code in zamba/pytorch/finetuning.py
@staticmethod
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
"""This function is used to flatten a module or an iterable of modules into a list of its leaf modules
(modules with no children) and parent modules that have parameters directly themselves.
Args:
modules: A given module or an iterable of modules
Returns:
List of modules
"""
if isinstance(modules, ModuleDict):
modules = modules.values()
if isinstance(modules, Iterable):
_modules = []
for m in modules:
_modules.extend(BaseFinetuning.flatten_modules(m))
else:
_modules = modules.modules()
# Capture all leaf modules as well as parent modules that have parameters directly themsleves
return [m for m in _modules if not list(m.children()) or m._parameters]
freeze(modules: Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]], train_bn: bool = True) -> None
inherited
¶
Freezes the parameters of the provided modules.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
modules |
Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]] |
A given module or an iterable of modules |
required |
train_bn |
bool |
If True, leave the BatchNorm layers in training mode |
True |
Returns:
Type | Description |
---|---|
None |
None |
Source code in zamba/pytorch/finetuning.py
@staticmethod
def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None:
"""Freezes the parameters of the provided modules.
Args:
modules: A given module or an iterable of modules
train_bn: If True, leave the BatchNorm layers in training mode
Returns:
None
"""
modules = BaseFinetuning.flatten_modules(modules)
for mod in modules:
if isinstance(mod, _BatchNorm) and train_bn:
BaseFinetuning.make_trainable(mod)
else:
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in mod.parameters(recurse=False):
param.requires_grad = False
freeze_before_training(self, pl_module: pl.LightningModule)
¶
Override to add your freeze logic.
Source code in zamba/pytorch/finetuning.py
def freeze_before_training(self, pl_module: "pl.LightningModule"):
self.freeze(pl_module.backbone, train_bn=self.pre_train_bn)
make_trainable(modules: Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]]) -> None
inherited
¶
Unfreezes the parameters of the provided modules.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
modules |
Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]] |
A given module or an iterable of modules |
required |
Source code in zamba/pytorch/finetuning.py
@staticmethod
def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> None:
"""Unfreezes the parameters of the provided modules.
Args:
modules: A given module or an iterable of modules
"""
modules = BaseFinetuning.flatten_modules(modules)
for module in modules:
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in module.parameters(recurse=False):
param.requires_grad = True
on_after_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called after loss.backward()
and before optimizers are stepped.
Source code in zamba/pytorch/finetuning.py
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called after ``loss.backward()`` and before optimizers are stepped."""
pass
on_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the training batch ends.
Source code in zamba/pytorch/finetuning.py
def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the training batch ends."""
pass
on_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the training batch begins.
Source code in zamba/pytorch/finetuning.py
def on_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the training batch begins."""
pass
on_before_accelerator_backend_setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule)
inherited
¶
Called before accelerator is being setup.
Source code in zamba/pytorch/finetuning.py
def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
self.freeze_before_training(pl_module)
on_before_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule, loss: Tensor) -> None
inherited
¶
Called before loss.backward()
.
Source code in zamba/pytorch/finetuning.py
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor) -> None:
"""Called before ``loss.backward()``."""
pass
on_before_optimizer_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Optimizer, opt_idx: int) -> None
inherited
¶
Called before optimizer.step()
.
Source code in zamba/pytorch/finetuning.py
def on_before_optimizer_step(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer, opt_idx: int
) -> None:
"""Called before ``optimizer.step()``."""
pass
on_before_zero_grad(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Optimizer) -> None
inherited
¶
Called before optimizer.zero_grad()
.
Source code in zamba/pytorch/finetuning.py
def on_before_zero_grad(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None:
"""Called before ``optimizer.zero_grad()``."""
pass
on_configure_sharded_model(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called before configure sharded model.
Source code in zamba/pytorch/finetuning.py
def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called before configure sharded model."""
on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when either of train/val/test epoch ends.
Source code in zamba/pytorch/finetuning.py
def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when either of train/val/test epoch ends."""
pass
on_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when either of train/val/test epoch begins.
Source code in zamba/pytorch/finetuning.py
def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when either of train/val/test epoch begins."""
pass
on_exception(self, trainer: pl.Trainer, pl_module: pl.LightningModule, exception: BaseException) -> None
inherited
¶
Called when any trainer execution is interrupted by an exception.
Source code in zamba/pytorch/finetuning.py
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
"""Called when any trainer execution is interrupted by an exception."""
pass
on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when fit ends.
Source code in zamba/pytorch/finetuning.py
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when fit ends."""
pass
on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Source code in zamba/pytorch/finetuning.py
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""
Raises:
MisconfigurationException:
If LightningModule has no nn.Module `backbone` attribute.
"""
if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module):
return super().on_fit_start(trainer, pl_module)
raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")
on_init_end(self, trainer: pl.Trainer) -> None
inherited
¶
Called when the trainer initialization ends, model has not yet been set.
Source code in zamba/pytorch/finetuning.py
def on_init_end(self, trainer: "pl.Trainer") -> None:
"""Called when the trainer initialization ends, model has not yet been set."""
pass
on_init_start(self, trainer: pl.Trainer) -> None
inherited
¶
Called when the trainer initialization begins, model has not yet been set.
Source code in zamba/pytorch/finetuning.py
def on_init_start(self, trainer: "pl.Trainer") -> None:
"""Called when the trainer initialization begins, model has not yet been set."""
pass
on_keyboard_interrupt(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
.. deprecated:: v1.5
This callback hook was deprecated in v1.5 in favor of on_exception
and will be removed in v1.7.
Called when any trainer execution is interrupted by KeyboardInterrupt.
Source code in zamba/pytorch/finetuning.py
def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
r"""
.. deprecated:: v1.5
This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7.
Called when any trainer execution is interrupted by KeyboardInterrupt.
"""
pass
on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, callback_state: Dict[int, List[Dict[str, Any]]]) -> None
inherited
¶
Called when loading a model checkpoint, use to reload state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trainer |
pl.Trainer |
the current :class: |
required |
pl_module |
pl.LightningModule |
the current :class: |
required |
callback_state |
Dict[int, List[Dict[str, Any]]] |
the callback state returned by |
required |
!!! note
The on_load_checkpoint
won't be called with an undefined state.
If your on_load_checkpoint
hook behavior doesn't rely on a state,
you will still need to override on_save_checkpoint
to return a dummy state
.
Source code in zamba/pytorch/finetuning.py
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[int, List[Dict[str, Any]]]
) -> None:
self.previous_backbone_lr = callback_state["previous_backbone_lr"]
super().on_load_checkpoint(trainer, pl_module, callback_state["internal_optimizer_metadata"])
on_predict_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None
inherited
¶
Called when the predict batch ends.
Source code in zamba/pytorch/finetuning.py
def on_predict_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
"""Called when the predict batch ends."""
pass
on_predict_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int) -> None
inherited
¶
Called when the predict batch begins.
Source code in zamba/pytorch/finetuning.py
def on_predict_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""Called when the predict batch begins."""
pass
on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when predict ends.
Source code in zamba/pytorch/finetuning.py
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when predict ends."""
pass
on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: List[Any]) -> None
inherited
¶
Called when the predict epoch ends.
Source code in zamba/pytorch/finetuning.py
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: List[Any]) -> None:
"""Called when the predict epoch ends."""
pass
on_predict_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the predict epoch begins.
Source code in zamba/pytorch/finetuning.py
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the predict epoch begins."""
pass
on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the predict begins.
Source code in zamba/pytorch/finetuning.py
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the predict begins."""
pass
on_pretrain_routine_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the pretrain routine ends.
Source code in zamba/pytorch/finetuning.py
def on_pretrain_routine_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the pretrain routine ends."""
pass
on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the pretrain routine begins.
Source code in zamba/pytorch/finetuning.py
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the pretrain routine begins."""
pass
on_sanity_check_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the validation sanity check ends.
Source code in zamba/pytorch/finetuning.py
def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the validation sanity check ends."""
pass
on_sanity_check_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the validation sanity check starts.
Source code in zamba/pytorch/finetuning.py
def on_sanity_check_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the validation sanity check starts."""
pass
on_save_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]) -> Dict[str, Any]
inherited
¶
Called when saving a model checkpoint, use to persist state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trainer |
pl.Trainer |
the current :class: |
required |
pl_module |
pl.LightningModule |
the current :class: |
required |
checkpoint |
Dict[str, Any] |
the checkpoint dictionary that will be saved. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The callback state. |
Source code in zamba/pytorch/finetuning.py
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> Dict[str, Any]:
return {
"internal_optimizer_metadata": self._internal_optimizer_metadata,
"previous_backbone_lr": self.previous_backbone_lr,
}
on_test_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Union[torch.Tensor, Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int) -> None
inherited
¶
Called when the test batch ends.
Source code in zamba/pytorch/finetuning.py
def on_test_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Optional[STEP_OUTPUT],
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
"""Called when the test batch ends."""
pass
on_test_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int) -> None
inherited
¶
Called when the test batch begins.
Source code in zamba/pytorch/finetuning.py
def on_test_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""Called when the test batch begins."""
pass
on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the test ends.
Source code in zamba/pytorch/finetuning.py
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the test ends."""
pass
on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the test epoch ends.
Source code in zamba/pytorch/finetuning.py
def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the test epoch ends."""
pass
on_test_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the test epoch begins.
Source code in zamba/pytorch/finetuning.py
def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the test epoch begins."""
pass
on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the test begins.
Source code in zamba/pytorch/finetuning.py
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the test begins."""
pass
on_train_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Union[torch.Tensor, Dict[str, Any]], batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None
inherited
¶
Called when the train batch ends.
Source code in zamba/pytorch/finetuning.py
def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
"""Called when the train batch ends."""
pass
on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None
inherited
¶
Called when the train batch begins.
Source code in zamba/pytorch/finetuning.py
def on_train_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
"""Called when the train batch begins."""
pass
on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the train ends.
Source code in zamba/pytorch/finetuning.py
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the train ends."""
pass
on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, either:
- Implement
training_epoch_end
in theLightningModule
and access outputs via the module OR - Cache data across train batch hooks inside the callback implementation to post-process in this hook.
Source code in zamba/pytorch/finetuning.py
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the train epoch ends.
To access all batch outputs at the end of the epoch, either:
1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR
2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.
"""
pass
on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the epoch begins.
Source code in zamba/pytorch/finetuning.py
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the epoch begins."""
# import is here to avoid circular imports
from pytorch_lightning.loops.utilities import _get_active_optimizers
for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies):
num_param_groups = len(optimizer.param_groups)
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
current_param_groups = optimizer.param_groups
self._store(pl_module, opt_idx, num_param_groups, current_param_groups)
on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the train begins.
Source code in zamba/pytorch/finetuning.py
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the train begins."""
pass
on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Union[torch.Tensor, Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int) -> None
inherited
¶
Called when the validation batch ends.
Source code in zamba/pytorch/finetuning.py
def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Optional[STEP_OUTPUT],
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
"""Called when the validation batch ends."""
pass
on_validation_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int) -> None
inherited
¶
Called when the validation batch begins.
Source code in zamba/pytorch/finetuning.py
def on_validation_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""Called when the validation batch begins."""
pass
on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the validation loop ends.
Source code in zamba/pytorch/finetuning.py
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the validation loop ends."""
pass
on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the val epoch ends.
Source code in zamba/pytorch/finetuning.py
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the val epoch ends."""
pass
on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the val epoch begins.
Source code in zamba/pytorch/finetuning.py
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the val epoch begins."""
pass
on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
inherited
¶
Called when the validation loop begins.
Source code in zamba/pytorch/finetuning.py
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the validation loop begins."""
pass
setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None
inherited
¶
Called when fit, validate, test, predict, or tune begins.
Source code in zamba/pytorch/finetuning.py
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
"""Called when fit, validate, test, predict, or tune begins."""
pass
teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None
inherited
¶
Called when fit, validate, test, predict, or tune ends.
Source code in zamba/pytorch/finetuning.py
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
"""Called when fit, validate, test, predict, or tune ends."""
pass
unfreeze_and_add_param_group(modules: Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]], optimizer: Optimizer, lr: Optional[float] = None, initial_denom_lr: float = 10.0, train_bn: bool = True) -> None
inherited
¶
Unfreezes a module and adds its parameters to an optimizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
modules |
Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]] |
A module or iterable of modules to unfreeze. Their parameters will be added to an optimizer as a new param group. |
required |
optimizer |
Optimizer |
The provided optimizer will receive new parameters and will add them to
|
required |
lr |
Optional[float] |
Learning rate for the new param group. |
None |
initial_denom_lr |
float |
If no lr is provided, the learning from the first param group will be used
and divided by |
10.0 |
train_bn |
bool |
Whether to train the BatchNormalization layers. |
True |
Source code in zamba/pytorch/finetuning.py
@staticmethod
def unfreeze_and_add_param_group(
modules: Union[Module, Iterable[Union[Module, Iterable]]],
optimizer: Optimizer,
lr: Optional[float] = None,
initial_denom_lr: float = 10.0,
train_bn: bool = True,
) -> None:
"""Unfreezes a module and adds its parameters to an optimizer.
Args:
modules: A module or iterable of modules to unfreeze.
Their parameters will be added to an optimizer as a new param group.
optimizer: The provided optimizer will receive new parameters and will add them to
`add_param_group`
lr: Learning rate for the new param group.
initial_denom_lr: If no lr is provided, the learning from the first param group will be used
and divided by `initial_denom_lr`.
train_bn: Whether to train the BatchNormalization layers.
"""
BaseFinetuning.make_trainable(modules)
params_lr = optimizer.param_groups[0]["lr"] if lr is None else float(lr)
denom_lr = initial_denom_lr if lr is None else 1.0
params = BaseFinetuning.filter_params(modules, train_bn=train_bn, requires_grad=True)
params = BaseFinetuning.filter_on_optimizer(optimizer, params)
if params:
optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr})
Functions¶
multiplier_factory(rate: float)
¶
Returns a function that returns a constant value for use in computing a constant learning rate multiplier.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rate |
float |
Constant multiplier. |
required |
Source code in zamba/pytorch/finetuning.py
def multiplier_factory(rate: float):
"""Returns a function that returns a constant value for use in computing a constant learning
rate multiplier.
Args:
rate (float): Constant multiplier.
"""
def multiplier(*args, **kwargs):
return rate
return multiplier