首页 > 编程笔记

Python生成器的用法

生成器类似于列表,其输出为一个线性的数据链。但生成器并不是一次将所有的数据都生成,而是仅在需要时生成一个数据。

下面的例子定义一个最简单的生成器:
>>> generator_Demo1 = (x*x for x in range(3))  # 创建一个生成器
>>> type(generator_Demo1)                      # 查看类型
<type 'generator'>
>>> generator_Demo1.next()                     # 读出一个数据
0
>>> generator_Demo1.next()                     # 读出一个数据
1
>>> generator_Demo1.next()                     # 读出一个数据
4
>>> generator_Demo1.next()                     # 读出一个数据,失败,抛出异常
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

定义生成器

生成器的第一种定义方法和列表解析类似,不过需要将 [] 换成 ()。下面是 3 种基本的定义方法。
  1. (表达式 for 变量 in 可迭代对象)
  2. (表达式 for 变量 in 可迭代对象 if 条件)
  3. (表达式1 if 条件 else 表达式2 for 变量 in 可迭代对象)

生成器的第二种定义方法是定义一个包含 yield 语句的函数。yield 语句和 return 语句类似,它们的不同之处在于 return 会完全退出函数,而 yield 会保存函数的当前状态,下次还可以接着执行。

当第一次调用该函数时,其返回 yield 语句的参数;当再次调用该函数时,其从 yield 语句的下一行代码开始执行。
>>> def hello_generator():             # 定义函数
...     l = range(3)
...     for i in l:
...         yield i*i                  # 每次在这里返回
...                                    # 函数定义结束
>>> gen_obj = hello_generator()        # 得到生成器对象
>>> gen_obj
<generator object hello_generator at 0x00000000026715E8>
>>> type(gen_obj)                      # 查看gen_obj的类型
<type 'generator'>                     # 生成器对象
>>> for ele in gen_obj:                # 依次访问其生成的元素
...     print(ele)                     # 打印元素的值
...                                    # for循环介绍
0                                      # 遍历得到的元素
1
4

接口函数

对生成器的最简单操作就是调用 next() 函数来得到下一个数据。但是生成器也有一些高级接口函数,如 close(),它可以强制结束生成器。接下来就来介绍生成器的常用接口函数。

1) next()

next() 得到下一个数据。它是最基本的生成器接口函数。

下面定义一个生成器 fib,它可以生成斐波拉契数列。斐波拉契数列是这样的数列,前两个元素值为 1,从第三个元素开始,所有元素的值等于前两个元素的和。代码如下:
>>> def fib():                # 定义斐波拉契函数
...     yield 1               # 第一个元素
...     first_ele = 1
...     yield 1               # 第二个元素
...     second_ele = 1
...     while True:           # 后面的元素
...         next_ele = first_ele + second_ele
...         first_ele, second_ele = second_ele, next_ele
...         yield next_ele
...                           # 函数定义结束
>>> gen1 = fib()              # 得到生成器
>>> type(gen1)                # 查看类型
<class 'generator'>
>>> next(gen1)                # 得到第一个元素
1
>>> next(gen1)                # 得到下一个元素,该生成器产生的元素个数是无限的
1
>>> next(gen1)                # 得到下一个元素
2
>>> next(gen1)                # 得到下一个元素
3
>>> next(gen1)                # 得到下一个元素
5
>>> next(gen1)                # 得到下一个元素
8
>>> next(gen1)                # 得到下一个元素
13
>>> next(gen1)                # 得到下一个元素
21

2) close():停止生成

close() 是对象的接口函数。调用该接口函数后,如果以后再次试图通过调用 next() 来生成数据,就会抛出 StopIteration 异常。
>>> gen_a = (x*x for x in range(100))  # 可以生成100个元素
>>> gen_a.next()                       # 生成第一个元素
0     
>>> gen_a.close()                      # 关闭,不能再生成元素了
>>> gen_a.next()                       # 后面的next()导致异常抛出
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

3) send():发送数据

前面我们生成的数据都是预先就可以知道的,如斐波拉契数列。如果希望结合生成器的状态和输入来生成某个数据,这时 send() 就能派上用场了。其基本过程如下:
  1. 创建生成器 A。
  2. 给生成器 A 发送开始命令。
  3. 给生成器 A 一个输入数据 d,生成器 A 生成一个结果 r。
  4. 跳到第三步。

下面是这样一个生成器,其接收一个整数序列,并计算到目前为止接收到的输入序列的和。例如开始输入的是 1,那么其生成 1;第二次输入的是 10,那么生成值是 10+1=11;第三次输入的是 10,那么生成值是 1+10+10=21,以此类推。
>>> def get_sum():                                     # 定义一个迭代器函数
...     sum = 0
...     while True:
...         input_val = yield sum       # 返回一个值,并且等待用户输入
...         if input_val == "quit":
...             break
...         sum = sum + input_val
...                                                                     # 函数定义结束
>>> gen1 = get_sum()                           # 构造一个生成器
>>> gen1.send(None)                                    # 开始运行,跳转到第4行
0
>>> gen1.send(1)                                       # 输入值是1,返回值是1
1
>>> gen1.send(10)                                      # 输入值是10,返回值是11
11
>>> gen1.send(10)                                      # 输入值是10,返回值是21
21
>>> gen1.send(100)                                     # 输入值是100,返回值是121
121
>>> gen1.send("quit")                          # 输入值是quit,输出是异常
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration

4) throw():抛出异常

可以用 throw() 函数直接让生成器在当前位置抛出指定的异常。

下面的例子是创建一个生成器,该生成器每次的返回值都加1,即依次生成“1,2,3,4,...”这样的序列。但如果发生了异常,那么下次重新从 1 开始,也就是如果发生了异常,那么下次还是生成“1,2,3,...”这样的序列。代码如下:
>>> def Generator_A():                         # 定义函数
...     value = 1
...     while True:
...         try:
        # 返回值,如果输入了异常,跳转到第8行,然后再次到这里
...             yield value
...             value += 1
...         except:
...             print("Got Exception")
...             value = 1
...                                                             # 函数定义结束
>>> gen_obj = Generator_A()            # 构造生成器
>>> next(gen_obj)                              # 生成数据
1
>>> next(gen_obj)                              # 生成数据
2
>>> next(gen_obj)                              # 生成数据
3
>>> gen_obj.throw(Exception, "Method throw called!")   # 输入异常
Got Exception                                   # 第8行的输出
1                                                               # 第5行的返回值
>>> next(gen_obj)                              # 生成数据
2
>>> next(gen_obj)                              # 生成数据
3
>>> next(gen_obj)                              # 生成数据
4
>>> next(gen_obj)                              # 生成数据
5

5) yield from生成器:从其他生成器生成数据

yield from 是一个语句,不是函数。我们知道列表可以嵌套,即某个列表元素可能也是一个列表,这就构成了一个树状结构。有时我们希望使用中序遍历将这个树转换成一个线性的数据结构。

例如有这样一个列表:['str1',[11,32,13],27,24,[45,[106,[89,[92]],'str2'],27]],如果要用树来表示该列表的数据结构,结果如图 1 所示。


图 1 树状结构表示的嵌套列表

现在我们希望将其变成序列:['str1',11,32,13,27,24,45,106,89,92,'str2',27],可以用下面的生成器来完成该任务。
>>> def in_order_traverse_tree(tree_obj):              # 定义一个函数
...         if not isinstance(tree_obj, list):  # 如果不是列表
...             yield tree_obj                                  # 直接返回该元素
...         else:                                               # 如果是一个列表
...             for ele in tree_obj:                            # 遍历所有的列表元素
...                 yield from in_order_traverse_tree(ele)      # 嵌套输出
...                                                             # 函数定义结束
>>> input_list = ['str1',[11, 32, 13], 27, 24, [45, [106, [89,
        [92] ],'str2' ], 27] ]
>>> gen1 = in_order_traverse_tree(input_list)  # 构造生成器
>>> next(gen1)                                 # 得到第一个元素
'str1'
>>> next(gen1)                                 # 得到下一个元素
11
>>> next(gen1)                                 # 得到下一个元素
32
>>> next(gen1)                                 # 得到下一个元素
13
>>> next(gen1)                                 # 得到下一个元素
27
>>> next(gen1)                                 # 得到下一个元素
24
>>> next(gen1)                                 # 得到下一个元素
45
>>> next(gen1)                                 # 得到下一个元素
106
>>> next(gen1)                                 # 得到下一个元素
89
>>> next(gen1)                                 # 得到下一个元素
92
>>> next(gen1)                                 # 得到下一个元素
'str2'
>>> next(gen1)                                 # 得到下一个元素
27
>>> next(gen1)                 # 没有更多元素可以输出了,抛出异常StopIteration
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration

优秀文章