4

What is the intuitive idea behind automatic differentiation?

If I have a program which computes $f(x, y)=x^2+yx$, which steps lead to the program which computes the derivative $df/dx$ of f?

double f(double x, double y)
{
    double res = x*x;
    double res2 = x*y;
    double res3 = res + res2
    return res3;
}

If I read the corresponding Wikipedia entry correctly I need nothing besides the chain rule, but I don't understand how this can be done in practice.

My understanding so far is that I use a table of the derivatives of all involved basic functions, such as:

  • $dx/dx=1$
  • $dy/dx=0$
  • $d(x*y)/dx=y$
  • $d(x+y)=dx/dx+0$
  • $d(x^2)/dx=2x$

How do I get the final function from this?

I would expect something like this:

double f(double x, double y, double* dx)
{
    *dx = 0;
    double res = x*x;
    ? *dx = 2*x ?
    double res2 = x*y;
    ? *dx = *dx ?
    double res3 = res + res2
    ? *dx = ??? ?

    return res3;
}

Edit: I have found this video "Beautiful Differentiation" of a talk at the International Conference on Functional Programming (ICFP) 2009, by Marcom Wallace which explains the topic quite well.

1 Answers1

6

The function is given as a "decorated" tree. Your sample function has a root node decorated "+", with one child representing $x^2$ and the other $yx$. The first child has root node "^", with children "$x$" and "2", both leaves. The second child has root node "*" with children "$y$" and "$x$".

In principle, automatic differentiation just applies the rules of differentiation recursively. In your case, we start with $\partial f/\partial x = \partial(x^2)/\partial x + \partial (yx)/\partial x$. To compute $\partial (yx)/\partial x$, we use the product rule: $$ \frac{\partial(yx)}{\partial x} = y \frac{\partial(x)}{\partial x} + \frac{\partial(y)}{\partial x} x. $$ For variables we know that $\partial x/\partial x = 1$ while $\partial y/\partial x = 0$, and simplification (one of the basic building blocks in any computer algebra system) gives $\partial(yx)/\partial x = y$. It is similar for $x^2$, though here the rule is more complicated: $(f^g)' = f^g (gf'/f+g'\log f)$.

In practice, there are probably some optimizations. Using pattern matching techniques (that is, by having specific rules) we can, for example, recognize $x^2$ for what it is, and directly calculate $\partial(x^2)/\partial x = 2x$. Similarly, if there is a subtree not depending on $x$, we could directly conclude that its derivative with respect to $x$ is zero.

Yuval Filmus
  • 280,205
  • 27
  • 317
  • 514