I am struggling to wrap my head around this compiler technique, so let's say here's my factorial function
def factorial(value: int) -> int:
if value == 0:
return 1
else:
return factorial(value-1) * value
It is recursive, but not TCO friendly yet, so, as the theory goes, the first thing to try here is translate it to CPS:
def factorial_cont(value: int, cont: typing.Callable[[int], T]) -> T:
if value == 0:
return cont(1)
else:
return factorial_cont(value-1, lambda result: cont(value * result))
Now, as the function is tail call recursive, I can do the usual trick with the while loop:
def factorial_while(value: int, cont: typing.Callable[[int], T]) -> T:
current_cont = cont
current_value = value
while True:
if current_value == 0:
return current_cont(1)
else:
current_cont = lambda result: current_cont(current_value * result)
# note: in actual python that would look like
# current_cont = lambda result, c=current_cont, v=current_value: c(v * result)
current_value = current_value - 1
This current_cont thing effectively becomes a huge composition chain, in haskell terms for the value == 3 that would be let resulting_cont = ((initial_cont . (3*)) . (2*)) . (1*), where initial_cont is safe to default to id, and surely enough resulting_cont value == value!.
But I also know the trick with "accumulator" value:
def factorial_acc(value: int, acc: int = 1) -> int:
current_acc = acc
current_value = value
while True:
if current_value == 1:
return current_acc
else:
current_acc = current_acc * current_value
current_value = current_value - 1
which looks pretty much identical to the CPS version after the introduction of while loop.
The question is, how exactly do I massage the continuation let resulting_cont = ((initial_cont . (3*)) . (2*)) . (1*) into the form resembling accumulator version?