Programming 1, 2022-3
Week 9: Notes

iterating with enumerate()

The enumerate() function lets you iterate over a sequence and its indices simulatenously.

For example, consider this function, which returns the index of the first odd element in a list or other iterable (or -1 if no odd elements are found):

def first_odd(a):
    for i in range(len(a)):
        if a[i] % 2 == 1:
            return i

    return -1

We may rewrite the function using enumerate():

def first_odd(a):
    for i, x in enumerate(a):
        if x % 2 == 1:
            return i

    return -1

On each iteration of the loop, i receives the index of an element in a, and x is the element at that index.

enumerate() actually returns an iterable of pairs. If we like, we may convert it to a list:

>>> list(enumerate([20, 40, 50, 80, 100]))
[(0, 20), (1, 40), (2, 50), (3, 80), (4, 100)]

iterating with zip()

The handy zip() function lets you iterate over two (or more) sequences simultaneously.

For example, let's use zip() to iterate over two lists of integers:

>>> l = [2, 4, 6, 8, 10]
>>> m = [20, 40, 60, 80, 100]
>>> for x, y in zip(l, m):
...     print(f'x = {x}, y = {y}')
x = 2, y = 20
x = 4, y = 40
x = 6, y = 60
x = 8, y = 80
x = 10, y = 100

Notice that on iteration we receive a pair of values, one from each of the lists that we zipped. (Think of a zipper on a jacket that pulls together two edges as it moves upward.)

zip() actually returns an iterable of pairs, which we may collect into a list:

>>> l = [2, 4, 6, 8, 10]
>>> m = [20, 40, 60, 80, 100]
>>> list(zip(l, m))
[(2, 20), (4, 40), (6, 60), (8, 80), (10, 100)]

Note that zip() wil stop as soon as it reaches the end of any list:

>>> l = [2, 4, 6]
>>> m = [20, 40, 60, 80, 100]
>>> list(zip(l, m))
[(2, 20), (4, 40), (6, 60)]

We can use zip() to simplify some loops. For example, consider a function that takes two lists a and b, and produces a new list in which each element is the sum of two corresponding elements from a and b:

# produce a list of sums of values in a and b
def list_sum(a, b):
    assert len(a) == len(b)
    return [a[i] + b[i] for i in range(len(a))]

Instead of iterating over indices, we may use zip():

# produce a list of sums of values in a and b
def list_sum(a, b):
    assert len(a) == len(b)
    return [x + y for x, y in zip(a, b)]

In fact zip() may take any number of arguments. If we give it three lists of integers, it will return a list of triples:

>>> list(zip([1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]))
[(1, 5, 9), (2, 6, 10), (3, 7, 11), (4, 8, 12)]

Let's now generalize the function list_sum() above into a function multi_sum(), which takes any number of lists and adds the corresponding elements in all lists:

def multi_sum(*args):
    return [sum(t) for t in zip(*args)]

For example:

>>> multi_sum([1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12])
[15, 18, 21, 24]

functions as values

In Python, functions are first-class values. That means that we can work with functions just like with other values such as integers and strings: we can refer to functions with variables, pass them as arguments, return them from other functions, and so on.

Here is a Python function that adds the numbers from 1 to 1,000,000:

def bigsum():
    sum = 0
    for i in range(1, 1_000_001):
        sum += i
    return sum

We can put this function into a variable f:

>>> f = bigsum

And now we can call f just like the original function bigSum:

>>> f()
500000500000

Let's write a function time_it that takes a function as an argument:

def time_it(f):
    start = time.time()
    x = f()
    end = time.time()
    print(f'function ran in {end - start:.2f} seconds')
    return x

Given any function f, time_it runs f and measures the time that elapses while f is running. It prints this elapsed time, and then returns whatever f returned:

>>> time_it(bigsum)
function ran in 0.04 seconds
500000500000

This is a first example illustrating that it can be useful to pass functions to functions. As we will see, there are many other reasons why we might want to do this.

map and filter

Python's standard library contains the map() function, which takes a function f and an iterable (such as a list). It returns a new iterable in which f has been applied to every element of the original iterable. For example:

def square(x):
    return x * x

>>> for x in map(square, [2, 3, 4]):
...     print(x)
... 
4
9
16

We may wish to collect the resulting values into a list:

>>> list(map(square, [2, 3, 4]))
[4, 9, 16]

We may achieve the same result using a list comprehension:

>>> [x * x for x in [2, 3, 4]]
[4, 9, 16]

Which is better, using map() or a comprehension? To some degree this is a matter of style. However, in some situations one approach or the other may be more compact.

Consider this program, which reads three integers from a single line of standard input:

words = input().split()
a = int(words[0])
b = int(words[1])
c = int(words[2])
print(f'a = {a}, b = {b}, c = {c}, sum = {a + b + c}')

We may rewrite it using map():

a, b, c = map(int, input().split())
print(f'a = {a}, b = {b}, c = {c}, sum = {a + b + c}')

In fact we may write a function similar to map() ourselves. Here's an implementation my_map() that takes arguments f (a function) and list (a list). The function applies f to each element of the list, and collects the results into a list that it returns:

def my_map(f, list):
    return [f(x) for x in a]

Here is how we might use my_map:

def double(x):
    return x * 2 

>>> my_map(double, [10, 20, 30])
[20, 40, 60]

In the returned list, every value in the input list has been doubled.

A related built-in function is filter(), which takes a function and an iterable such a list, and returns a new iterable containing only the values for which the function returns true. For example:

def odd(x):
    return x % 2 == 1

>>> for x in filter(odd, [2, 4, 5, 7, 8, 9, 11]):
...    print(x)
... 
5
7
9
11

Again, we may wish to collect the results into a list:

>>> list(filter(odd, [2, 4, 5, 7, 8, 9, 11]))
[5, 7, 9, 11]

Once again, we could achieve the same result using a list comprehension:

>>> [x for x in [2, 4, 5, 7, 8, 9, 11] if odd(x)]
[5, 7, 9, 11]

Let's write our own version of filter():

# Produce a list containing all elements of a for which f is true.
def my_filter(f, a):
    return [x for x in a if f(x)]

max() and sort() with a key function

Here's a function max_by that finds the maximum value in an input sequence, applying a function f to each element to yield a comparison key:

def max_by(seq, f):
    max_elem = None
    max_val = None
    for x in seq:
        v = f(x)
        if max_elem == None or v > max_val:
            max_elem = x
            max_val = v
    return max_elem

We can use max_by to find the longest list in a list of lists:

>>> max_by([[1, 7], [3, 4, 5], [2]], len)
[3, 4, 5]

Or we can use it to find the list whose last element is greatest:

def last(s):
    return s[-1]

>>> max_by([[1, 7], [3, 4, 5], [2]], last)
[1, 7]

This capability is so useful that it's also built into the standard library. The standard function max can take a keyword argument key holding a function that works exactly like the second argument to max_by:

>>> max([[1, 7], [3, 4, 5], [2]], key = len)
[3, 4, 5]

The built-in function sorted() and the sort() method take a similar argument 'key', so that you can sort by any attribute you like. For example:

>>> l = [[2, 7], [1, 3, 5, 2], [3, 10, 6], [8]]
>>> l.sort(key = len)
>>> l
[[8], [2, 7], [3, 10, 6], [1, 3, 5, 2]]

methods as values

We have just seen that a Python variable may refer to a function. It may also refer to a method of a particular object.

For example, consider this class:

class Counter:
    def __init__(self):
        self.count = 0
    
    def inc(self):
        self.count += 1

Let's create a couple of instances of Counter, and a variable 'f' that refers to the 'inc' method of one of those instances:

>>> c = Counter()
>>> c.inc()
>>> c.inc()
>>> c.count
2
>>> d = Counter()
>>> d.count
0
>>> f = c.inc

When we call f(), it will increment the count in the object c:

>>> f()
>>> f()
>>> c.count
4

The value in 'd' remains unchanged, since f refers to the inc() method of c, not d:

>>> d.count
0

lambda expressions

Let's return to the previous example where we were given a list of lists, and found the list whose last element is greatest:

def last(s):
    return s[-1]

>>> max_by([[1, 7], [3, 4, 5], [2]], last)
[1, 7]

It's a bit of a nuisance to have to define a separate function last here. Instead, we can use a lambda expression:

>>> max_by([[1, 7], [3, 4, 5], [2]], lambda l: l[-1])
[1, 7]

A lambda expression creates a function "on the fly", without giving it a name. In other words, a lambda expression creates an anonymous function.

A function created by a lambda expression is no different from any other function: we can call it, pass it as an argument, and so forth. Even though the function is initially anonymous, we can certainly put it into a variable:

>>> abc = lambda x, y: 2 * x + y
>>> abc(10, 3)
23

The assignment to abc above is basically equivalent to

def abc(x, y):
    return 2 * x + y

which is how we would more typically define this function.

As another example, suppose that we'd like to write a function that takes a string and returns its most frequent character. We may use max() with a key function that is a lambda expression:

# Return the character in s which occurs most frequently.
def freq(s):
    d = {}
    for c in s:
        if c in d:
            d[c] += 1
        else:
            d[c] = 1

    return max(d.keys(), key = lambda k: d[k])

We see here that a lambda expression may refer to a local variable in an enclosing scope. We could not write this function outside the freq() function, since then it would not have access to the dictionary d.

nested functions

Python allows us to write nested functions, i.e. functions that are defined inside other functions or methods.

As a first example, consider the freq() function that we wrote above. Instead of a lambda expression, we could use a nested function:

# Return the character in s which occurs most frequently.
def freq(s):
    d = {}
    for c in s:
        if c in d:
            d[c] += 1
        else:
            d[c] = 1

    def keyval(k):
        return d[k]

    return max(d.keys(), key = keyval)

Notiec that the function keyval() is nested inside the function freq(), and has access to the local variable 'd' defined inside freq().

As another example, suppose that we'd like to write a function replace_with_max() that takes a square matrix m and returns a matrix n in which each value in m is replaced with the maximum of its neighbors in all 4 directions. For example, if m is

2 4
5 9

then replace_with_max(m) will return

5 9
9 5

As a first attempt, we might write

def replace_with_max(m):
    size = len(m)
    
    # Make a matrix of dimensions (size x size) filled with zeroes
    n = [ size * [ 0 ] for _ in range(size) ]
    
    for r in range(size):
        for c in range(size):
            n[r][c] = max(m[r  1][c], m[r + 1][c],
                          m[r][c  1], m[r][c + 1])
                          
    return n

However, we have a problem: if a square (r, c) is at the edge of the matrix, then an array reference such as m[r][c + 1] might go out of bounds.

To solve this problem, let's write a nested helper function get(i, j) that returns an array element if the position (i, j) is inside the matrix, otherwise (- math.inf), i.e. -∞. Here is the improved function:

def replace_with_max(m):
    def get(i, j):
        if 0 <= i < size and 0 <= j < size:
            return m[i][j]
        else:
            return -math.inf
    
    size = len(m)
    
    # Make a matrix of dimensions (size x size) filled with zeroes
    n = [ size * [ 0 ] for _ in range(size) ]
    
    for r in range(size):
        for c in range(size):
            n[r][c] = max(get(r - 1, c), get(r + 1, c),
                          get(r, c - 1), get(r, c + 1))
                          
    return n

Notice that the nested function get() can refer to the parameter 'm', and also to the local variable 'size' that is defined in its containing function replace_with_max().

updating outer variables in nested functions

In the previous example, we saw that the nested function get() can read the values of the variables 'm' and 'size' in the containing function. What if get() wants to update the value of such a variable? For example, suppose that we want to count the number of calls to get() made inside a single call to replace_with_max(). We could attempt to write

def replace_with_max(m):
    g = 0     # number of calls to get()

    def get(i, j):
        g += 1
        if 0 <= i < size and 0 <= j < size:
            return m[i][j]
        else:
            return -math.inf
    

However, that won't work because as we have seen before, any variable that is updated inside a function is local by default in Python. And so in the code above, Python will think that 'g' is a local variable inside get(), and will report an error when we first attempt to increment it.

One possible solution would be to make 'g' global, and use a declaration 'global g' inside get(). However, that's a bit ugly since 'g' doesn't really need to be global. A better way is to declare g as nonlocal:

def replace_with_max(m):
    g = 0     # number of calls to get()

    def get(i, j):
        nonlocal g
        g += 1
        if 0 <= i < size and 0 <= j < size:
            return m[i][j]
        else:
            return -math.inf
    

Now the code will work. The nonlocal statement is somewhat like the global statement in that it declares that a variable is not local. The difference is that global declares that a variable is to found in the global (i.e. top-level) scope, whereas nonlocal declares that a variable is a local variable in an enclosing function.

functions as return values

A function can return a function. As an example, let's write a function add_n(n) that takes an integer n and returns a function that adds n to its argument. As one possible approach, we can define a nested function and then return it:

def add_n(n):
    def adder(x):
        return x + n
        
    return adder

Let's try it:

>>> f = add_n(10)
>>> f(5)
15

Alternatively, we can write add_n using a lambda:

def add_n(n):
    return lambda x: x + n

In these functions, we say that the parameter n has been captured by the returned function. Although it is a parameter to add_n(), it continues to exist even after add_n() has returned: the returned function can access the value of n when it is called.

transforming a function

We can write a function that takes a function f as an argument and returns a transformed function based on f.

For example, let's write a function swap() that takes a function f of two arguments. swap() will return a function that is just like f, but takes its arguments in the opposite order. To achieve this, we can simply define a nested function and return it:

# Returns f with its arguments swapped.
def swap(f):
    def g(x, y):
        return f(y, x)

    return g

Let's try it:

def foo(x, y):
    return 1000 * x + y

>>> h = swap(foo)
>>> foo(2, 3)
2003
>>> h = swap(foo)
>>> h(2, 3)
3002
>>> h(3, 2)
2003

As another example, let's write a function twice() that takes a function f and returns a function g such that g(x) = f(f(x)) for any x.

def twice(f):
    def g(x):
        return f(f(x))

    return g

Let's try it:

>>> f = twice(lambda x: x + 10)
>>> f(5)
25

We can even pass the result of twice() back to the same function, yielding a function that applies the original function four times:

>>> f = twice(twice(lambda x: x + 10))
>>> f(5)
45

Alternatively, we can define twice() using a lambda:

def twice(f):
    return lambda x: f(f(x))