2

I am training a HMM with Baum Welch for part of speech tagging. I am training the model with 79 hidden variables (part of speech tags) and 80,000 observed variables (words). I am working with log probabilities. To give you an idea, I defined the necessary arithmetic operations like so:

struct log_policy
{
    template<class T, class U>
    static auto mul(const T t1, const U t2)
    {
        return t1 + t2;
    }

    template<class T, class U>
    static auto div(const T t1, const U t2)
    {
        return t1 - t2;
    }

    template<class T, class U>
    static auto add(const T t1, const U t2)
    {
        return log(exp(t1) + exp(t2));
    }

    template<class T, class U>
    static auto sub(const T t1, const U t2)
    {
        return log(exp(t1) - exp(t2));
    }

    template<class T>
    static auto linear_scale(const T t)
    {
        return exp(t);
    }

    template<class T>
    static auto native_scale(const T t)
    {
        return log(t);
    }

    template<class T>
    constexpr static auto make_additive_neutral_element()
    {
        return -1 * std::numeric_limits<T>::infinity();
    }

    template<class T>
    constexpr static auto make_multiplicative_neutral_element()
    {
        return static_cast<T>(0);
    }

    template<class T>
    constexpr static auto make_annihilator()
    {
        return -1 * std::numeric_limits<T>::infinity();
    }
};

Anyway, besides working with log probabilities I still get numeric instability for longer training sentences. This is primarily due to the size of the set of observed variables. As the size is 80,000, each state in the HMM has an average emission probability of 1/80,000 for each word. I was wondering how to counteract this? Would reducing the number of emissions be a valid option?

To further illustrate: This is a section of a debug representation of the forward trellis after doing the forward algorithm on a training sequence. As you can see the right most fifth of the table is full of -inf, and they all appear at the same time... which I don't quite know why yet.

(The format for a cell is <hidden variable human readable name> <hidden variable internal name and state ID> <probabiliy of cell> <backpointer>)

($$(      (0)   :  -40.3203,  -1)  ($$(      (0)   :  -78.1552,  -1)  ($$(      (0)   :  -115.029,  -1)  ($$(      (0)   :  -151.865,  -1)  ($$(      (0)   :  -188.706,  -1)  ($$(      (0)   :  -225.547,  -1)  ($$(      (0)   :  -262.389,  -1)  ($$(      (0)   :  -299.23,   -1)  ($$(      (0)   :  -336.071,  -1)  ($$(      (0)   :  -372.913,  -1)  ($$(      (0)   :  -409.754,  -1)  ($$(      (0)   :  -446.596,  -1)  ($$(      (0)   :  -483.437,  -1)  ($$(      (0)   :  -520.278,  -1)  ($$(      (0)   :  -557.12,   -1)  ($$(      (0)   :  -593.961,  -1)  ($$(      (0)   :  -630.802,  -1)  ($$(      (0)   :  -667.644,  -1)  ($$(      (0)   :  -704.485,  -1)  ($$(      (0)   :  -741.955,  -1)  ($$(      (0)   :  -inf,  -1)  ($$(      (0)   :  -inf,  -1)  ($$(      (0)   :  -inf,  -1)  ($$(      (0)   :  -inf,  -1)  ($$(      (0)   :  -inf,  -1)  ($$(      (0)   :  -inf,  -1)  ($$(      (0)   :  -inf,  -1)
($$,      (1)   :  -43.2152,  -1)  ($$,      (1)   :  -77.935,   -1)  ($$,      (1)   :  -114.923,  -1)  ($$,      (1)   :  -151.76,   -1)  ($$,      (1)   :  -188.602,  -1)  ($$,      (1)   :  -225.444,  -1)  ($$,      (1)   :  -262.285,  -1)  ($$,      (1)   :  -299.127,  -1)  ($$,      (1)   :  -335.968,  -1)  ($$,      (1)   :  -372.809,  -1)  ($$,      (1)   :  -409.651,  -1)  ($$,      (1)   :  -446.492,  -1)  ($$,      (1)   :  -483.333,  -1)  ($$,      (1)   :  -520.175,  -1)  ($$,      (1)   :  -557.016,  -1)  ($$,      (1)   :  -593.857,  -1)  ($$,      (1)   :  -630.699,  -1)  ($$,      (1)   :  -667.54,   -1)  ($$,      (1)   :  -704.382,  -1)  ($$,      (1)   :  -741.667,  -1)  ($$,      (1)   :  -inf,  -1)  ($$,      (1)   :  -inf,  -1)  ($$,      (1)   :  -inf,  -1)  ($$,      (1)   :  -inf,  -1)  ($$,      (1)   :  -inf,  -1)  ($$,      (1)   :  -inf,  -1)  ($$,      (1)   :  -inf,  -1)
($$.      (2)   :  -40.3012,  -1)  ($$.      (2)   :  -77.9027,  -1)  ($$.      (2)   :  -114.861,  -1)  ($$.      (2)   :  -151.692,  -1)  ($$.      (2)   :  -188.534,  -1)  ($$.      (2)   :  -225.375,  -1)  ($$.      (2)   :  -262.217,  -1)  ($$.      (2)   :  -299.058,  -1)  ($$.      (2)   :  -335.899,  -1)  ($$.      (2)   :  -372.741,  -1)  ($$.      (2)   :  -409.582,  -1)  ($$.      (2)   :  -446.423,  -1)  ($$.      (2)   :  -483.265,  -1)  ($$.      (2)   :  -520.106,  -1)  ($$.      (2)   :  -556.948,  -1)  ($$.      (2)   :  -593.789,  -1)  ($$.      (2)   :  -630.63,   -1)  ($$.      (2)   :  -667.472,  -1)  ($$.      (2)   :  -704.313,  -1)  ($$.      (2)   :  -741.444,  -1)  ($$.      (2)   :  -inf,  -1)  ($$.      (2)   :  -inf,  -1)  ($$.      (2)   :  -inf,  -1)  ($$.      (2)   :  -inf,  -1)  ($$.      (2)   :  -inf,  -1)  ($$.      (2)   :  -inf,  -1)  ($$.      (2)   :  -inf,  -1)
($(       (3)   :  -45.7222,  -1)  ($(       (3)   :  -78.2411,  -1)  ($(       (3)   :  -115.133,  -1)  ($(       (3)   :  -151.978,  -1)  ($(       (3)   :  -188.82,   -1)  ($(       (3)   :  -225.661,  -1)  ($(       (3)   :  -262.502,  -1)  ($(       (3)   :  -299.344,  -1)  ($(       (3)   :  -336.185,  -1)  ($(       (3)   :  -373.026,  -1)  ($(       (3)   :  -409.868,  -1)  ($(       (3)   :  -446.709,  -1)  ($(       (3)   :  -483.551,  -1)  ($(       (3)   :  -520.392,  -1)  ($(       (3)   :  -557.233,  -1)  ($(       (3)   :  -594.075,  -1)  ($(       (3)   :  -630.916,  -1)  ($(       (3)   :  -667.757,  -1)  ($(       (3)   :  -704.599,  -1)  ($(       (3)   :  -742.137,  -1)  ($(       (3)   :  -inf,  -1)  ($(       (3)   :  -inf,  -1)  ($(       (3)   :  -inf,  -1)  ($(       (3)   :  -inf,  -1)  ($(       (3)   :  -inf,  -1)  ($(       (3)   :  -inf,  -1)  ($(       (3)   :  -inf,  -1)
($,       (4)   :  -43.7637,  -1)  ($,       (4)   :  -78.1276,  -1)  ($,       (4)   :  -115.012,  -1)  ($,       (4)   :  -151.848,  -1)  ($,       (4)   :  -188.689,  -1)  ($,       (4)   :  -225.53,   -1)  ($,       (4)   :  -262.371,  -1)  ($,       (4)   :  -299.213,  -1)  ($,       (4)   :  -336.054,  -1)  ($,       (4)   :  -372.896,  -1)  ($,       (4)   :  -409.737,  -1)  ($,       (4)   :  -446.578,  -1)  ($,       (4)   :  -483.42,   -1)  ($,       (4)   :  -520.261,  -1)  ($,       (4)   :  -557.102,  -1)  ($,       (4)   :  -593.944,  -1)  ($,       (4)   :  -630.785,  -1)  ($,       (4)   :  -667.626,  -1)  ($,       (4)   :  -704.468,  -1)  ($,       (4)   :  -741.732,  -1)  ($,       (4)   :  -inf,  -1)  ($,       (4)   :  -inf,  -1)  ($,       (4)   :  -inf,  -1)  ($,       (4)   :  -inf,  -1)  ($,       (4)   :  -inf,  -1)  ($,       (4)   :  -inf,  -1)  ($,       (4)   :  -inf,  -1)
($.       (5)   :  -43.5252,  -1)  ($.       (5)   :  -77.9508,  -1)  ($.       (5)   :  -114.713,  -1)  ($.       (5)   :  -151.56,   -1)  ($.       (5)   :  -188.401,  -1)  ($.       (5)   :  -225.243,  -1)  ($.       (5)   :  -262.084,  -1)  ($.       (5)   :  -298.926,  -1)  ($.       (5)   :  -335.767,  -1)  ($.       (5)   :  -372.608,  -1)  ($.       (5)   :  -409.45,   -1)  ($.       (5)   :  -446.291,  -1)  ($.       (5)   :  -483.132,  -1)  ($.       (5)   :  -519.974,  -1)  ($.       (5)   :  -556.815,  -1)  ($.       (5)   :  -593.656,  -1)  ($.       (5)   :  -630.498,  -1)  ($.       (5)   :  -667.339,  -1)  ($.       (5)   :  -704.18,   -1)  ($.       (5)   :  -741.349,  -1)  ($.       (5)   :  -inf,  -1)  ($.       (5)   :  -inf,  -1)  ($.       (5)   :  -inf,  -1)  ($.       (5)   :  -inf,  -1)  ($.       (5)   :  -inf,  -1)  ($.       (5)   :  -inf,  -1)  ($.       (5)   :  -inf,  -1)

When I reduce the number of emissions and thus increase the average emission probability the problem goes away. But I am unsure whether that is the right way to go about this as the more emissions the model is trained with the better it will perform later on.

lo tolmencre
  • 295
  • 1
  • 3
  • 8

1 Answers1

2

Avoiding zero probabilities

You probably want to be using additive smoothing when estimating probabilities from count data.

With a dictionary of 80,000 words, most of those words will be very rare: many of them might never appear anywhere in your training data, or will never appear associated with a particular part of speech. Thus, your counts for those words will be zero. That causes you to estimate an emission probability of zero, if you naively estimate the probability as a ratio of counts. However, a zero probability is unlikely to be physically meaningful (I doubt that there is truly zero probability of outputting that word; instead, it's probably some small probability that's close to zero but not exactly zero). Additive smoothing is one way to address this.

Lack of additive smoothing probably explains the -inf. You are computing with log-probabilities. A probability of zero corresponds to a log-probability of $-\infty$ (i.e., -inf). Implement additive smoothing, and your -inf's might largely go away.

Numerical stability of log probabilities

While computing with log probabilities does generally improve stability compared to computing with probabilities directly, there are some tricky aspects you must be careful of.

For instance, instead of adding probabilities by computing $\log(\exp(x)+\exp(y))$, it's usually better to compute $x + \log(1 + \exp(y-x))$ (assuming $y \ge x$; if $y <x$, swap $x,y$ before doing this computation). This avoids loss of accuracy when $x,y$ are both small and are close to each other. You might want to use a built-in library for computing the function $u \mapsto \log(1+u)$. Some languages have built-in library functions for computing $\log(\exp(x)+\exp(y))$ in a numerically-stable way: it might be called something like logsumexp.

Other relevant references: https://stackoverflow.com/q/7480996/781723, https://math.stackexchange.com/q/2189716/14578, https://stackoverflow.com/q/42355196/781723, https://stackoverflow.com/q/23630277/781723.

D.W.
  • 167,959
  • 22
  • 232
  • 500