Recent linear state-space model papers like Mamba often use matrix exponential to discretize the system. They initialize the system in a continuous-time regime, and discretize it to run it like a vanilla RNN / CNN.
Consider the simplest continuous-time autonomous system $d\mathbf{h(t)}/dt = A\mathbf{h(t)}$ (assuming the time constant $\tau$ = 1). $A$ is a matrix, and $\mathbf{h}$ is a vector. The discretized version of this system can be represented as $\mathbf{h}_t = exp(A)\mathbf{h}_{t-1}$ (again, assuming the discretization window $\Delta$ = 1 for simplicity). Output for loss calculation is just $\mathbf{y}_t = f(C\mathbf{h}_t)$ where f is an activation function. It can just be a linear function.
Here, I'm facing difficulty in mathematically determining how to calculate $\partial L / \partial A_{ij}$. The matrix exponential complicates the process for me. One might approximate $exp(A) \approx I + A$ but I suspect this approximation may not accurately capture the loss landscape. Does anyone have suggestions or references that could help clarify this calculation? I am not very familiar with tensor calculus so the blame is on me.
EDIT 1
$L$ is a loss function defined as $g(\mathbf{y}^*, \mathbf{y}_t)$. $\mathbf{y^*}$ is an optimal target. I would love to know how to correctly update the weight $A_{ij}$ regarding this loss function. Therefore, $\partial L / \partial A_{ij} = \frac{\partial L}{\partial \mathbf{y}_t} * \frac{\partial \mathbf{y}_t}{\partial \mathbf{h}_t} * \frac{\partial \mathbf{h}_t}{\partial A_{ij}}$ by chain rule.
I'm having trouble in calculating $\frac{\partial \mathbf{h}_t}{\partial A_{ij}}$ term.
EDIT 2
I realized my question can be boiled down to this:
Given $B = exp(A)$, can we express an element $B_{kl}$ with an elementary operation of A’s elements $A_{ij}$?
As my previous question is not well-structured, I may open a new question regarding an element of matrix exponential.
EDIT 3
My question is boiled down at this link. Any further update on this question will be done at the linked post.
Also, I realized my original question did not do a very good job on indexing elements of matrix partial derivatives. Apologies on that.