Wednesday, November 4, 2009

Python: lambda, partial, and scopes

Python rarely surprises me, which is a good thing. Unfortunately, today I was surprised by how lambda works. It made me worry that I had no idea what I was doing at all. (And I guess I didn't!) A lambda captures a reference to any variables it closes over (vs. captuiring the value of those variables), so in a loop context you can get some bizarre effects.

It turns out this is well documented, but I'm writing this as a note to myself if nothing else.

def adder(x, y):
    return x + y

# make a list of functions with lambda that close over i and call adder()
fns = [lambda x: adder(x, i) for i in range(10)]

# call each of our new functions with a constant value
# expected output: [13, 14, 15, 16, 17, 18, 19, 20, 21, 22]

[f(13) for f in fns]
# output: [22, 22, 22, 22, 22, 22, 22, 22, 22, 22]

# buh!?

The functions that are created in the list fns each capture a reference to the local variable i, but that reference is constantly being rebound to different values as we iterate through range(10). The last value i takes in range(10) is 9, so all of the functions end up being lambda x: adder(x, 9). Below are two "fixes".

Fix 1: capture the value of i as a default argument to the lambda. Hacky, but supposedly idiomatic for this sort of issue.

fns = [lambda x1, x2=i: adder(x1, x2) for i in range(10)]
[f(13) for f in fns]
# output: [13, 14, 15, 16, 17, 18, 19, 20, 21, 22]

Fix 2: use functools.partial. I prefer this.

from functools import partial
fns = [partial(adder, i) for i in range(10)]
[f(13) for f in fns]
# output: [13, 14, 15, 16, 17, 18, 19, 20, 21, 22]

In retrospect, this really shouldn't be surprising. I think the main issue here is that in these examples i is just a plain old integer, which, due to my Java background, I tend to think of as a primitive Java int that can only be passed by value. In Python, everything is just references until interpretation time. e.g.

foo = 42
def test():
    print foo
test() # outputs: 42

foo = 100
test() # outputs: 100 (!)