PyTorch Tidbits: Measuring Streams and Times

Simple utilities for statistics collection

Stefan Schroedl
Towards Data Science

--

A typical PyTorch training loop contains code to keep track of training and testing metrics; they help us monitor progress and draw learning curves.

Let’s say for a classification task, we are using binary cross entropy as training loss, and we are also interested in accuracy. After each epoch, we measure the same on a held out validation data set. We want to write periodic progress information to the console as well as to Tensorboard (or any other of your favorite dashboard tools).

Note the pattern of recording a number of successive values, then computing statistics over it. While this snippet is simple enough, there is some amount of repetition, and logging starts to obscure the main algorithmic core. the more metrics and statistics (say, standard deviations in addition to the means) we want to track.

This was already recognized in the ImageNet demo code, which contains an auxiliary class for this purpose:

Pretty straightforward to use: We create an AverageMeter instance m, repeatedly call m.update(loss), and at the end of an epoch retrieve the average as m.avg.

Admittedly, this is a rather modest class; nevertheless, tracking statistics like this is ubiquitous when writing machine learning applications in Python. Therefore, it might be worthwhile to have a closer look at this concept, as it lends itself to multiple extensions.

Aggregation Functions

Surely you have noticed in the first listing that storing each single loss value is a bit naive; instead, we only need to incrementally update the sum and number of elements. Most data-crunching languages such as SQL or PigLatin typically call this a decomposable aggregation function; it only records a constant number of sufficient statistics, and updates in a commutative and associative fashion. Common aggregation functions are minimum, maximum, and standard deviation (for that we require the sum of squared values). It is easy to come up with other useful metrics: The most recent value, the percentage of zero or very small values, absolute minima and maxima. Our update function becomes:

The average and standard deviation can be extracted with the following simple properties:

If we don’t limit ourselves to constant space and time requirement, we can of course implement more complex aggregation functions, either as the exact or one of the many known approximation algorithms for k-most frequent, k-smallest, and unique elements.

Syntactic Sugar

As a shortcut, the incremental add operator += can serve as an alias for update. Addition means updating a copy. The standard __len__ function defaults to count. It is convenient to retrieve values as a dictionary, and to have a suitable string representation.

Aggregation with other Meters, NumPy Arrays, and PyTorch Tensors

Two Meters can aggregate metrics in two concurrent threads, and we want to summarize them together. The extension of the update function is straightforward:

So far, we have assumed our values are native Python numbers. But sometimes we want to look at statistics over PyTorch Tensors (such as model activations, weights, and their gradients). We define the semantics of updating a Meter with an array to be the same as updating it with each individual element.

Recall that PyTorch Tensors are convertible into NumPy ndarrays by sharing the underlying memory; so when we want both, one can be reduced to the other:

We want to provide the user with one update function, and not to have to remember to choose update_scalar, update_tensor, or update_meter. Unfortunately, unlike strongly-typed languages such as C++, Python doesn’t allow parameter overloading, but we can move the common functionality to an abstract base class and create a dispatch:

Meter Dictionaries

When tracking more than one metrics, the code can be simplified by using a dictionary whose keys and values are the names and aggregators of metrics, respectively. The operations for aggregating two MeterDicts, or a MeterDict and a Meter, are straightforward.

When we were growing up as software engineers, the horrors of global variables have been drilled into our heads. But sometimes that makes us forget that global variables can indeed be useful in a few cases; they can save us from passing around the same or non-essential arguments in functions, or from repeatedly spelling out long class reference chains. Existing examples in Python are logging and the random number generator. In our implementation of the training loop, we started using member variables of our Engine class for tracking statistics, but it turned out to be too cumbersome — e.g., worker threads in the PyTorch DataLoader don’t have direct access to the Engine, so how can we share metrics easily? We argue that it can be useful to share a singleton, global MeterDict for metrics tracking across the application.

Formatting

For human-readable output, numbers should always be appropriately formatted, e.g., by rounding to significant digits. Moreover, log files can contain a lot of output; we can search for things we are interested in using tools like grep, but for this too work it helps using a uniform format and keeping related information in the same line. We define a formatter class that registers one or several Meters, and whose string representation is the desired logging output.

Putting everything together, our initial, naive training loop can now be written as follows:

Timing

For diagnostic and monitoring purposes, it is useful to log execution times to Tensorboard, by default. Typically, such measurements are averaged over a number of invocations, which suggests aggregating using a Meter. It can be instructive to look at both the wall clock time and the CPU time, so our timing functions will provide both (distinguished by a suffix). Sometimes we want to time code sections, which call for a context manager (class Timing); for complete function calls, we would prefer a decorator, @timed. All these functions allow an optional MeterDict argument, but they default to the global MeterDict.

Code

The complete code for this article is available here; I hope you find it useful in your next python ML project!

--

--

Head of Machine Learning @ Atomwise — Deep Learning for Better Medicines, Faster. Formerly Amazon, Yahoo, DaimlerChrysler.