Week 8: Notes

iterables and sequences

Let's briefly review the concept of iterables and sequences in Python. An object is iterable if you can loop over it with the 'for' statement. An object is a sequence if you can access its elements by integer index, i.e. using the syntax s[i]. All sequences are iterable, but not all iterables are sequences.

We've now seen these kinds of sequences: lists, tuples, string, and ranges.

We've also seen these kinds of iterables which are not sequences: sys.stdin, file objects returned by open(), sets, and dictionaries.

Note that if you iterate over a dictionary you get its keys:

d = { 'red' : 1, 'green' : 2, 'blue' : 3 }
for x in d:
    print(x)

produces the output

red
green
blue

As we saw last week, a dictionary has methods keys(), values(), and items() that produce the keys, values, and key-value pairs in the dictionary. These methods all return iterable objects.

list comprehensions

A list comprehension is an expression that loops over a sequence and collects a series of computed values into a list. List comprehensions are powerful and convenient.

For example, consider this loop that builds a list of all perfect squares from 1 to 400:

squares = []
for i in range(1, 21):
    squares.append(i * i)

We may replace the three lines above by a single line with a list comprehension:

squares = [i * i for i in range(1, 21)]

In general, a list comprehension may have the form

[ <expression> for <var> in <sequence> ]

A comprehension of this form will loop over the given <sequence>. On each loop iteration, it sets <var> to the value of an element of the sequence, then evaluates the given <expression>. All results are collected into a list.

Here are some more examples of list comprehensions. This comprehension builds a list of numbers from 1 to 20 and their squares:

>>> [(i, i * i) for i in range(1, 11)]
[(1, 1), (2, 4), (3, 9), (4, 16), (5, 25), (6, 36), (7, 49), (8, 64), (9, 81), (10, 100)]

We may add 1 to each element in a list:

>>> l = [2, 5, 7, 10]
>>> [i + 1 for i in l]
[3, 6, 8, 11]

Let's write a program that will read a single input line containing a number of integers, separated by spaces. The program will print the sum of the integers. Using a list comprehension, we can write this in a single line:

print(sum([int(w) for w in input().split()]))

Consider Project Euler's Problem 6 (find the difference between the sum of the squares of the first one hundred natural numbers and the square of the sum). Here's a solution using a list comprehension:

sum_of_squares = sum([x * x for x in range(1, 101)])
square_of_sum = sum(range(1, 101)) ** 2
answer = square_of_sum - sum_of_squares

As another example, here's a file animals_en_cz containing English and Czech animal names:

bear medvěd
bird pták
cat kočka
cow kráva
dog pes
goat kozel
horse kůň
mouse myš
pig prase
sheep ovce

Using a list comprehension, we may read it into a dictionary in a single line of code:

with open('animals_en_cz') as f:
    d = dict([line.split() for line in f])

This works because the split() method splits each line above into a 2-element list such as ['bear', 'medvěd']. We can pass a list of these lists to the dict() constructor, which interprets each 2-element list as a key-value pair.

Finally, here's a function that builds a nested list representing a matrix of zeroes:

def empty(rows, cols):
    return [cols * [0] for r in range(rows)]

if clauses in list comprehensions

A list comprehension may have an if clause containing an arbitrary condition. Only list elements that satisfy the condition are included in the generated list.

For example, this comprehension collects all characters in the string 'watermelon' that are in the second half of the alphabet:

>>> [c for c in 'watermelon' if c >= 'n']
['w', 't', 'r', 'o', 'n']

Here's a 1-line solution to Project Euler's Problem 1 (find the sum of all the multiples of 3 or 5 below 1000):

>>> sum([i for i in range(1000) if i % 3 == 0 or i % 5 == 0])
233168

multiple for clauses in list comprehensions

A list comprehension may have more than one 'for' clause. The 'for' clauses represent a nested loop. All values generated by the inner loop are collected into the resulting list.

For example, this comprehension generates all pairs of values (x, y), where 0 ≤ x, y < 3:

>>> [(x, y) for x in range(3) for y in range(3)]
[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]

The following comprehension is similar, but in it 'y' only iterates up to the value 'x', so it only generates pairs where y < x:

>>> [(x, y) for x in range(3) for y in range(x)]
[(1, 0), (2, 0), (2, 1)]

The comprehension above is equivalent to the following:

l = []
for x in range(3):
    for y in range(x):
        l.append( (x, y) )

Notice that when there are multiple 'for' clauses in a single comprehension:

Let's write a function to flatten a 2-dimensional matrix, i.e. return a 1-dimensional list of its elements:

def flatten(m):
    return [ x for row in m for x in row ]

For example:

>>> mat = [ [2, 4, 6],
...         [1, 3, 7],
...         [8, 9, 1] ]
>>> flatten(mat)
[2, 4, 6, 1, 3, 7, 8, 9, 1]

Let's now write a program that will read any number of lines of input, each containing any number of integers separated by whitespace. The program will print the sum of all the numbers on all the lines:

import sys

nums = [int(w) for line in sys.stdin for w in line.split()]
print(sum(nums))

It's possible to mix 'for' and 'if' clauses in a comprehension. For example, here's code to generate all pairs (x, y) where 0 ≤ x, y < 3, but skipping those with x = 1:

>>> [(x, y) for x in range(3) if x != 1 for y in range(3)]
[(0, 0), (0, 1), (0, 2), (2, 0), (2, 1), (2, 2)]

As a larger example, here's a comprehension that finds all triples of integers (a, b, c) such that a2 + b2 = c2, with 0 ≤ a < b < c ≤ 20:

>>> [(a, b, c) for c in range(21)
               for b in range(c)
               for a in range(b)
               if a * a + b * b == c * c]
[(3, 4, 5), (6, 8, 10), (5, 12, 13), (9, 12, 15), (8, 15, 17), (12, 16, 20)]

nested list comprehensions

It's possible to nest one list comprehension inside another. Unlike the examples we saw above, this will produce a 2-dimensional result, i.e. a list of lists.

For example, suppose that we'd like to generate an N x N matrix, where each element is the sum of its row and column numbers (where rows and columns are numbered from 0). If N is 4, then the matrix will look like this:

0 1 2 3
1 2 3 4
2 3 4 5
3 4 5 6

Here is a function that takes n as a parameter and produces this matrix:

def num_matrix(n):
    return [ [ i + j for j in range(n) ] for i in range(n) ]

Notice that in this nested comprehension, the for loop on the right represents the outer loop. The inner comprehension will be evaluated once on each iteration of that loop.

Similarly, we may write a function that produces an identity matrix of dimensions N x N:

def identity_matrix(n):
    return [ [ int(i == j) for j in range(n) ] for i in range(n) ]

Let's try it:

>>> identity_matrix(4)
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]

This works because int(True) is 1, and int(False) is 0.

As another example, here is code that will read a matrix from standard input, where each input line contains the numbers in one row of the matrix:

import sys

m = [ [ int(w) for w in line.split() ] for line in sys.stdin ]

set comprehensions

A set comprehension is similar to a list comprehension, but collects values into a set, not a list.

For example, suppose that we have a set s of integers. We'd like to add one to each integer, and collect the resulting values into a new set t. We can perform this easily using a set comprehension:

>>> s = {7, 9, 100}
>>> t = {x + 1 for x in s}
>>> t
{8, 10, 101}

Suppose that we'd like to know how many distinct integers are products of two integers from 1 to 10. We can solve this in a single line using a set comprehension:

>>> len({ a * b for a in range(1, 11) for b in range(1, 11) })
42

Note that a list comprehension would not produce the same result, since the generated list would contain duplicate elements. With a set comprehension, the duplicates are automatically eliminated.

dictionary comprehensions

A dictionary comprehension is similar to a list or set comprehension, but collects values into a dictionary. For example:

>>> { x : x * x for x in range(15) }
{0: 0, 1: 1, 2: 4, 3: 9, 4: 16, 5: 25, 6: 36, 7: 49, 8: 64, 9: 81, 10: 100, 11: 121, 12: 144, 13: 169, 14: 196}

We can easily swap the key-value pairs in a dictionary using a dictionary comprehension:

>>> d = { 'red' : 'červený', 'green' : 'zelený', 'blue' : 'modrý' }
>>> { v : k for k, v in d.items() }
{'červený': 'red', 'zelený': 'green', 'modrý': 'blue'}

iterating with enumerate()

The enumerate() function lets you iterate over a sequence and its indices simultaneously.

For example, consider this function, which returns the index of the first odd element in a list or other iterable (or -1 if no odd elements are found):

def first_odd(a):
    for i in range(len(a)):
        if a[i] % 2 == 1:
            return i

    return -1

We may rewrite the function using enumerate():

def first_odd(a):
    for i, x in enumerate(a):
        if x % 2 == 1:
            return i

    return -1

On each iteration of the loop, i receives the index of an element in a, and x is the element at that index.

enumerate() actually returns an iterable of pairs. If we like, we may convert it to a list:

>>> list(enumerate([20, 40, 50, 80, 100]))
[(0, 20), (1, 40), (2, 50), (3, 80), (4, 100)]

iterating with zip()

The handy zip() function lets you iterate over two (or more) sequences simultaneously.

For example, let's use zip() to iterate over two lists of integers:

>>> l = [2, 4, 6, 8, 10]
>>> m = [20, 40, 60, 80, 100]
>>> for x, y in zip(l, m):
...     print(f'x = {x}, y = {y}')
x = 2, y = 20
x = 4, y = 40
x = 6, y = 60
x = 8, y = 80
x = 10, y = 100

Notice that on iteration we receive a pair of values, one from each of the lists that we zipped. (Think of a zipper on a jacket that pulls together two edges as it moves upward.)

zip() actually returns an iterable of pairs, which we may collect into a list:

>>> l = [2, 4, 6, 8, 10]
>>> m = [20, 40, 60, 80, 100]
>>> list(zip(l, m))
[(2, 20), (4, 40), (6, 60), (8, 80), (10, 100)]

Note that zip() wil stop as soon as it reaches the end of any list:

>>> l = [2, 4, 6]
>>> m = [20, 40, 60, 80, 100]
>>> list(zip(l, m))
[(2, 20), (4, 40), (6, 60)]

We can use zip() to simplify some loops. For example, consider a function that takes two lists a and b, and produces a new list in which each element is the sum of two corresponding elements from a and b:

# produce a list of sums of values in a and b
def list_sum(a, b):
    assert len(a) == len(b)
    return [a[i] + b[i] for i in range(len(a))]

Instead of iterating over indices, we may use zip():

# produce a list of sums of values in a and b
def list_sum(a, b):
    assert len(a) == len(b)
    return [x + y for x, y in zip(a, b)]

In fact zip() may take any number of arguments. If we give it three lists of integers, it will return a list of triples:

>>> list(zip([1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]))
[(1, 5, 9), (2, 6, 10), (3, 7, 11), (4, 8, 12)]

As we saw above, zip() converts two lists to a list of pairs:

>>> l = list(zip([1, 2, 3, 4, 5], [6, 7, 8, 9, 10]))
>>> l
[(1, 6), (2, 7), (3, 8), (4, 9), (5, 10)]

Now, what will happen if we take these pairs and pass them as separate arguments to zip()? Let's try it:

>>> list(zip((1, 6), (2, 7), (3, 8), (4, 9), (5, 10)))
[(1, 2, 3, 4, 5), (6, 7, 8, 9, 10)]

We get back two tuples that look just like the lists that we started with! This works because zip() accepts any number of arguments. When we give it any number of pairs, it will zip together the first elements of all pairs, as well as the second elements.

And so we can use this clever trick to unzip any list of pairs. Let's write a function that will do that:

def unzip(pairs):
    return zip(*pairs)

For example:

>>> list(unzip([(1, 10), (2, 20), (3, 30), (4, 40)]))
[(1, 2, 3, 4), (10, 20, 30, 40)]