読者です 読者をやめる 読者になる 読者になる

なにも わからぬ

パソコンとプログラミング関係をメモっていきたい

pythonでlru_cacheでメモ化?

例えばフィボナッチ数を求める関数

import sys
a = int(sys.argv[1])
def fib(n):
    return fib(n-1)+fib(n-2) if n>1 else n

print(fib(a))

こんなどう見たって再起しまくるのを回すと

$ time python test.py 40
102334155

real    1m39.927s
user    1m39.884s
sys     0m0.012s

とアホみたいに時間がかかるわけで、こういうのはdp[]を用意してメモ化するわけだけど、頭にlru_cacheのデコレータを付けるだけで勝手にメモ化してくれるとか。

import sys
from functools import lru_cache
a = int(sys.argv[1])
@lru_cache(None)
def fib(n):
    return fib(n-1)+fib(n-2) if n>1 else n

print(fib(a))

結果は

$ time python test.py 40
102334155

real    0m0.105s
user    0m0.100s
sys     0m0.004s

うぉーすげー!ところが上の例だとfib(333)以上を求めようとするとmaximum recursion depth exceeded in comparisonが出てしまう。普通にdpテーブルを作ると

import sys
a = int(sys.argv[1])

dp = [None for i in range(a+1)]
def fib(n):
    if dp[n] == None:
        dp[n] = fib(n-1)+fib(n-2) if n>1 else n
    return dp[n]

print(fib(a))
$ time python test.py 998
166027476624520970495418004728977018349480511983848280623585530919185737177011702010655101855958986051040947369
18879278462233015981029522997836311232618760539199036765399799926731433239718860373345088375054249

real    0m0.085s
user    0m0.060s
sys     0m0.020s

998まで出る(999以上はエラー)。こっちのが早いっぽいし、内部でどのようにメモ化してるのかわからないけどあんまり上手いことやってくれてるわけではなさそうなので、普通に自分でメモ化したほうが良さそうな気がする。