I am a little bit confused with the chain rule of matrix derivatives. For example, let
$$ f(X) := \operatorname{tr} \left(\log \left( W X W^\top + B \right) \right)^2 $$
where $\log(X)$ is the matrix logarithm of $m \times m$ symmetric positive definite (SPD) matrix $X$, $B$ is a $n \times n$ SPD matrix ($n>m$), and $W\in \mathbb{R}^{n\times m}$ is a rectangle matrix. If I use the chain rule, I should have
$$\frac{\partial f}{\partial X} = 2\log(W X W^\top + B) W^\top(W^\top X W + B)^{-1}W$$
However, the dimensions of $\log(W X W^\top + B)$ and $W^\top(W^\top X W + B)^{-1}W$ are $n \times n$ and $m \times m$ respectively. So there must be something wrong with my my derivations, but I don't know where is it. Any comments?
*** Addition ***
Let $Z=(W X W^\top + B)S$.
What if $f(X) = \text{tr}((\log(Z))^\top\log(Z))$, where $S$ is also SPD matrix. Do we have
$\frac{\partial f}{\partial X} = 2W^\top \log[(W X W^\top + B)S] (W^\top X W + B)^{-1}W$,
or
$\frac{\partial f}{\partial X} = 2W^\top(W^\top X W + B)^{-1}S^{-1} \log[(W X W^\top + B)S] SW$?
Using the notations in this post, my solution is:
Define $Z=(W X W^\top + B)S$, and $\phi=\text{tr}([\log(Z)]^2)$. Then we have
$d\phi = 2\log(Z)\cdot Z^{-\top}:dZ=2\log(Z)\cdot Z^{-\top}: WdXW^\top S = 2W^\top\log(Z)\cdot Z^{-\top}SW: dX$.
Therefore, we have
$\frac{\partial \phi}{\partial X} = 2W^\top\log(Z)\cdot Z^{-\top}SW=2W^\top \log[(W X W^\top + B)S] (W^\top X W + B)^{-1}W$.
Is it correct?