Python:基于新函数参数调用缓存函数结果
Python:基于新函数参数调用缓存函数结果
我对缓存和记忆化的概念还比较新。我已经在这里、这里和这里阅读了一些相关讨论和资源,但是还没有完全理解。
假设我有一个类中的两个成员函数(如下所示的简化示例)。假设第一个函数total
计算代价较高。第二个函数subtotal
计算代价较小,但是它使用第一个函数的返回值,因此由于这个原因也变得计算代价很高,因为它当前需要重新调用total
以获取其返回结果。
我想缓存第一个函数的结果,并将其用作第二个函数的输入,如果输入y
到subtotal
与近期调用total
的输入x
共享。也就是说:
- 如果调用subtotal(),其中
y
等于先前调用total
的输入x
的值,那么使用缓存的结果而不是重新调用total
。 - 否则,只需使用
x = y
调用total()
。
示例:
class MyObject(object): def __init__(self, a, b): self.a, self.b = a, b def total(self, x): return (self.a + self.b) * x # some time-expensive calculation def subtotal(self, y, z): return self.total(x=y) + z # Don't want to have to re-run total() here # IF y == x from a recent call of total(), # otherwise, call total().
谢谢大家的回复,阅读它们并了解底层发生了什么非常有帮助。正如@Tadhg McDonald-Jensen所说,我在这里似乎不需要比@functools.lru_cache
更多的东西。(我在Python 3.5中。)关于@unutbu的评论,我没有从使用@lru_cache
装饰total()中获得任何错误。让我纠正我自己的例子,我将保留它供其他初学者使用:
from functools import lru_cache from datetime import datetime as dt class MyObject(object): def __init__(self, a, b): self.a, self.b = a, b @lru_cache(maxsize=None) def total(self, x): lst = [] for i in range(int(1e7)): val = self.a + self.b + x # time-expensive loop lst.append(val) return np.array(lst) def subtotal(self, y, z): return self.total(x=y) + z # if y==x from a previous call of # total(), used cached result. myobj = MyObject(1, 2) # Call total() with x=20 a = dt.now() myobj.total(x=20) b = dt.now() c = (b - a).total_seconds() # Call subtotal() with y=21 a2 = dt.now() myobj.subtotal(y=21, z=1) b2 = dt.now() c2 = (b2 - a2).total_seconds() # Call subtotal() with y=20 - should take substantially less time # with x=20 used in previous call of total(). a3 = dt.now() myobj.subtotal(y=20, z=1) b3 = dt.now() c3 = (b3 - a3).total_seconds() print('c: {}, c2: {}, c3: {}'.format(c, c2, c3)) c: 2.469753, c2: 2.355764, c3: 0.016998
使用Python3.2或更新版本,你可以使用functools.lru_cache
。
如果你直接在total
上使用functools.lru_cache
修饰符,那么lru_cache
将基于self
和x
参数的值缓存total
的返回值。由于lru_cache
的内部字典存储对self
的引用,直接在类方法上应用@lru_cache将创建一个对self
的循环引用,使得该类的实例无法解除引用(因此会造成内存泄漏)。
这里有一个解决方法,可以让你在类方法上使用lru_cache
-- 它会基于除第一个参数self
以外的所有参数缓存结果,并使用weakref避免循环引用问题:
import functools import weakref def memoized_method(*lru_args, **lru_kwargs): """ https://stackoverflow.com/a/33672499/190597 (orly) """ def decorator(func): @functools.wraps(func) def wrapped_func(self, *args, **kwargs): # We're storing the wrapped method inside the instance. If we had # a strong reference to self the instance would never die. self_weak = weakref.ref(self) @functools.wraps(func) @functools.lru_cache(*lru_args, **lru_kwargs) def cached_method(*args, **kwargs): return func(self_weak(), *args, **kwargs) setattr(self, func.__name__, cached_method) return cached_method(*args, **kwargs) return wrapped_func return decorator class MyObject(object): def __init__(self, a, b): self.a, self.b = a, b @memoized_method() def total(self, x): print('Calling total (x={})'.format(x)) return (self.a + self.b) * x def subtotal(self, y, z): return self.total(x=y) + z mobj = MyObject(1,2) mobj.subtotal(10, 20) mobj.subtotal(10, 30)
只打印了一次:
Calling total (x=10)
或者你也可以使用字典自己实现缓存功能:
class MyObject(object): def __init__(self, a, b): self.a, self.b = a, b self._total = dict() def total(self, x): print('Calling total (x={})'.format(x)) self._total[x] = t = (self.a + self.b) * x return t def subtotal(self, y, z): t = self._total[y] if y in self._total else self.total(y) return t + z mobj = MyObject(1,2) mobj.subtotal(10, 20) mobj.subtotal(10, 30)
lru_cache
相对于基于字典的缓存的一个优点是它是线程安全的。lru_cache
也有一个maxsize
参数,可以帮助防止内存使用无限增长(例如,由于长时间运行的进程多次使用不同的x
值调用total
)。