Programming 1, 2021-2
Week 14: Notes

generator functions

Python and some other languages have generator functions, which are functions that can return a iterator that produces a series of values. Each time the caller asks for the next value in the series, the function's code runs until it yields the next value to the caller.

Here is a simple generator function in Python:

def foo():
    yield 3
    yield 5
    yield 7

Let's try using it:

>>> s = foo()
>>> next(s)
3
>>> next(s)
5
>>> next(s)
7
>>> next(s)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

Each time we call the bulit-in function next(), we receive the next value from the iterator. When no more values are available, next() raises a StopIteration exception.

Alternatively, we may retrieve the values using a 'for' loop:

>>> s = foo()
>>> for x in s:
...   print(x)
3
5
7
>>>

(Internally, 'for' will actually call next() repeatedly.)

Now let's modify our generator function to print some strings:

def foo():
    print('hello')
    yield 3
    print('hello again')
    yield 5
    print('yet again')
    yield 7

And let's use it:

>>> s = foo()
>>> next(s)
hello
3
>>> next(s)
hello again
5
>>> next(s)
yet again
7
>>>

This output may come as a surprise. When we call the generator function, it returns an iterator, but none of the code in the function has run yet. Each time we request a value from the iterator, the function runs until it yields a value, and then its execution is suspended until the next value is requested. This is quite different from any other control mechanism that we have seen in Python thus far.

Note that an iterator can produce values only once. Once it has reached the end of the sequence, a subsequent attempt to read values using 'for' will produce nothing:

>>> s = foo()
>>> for x in s:
...   print(x)
hello
3
hello again
5
yet again
7
>>> for x in s:
...   print(x)
>>> 

Let's write a generator function that produces a sequence of squares 1, 4, 9, …, n2 for any given n:

# return a sequence of values 1, 4, 9, ..., n^2
def squares_to(n):
    for i in range(1, n + 1):   # 1, ..., n
        yield i * i

Let's try it:

>>> for x in squares_to(10):
...   print(x)
... 
1
4
9
16
25
36
49
64
81
100
>>>

The list() function can gather all elements of an iterator into a list:

>>> list(squares_to(10))
[1, 4, 9, 16, 25, 36, 49, 64, 81, 100]

Let's write a function that writes the initial elements in a large sequence of squares:

def abc():
    for i in squares_to(1_000_000_000_000):
        print(i)
        if i > 50:
            break

Let's run it:

>>> abc()
1
4
9
16
25
36
49
64
>>>

Notice that the function runs quickly, even though the sequence is enormous. The program has not computed all values in the sequence; instead, the generator function squares_to() generates values on demand when and only if they are requested.

In fact, we may write a version of squares_to() that generates an infinite sequence of squares:

# return an infinite sequence of squares:
# 1, 4, 9, 16, ...
def all_squares():
    i = 1
    while True:
        yield i * i
        i += 1

If we now modify abc() to call all_squares() rather than squares_to(), it will still work fine.

Infinite sequences can be useful in solving many problems. Recall Project Euler's problem 2:

By considering the terms in the Fibonacci sequence whose values do not exceed four million, find the sum of the even-valued terms.

Let's write a solution using a generator function that generates an infinite sequence of Fibonacci numbers:

# return an infinite sequence of all Fibonacci numbers:
# 1, 1, 2, 3, 5, 8, ...
def fibs():
    a = b = 1
    while True:
        yield a
        a, b = b, a + b

def euler2():
    sum = 0
    for f in fibs():
        if f > 4_000_000:
            break
        if f % 2 == 0:     # f is even
            sum += f
    return sum

Note that the built-in function map() can even work on infinite sequences. For example, we may use it to double each value in the Fibonacci sequence:

>>> s = map(twice, fibs())
>>> next(s)
2
>>> next(s)
2
>>> next(s)
4
>>> next(s)
6
>>> next(s)
10
>>>

In this example, map() returns an iterator representing an infinite sequence. In fact, map() always returns an iterator, and reads values from the input sequence only when the caller requests values from the mapped sequence. We can implement map() ourselves like this:

def map(f, seq):
    for x in seq:
        yield f(x)

making classes iterable

Recall that Python's built-in set and dictionary objects are iterable, meaning that we can iterate over them using the 'for' statement. For example:

>>> s = {11, 14, 17}
>>> for x in s:
...   print(x)
... 
17
11
14
>>>

Lists and strings are examples of sequences, which are also iterable of course.

Note the difference between an iterable object and an iterator. You can iterate as many times as you like over an iterable object. However, as we saw above, an iterator produces a series of values only once. In fact, each time that you iterate over an iterable object, Python uses a new iterator for the iteration.

Using generators, we can even make our own classes iterable in Python. If a class implements the __iter__() magic method, then Python will consider instances to be iterable, and will call __iter__() to obtain an iterator when needed. For example, here is an iterable LinkedList class:

class Node:
    def __init__(self, val, next):
        self.val = val
        self.next = next

class LinkedList:
    def __init__(self):
        self.root = None

    def prepend(self, x):
        self.root = Node(x, self.root)

    def __iter__(self):
        n = self.root
        while n != None:
            yield n.val
            n = n.next

We may use it like this, for example:

>>> l = LinkedList()
>>> l.prepend(4)
>>> l.prepend(2)
>>> l.prepend(0)
>>> for x in l:
...   print(x)
... 
0
2
4
>>>

Each time that we request a new value from the iterator, the while loop in __iter__() runs until it yields the requested value, then is suspended until the next value is requested.

As a more advanced example, consider a node class for a binary search tree:

class Node:
    def __init__(self, val, left = None, right = None):
        self.val = val
        self.left = left
        self.right = right

Let's write a generator function that can produce all values in a tree in ascending order. Naturally, our function will be recursive. In the base case when the tree is empty, we don't need to yield anything. In the recursive case, when we recurse left or right, we will get an iterator that produces all the values in the left or right subtree. We must yield all the values yielded by that iterator:

# argument n is the root of a binary tree
# yield all values in the binary tree, in order
def all_vals(n):
    if n != None:
        for x in all_vals(n.left):
            yield x
        yield n.val
        for x in all_vals(n.right):
            yield x

In fact this pattern is so common that Python has a statement 'yield from' that yields all the values produced by a given iterator. We can rewrite the code above more simply using 'yield from':

# argument n is the root of a binary tree
# yield all values in the binary tree, in order
def all_vals(n):
    if n != None:
        yield from all_vals(n.left)
        yield n.val
        yield from all_vals(n.right)

Now suppose that we have a TreeSet class representing a binary search tree. We can make the class iterable by adding an __iter__() method that calls all_vals():

class TreeSet:
    def __init__(self):
        self.root = None

    def insert(self, x):
        

    def contains(self, x):
         

    def __iter__(self):
        return all_vals(self.root)

using generators in combinatorial recursion

In the last lecture we studied how to solve various combinatorial problems using recursion. For example, we studied the 'abc problem', in which we want to generate all possible strings of N characters in which every character is 'a', 'b', and 'c'. We saw a top-down solution that looked like this:

def abc(n):
    def go(result, i):
        if i == 0:
            print(result)
        else:
            for c in 'abc':
                go(result + c, i - 1)
    go('', n)

The function go() builds up a string 'result' as it recurses. When there are no more characters to produce, it prints out the accumulated string.

We also saw a bottom-up solution:

def abc(n):
    # base case
    if n == 0:
        return ['']
        
    # recursive case
    l = []
    for c in 'abc':
        for s in abc(n - 1):
            l.append(c + s)

    return l

This solution produces a list of all possible results.

The bottom-up solution has a nice recursive structure, however it has a significant disadvantage: it may use a lot of memory if the solution set is large, since it must store all solutions in memory at once. To improve this situation, we may rewrite this bottom-up solution as a generator function that yields results one at a time:

def abc(n):
    # base case
    if n == 0:
        yield ''
    else:
        # recursive case
        for c in 'abc':
            for s in abc(n - 1):
                yield c + s

Now even if there are a huge number of solutions, we may iterate over as many as we like using only a small amount of memory, and the first solutions will be generated instantly:

>>> for x in abc(30):
...   print(x)
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
aaaaaaaaaaaaaaaaaaaaaaaaaaaaab
aaaaaaaaaaaaaaaaaaaaaaaaaaaaac
aaaaaaaaaaaaaaaaaaaaaaaaaaaaba
aaaaaaaaaaaaaaaaaaaaaaaaaaaabb
…

generator comprehensions

We have already seen that a list comprehension is equivalent to a 'for' loop that appends results to a list. In other words, the following two functions are equivalent:

def squares():
    results = []
    for x in range(100):
        results.append(x * x)
    return results

def squares2():
    return [x * x for x in range(100)]

We've also seen that rather than returning a list of values, it's sometimes useful to return an iterator that produces the values on demand:

def squares3():
    for x in range(100):
        yield x * x

As an alternative to a 'for' loop that yields values, we may instead use a generator comprehension, which looks like a list comprehension without the square brackets. The following function is equivalent to the previous one:

def squares4():
    return (x * x for x in range(100))

To put it differently, a generator comprehension is like a list comprehension, but produces an iterator that generates values on demand, rather than computing them all up front and gathering them into a list.

Consider this expression, which computes the sum of the squares of all positive integers below 10,000,000:

>>> sum([x * x for x in range(10_000_000)])
333333283333335000000

Evaluating this expression will use a fair amount of memory, since it will hold all the numbers to be added in a list. Instead, we may use a generator comprehension to achieve the same result:

>>> sum(x * x for x in range(10_000_000))
333333283333335000000

This will use much less memory, and works because sum() can take any iterable object as an argument.

Coming back to the abc problem, consider the bottom-up solution that we saw in an earlier lecture using a list comprehension:

def abc(n):
    if n == 0:
        return ['']
    
    return [c + s for c in 'abc' for s in abc(n  1)]

We may modify it to use a generator comprehension:

def abc(n):
    if n == 0:
        return ['']
    
    return (c + s for c in 'abc' for s in abc(n  1))

Now the function will compute values only as needed, and will only use a constant amount of memory:

>>> s = abc(30)
>>> next(s)
'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
>>> next(s)
'aaaaaaaaaaaaaaaaaaaaaaaaaaaaab'
>>> next(s)
'aaaaaaaaaaaaaaaaaaaaaaaaaaaaac'
>>>