Week 9: Notes

functions as values

In Python, functions are first-class values. That means that we can work with functions just like with other values such as integers and strings: we can refer to functions with variables, pass them as arguments, return them from other functions, and so on.

Here is a Python function that adds the numbers from 1 to 1,000,000:

def bigsum():
    sum = 0
    for i in range(1, 1_000_001):
        sum += i
    return sum

We can put this function into a variable f:

>>> f = bigsum

And now we can call f just like the original function bigSum:

>>> f()
500000500000

Let's write a function time_it that takes a function as an argument:

import time

def
time_it(f): start = time.time() x = f() end = time.time() print(f'function ran in {end - start:.2f} seconds') return x

Given any function f, time_it runs f and measures the time that elapses while f is running. It prints this elapsed time, and then returns whatever f returned:

>>> time_it(bigsum)
function ran in 0.04 seconds
500000500000

This is a first example illustrating that it can be useful to pass functions to functions. As we will see, there are many other reasons why we might want to do this.

map and filter

Python's standard library contains the map() function, which takes a function f and an iterable (such as a list). It returns a new iterable in which f has been applied to every element of the original iterable. For example:

def square(x):
    return x * x

>>> for x in map(square, [2, 3, 4]):
...     print(x)
... 
4
9
16

We may wish to collect the resulting values into a list:

>>> list(map(square, [2, 3, 4]))
[4, 9, 16]

We may achieve the same result using a list comprehension:

>>> [x * x for x in [2, 3, 4]]
[4, 9, 16]

Which is better, using map() or a comprehension? To some degree this is a matter of style. However, in some situations one approach or the other may be more compact.

Consider this program, which reads three integers from a single line of standard input:

words = input().split()
a = int(words[0])
b = int(words[1])
c = int(words[2])
print(f'a = {a}, b = {b}, c = {c}, sum = {a + b + c}')

We may rewrite it using map():

a, b, c = map(int, input().split())
print(f'a = {a}, b = {b}, c = {c}, sum = {a + b + c}')

In fact we may write a function similar to map() ourselves. Here's an implementation my_map() that takes arguments f (a function) and list (a list). The function applies f to each element of the list, and collects the results into a list that it returns:

def my_map(f, list):
    return [f(x) for x in a]

Here is how we might use my_map:

def double(x):
    return x * 2 

>>> my_map(double, [10, 20, 30])
[20, 40, 60]

In the returned list, every value in the input list has been doubled.

Note that my_map() is not exactly like map(), since my_map() returns a list, whereas map() returns an iterable. In some situations map() may be more efficient than my_map(). For example, consider these calls to map() and my_map():

>>> from math import sqrt
>>> sum(map(sqrt, range(1_000_000)))
666666166.4588418
>>> sum(my_map(sqrt, range(1_000_000)))
666666166.4588418

The call to my_map() builds a list with 1,000,000 elements. However the first expression above, which calls map(), runs in O(1) memory since this call to map() does not build a list - instead, it returns an iterator that produces successive values of the sequence sqrt(1), sqrt(2), ..., sqrt(999_000).

(You might ask: can we write a function that imitates the built-in map(), returning an iterator, not a list? The answer is yes, though we don't know how to do that yet. We would have to use a generator function or comprehension, which are features we might see in a later lecture.)

By the way, why is the sum above so close to (2 / 3)(1,000,000,000)? Answering this is an elementary exercise in differential calculus. :)

A related built-in function is filter(), which takes a function and an iterable such as a list, and returns a new iterable containing only the values for which the function returns true. For example:

def odd(x):
    return x % 2 == 1

>>> for x in filter(odd, [2, 4, 5, 7, 8, 9, 11]):
...    print(x)
... 
5
7
9
11

Again, we may wish to collect the results into a list:

>>> list(filter(odd, [2, 4, 5, 7, 8, 9, 11]))
[5, 7, 9, 11]

Once again, we could achieve the same result using a list comprehension:

>>> [x for x in [2, 4, 5, 7, 8, 9, 11] if odd(x)]
[5, 7, 9, 11]

Let's write our own version of filter():

# Produce a list containing all elements of a for which f is true.
def my_filter(f, a):
    return [x for x in a if f(x)]

max() and sort() with a key function

Here's a function max_by that finds the maximum value in an input sequence, applying a function f to each element to yield a comparison key:

def max_by(seq, f):
    max_elem = None
    max_val = None
    for x in seq:
        v = f(x)
        if max_elem == None or v > max_val:
            max_elem = x
            max_val = v
    return max_elem

We can use max_by to find the longest list in a list of lists:

>>> max_by([[1, 7], [3, 4, 5], [2]], len)
[3, 4, 5]

Or we can use it to find the list whose last element is greatest:

def last(s):
    return s[-1]

>>> max_by([[1, 7], [3, 4, 5], [2]], last)
[1, 7]

This capability is so useful that it's also built into the standard library. The standard function max can take a keyword argument key holding a function that works exactly like the second argument to max_by:

>>> max([[1, 7], [3, 4, 5], [2]], key = len)
[3, 4, 5]

The built-in function sorted() and the sort() method take a similar argument 'key', so that you can sort by any attribute you like. For example:

>>> l = [[2, 7], [1, 3, 5, 2], [3, 10, 6], [8]]
>>> l.sort(key = len)
>>> l
[[8], [2, 7], [3, 10, 6], [1, 3, 5, 2]]

Let's write a similar function that sorts a list using bubble sort, with an arbitrary key function:

def sort_by(a, f):
    n = len(a)
    for i in range(n - 1, 0, -1):   # (n - 1), ..., 1
        for j in range(i):
            if f(a[j]) > f(a[j + 1]):
                a[j], a[j + 1] = a[j + 1], a[j]

methods as values

We have just seen that a Python variable may refer to a function. It may also refer to a method of a particular object.

For example, consider this class:

class Counter:
    def __init__(self):
        self.count = 0
    
    def inc(self):
        self.count += 1

Let's create a couple of instances of Counter, and a variable 'f' that refers to the 'inc' method of one of those instances:

>>> c = Counter()
>>> c.inc()
>>> c.inc()
>>> c.count
2
>>> d = Counter()
>>> d.count
0
>>> f = c.inc

When we call f(), it will increment the count in the object c:

>>> f()
>>> f()
>>> c.count
4

The value in 'd' remains unchanged, since f refers to the inc() method of c, not d:

>>> d.count
0

lambda expressions

Let's return to the previous example where we were given a list of lists, and found the list whose last element is greatest:

def last(s):
    return s[-1]

>>> max_by([[1, 7], [3, 4, 5], [2]], last)
[1, 7]

It's a bit of a nuisance to have to define a separate function last here. Instead, we can use a lambda expression:

>>> max_by([[1, 7], [3, 4, 5], [2]], lambda l: l[-1])
[1, 7]

A lambda expression creates a function "on the fly", without giving it a name. In other words, a lambda expression creates an anonymous function.

A function created by a lambda expression is no different from any other function: we can call it, pass it as an argument, and so forth. Even though the function is initially anonymous, we can certainly put it into a variable:

>>> abc = lambda x, y: 2 * x + y
>>> abc(10, 3)
23

The assignment to abc above is basically equivalent to

def abc(x, y):
    return 2 * x + y

which is how we would more typically define this function.

A lambda function may even take no arguments at all:

>>> f = lambda: 14
>>> f()
14

As another example, suppose that we'd like to write a function that takes a string and returns its most frequent character. We may use max() with a key function that is a lambda expression:

# Return the character in s which occurs most frequently.
def freq(s):
    d = {}
    for c in s:
        if c in d:
            d[c] += 1
        else:
            d[c] = 1

    return max(d.keys(), key = lambda k: d[k])

We see here that a lambda expression may refer to a local variable in an enclosing scope. We could not write this function outside the freq() function, since then it would not have access to the dictionary d.

As a final example, we may write a selection sort using max() with a lambda function rather than an inner loop:

def selection_sort(a):
    n = len(a)
    for i in range(n):
        j = min(range(i, n), key = lambda k: a[k])
        a[i], a[j] = a[j], a[i]

nested functions

Python allows us to write nested functions, i.e. functions that are defined inside other functions or methods.

As a first example, consider the freq() function that we wrote above to find the most frequent character in a string. Instead of a lambda expression, we could use a nested function:

# Return the character in s which occurs most frequently.
def freq(s):
    d = {}
    for c in s:
        if c in d:
            d[c] += 1
        else:
            d[c] = 1

    def keyval(k):
        return d[k]

    return max(d.keys(), key = keyval)

Notiec that the function keyval() is nested inside the function freq(), and has access to the local variable d defined inside freq().

As another example, suppose that we'd like to write a function replace_with_max() that takes a square matrix m and returns a matrix n in which each value in m is replaced with the maximum of its neighbors in all 4 (horizontal or vertical) directions. For example, if m is

2 4
5 9

then replace_with_max(m) will return

5 9
9 5

As a first attempt, we might write

def replace_with_max(m):
    size = len(m)
    
    # Make a matrix of dimensions (size x size) filled with zeroes
    n = [ size * [ 0 ] for _ in range(size) ]
    
    for i in range(size):
        for j in range(size):
            n[i][j] = max(m[i  1][j], m[i + 1][j],
                          m[i][j  1], m[i][j + 1])
                          
    return n

However, we have a problem: if a square (i, j) is at the edge of the matrix, then an array reference such as m[i][j + 1] might go out of bounds.

To solve this problem, let's write a nested helper function get(i, j) that returns an array element if the position (i, j) is inside the matrix, otherwise (- math.inf), i.e. -∞. Here is the improved function:

def replace_with_max(m):
    def get(i, j):
        if 0 <= i < size and 0 <= j < size:
            return m[i][j]
        else:
            return - math.inf
    
    size = len(m)
    
    # Make a matrix of dimensions (size x size) filled with zeroes.
    n = [ size * [ 0 ] for _ in range(size) ]
    
    for i in range(size):
        for j in range(size):
            n[i][j] = max(get(i - 1, j), get(i + 1, j),
                          get(i, j - 1), get(i, j + 1))
                          
    return n

Notice that the nested function get() can refer to the parameter m, and also to the local variable size that is defined in its containing function replace_with_max().

Nested functions are often convenient for writing recursive helper functions. For example, suppose that we have a class TreeSet that holds values in a binary search tree. We'd like to write a method contains(x) that returns True if the value x is present in the tree. We could write contains() iteratively (we did this in our algorithms lecture), but let's write it recursively here. We'll need a recursive function that takes a tree node as a parameter; in the recursive case it will call itself, passing either the node's left or right child. It would be a bit awkward to make this a method. We could write the function outside the TreeSet class, however it's convenient to nest it inside contains():

class Node:
    def __init__(self, val, left, right):
        self.val = val
        self.left = left
        self.right = right

class TreeSet:
    def __init__(self):
        self.root = None

    def contains(self, x):
        def f(node):
            if node == None:
                return False
            if x == node.val:
                return True
            
            return f(node.left if x < node.val else node.right)
        
        return f(self.root)
    

Notice that the nested function f() can access the parameter x in its parent function contains(), which is convenient. If we wrote the function outside the class, it would need to take x as a parameter.

updating outer variables in nested functions

In the replace_with_max() function we wrote in the previous section, we saw that the nested function get() can read the values of the variables m and size in the containing function. What if get() wants to update the value of such a variable? For example, suppose that we want to count the number of calls to get() made inside a single call to replace_with_max(). We could attempt to write

def replace_with_max(m):
    g = 0     # number of calls to get()

    def get(i, j):
        g += 1
        if 0 <= i < size and 0 <= j < size:
            return m[i][j]
        else:
            return -math.inf
    

However, that won't work because as we have seen before, any variable that is updated inside a function is local by default in Python. And so in the code above, Python will think that g is a local variable inside get(), and will report an error when we first attempt to increment it.

One possible solution would be to make g global, and use a declaration global g inside get(). However, that's ugly since g doesn't really need to be global. A better way is to declare g as nonlocal:

def replace_with_max(m):
    g = 0     # number of calls to get()

    def get(i, j):
        nonlocal g
        g += 1
        if 0 <= i < size and 0 <= j < size:
            return m[i][j]
        else:
            return -math.inf
    

Now the code will work. The nonlocal statement is somewhat like the global statement in that it declares that a variable is not local. The difference is that global declares that a variable is to found in the global (i.e. top-level) scope, whereas nonlocal declares that a variable is a local variable in an enclosing function.

dictionaries with defaults

In the freq() function we wrote in an earlier section, we have code that builds a dictionary holding the number of occurrences of each character in a string:

    d = {}
    for c in s:
        if c in d:
            d[c] += 1
        else:
            d[c] = 1

In this code it's a bit of a bother that we have to check whether each key c is already in the dictionary. As an easier alternative, we can use the defaultdict class that's built into Python's standard library. When we create a defaultdict, we provide it with a default value function. When we look up a key K in a defaultdict and it's not found, Python will call this function, passing no arguments. The function will return a default value, and Python will then automatically add a mapping from K to that value. For example:

>>> from collections import defaultdict
>>> d = defaultdict(lambda: 0)
>>> d['red']
0
>>> d['blue'] += 1
>>> d['blue']
1
>>> d
defaultdict(<function <lambda> at 0x7fb3b6370540>, {'red': 0, 'blue': 1})

Note that instead of "lambda: 0" we could just write "int", since the built-in int() function just returns 0, the default value of an integer:

>>> int()
0

Using a defaultdict, we can rewrite the character-counting code above more easily:

from collections import defaultdict

d = defaultdict(int)
for c in s:
    d[c] += 1