Week 8: Notes

iterables and sequences

Let's briefly review the concept of iterables and sequences in Python. An object is iterable if you can loop over it with the 'for' statement. An object is a sequence if you can access its elements by integer index, i.e. using the syntax s[i]. All sequences are iterable, but not all iterables are sequences.

We've now seen these kinds of sequences: lists, tuples, string, and ranges.

We've also seen these kinds of iterables which are not sequences: sys.stdin, file objects returned by open(), sets, and dictionaries.

Note that if you iterate over a dictionary you get its keys:

d = { 'red' : 1, 'green' : 2, 'blue' : 3 }
for x in d:
    print(x)

produces the output

red
green
blue

As we saw last week, a dictionary has methods keys(), values(), and items() that produce the keys, values, and key-value pairs in the dictionary. These methods all return iterable objects.

list comprehensions

A list comprehension is an expression that loops over a sequence and collects a series of computed values into a list. List comprehensions are powerful and convenient.

For example, consider this loop that builds a list of all perfect squares from 1 to 400:

squares = []
for i in range(1, 21):
    squares.append(i * i)

We may replace the three lines above by a single line with a list comprehension:

squares = [i * i for i in range(1, 21)]

In general, a list comprehension may have the form

[ <expression> for <var> in <sequence> ]

A comprehension of this form will loop over the given <sequence>. On each loop iteration, it sets <var> to the value of an element of the sequence, then evaluates the given <expression>. All results are collected into a list.

Here are some more examples of list comprehensions. This comprehension builds a list of numbers from 1 to 20 and their squares:

>>> [(i, i * i) for i in range(1, 11)]
[(1, 1), (2, 4), (3, 9), (4, 16), (5, 25), (6, 36), (7, 49), (8, 64), (9, 81), (10, 100)]

We may add 1 to each element in a list:

>>> l = [2, 5, 7, 10]
>>> [i + 1 for i in l]
[3, 6, 8, 11]

Let's write a program that will read a single input line containing a number of integers, separated by spaces. The program will print the sum of the integers. Using a list comprehension, we can write this in a single line:

print(sum([int(w) for w in input().split()]))

Consider Project Euler's Problem 6 (find the difference between the sum of the squares of the first one hundred natural numbers and the square of the sum). Here's a solution using a list comprehension:

sum_of_squares = sum([x * x for x in range(1, 101)])
square_of_sum = sum(range(1, 101)) ** 2
answer = square_of_sum - sum_of_squares

As another example, here's a file animals_en_cz containing English and Czech animal names:

bear medvěd
bird pták
cat kočka
cow kráva
dog pes
goat kozel
horse kůň
mouse myš
pig prase
sheep ovce

Using a list comprehension, we may read it into a dictionary in a single line of code:

with open('animals_en_cz') as f:
    d = dict([line.split() for line in f])

This works because the split() method splits each line above into a 2-element list such as ['bear', 'medvěd']. We can pass a list of these lists to the dict() constructor, which interprets each 2-element list as a key-value pair.

Finally, here's a function that builds a nested list representing a matrix of zeroes:

def empty(rows, cols):
    return [cols * [0] for r in range(rows)]

if clauses in list comprehensions

A list comprehension may have an if clause containing an arbitrary condition. Only list elements that satisfy the condition are included in the generated list.

For example, this comprehension collects all characters in the string 'watermelon' that are in the second half of the alphabet:

>>> [c for c in 'watermelon' if c >= 'n']
['w', 't', 'r', 'o', 'n']

Here's a 1-line solution to Project Euler's Problem 1 (find the sum of all the multiples of 3 or 5 below 1000):

>>> sum([i for i in range(1000) if i % 3 == 0 or i % 5 == 0])
233168

multiple for clauses in list comprehensions

A list comprehension may have more than one 'for' clause. The 'for' clauses represent a nested loop. All values generated by the inner loop are collected into the resulting list.

For example, this comprehension generates all pairs of values (x, y), where 0 ≤ x, y < 3:

>>> [(x, y) for x in range(3) for y in range(3)]
[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]

The following comprehension is similar, but in it 'y' only iterates up to the value 'x', so it only generates pairs where y < x:

>>> [(x, y) for x in range(3) for y in range(x)]
[(1, 0), (2, 0), (2, 1)]

The comprehension above is equivalent to the following:

l = []
for x in range(3):
    for y in range(x):
        l.append( (x, y) )

Notice that when there are multiple 'for' clauses in a single comprehension:

Let's write a function to flatten a 2-dimensional matrix, i.e. return a 1-dimensional list of its elements:

def flatten(m):
    return [ x for row in m for x in row ]

For example:

>>> mat = [ [2, 4, 6],
...         [1, 3, 7],
...         [8, 9, 1] ]
>>> flatten(mat)
[2, 4, 6, 1, 3, 7, 8, 9, 1]

Let's now write a program that will read any number of lines of input, each containing any number of integers separated by whitespace. The program will print the sum of all the numbers on all the lines:

import sys

nums = [int(w) for line in sys.stdin for w in line.split()]
print(sum(nums))

It's possible to mix 'for' and 'if' clauses in a comprehension. For example, here's code to generate all pairs (x, y) where 0 ≤ x, y < 3, but skipping those with x = 1:

>>> [(x, y) for x in range(3) if x != 1 for y in range(3)]
[(0, 0), (0, 1), (0, 2), (2, 0), (2, 1), (2, 2)]

As a larger example, here's a comprehension that finds all triples of integers (a, b, c) such that a2 + b2 = c2, with 0 ≤ a < b < c ≤ 20:

>>> [(a, b, c) for c in range(21)
               for b in range(c)
               for a in range(b)
               if a * a + b * b == c * c]
[(3, 4, 5), (6, 8, 10), (5, 12, 13), (9, 12, 15), (8, 15, 17), (12, 16, 20)]

functions as values

In Python, functions are first-class values. That means that we can work with functions just like with other values such as integers and strings: we can refer to functions with variables, pass them as arguments, return them from other functions, and so on.

Here is a Python function that adds the numbers from 1 to 1,000,000:

def bigsum():
    sum = 0
    for i in range(1, 1_000_001):
        sum += i
    return sum

We can put this function into a variable f:

>>> f = bigsum

And now we can call f just like the original function bigSum:

>>> f()
500000500000

Let's write a function time_it that takes a function as an argument:

import time


def time_it(f):
    start = time.time()
    x = f()
    end = time.time()
    print(f'function ran in {end - start:.2f} seconds')
    return x

Given any function f, time_it runs f and measures the time that elapses while f is running. It prints this elapsed time, and then returns whatever f returned:

>>> time_it(bigsum)
function ran in 0.04 seconds
500000500000

This is a first example illustrating that it can be useful to pass functions to functions. As we will see, there are many other reasons why we might want to do this.

map and filter

Python's standard library contains the map() function, which takes a function f and an iterable (such as a list). It returns a new iterable in which f has been applied to every element of the original iterable. For example:

def square(x):
    return x * x

>>> for x in map(square, [2, 3, 4]):
...     print(x)
... 
4
9
16

We may wish to collect the resulting values into a list:

>>> list(map(square, [2, 3, 4]))
[4, 9, 16]

We may achieve the same result using a list comprehension:

>>> [x * x for x in [2, 3, 4]]
[4, 9, 16]

Which is better, using map() or a comprehension? To some degree this is a matter of style. However, in some situations one approach or the other may be more compact.

Consider this program, which reads three integers from a single line of standard input:

words = input().split()
a = int(words[0])
b = int(words[1])
c = int(words[2])
print(f'a = {a}, b = {b}, c = {c}, sum = {a + b + c}')

We may rewrite it using map():

a, b, c = map(int, input().split())
print(f'a = {a}, b = {b}, c = {c}, sum = {a + b + c}')

In fact we may write a function similar to map() ourselves. Here's an implementation my_map() that takes arguments f (a function) and list (a list). The function applies f to each element of the list, and collects the results into a list that it returns:

def my_map(f, list):
    return [f(x) for x in a]

Here is how we might use my_map:

def double(x):
    return x * 2 

>>> my_map(double, [10, 20, 30])
[20, 40, 60]

In the returned list, every value in the input list has been doubled.

Note that my_map() is not exactly like map(), since my_map() returns a list, whereas map() returns an iterable. In some situations map() may be more efficient than my_map(). For example, consider these calls to map() and my_map():

>>> from math import sqrt
>>> sum(map(sqrt, range(1_000_000)))
666666166.4588418
>>> sum(my_map(sqrt, range(1_000_000)))
666666166.4588418

The call to my_map() builds a list with 1,000,000 elements. However the first expression above, which calls map(), runs in O(1) memory since this call to map() does not build a list - instead, it returns an iterator that produces successive values of the sequence sqrt(1), sqrt(2), ..., sqrt(999_000).

(You might ask: can we write a function that imitates the built-in map(), returning an iterator, not a list? The answer is yes, though we don't know how to do that yet. We would have to use a generator function or comprehension, which are features we might see in a later lecture.)

By the way, why is the sum above so close to (2 / 3)(1,000,000,000)? Answering this is an elementary exercise in differential calculus. :)

A related built-in function is filter(), which takes a function and an iterable such as a list, and returns a new iterable containing only the values for which the function returns true. For example:

def odd(x):
    return x % 2 == 1

>>> for x in filter(odd, [2, 4, 5, 7, 8, 9, 11]):
...    print(x)
... 
5
7
9
11

Again, we may wish to collect the results into a list:

>>> list(filter(odd, [2, 4, 5, 7, 8, 9, 11]))
[5, 7, 9, 11]

Once again, we could achieve the same result using a list comprehension:

>>> [x for x in [2, 4, 5, 7, 8, 9, 11] if odd(x)]
[5, 7, 9, 11]

Let's write our own version of filter():

# Produce a list containing all elements of a for which f is true.
def my_filter(f, a):
    return [x for x in a if f(x)]

max() and sort() with a key function

Here's a function max_by that finds the maximum value in an input sequence, applying a function f to each element to yield a comparison key:

def max_by(seq, f):
    first = True
    max_elem = None
    max_val = None
    for x in seq:
        v = f(x)
        if first or v > max_val:
            max_elem = x
            max_val = v
            first = False
    return max_elem

We can use max_by to find the longest list in a list of lists:

>>> max_by([[1, 7], [3, 4, 5], [2]], len)
[3, 4, 5]

Or we can use it to find the list whose last element is greatest:

def last(s):
    return s[-1]

>>> max_by([[1, 7], [3, 4, 5], [2]], last)
[1, 7]

This capability is so useful that it's also built into the standard library. The standard function max can take a keyword argument key holding a function that works exactly like the second argument to max_by:

>>> max([[1, 7], [3, 4, 5], [2]], key = len)
[3, 4, 5]

The built-in function sorted() and the sort() method take a similar argument 'key', so that you can sort by any attribute you like. For example:

>>> l = [[2, 7], [1, 3, 5, 2], [3, 10, 6], [8]]
>>> l.sort(key = len)
>>> l
[[8], [2, 7], [3, 10, 6], [1, 3, 5, 2]]

Let's write a similar function that sorts a list using bubble sort, with an arbitrary key function:

def sort_by(a, f):
    n = len(a)
    for i in range(n - 1, 0, -1):   # (n - 1), ..., 1
        for j in range(i):
            if f(a[j]) > f(a[j + 1]):
                a[j], a[j + 1] = a[j + 1], a[j]

lambda expressions

Let's return to the previous example where we were given a list of lists, and found the list whose last element is greatest:

def last(s):
    return s[-1]

>>> max_by([[1, 7], [3, 4, 5], [2]], last)
[1, 7]

It's a bit of a nuisance to have to define a separate function last here. Instead, we can use a lambda expression:

>>> max_by([[1, 7], [3, 4, 5], [2]], lambda l: l[-1])
[1, 7]

A lambda expression creates a function "on the fly", without giving it a name. In other words, a lambda expression creates an anonymous function.

A function created by a lambda expression is no different from any other function: we can call it, pass it as an argument, and so forth. Even though the function is initially anonymous, we can certainly put it into a variable:

>>> abc = lambda x, y: 2 * x + y
>>> abc(10, 3)
23

The assignment to abc above is basically equivalent to

def abc(x, y):
    return 2 * x + y

which is how we would more typically define this function.

A lambda function may even take no arguments at all:

>>> f = lambda: 14
>>> f()
14

As another example, suppose that we'd like to write a function that takes a string and returns its most frequent character. We may use max() with a key function that is a lambda expression:

# Return the character in s which occurs most frequently.
def freq(s):
    d = {}
    for c in s:
        if c in d:
            d[c] += 1
        else:
            d[c] = 1

    return max(d.keys(), key = lambda k: d[k])

We see here that a lambda expression may refer to a local variable in an enclosing scope. We could not write this function outside the freq() function, since then it would not have access to the dictionary d.

As a final example, we may write a selection sort using max() with a lambda function rather than an inner loop:

def selection_sort(a):
    n = len(a)
    for i in range(n):
        j = min(range(i, n), key = lambda k: a[k])
        a[i], a[j] = a[j], a[i]