学会这个思路,你写递归时再也不用担心慢了

/ python / 没有评论 / 1586浏览

之前我在学 Python 的时候,第一次觉得它慢是执行一个递归函数,来求斐波那契数列,计算第 40 个数就需要 37 秒,同样的逻辑使用 java,则不到 1 秒就执行完毕。以下是在 IPython 环境下的运行耗时:

In [5]: def fib(n):
   ...:     if n == 0:
   ...:         return 0
   ...:     elif n == 1:
   ...:         return 1
   ...:     else:
   ...:         return fib(n-1) + fib(n-2)
   ...:

In [6]: %%time
   ...: fib(40)
   ...:
   ...:
CPU times: user 37.2 s, sys: 54.7 ms, total: 37.2 s
Wall time: 37.3 s
Out[6]: 102334155

而 Java 则非常只需要 0.383 秒:

public class MainClass {
    public static long fibonacci(long number) {
        if (number == 0)
            return 0;
        else if (number == 1)
            return 1;
        else
            return fibonacci(number - 1) + fibonacci(number - 2);
        }
        public static void main(String[] args) {
            long startTime=System.currentTimeMillis();
            System.out.println(fibonacci(40));
            long endTime =System.currentTimeMillis();
            System.out.println("耗时 "+ (endTime-startTime) + " ms");
    }
}

执行结果如下:

➜  ~ javac MainClass.java
➜  ~ java MainClass
102334155
耗时 383 ms

当时我觉得非常气馁,这么好用的 Python 竟然这么慢,还要不要坚持学下去?当然是要的,不能因噎废食,每个语言都有优点和缺点,我们要集中精力学习并发挥他们的长处,试想一下,你的编程生涯中有多少情况是需要这种手写大规模计算的代码的? 此外,虽然 Python 慢,但 Python 足够灵活,有很多方法可以进行优化,今天就分享一种利用缓存的优化方法。学完后再也不怕递归了。

方法就是使用 lru_cache,很简单,show you the code:

In [8]: from functools import lru_cache

In [9]: @lru_cache(100)
   ...: def fib(n):
   ...:     if n == 0:
   ...:         return 0
   ...:     elif n == 1:
   ...:         return 1
   ...:     else:
   ...:         return fib(n-1) + fib(n-2)
   ...:

In [10]: %%time
    ...: fib(40)
    ...:
    ...:
CPU times: user 25 µs, sys: 0 ns, total: 25 µs
Wall time: 27.9 µs
Out[10]: 102334155

你看,这次只用了 27.9 微秒,也就是 0.0279 毫秒,耗时是 java 的 1/13,是不是非常牛逼!官方文档是这样描述 lru_cache 的功能的:

一个为函数提供缓存功能的装饰器,缓存 maxsize 组传入参数,在下次以相同参数调用时直接返回上一次的结果。用以节约高开销或I/O函数的调用时间。由于使用了字典存储缓存,所以该函数的固定参数和关键字参数必须是可哈希的。不同模式的参数可能被视为不同从而产生多个缓存项,例如, f(a=1, b=2) 和 f(b=2, a=1) 因其参数顺序不同,可能会被缓存两次。

根据官方的解释,我们可以试着自己编写一个类似 lru_cache 的装饰器 my_cache 来实现同样的效果。

def my_cache(func):
    cache = {}
    def helper(x):
        if x not in cache:
            cache[x] = func(x)
        return cache[x]
    return helper

然后执行一下看耗时:

In [11]: def my_cache(func):
    ...:     cache = {}
    ...:     def helper(x):
    ...:         if x not in cache:
    ...:             cache[x] = func(x)
    ...:         return cache[x]
    ...:     return helper
    ...:

In [12]: @my_cache
    ...: def fib(n):
    ...:     if n == 0:
    ...:         return 0
    ...:     elif n == 1:
    ...:         return 1
    ...:     else:
    ...:         return fib(n-1) + fib(n-2)
    ...:

In [13]: %%time
    ...: fib(40)
    ...:
    ...:
CPU times: user 49 µs, sys: 42 µs, total: 91 µs
Wall time: 94.9 µs
Out[13]: 102334155

可以看出,即使是自己编写的 cache 也只用了 94.9 微秒,依然比 java 快。 本文的重点不是哪个语言快,而是这种缓存的思路可以大大提升程序的运行速度。缓存是一种用空间换取时间的思想,递归调用存在多次调用同一函数的情况,把每一次的调用结果使用缓存来存下来,下次调用是直接返回,可以大大提升程序的运行速度。

空间换时间这一种思路在现实生活中也非常实用,比如开车绕远路躲避拥堵可以更快到达目的地,为了赶工增加人力资源,为了更高效的运维把常用的命令牢记在脑海中,或编写批处理脚本等。

还记得之前吴军老师在谷歌方法论中提到过一个面试题,如何统计一个数字的的二进制数有多少个 1 ,请你试着从空间换时间的角度思考下如何更快的统计出来?

如果你问这么无聊的问题有意义吗?那我猜测你一定不太喜欢数学。这类问题其实是对具体问题的一种抽象,比如计算机只认识二进制的 0 和 1,这两个 0 和 1 经过运算和转换,却能表达整改世界。你也许认为人工智能非常高大上,而在我眼里,不过是 if、else、循环的组合罢了。因此不要忽视此类看似没有意义的问题,仔细思考并试着回答,可以训练我们的计算机思维。

回到题目,大多数人最先想到的就是直接数一下有多少个 1,这个方法可以得到结果,但肯定不是最优的,一个数有多少个二进制位,就需要数多少次,一个 32 个二进制的整数,就要数 32 次。

def count_bit(num :int) ->int :
    count = 0
    print(bin(num))
    while num > 0 :
        if num & 1 == 1: #说明该位为 1
            count += 1
        num = num >> 1  #循环移动
    return count

稍微高明一些做法通常会是这样:

熟悉二进制的话就会知道,任何一个二进制数都可以转成 2 的 n 次方的和,这样,一个二进制数有 n 个 1 ,只需要判断 n 次就知道有多少个 1 而不是全部位数都判断一次。举个例子: 8 是 2 的 3 次方,那 8(0b1000) 就只有一个 1,而 7(0b0111) 是 2 的 2 次方 加上 2 的 1 次方 加上 2 的 0 次方,共加了 3 次,因此有 3 个 1,依次类推。

二进制数有个特点,可以将上述思路代码化:a 与比它小 1 的数 a - 1 进行与(&) 运算,可以将 a 最右边的 1 变成 0,比如 5 = 0b101 ,4 = 0b100 ,5 & 4 = 0b101 & 0b100 = 0b100 = 4(相当于 5 的二进制数右边的 1 变成了 0,计数一次)4 继续循环,直到变成 0b00,共循环两次,因此 5 有 2 个 1。

def count_bit2(num :int) ->int :
    print(bin(num))
    count = 0
    while num:
        num = num & num - 1
        count += 1
    return count

更聪明的回答者可以将此问题推广到任意进制数的判断。

还有一个更简单高效的答案,就是查表法,利用空间换取时间。如果要统计一个数的二进制数有多少个 1,直接先算好放在一张缓存表里,需要时直接去表里查就得到了结果,这样的查询时间复杂度为 O(1), 效率比上述第二种与算法的方式还要快。比如 cache = {103:5},那么 直接 cache[103] 就得出结果 5,只需要查找一次。

但是问题来了,一个 32 位的计算机可以表示的整数有 2 的 32 次方个,每个整数假如是 4 字节,如果要把这些数都存在表里,至少需要 16 GB 的内存空间,如果是 64 位,则需要的内存不小于 67108864 TB,那么查表法是不是就不行了?

当然不是,我们可以只保留 16 位整数的缓存表,只需要 256 KB 左右的内存空间,然后将 32 位或 64 位的整数拆成每 16 位一组,这样 32 位的只需要查 2 次,64 位的只需要查 4 次。甚至可以只保留 8 位的,这样 32 位的只需要查 4 次,64 位的只需要查 8 次。代码如下:

def count_bit3(num: int, bit_type: str = "32") -> int:
    print(bin(num))
    ##初始化缓存表
    cache_16 = [0] * 256
    for i in range(256):
        cache_16[i]=(i & 1) + cache_16[i // 2]
    count = 0

    ##将一个数转为字节数组
    byte_nums = []
    if bit_type == "32":
        byte_nums = [num >> i & 0xFF for i in (24, 16, 8, 0)]
    elif bit_type == "64":
        byte_nums = [num >> i & 0xFF for i in (56, 48, 40, 32, 24, 16, 8, 0)]

    for byte in byte_nums:
        count += cache_16[byte]
    return count

假如不考虑内存够不够用,使用 32 位的缓存表会比 16 位的快吗?,从理论上上看,32 位的缓存表查询次数更少,应该更快,实际上,计算机的 cpu 和内存之间还有一个高速缓存,高速缓存的空间非常小,通常只有几兆,计算机往往需要把内存先往高速缓存中搬运,然后做相应的处理,缓存太大,搬运工作就做的越多,因此并不是缓存表越大越快。