"""Running experiments."""
import datetime
import time
from diffcp import SolverError
import numpy as np
import pandas as pd
import torch
from .benchmarks import Benchmark
from .callbacks import (
BenchmarkCallback,
EarlyStoppingException,
ProgressBarCallback,
ValidationCallback,
)
from .data import FlexibleDataLoader, RigidDataLoader
from .losses import Loss
[docs]class History:
"""A shared information database for the training process.
Attributes
----------
database : dict
Keys are different epochs and values are ``pd.DataFrame`` with detailed results.
"""
def __init__(self):
self.database = {} # dict where keys are epochs and values are lists
@property
def metrics(self):
"""Concatenate metrics over all epochs.
Returns
-------
df : pd.DataFrame
Each row represents a unique logging sample. Columns are `batch`, `current_time`, `dataloader`, `epoch`,
`lookback`, `metric`, `model`, `timestamp`, `value`.
"""
master_list = [] # over all epochs
for el in self.database.values():
master_list.extend(el)
return pd.DataFrame(master_list)
[docs] def metrics_per_epoch(self, epoch):
"""Results over a specified epoch.
Returns
-------
df : pd.DataFrame
Each row represents a unique logging sample. Columns are `batch`, `current_time`, `dataloader`, `epoch`,
`lookback`, `metric`, `model`, `timestamp`, `value`.
"""
return pd.DataFrame(self.database[epoch])
[docs] def add_entry(
self,
model=None,
metric=None,
batch=None,
epoch=None,
dataloader=None,
lookback=None,
timestamp=None,
value=np.nan,
):
"""Add entry to the internal database."""
if epoch not in self.database:
self.database[epoch] = []
self.database[epoch].append(
{
"model": model,
"metric": metric,
"value": value,
"batch": batch,
"epoch": epoch,
"dataloader": dataloader,
"lookback": lookback,
"timestamp": timestamp,
"current_time": datetime.datetime.now(),
}
)
[docs] def pretty_print(self, epoch=None):
"""Print nicely the internal database.
Parameters
----------
epoch : int or None
If epoch given, then a results only over this epoch. If epoch is `None` then print results over all epochs.
"""
if epoch is None:
df = self.metrics
else:
df = self.metrics_per_epoch(epoch)
pd.options.display.float_format = "{:,.3f}".format
print(
df.groupby(["model", "metric", "epoch", "dataloader"])["value"]
.mean()
.to_string()
)
[docs]class Run:
"""Represents one experiment.
Parameters
----------
network : nn.Module and Benchmark
Network.
loss : callable
A callable computing per sample loss. Specifically it accepts `weights`, `y` and returns ``torch.Tensor``
of shape `(n_samples,)`.
train_dataloader : FlexibleDataLoader or RigidDataLoader
DataLoader used for training.
val_dataloaders : None or dict
If None no validation is performed. If ``dict`` then keys are names and values are
instances of ``RigidDataLoader``.
metrics : None or dict
If None the only metric is the loss function. If ``dict`` then keys are names and values are
instances of ``Loss``.
benchmarks : None or dict
If None then no benchmark models used. If ``dict`` then keys are names and values are instances of
``Benchmark``.
device : torch.device or None
Device on which to perform the deep network calculations. If None then `torch.device('cpu')` used.
dtype : torch.dtype or None
Dtype to use for all torch tensors. If None then `torch.double` used.
optimizer : None or torch.optim.Optimizer
Optimizer to be used. If None then using Adam with lr=0.01.
callbacks : list
List of callbacks to be used.
Attributes
----------
metrics : dict
Keys represent metric names and values are callables. Note that it always has an element
called 'loss' representing the actual loss.
val_dataloaders : dict
Keys represent dataloaders names and values are ``RigidDataLoader`` instances. Note that if empty then no
logging is made.
models : dict
Keys represent model names and values are either `Benchmark` or `torch.nn.Module`. Note that it always
has an element called `main` representing the main network.
callbacks : list
Full list of callbacks. There are three defaults ones `BenchmarkCallback`, `ValidationCallback` and
`ProgressBarCallback`. On top of them there are the manually selected callbacks from `callbacks`.
"""
def __init__(
self,
network,
loss,
train_dataloader,
val_dataloaders=None,
metrics=None,
benchmarks=None,
device=None,
dtype=None,
optimizer=None,
callbacks=None,
):
# checks
if not isinstance(
train_dataloader, (FlexibleDataLoader, RigidDataLoader)
):
raise TypeError(
"The train_dataloader needs to be an instance of RigidDataLoader or FlexibleDataLoadeer."
)
if not isinstance(loss, Loss):
raise TypeError("The loss needs to be an instance of Loss.")
if not (
isinstance(network, torch.nn.Module)
and isinstance(network, Benchmark)
):
raise TypeError(
"The network needs to be a torch.nn.Module and Benchmark. "
)
self.network = network
self.loss = loss
self.train_dataloader = train_dataloader
# metrics
self.metrics = {"loss": loss}
if metrics is None:
pass
elif isinstance(metrics, dict):
if not all([isinstance(x, Loss) for x in metrics.values()]):
raise TypeError("All values of metrics need to be Loss.")
if "loss" in metrics:
raise ValueError(
"Cannot name a metric 'loss' - restricted for the actual loss."
)
self.metrics.update(metrics)
else:
raise TypeError(
"Invalid type of metrics: {}".format(type(metrics))
)
# metrics_dataloaders
self.val_dataloaders = {}
if val_dataloaders is None:
pass
elif isinstance(val_dataloaders, dict):
if not all(
[
isinstance(x, RigidDataLoader)
for x in val_dataloaders.values()
]
):
raise TypeError(
"All values of val_dataloaders need to be RigidDataLoader."
)
self.val_dataloaders.update(val_dataloaders)
else:
raise TypeError(
"Invalid type of val_dataloaders: {}".format(
type(val_dataloaders)
)
)
# benchmarks
self.models = {"main": network}
if benchmarks is None:
pass
elif isinstance(benchmarks, dict):
if not all(
[isinstance(x, Benchmark) for x in benchmarks.values()]
):
raise TypeError(
"All values of benchmarks need to be a Benchmark."
)
if "main" in benchmarks:
raise ValueError(
"Cannot name a benchmark 'main' - restricted for the main network."
)
self.models.update(benchmarks)
else:
raise TypeError(
"Invalid type of benchmarks: {}".format(type(benchmarks))
)
self.callbacks = [
BenchmarkCallback(),
ValidationCallback(),
ProgressBarCallback(),
] + (callbacks or [])
# Inject self into callbacks
for cb in self.callbacks:
cb.run = self
self.history = History()
self.loss = loss
self.device = device or torch.device("cpu")
self.dtype = dtype or torch.float
self.optimizer = (
torch.optim.Adam(self.network.parameters(), lr=1e-2)
if optimizer is None
else optimizer
)
self.current_epoch = -1
[docs] def launch(self, n_epochs=1):
"""Launch the training and logging loop.
Parameters
----------
n_epochs : int
Number of epochs.
"""
try:
self.network.to(device=self.device, dtype=self.dtype)
# Train begin
if self.current_epoch == -1:
self.on_train_begin(metadata={"n_epochs": n_epochs})
for _ in range(n_epochs):
self.current_epoch += 1
# Epoch begin
self.on_epoch_begin(metadata={"epoch": self.current_epoch})
for batch_ix, (
X_batch,
y_batch,
timestamps,
asset_names,
) in enumerate(self.train_dataloader):
# Batch begin
self.on_batch_begin(
metadata={
"asset_names": asset_names,
"batch": batch_ix,
"epoch": self.current_epoch,
"timestamps": timestamps,
"X_batch": X_batch,
"y_batch": y_batch,
}
)
# Get batch
X_batch, y_batch = X_batch.to(self.device).to(
self.dtype
), y_batch.to(self.device).to(self.dtype)
# Make sure network on the right device and train mode
self.network.train()
# Forward & Backward
weights = self.network(X_batch)
loss_per_sample = self.loss(weights, y_batch)
loss = loss_per_sample.mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Switch back to eval mode
self.network.eval()
# Batch end
self.on_batch_end(
metadata={
"asset_names": asset_names,
"batch": batch_ix,
"batch_loss": loss.item(),
"epoch": self.current_epoch,
"timestamps": timestamps,
"weights": weights,
"X_batch": X_batch,
"y_batch": y_batch,
}
)
# Epoch end
self.on_epoch_end(
metadata={
"epoch": self.current_epoch,
"n_epochs": n_epochs,
}
)
# Train end
self.on_train_end()
except (EarlyStoppingException, KeyboardInterrupt, SolverError) as ex:
print("Training interrupted")
time.sleep(1)
self.on_train_interrupt(
metadata={"exception": ex, "locals": locals()}
)
return self.history
[docs] def on_train_begin(self, metadata=None):
"""Take actions at the beginning of the training."""
for cb in self.callbacks:
cb.on_train_begin(metadata=metadata)
[docs] def on_train_interrupt(self, metadata=None):
"""Take actions when training interrupted."""
for cb in self.callbacks:
cb.on_train_interrupt(metadata=metadata)
[docs] def on_train_end(self, metadata=None):
"""Take actions at the end of the training."""
for cb in self.callbacks:
cb.on_train_end(metadata=metadata)
[docs] def on_epoch_begin(self, metadata=None):
"""Take actions at the beginning of an epoch."""
for cb in self.callbacks:
cb.on_epoch_begin(metadata=metadata)
[docs] def on_epoch_end(self, metadata=None):
"""Take actions at the end of an epoch."""
for cb in self.callbacks:
cb.on_epoch_end(metadata=metadata)
[docs] def on_batch_begin(self, metadata=None):
"""Take actions at the beginning of a batch."""
for cb in self.callbacks:
cb.on_batch_begin(metadata=metadata)
[docs] def on_batch_end(self, metadata=None):
"""Take actions at the end of a batch."""
for cb in self.callbacks:
cb.on_batch_end(metadata=metadata)
@property
def hparams(self):
"""Collect relevant hyperparamters specifying an experiment."""
res = {}
res.update(self.network.hparams)
res.update(self.train_dataloader.hparams)
res.update(
{
"device": str(self.device),
"dtype": str(self.dtype),
"loss": str(self.loss),
"weight_decay": self.optimizer.defaults.get(
"weight_decay", ""
),
"lr": self.optimizer.defaults.get("lr", ""),
}
)
return res