Let's briefly review the concept of iterables and
sequences in Python. An object is iterable if you can loop
over it with the 'for' statement. An object is a sequence if
you can access its elements by integer index, i.e. using the syntax
s[i]. All sequences are iterable, but
not all iterables are sequences.
We've now seen these kinds of sequences: lists, tuples, string, and ranges.
We've also seen these kinds of iterables which are not sequences: sys.stdin, file objects returned by open(), sets, and dictionaries.
Note that if you iterate over a dictionary you get its keys:
d = { 'red' : 1, 'green' : 2, 'blue' : 3 }
for x in d:
print(x)produces the output
red green blue
As we saw last week, a dictionary has methods keys(), values(), and items() that produce the keys, values, and key-value pairs in the dictionary. These methods all return iterable objects.
A list comprehension is an expression that loops over a sequence and collects a series of computed values into a list. List comprehensions are powerful and convenient.
For example, consider this loop that builds a list of all perfect squares from 1 to 400:
squares = []
for i in range(1, 21):
squares.append(i * i)We may replace the three lines above by a single line with a list comprehension:
squares = [i * i for i in range(1, 21)]
In general, a list comprehension may have the form
[ <expression> for <var> in <sequence> ]
A comprehension of this form will loop over the given <sequence>. On each loop iteration, it sets <var> to the value of an element of the sequence, then evaluates the given <expression>. All results are collected into a list.
Here are some more examples of list comprehensions. This comprehension builds a list of numbers from 1 to 20 and their squares:
>>> [(i, i * i) for i in range(1, 11)] [(1, 1), (2, 4), (3, 9), (4, 16), (5, 25), (6, 36), (7, 49), (8, 64), (9, 81), (10, 100)]
We may add 1 to each element in a list:
>>> l = [2, 5, 7, 10] >>> [i + 1 for i in l] [3, 6, 8, 11]
Let's write a program that will read a single input line containing a number of integers, separated by spaces. The program will print the sum of the integers. Using a list comprehension, we can write this in a single line:
print(sum([int(w) for w in input().split()]))
Consider Project Euler's Problem 6 (find the difference between the sum of the squares of the first one hundred natural numbers and the square of the sum). Here's a solution using a list comprehension:
sum_of_squares = sum([x * x for x in range(1, 101)]) square_of_sum = sum(range(1, 101)) ** 2 answer = square_of_sum - sum_of_squares
As another example, here's a file animals_en_cz containing English and Czech animal names:
bear medvěd bird pták cat kočka cow kráva dog pes goat kozel horse kůň mouse myš pig prase sheep ovce
Using a list comprehension, we may read it into a dictionary in a single line of code:
with open('animals_en_cz') as f:
d = dict([line.split() for line in f])
This works because the split() method splits each line above into a
2-element list such as ['bear', 'medvěd'].
We can pass a list of these lists to the dict() constructor, which
interprets each 2-element list as a key-value pair.
Finally, here's a function that builds a nested list representing a matrix of zeroes:
def empty(rows, cols):
return [cols * [0] for r in range(rows)]A list comprehension may have an if clause containing an arbitrary condition. Only list elements that satisfy the condition are included in the generated list.
For example, this comprehension collects all characters in the string 'watermelon' that are in the second half of the alphabet:
>>> [c for c in 'watermelon' if c >= 'n'] ['w', 't', 'r', 'o', 'n']
Here's a 1-line solution to Project Euler's Problem 1 (find the sum of all the multiples of 3 or 5 below 1000):
>>> sum([i for i in range(1000) if i % 3 == 0 or i % 5 == 0]) 233168
A list comprehension may have more than one 'for' clause. The 'for' clauses represent a nested loop. All values generated by the inner loop are collected into the resulting list.
For example, this comprehension generates all pairs of values (x, y), where 0 ≤ x, y < 3:
>>> [(x, y) for x in range(3) for y in range(3)] [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]
The following comprehension is similar, but in it 'y' only iterates up to the value 'x', so it only generates pairs where y < x:
>>> [(x, y) for x in range(3) for y in range(x)] [(1, 0), (2, 0), (2, 1)]
The comprehension above is equivalent to the following:
l = []
for x in range(3):
for y in range(x):
l.append( (x, y) )Notice that when there are multiple 'for' clauses in a single comprehension:
The first 'for' clause represents the outer loop.
The comprehension generates a 1-dimensional list, not a list of lists.
Let's write a function to flatten a 2-dimensional matrix, i.e. return a 1-dimensional list of its elements:
def flatten(m):
return [ x for row in m for x in row ]For example:
>>> mat = [ [2, 4, 6], ... [1, 3, 7], ... [8, 9, 1] ] >>> flatten(mat) [2, 4, 6, 1, 3, 7, 8, 9, 1]
Let's now write a program that will read any number of lines of input, each containing any number of integers separated by whitespace. The program will print the sum of all the numbers on all the lines:
import sys nums = [int(w) for line in sys.stdin for w in line.split()] print(sum(nums))
It's possible to mix 'for' and 'if' clauses in a comprehension. For example, here's code to generate all pairs (x, y) where 0 ≤ x, y < 3, but skipping those with x = 1:
>>> [(x, y) for x in range(3) if x != 1 for y in range(3)] [(0, 0), (0, 1), (0, 2), (2, 0), (2, 1), (2, 2)]
As a larger example, here's a comprehension that finds all triples of integers (a, b, c) such that a2 + b2 = c2, with 0 ≤ a < b < c ≤ 20:
>>> [(a, b, c) for c in range(21)
for b in range(c)
for a in range(b)
if a * a + b * b == c * c]
[(3, 4, 5), (6, 8, 10), (5, 12, 13), (9, 12, 15), (8, 15, 17), (12, 16, 20)]In Python, functions are first-class values. That means that we can work with functions just like with other values such as integers and strings: we can refer to functions with variables, pass them as arguments, return them from other functions, and so on.
Here is a Python function that adds the numbers from 1 to 1,000,000:
def bigsum():
sum = 0
for i in range(1, 1_000_001):
sum += i
return sum
We can put this function into a variable f:
>>> f = bigsum
And now we can call f just like the original function bigSum:
>>> f() 500000500000
Let's write a function time_it that
takes a function as an argument:
import time
def time_it(f):
start = time.time()
x = f()
end = time.time()
print(f'function ran in {end - start:.2f} seconds')
return x
Given any function f, time_it runs f and
measures the time that elapses while f is running. It prints
this elapsed time, and then returns whatever f returned:
>>> time_it(bigsum) function ran in 0.04 seconds 500000500000
This is a first example illustrating that it can be useful to pass functions to functions. As we will see, there are many other reasons why we might want to do this.
Python's standard library contains the map() function, which takes a function f and an iterable (such as a list). It returns a new iterable in which f has been applied to every element of the original iterable. For example:
def square(x):
return x * x
>>> for x in map(square, [2, 3, 4]):
... print(x)
...
4
9
16We may wish to collect the resulting values into a list:
>>> list(map(square, [2, 3, 4])) [4, 9, 16]
We may achieve the same result using a list comprehension:
>>> [x * x for x in [2, 3, 4]] [4, 9, 16]
Which is better, using map() or a comprehension? To some degree this is a matter of style. However, in some situations one approach or the other may be more compact.
Consider this program, which reads three integers from a single line of standard input:
words = input().split()
a = int(words[0])
b = int(words[1])
c = int(words[2])
print(f'a = {a}, b = {b}, c = {c}, sum = {a + b + c}')We may rewrite it using map():
a, b, c = map(int, input().split())
print(f'a = {a}, b = {b}, c = {c}, sum = {a + b + c}')In fact we may write a function similar to map() ourselves. Here's an implementation my_map() that takes arguments f (a function) and list (a list). The function applies f to each element of the list, and collects the results into a list that it returns:
def my_map(f, list):
return [f(x) for x in a]Here is how we might use my_map:
def double(x):
return x * 2
>>> my_map(double, [10, 20, 30])
[20, 40, 60]In the returned list, every value in the input list has been doubled.
Note that my_map() is not exactly like map(), since my_map() returns a list, whereas map() returns an iterable. In some situations map() may be more efficient than my_map(). For example, consider these calls to map() and my_map():
>>> from math import sqrt >>> sum(map(sqrt, range(1_000_000))) 666666166.4588418 >>> sum(my_map(sqrt, range(1_000_000))) 666666166.4588418
The call to my_map() builds a list with 1,000,000 elements. However the first expression above, which calls map(), runs in O(1) memory since this call to map() does not build a list - instead, it returns an iterator that produces successive values of the sequence sqrt(1), sqrt(2), ..., sqrt(999_000).
(You might ask: can we write a function that imitates the built-in map(), returning an iterator, not a list? The answer is yes, though we don't know how to do that yet. We would have to use a generator function or comprehension, which are features we might see in a later lecture.)
By the way, why is the sum above so close to (2 / 3)(1,000,000,000)? Answering this is an elementary exercise in differential calculus. :)
A related built-in function is filter(), which takes a function and an iterable such as a list, and returns a new iterable containing only the values for which the function returns true. For example:
def odd(x):
return x % 2 == 1
>>> for x in filter(odd, [2, 4, 5, 7, 8, 9, 11]):
... print(x)
...
5
7
9
11Again, we may wish to collect the results into a list:
>>> list(filter(odd, [2, 4, 5, 7, 8, 9, 11])) [5, 7, 9, 11]
Once again, we could achieve the same result using a list comprehension:
>>> [x for x in [2, 4, 5, 7, 8, 9, 11] if odd(x)] [5, 7, 9, 11]
Let's write our own version of filter():
# Produce a list containing all elements of a for which f is true.
def my_filter(f, a):
return [x for x in a if f(x)]Here's
a function max_by that finds the maximum
value in an input sequence, applying a function f to each element to
yield a comparison key:
def max_by(seq, f):
first = True
max_elem = None
max_val = None
for x in seq:
v = f(x)
if first or v > max_val:
max_elem = x
max_val = v
first = False
return max_elemWe can use max_by to find the longest list in a list of lists:
>>> max_by([[1, 7], [3, 4, 5], [2]], len) [3, 4, 5]
Or we can use it to find the list whose last element is greatest:
def last(s):
return s[-1]
>>> max_by([[1, 7], [3, 4, 5], [2]], last)
[1, 7]
This capability is so useful that it's also built into the standard
library. The standard function max can
take a keyword argument key holding a
function that works exactly like the second argument to max_by:
>>> max([[1, 7], [3, 4, 5], [2]], key = len) [3, 4, 5]
The built-in function sorted() and the sort() method take a similar argument 'key', so that you can sort by any attribute you like. For example:
>>> l = [[2, 7], [1, 3, 5, 2], [3, 10, 6], [8]] >>> l.sort(key = len) >>> l [[8], [2, 7], [3, 10, 6], [1, 3, 5, 2]]
Let's write a similar function that sorts a list using bubble sort, with an arbitrary key function:
def sort_by(a, f):
n = len(a)
for i in range(n - 1, 0, -1): # (n - 1), ..., 1
for j in range(i):
if f(a[j]) > f(a[j + 1]):
a[j], a[j + 1] = a[j + 1], a[j]Let's return to the previous example where we were given a list of lists, and found the list whose last element is greatest:
def last(s):
return s[-1]
>>> max_by([[1, 7], [3, 4, 5], [2]], last)
[1, 7]
It's a bit of a nuisance to have to define a separate function last
here. Instead, we can use a lambda expression:
>>> max_by([[1, 7], [3, 4, 5], [2]], lambda l: l[-1]) [1, 7]
A lambda expression creates a function "on the fly", without giving it a name. In other words, a lambda expression creates an anonymous function.
A function created by a lambda expression is no different from any other function: we can call it, pass it as an argument, and so forth. Even though the function is initially anonymous, we can certainly put it into a variable:
>>> abc = lambda x, y: 2 * x + y >>> abc(10, 3) 23
The assignment to abc above is basically equivalent to
def abc(x, y):
return 2 * x + ywhich is how we would more typically define this function.
A lambda function may even take no arguments at all:
>>> f = lambda: 14 >>> f() 14
As another example, suppose that we'd like to write a function that takes a string and returns its most frequent character. We may use max() with a key function that is a lambda expression:
# Return the character in s which occurs most frequently.
def freq(s):
d = {}
for c in s:
if c in d:
d[c] += 1
else:
d[c] = 1
return max(d.keys(), key = lambda k: d[k])We see here that a lambda expression may refer to a local variable in an enclosing scope. We could not write this function outside the freq() function, since then it would not have access to the dictionary d.
As a final example, we may write a selection sort using max() with a lambda function rather than an inner loop:
def selection_sort(a):
n = len(a)
for i in range(n):
j = min(range(i, n), key = lambda k: a[k])
a[i], a[j] = a[j], a[i]