A Hitchhiker’s Guide to Mixture Density Networks

Given a vector x of inputs (product attributes, customer, …, you name it again), we wish to predict y (price, website visits, …).

More precisely: We aim to obtain the probability of y given x: p(y|x).

If we assume a Gaussian distribution of the real valued target data (as we usually do, when we minimize the squared error), then p(y|x) takes the well-known form of:In practical applications, we minimize the squared error term ( μ(x, Θ)− y )² of the output of the linear function μ given x, its parameters Θ, and the target value y for all pairs of (x, y) in some dataset ????.

The learned function essentially “spits” out the conditional mean of the Gaussian distribution μ(x, Θ) given the data and parameters.

It throws away the std.

deviation and normalization constant, which do not depend on Θ.

By doing so, the model imposes multiple important assumptions, which can be very limiting in practice:The data distribution is Gaussian.

To quote [1]:“Practical machine learning problems can often have significantly non-Gaussian distributions” (p.


The output distribution is unimodal.

Thereby, we cannot account for the situation that x can produce multiple valid answers, which a multimodal distribution can capture (as in the headphones example above).

The std.

deviation σ of the noise distribution is assumed to be constant and thus must not depend on x (Homoscedasticity, Isotropic Covariance Matrix).

Again, this is not always the case in the real world.

The function μ(x, Θ) is linear, i.


μ(x, Θ)= x×w+b, where Θ={w, b}.

Linear models are widely considered to be more interpretable.

Neural networks on the other hand provide excellent predictive capabilities because they are theoretically able to model any function.

Let us consider two situations, which graphically motivate the previously outlined technical problems:(LHS): The underlying function is linear.

However, we observe two violations: Firstly, the std.

deviation of the (noise) distribution is not constant.

Secondly, the noise does depend on the input.

(RHS): Not only is the std.

deviation of the noise distribution dependent on x, but the output is additionally non-linear.

Furthermore, the output distribution is multimodal.

The simple mean is not a reasonable solution for some areas of the data (around ± 8).

Imposing the previously outlined assumptions might easily mislead, when we are about to predict outcomes which follow such complex patterns.

Background of the MDNTo account for the outlined limitations, [2] proposed to parametrize a mixture of distributions by a DNN.

Originally conceived in 1994 [1] [2], the MDN recently found a series of different applications.

For example: Apple’s Siri in iOS 11 uses MDN for speech generation [3].

Alex Graves used MDNs in combination with RNNs to generate artificial handwriting [4].

Additionally, there are multiple blog posts dedicated to the topic [5] [6] [7] [8].

Amazon Forecast offers the MDN as an algorithm for its customers [9], while [10] wrote a masterthesis about the topic.

However, we wish to establish the method for a broader audience.

For the simple reason that many modern neural network architectures can be extended to become MDNs (Transformer, LSTMs, CovNets, …).

MDNs can essentially be seen as an extension module, applicable to a broad variety of business relevant tasks.

At its very core, the MDN concept is simple, straightforward, and appealing: Combine a deep neural network (DNN) and a mixture of distributions.

The DNN provides the parameters for multiple distributions, which are then mixed by some weights.

These weights are also provided by the DNN.

The resulting (multimodal) conditional probability distribution helps us to model complex patterns found in real-world data.

We are thus better able to asses, how likely certain values of our predictions are.

Formalization of the Mixture ModelTheoretically, a Gaussian mixture is capable of modeling arbitrary probability densities [2], if it is adequately parametrized (e.


, given enough components).

Formally, the conditional probability for a mixture is defined asLet us elaborate on each parameter individually:c denotes the index of the corresponding mixture component.

There are up to C mixture components (i.


: distributions) per output, which is a user-definable hyperparameter.

⍺ denotes the mixing parameter.

Think of the mixing parameter as sliders mixing together C different audio signals at different intensities, producing a richer output.

The mixing parameter is conditioned on the input x.

????.is the corresponding distribution (audio signal) to be mixed.

The distribution can be chosen depending on the task or the application.

λ denotes the parameters of the distribution ????.

In case we denote ????.to be a Gaussian distribution, λ1 corresponds to the conditional mean μ(x) and λ2 to the conditional std.

deviation σ(x).

Distributions can have multiple parameters (e.


: Bernoulli and Chi2 have one, Gaussian and Beta have two, and a truncated Gaussian has up to four parameters).

These are the parameters, which the neural network outputs.

The formulation of the conditional probability as a mixture of distributions already solves multiple problems associated with the outlined assumptions.

First, the distribution can be arbitrary, as we are theoretically able to model every distribution as a mixture of Gaussians [2].

Second, using multiple distributions helps us to model multimodal signals.

Consider our headphone price example, which is clearly multimodal.

Third, the std.

deviation is now conditioned on the input, allowing us to account for variable std.


Even when we are just using a single Gaussian distribution, this advantage applies.

Fourth, the problem of the linearity of the function can be circumvented by choosing a non-linear model, which conditions the distribution parameters on the input.

To obtain the parameters for the mixture, a DNN is modified to output multiple parameter vectors.

We start off with a single layer DNN and a ReLU activation.

Using the hidden layer h1(x), we proceed by computing the parameters of the mixture as follows:The mixing coefficient must sum to unity: ∑ ⍺(x) = 1.

Therefore, we are using a softmax function to constrain the output.

This step is important, as the mixture of probabilities must integrate to one.

The constraints for λ1 and λ2 themselves depend on the distribution we are choosing for our model.

The only constraint we must enforce for Gaussian is, that the std.

deviation is σ(x) > 0.

This effect can be achieved in multiple ways.

For example, we could use an exponential activation as originally propsed by Bishop [1] [2].

The exponential however, can lead to numerical instability.

Alternatively, we can use a simple softplus activation, similar to the oneplus activation used in [11].

Or we employ a variant of the ELU activation with an offset.

Due to the recent surge in prominence of the ELU, we are opting for the latter.

Therefore, we end up with the following transformations:The choice of the constraints is dependent on the distributions and the data.

As always: Different constraints might perform better on different datasets.

One might even argue, that it is sensible from a business perspective to constrain μ(x) to positive values as well.

Thereby we could that negative prices are not within the realms of possibilities.

As we now specified the parameters and the conditional probability, we have everything we need to directly minimize the average negative log-likelihood (NLL) using some form of gradient descent (SGD, Adagrad, Adadelta, Adam, RMSProp, etc.


Implementation of the MDNThe code is available on Github / Colab.

Having established the basic theory for the MDN, we now show how to implement the model in Tensorflow/Keras.

We essentially need two components: a custom layer to compute the parameters and the loss function to be minimized.

For reasons of numerical stability and convenience, we will do most of the computation within Tensorflow functions.

As we pronounced the flexibility of the MDN framework earlier, we are not discussing everything, but the relevant parts for you to build your own version of it.

Defining the DNN is straightforward:Depending on the data and the application, it possibly also makes sense to impose additional activity regularization on the sigmas to prevent the std.

deviation from blowing up.

A simple L2 regularization would be a sensible option.

The code example requires a “non-negative exponential linear unit” activation function, which ensures that the sigmas are strictly greater than zero.

Tensorflow provides a very friendly way to define the required activation function.

We simply make nnelu a callable function and register it as a custom activation function in Keras.

The remaining building block is the the implementation of the loss function.

The application of Tensorflow-Probability comes in handy because we only redefine the example at the beginning of the post a little bit.

The MixtureSameFamily requires a mixture distribution and component distributions.

The former is a simple categorical distributions, which obtains the mixture weightings ⍺(x).

The latter is a normal distribution, parametrized by mean and std.


Subsequently, we just compute the log-likelihood of y and its negative average.

By reverting to tensorflow probability, we avoid numerical over-/underflows (implementing this by hand can actually be quite tricky).

Following the definition of the most important components of the MDN, only the compilation of the model remains to be done.

Application on Simulated DataTime to go back to our examples from earlier.

We train a simple MDN with two layers, 200 neurons per layer, and one Gaussian component on the linear dataset.

The MDN shows its strength: Due to conditioning the std.

deviation of the distribution on the input, the MDN can adapt to the change of the underlying data distribution.

It neatly captures the linear trend (as expected), but adapts the std.

deviation according to the increase of uncertainty present in the data (I like this graph.

It looks like a shooting star).

To get a better grasp of the results, we additionally perform a comparison of the mean negative log-likelihood of several models.

Namely, let’s look at the null-model (sample mean and sample std.

deviation), a linear model (linear conditioned mean and sample std.

deviation), a DNN (non-linear conditioned mean and sample std.

deviation), and the MDN (non-linear conditioned mean and non-linear conditioned std.


The DNN and MDN use the same parameters and training routine.

Thankfully, we can monitor the training progress of the MDN using Tensorboard.

All it takes is a callback to the fit routine.

Thus, we don’t need to bother with storing training losses separately.

And we are converging!Source: AuthorAll models are able to beat the null-model.

The remaining models perform equal in terms of the MSE because the MSE assumes that the std.

deviation of the underlying distribution is constant.

We cannot adequately capture the behavior of the data!.The NLL, which incorporates the std.

deviation, does reflect a more nuanced picture.

As the underlying function is linear, the DNN and linear model perform equal.

The MDN, however, is better able to accommodate the data distribution, resulting in the lowest NLL value.

To derive the conditional mean from the probability density of the MDN for a single datapoint, one computes:Looking at the formula explains the result: The mean does not incorporate the std.

deviation σ(x).

Just looking at the mean of the MDN throws away valuable information, which we might need in real world applications.

Having this distribution at our hands, we can compute more elaborate quantities.

For example, the Shannon entropy can serve as an indicator of how certain we are.

Or we could compute f-divergences to assess, how similar predictions are.

Let us now turn to the second — non-linear — example.

We first use the min-max scaler to transform y into a reasonable range for the DNN / MDN to speed up learning.

Not only does the MDN capture the underlying non-linearity, it also captures the multimodality of the output and the change of the std.


The data generating distribution is captured adequately.

Looking at the conditional density for x = 8, we see that the MDN produces two disjoint peaks:Source: AuthorThe capability to model these complex distributions reflects in the NLL, where the MDN achieves the best NLL.

Source: AuthorApplication on Real World DataWe started this post with the example of predicting prices.

After long technical elaborations let us return to our initial example: predicting prices.

For ease of analysis, we are using the drosophila of datasets: Boston housing.

Given some 13 independent variables, the goal is to predict the median value of owner-occupied homes in $1000’s (MDEV).

The example might not necessarily make full use of the MDNs capability to model multimodal distributions.

Nonetheless, it shows how the MDN can model uncertainty of prices.

The independent variables are transformed using the min-max scaler, while the prices are log transformed.

Source: AuthorLooking at the NLL, we observe a similar behavior as in the previous examples.

The MDN is better able to cope with the data.

Thus, while we might not have multimodality in the example, we are certainly benefitting from modeling the full conditional probability instead of just a point estimate.

Analyzing the conditional densities for different houses helps us to better make a decision.

We are fairly confident about the high price of house 18, thus, as a manager, we can set the prices accordingly.

The prediction for house 12 is very inconclusive.

It may be necessary for a human expert to directly assess the case to set the price.

House 13 and 45 do overlap with regards to their price.

It makes sense to analyze their attributes directly to see, whether they can serve as objects of interest for buyers in the same price range.

Although we do not tap the full potential of the model in this simple dataset, we still benefit from the addional capabilities.

SummaryAssessing uncertainty is a crucial aspect for modern businesses.

This blog post highlighted the theoretical reasoning, the implementation details, and some tips and tricks when using MDNs.

We demonstrated the capabilities of the MDN in simulated and practical applications.

Due to its simplicity and modularity, we expect a broad variety of applications.

For inquiries and questions, feel free to contact me.

Additional InformationThe code is available on Github / Colab.

The guide was written for Tensorflow 1.


0 and Tensorflow-Probability 0.



Literature[1] Christopher M.

Bishop, Pattern Recognition and Machine Learning (2006)[2] Christopher M.

Bishop, Mixture Density Networks (1994)[3] Siri Team, Deep Learning for Siri’s Voice: On-device Deep Mixture Density Networks for Hybrid Unit Selection Synthesis (2017)[4] Alex Graves, Generating Sequences With Recurrent Neural Networks (2014)[5] Christopher Bonnett, Mixture Density Networks with Edward, Keras and TensorFlow (2016)[6] Binghao Ng, Mixture Density Networks: Basics (2017)[7] Otoro, Mixture Density Networks with TensorFlow (2016)[8] Mike Dusenberry, Mixture Density Networks (2017)[9] Amazon, Mixture Density Networks (MDN) Recipe (2019)[10] Axel Brando, Mixture Density Networks implementation for distribution and uncertainty estimation (2017)[11] Alex Graves et al.

, Hybrid computing using a neural network with dynamic external memory (2016)DisclaimerOpinions expressed are solely my own and do not express the views or opinions of my employer.

The author assumes no responsibility or liability for any errors or omissions in the content of this site.

The information contained in this site is provided on an “as is” basis with no guarantees of completeness, accuracy, usefulness or timeliness.


. More details

Leave a Reply