In Python, functions are first-class values. That means that we can work with functions just like with other values such as integers and strings: we can refer to functions with variables, pass them as arguments, return them from other functions, and so on.
Here is a Python function that adds the numbers from 1 to 1,000,000:
def bigsum(): sum = 0 for i in range(1, 1_000_001): sum += i return sum
We can put this function into a variable f
:
>>> f = bigsum
And now we can call f just like the original function bigSum
:
>>> f() 500000500000
Let's write a function time_it
that takes a function as
an argument:
import time
def time_it(f): start = time.time() x = f() end = time.time() print(f'function ran in {end - start:.2f} seconds') return x
Given any function f, time_it
runs f and measures the
time that elapses while f is running. It prints
this elapsed time, and then returns whatever f returned:
>>> time_it(bigsum) function ran in 0.04 seconds 500000500000
This is a first example illustrating that it can be useful to pass functions to functions. As we will see, there are many other reasons why we might want to do this.
Python's standard library contains the map() function, which takes a function f and an iterable (such as a list). It returns a new iterable in which f has been applied to every element of the original iterable. For example:
def square(x): return x * x >>> for x in map(square, [2, 3, 4]): ... print(x) ... 4 9 16
We may wish to collect the resulting values into a list:
>>> list(map(square, [2, 3, 4])) [4, 9, 16]
We may achieve the same result using a list comprehension:
>>> [x * x for x in [2, 3, 4]] [4, 9, 16]
Which is better, using map() or a comprehension? To some degree this is a matter of style. However, in some situations one approach or the other may be more compact.
Consider this program, which reads three integers from a single line of standard input:
words = input().split() a = int(words[0]) b = int(words[1]) c = int(words[2]) print(f'a = {a}, b = {b}, c = {c}, sum = {a + b + c}')
We may rewrite it using map():
a, b, c = map(int, input().split()) print(f'a = {a}, b = {b}, c = {c}, sum = {a + b + c}')
In fact we may write a function similar to map() ourselves. Here's an implementation my_map() that takes arguments f (a function) and list (a list). The function applies f to each element of the list, and collects the results into a list that it returns:
def my_map(f, list): return [f(x) for x in a]
Here is how we might use my_map:
def double(x): return x * 2 >>> my_map(double, [10, 20, 30]) [20, 40, 60]
In the returned list, every value in the input list has been doubled.
Note that my_map() is not exactly like map(), since my_map() returns a list, whereas map() returns an iterable. In some situations map() may be more efficient than my_map(). For example, consider these calls to map() and my_map():
>>> from math import sqrt >>> sum(map(sqrt, range(1_000_000))) 666666166.4588418 >>> sum(my_map(sqrt, range(1_000_000))) 666666166.4588418
The call to my_map() builds a list with 1,000,000 elements. However the first expression above, which calls map(), runs in O(1) memory since this call to map() does not build a list - instead, it returns an iterator that produces successive values of the sequence sqrt(1), sqrt(2), ..., sqrt(999_000).
(You might ask: can we write a function that imitates the built-in map(), returning an iterator, not a list? The answer is yes, though we don't know how to do that yet. We would have to use a generator function or comprehension, which are features we might see in a later lecture.)
By the way, why is the sum above so close to (2 / 3)(1,000,000,000)? Answering this is an elementary exercise in differential calculus. :)
A related built-in function is filter(), which takes a function and an iterable such as a list, and returns a new iterable containing only the values for which the function returns true. For example:
def odd(x): return x % 2 == 1 >>> for x in filter(odd, [2, 4, 5, 7, 8, 9, 11]): ... print(x) ... 5 7 9 11
Again, we may wish to collect the results into a list:
>>> list(filter(odd, [2, 4, 5, 7, 8, 9, 11])) [5, 7, 9, 11]
Once again, we could achieve the same result using a list comprehension:
>>> [x for x in [2, 4, 5, 7, 8, 9, 11] if odd(x)] [5, 7, 9, 11]
Let's write our own version of filter():
# Produce a list containing all elements of a for which f is true. def my_filter(f, a): return [x for x in a if f(x)]
Here's
a function max_by
that finds the maximum value in an
input sequence, applying a function f to each element to
yield a comparison key:
def max_by(seq, f): max_elem = None max_val = None for x in seq: v = f(x) if max_elem == None or v > max_val: max_elem = x max_val = v return max_elem
We can use max_by to find the longest list in a list of lists:
>>> max_by([[1, 7], [3, 4, 5], [2]], len) [3, 4, 5]
Or we can use it to find the list whose last element is greatest:
def last(s): return s[-1] >>> max_by([[1, 7], [3, 4, 5], [2]], last) [1, 7]
This capability is so useful that it's also built into the standard
library. The standard function max
can take a keyword
argument key
holding a function that works exactly like
the second argument to max_by
:
>>> max([[1, 7], [3, 4, 5], [2]], key = len) [3, 4, 5]
The built-in function sorted() and the sort() method take a similar argument 'key', so that you can sort by any attribute you like. For example:
>>> l = [[2, 7], [1, 3, 5, 2], [3, 10, 6], [8]] >>> l.sort(key = len) >>> l [[8], [2, 7], [3, 10, 6], [1, 3, 5, 2]]
Let's write a similar function that sorts a list using bubble sort, with an arbitrary key function:
def sort_by(a, f): n = len(a) for i in range(n - 1, 0, -1): # (n - 1), ..., 1 for j in range(i): if f(a[j]) > f(a[j + 1]): a[j], a[j + 1] = a[j + 1], a[j]
We have just seen that a Python variable may refer to a function. It may also refer to a method of a particular object.
For example, consider this class:
class Counter: def __init__(self): self.count = 0 def inc(self): self.count += 1
Let's create a couple of instances of Counter, and a variable 'f' that refers to the 'inc' method of one of those instances:
>>> c = Counter() >>> c.inc() >>> c.inc() >>> c.count 2 >>> d = Counter() >>> d.count 0 >>> f = c.inc
When we call f(), it will increment the count in the object c:
>>> f() >>> f() >>> c.count 4
The value in 'd' remains unchanged, since f refers to the inc() method of c, not d:
>>> d.count 0
Let's return to the previous example where we were given a list of lists, and found the list whose last element is greatest:
def last(s): return s[-1] >>> max_by([[1, 7], [3, 4, 5], [2]], last) [1, 7]
It's a bit of a nuisance to have to define a separate function last
here. Instead, we can use a lambda expression:
>>> max_by([[1, 7], [3, 4, 5], [2]], lambda l: l[-1]) [1, 7]
A lambda expression creates a function "on the fly", without giving it a name. In other words, a lambda expression creates an anonymous function.
A function created by a lambda expression is no different from any other function: we can call it, pass it as an argument, and so forth. Even though the function is initially anonymous, we can certainly put it into a variable:
>>> abc = lambda x, y: 2 * x + y >>> abc(10, 3) 23
The assignment to abc above is basically equivalent to
def abc(x, y): return 2 * x + y
which is how we would more typically define this function.
A lambda function may even take no arguments at all:
>>> f = lambda: 14 >>> f() 14
As another example, suppose that we'd like to write a function that takes a string and returns its most frequent character. We may use max() with a key function that is a lambda expression:
# Return the character in s which occurs most frequently. def freq(s): d = {} for c in s: if c in d: d[c] += 1 else: d[c] = 1 return max(d.keys(), key = lambda k: d[k])
We see here that a lambda expression may refer to a local variable in an enclosing scope. We could not write this function outside the freq() function, since then it would not have access to the dictionary d.
As a final example, we may write a selection sort using max() with a lambda function rather than an inner loop:
def selection_sort(a): n = len(a) for i in range(n): j = min(range(i, n), key = lambda k: a[k]) a[i], a[j] = a[j], a[i]
Python allows us to write nested functions, i.e. functions that are defined inside other functions or methods.
As a first example, consider the freq() function that we wrote above to find the most frequent character in a string. Instead of a lambda expression, we could use a nested function:
# Return the character in s which occurs most frequently. def freq(s): d = {} for c in s: if c in d: d[c] += 1 else: d[c] = 1 def keyval(k): return d[k] return max(d.keys(), key = keyval)
Notiec that the function
keyval() is nested inside the function freq(), and has access to the
local variable d
defined inside freq().
As another example, suppose that we'd like to write a function replace_with_max() that takes a square matrix m and returns a matrix n in which each value in m is replaced with the maximum of its neighbors in all 4 (horizontal or vertical) directions. For example, if m is
2 4 5 9
then replace_with_max(m) will return
5 9 9 5
As a first attempt, we might write
def replace_with_max(m): size = len(m) # Make a matrix of dimensions (size x size) filled with zeroes n = [ size * [ 0 ] for _ in range(size) ] for i in range(size): for j in range(size): n[i][j] = max(m[i – 1][j], m[i + 1][j], m[i][j – 1], m[i][j + 1]) return n
However, we have a problem: if a square (i, j) is at the edge of the matrix, then an array reference such as m[i][j + 1] might go out of bounds.
To solve this problem, let's write a nested helper function get(i, j) that returns an array element if the position (i, j) is inside the matrix, otherwise (- math.inf), i.e. -∞. Here is the improved function:
def replace_with_max(m): def get(i, j): if 0 <= i < size and 0 <= j < size: return m[i][j] else: return - math.inf size = len(m) # Make a matrix of dimensions (size x size) filled with zeroes. n = [ size * [ 0 ] for _ in range(size) ] for i in range(size): for j in range(size): n[i][j] = max(get(i - 1, j), get(i + 1, j), get(i, j - 1), get(i, j + 1)) return n
Notice that the nested function get() can refer to the parameter m
,
and also to the local variable size
that is defined in
its containing function replace_with_max().
Nested functions are often convenient for writing recursive helper functions. For example, suppose that we have a class TreeSet that holds values in a binary search tree. We'd like to write a method contains(x) that returns True if the value x is present in the tree. We could write contains() iteratively (we did this in our algorithms lecture), but let's write it recursively here. We'll need a recursive function that takes a tree node as a parameter; in the recursive case it will call itself, passing either the node's left or right child. It would be a bit awkward to make this a method. We could write the function outside the TreeSet class, however it's convenient to nest it inside contains():
class Node: def __init__(self, val, left, right): self.val = val self.left = left self.right = right class TreeSet: def __init__(self): self.root = None def contains(self, x): def f(node): if node == None: return False if x == node.val: return True return f(node.left if x < node.val else node.right) return f(self.root)
Notice that the nested function f() can access the parameter x
in its parent function contains(), which is convenient. If we wrote
the function outside the class, it would need to take x
as a parameter.
In the
replace_with_max()
function we wrote in the previous section,
we saw that the nested function get() can read the values of the
variables m
and size
in the containing function. What if get() wants to update
the value of such
a variable? For example, suppose that we want to count the number of
calls to get() made inside a single call to replace_with_max(). We
could attempt to write
def replace_with_max(m): g = 0 # number of calls to get() def get(i, j): g += 1 if 0 <= i < size and 0 <= j < size: return m[i][j] else: return -math.inf …
However, that won't work because as
we have seen before, any variable that is updated inside
a function is local by default in Python. And so in the code
above, Python will think that g
is a local variable
inside get(), and will report an error when we first attempt to
increment it.
One possible solution would be to make g
global, and use a declaration global g
inside get().
However, that's ugly since g
doesn't really need to be
global. A better way is to declare g
as nonlocal:
def replace_with_max(m): g = 0 # number of calls to get() def get(i, j): nonlocal g g += 1 if 0 <= i < size and 0 <= j < size: return m[i][j] else: return -math.inf …
Now the code will work. The nonlocal
statement is
somewhat like the global
statement in that it declares
that a variable is not local. The difference is that global
declares that a variable is to found in the global
(i.e. top-level)
scope, whereas nonlocal
declares that a variable
is a local variable in an enclosing function.
In the freq()
function we wrote in an
earlier section, we have code that builds a dictionary holding the
number of occurrences of each character in a string:
d = {} for c in s: if c in d: d[c] += 1 else: d[c] = 1
In this code it's a bit of a bother that we have to check whether
each key c
is already in the dictionary. As an easier
alternative, we can use the defaultdict
class that's
built into Python's standard library. When we create a defaultdict
,
we provide it with a default value function. When we look up a key K
in a defaultdict
and it's not found, Python will call
this function, passing no arguments. The function will return a
default value, and Python will then automatically add a mapping from
K to that value. For example:
>>> from collections import defaultdict >>> d = defaultdict(lambda: 0) >>> d['red'] 0 >>> d['blue'] += 1 >>> d['blue'] 1 >>> d defaultdict(<function <lambda> at 0x7fb3b6370540>, {'red': 0, 'blue': 1})
Note that instead of "lambda: 0" we could just write "int",
since the built-in int()
function just returns 0, the
default value of an integer:
>>> int() 0
Using a defaultdict
, we can rewrite the
character-counting code above more easily:
from collections import defaultdict d = defaultdict(int) for c in s: d[c] += 1