跳转至

lec7-12

记录了lec1-6的零碎内容.

High-order Function

Takes in a function as an argument,即在传入参数中包含其它的函数.

  • lab03:

    >>> def f1(n):
    ...     def f2(x):
    ...         if x == 0:
    ...             return 'cake'
    ...         else:
    ...             print('The cake is a lie.')
    ...             n -= x
    ...             return f2(n)
    ...     return f2
    ...
    >>> f1(2)(2)
    The cake is a lie.
    Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
    File "<stdin>", line 7, in f2
    UnboundLocalError: cannot access local variable 'n' where it is not associated with a value
    

  • hw03:

    def count_change(total, money):
        """Return the number of ways to make change for total,
        under the currency system described by money.
    
        >>> def chinese_yuan(ith):
        ...     if ith == 1:
        ...         return 100
        ...     if ith == 2:
        ...         return 50
        ...     if ith == 3:
        ...         return 20
        ...     if ith == 4:
        ...         return 10
        ...     if ith == 5:
        ...         return 5
        ...     if ith == 6:
        ...         return 1
        >>> def us_cent(ith):
        ...     if ith == 1:
        ...         return 25
        ...     if ith == 2:
        ...         return 10
        ...     if ith == 3:
        ...         return 5
        ...     if ith == 4:
        ...         return 1
        >>> count_change(15, chinese_yuan)
        6
        >>> count_change(49, chinese_yuan)
        44
        >>> count_change(49, us_cent)
        39
        >>> count_change(49, lambda x: 2 ** (6 - x) if x <= 6 else None)
        692
        >>> from construct_check import check
        >>> # ban iteration
        >>> check(HW_SOURCE_FILE, 'count_change', ['While', 'For'])
        True
        """
        def helper(total, ith):
            denomination = money(ith)
            if denomination is None:
                return 0
            if total == 0:
                return 1
            if total < 0:
                return 0
    
            # 两种情况:
            # 1. 使用当前面额(至少一次),继续用当前面额
            # 2. 不使用当前面额,跳到下一个面额
            return helper(total - denomination, ith) + helper(total, ith + 1)
        return helper(total, 1)
    

  • hw03 最后一关:

    def multiadder(n):
        """Return a function that takes N arguments, one at a time, and adds them.
    
        >>> f = multiadder(3)
        >>> f(5)(6)(7) # 5 + 6 + 7
        18
        >>> multiadder(1)(5)
        5
        >>> multiadder(2)(5)(6) # 5 + 6
        11
        >>> multiadder(4)(5)(6)(7)(8) # 5 + 6 + 7 + 8
        26
        >>> from construct_check import check
        >>> # Make sure multiadder is a pure function.
        >>> check(HW_SOURCE_FILE, 'multiadder',
        ...       ['Nonlocal', 'Global'])
        True
        """
        if n == 1:
            return lambda x: x
        else:
            return lambda x: lambda y: multiadder(n - 1)(x + y) if n > 2 else x + y
    

ADT: Abstract Data Structure

  • Tree:

    number_tree = tree(1, [
                            tree(2),
                            tree(3, [
                                    tree(4),
                                    tree(5)]),
                            tree(6, [
                                    tree(7)])])
    

    展开看实则是:

       1
    / | \
    2  3  6
    / \  \
    4   5  7
    

    想要从中提取出3,可以这样写:

    label(branches(number_tree)[1])
    

    还可以设置这样的函数,使得整棵树可以被结构清楚地打印出来:

    def print_tree(t, indent=0):
        """Print a representation of this tree in which each node is
        indented by two spaces times its depth from the root.
    
        >>> print_tree(tree(1))
        1
        >>> print_tree(tree(1, [tree(2)]))
        1
        2
        >>> numbers = tree(1, [tree(2), tree(3, [tree(4), tree(5)]), tree(6, [tree(7)])])
        >>> print_tree(numbers)
        1
        2
        3
            4
            5
        6
            7
        """
        print("  " * indent + str(label(t)))
        for b in branches(t):
            print_tree(b, indent + 1)
    

    mutability: 可变性,比如list是可变的,而tuple元素不可以被重新赋值,是不可变的.

  • hw04: add_trees:

    def add_trees(t1, t2):
        """
        >>> numbers = tree(1,
        ...                [tree(2,
        ...                      [tree(3),
        ...                       tree(4)]),
        ...                 tree(5,
        ...                      [tree(6,
        ...                            [tree(7)]),
        ...                       tree(8)])])
        >>> print_tree(add_trees(numbers, numbers))
        2
        4
            6
            8
        10
            12
            14
            16
        >>> print_tree(add_trees(tree(2), tree(3, [tree(4), tree(5)])))
        5
        4
        5
        >>> print_tree(add_trees(tree(2, [tree(3)]), tree(2, [tree(3), tree(4)])))
        4
        6
        4
        >>> print_tree(add_trees(tree(2, [tree(3, [tree(4), tree(5)])]), \
        tree(2, [tree(3, [tree(4)]), tree(5)])))
        4
        6
            8
            5
        5
        """
    
        if t1 == None:
            return t2
        if t2 == None:
            return t1
    
        nl = label(t1) + label(t2)
    
        branches1 = branches(t1)
        branches2 = branches(t2)
    
        new = []
    
        for i in range(min(len(branches1), len(branches2))):
            new.append(add_trees(branches1[i], branches2[i]))
    
        for i in range(min(len(branches1), len(branches2)), len(branches1)):
            new.append(branches1[i])
    
        for i in range(min(len(branches1), len(branches2)), len(branches2)):
            new.append(branches2[i])
    
        return tree(nl, new)
    

  • proj02: minimum_mewtation: 一开始就搞错了,认为需要两边向中间剪枝,所以写出了非常错误的答案:

    def minimum_mewtations(typed, source, limit):
        if typed == '' or source == '':  # Fill in the condition
            # BEGIN
            return max(len(typed), len(source))
            # END
        elif typed == source:
            return 0
        elif typed[-1] == source[-1]:  # Feel free to remove or add additional cases
            # BEGIN
            return minimum_mewtations(typed[:-1], source[:-1], limit)
            # END
        elif typed[0] == source[0]:
            return minimum_mewtations(typed[1:], source[1:], limit)
    
        else:
            add = len(source) - len(typed) if len(typed) < len(source) else 0 # Fill in these lines
            remove = len(typed) - len(source) if len(typed) > len(source) else 0
            substitute = min(minimum_mewtations(typed[1:], source[1:], limit-1)+1, 
                            minimum_mewtations(typed[:-1], source[1:], limit-2)+2,
                            minimum_mewtations(typed[1:], source[:-1], limit-2)+2,
                            minimum_mewtations(typed[:-1], source[:-1], limit-1)+1) 
            # BEGIN
            return add + remove + substitute
            # END
    

    这样会导致这个报错:

    Problem 7 > Suite 1 > Case 6
    
    >>> from cats import minimum_mewtations, autocorrect
    >>> import tests.construct_check as test
    >>> # ***Check that the recursion stops when the limit is reached***
    >>> import trace, io
    >>> from contextlib import redirect_stdout
    >>> with io.StringIO() as buf, redirect_stdout(buf):
    ...     trace.Trace(trace=True).runfunc(minimum_mewtations, "someawe", "awesome", 3)
    ...     output = buf.getvalue()
    >>> len([line for line in output.split('\n') if 'funcname' in line]) < 1000
    False
    
    # Error: expected
    #     True
    # but got
    #     False
    

    正确的写法是:

    def minimum_mewtations(typed, source, limit):
        """A diff function that computes the edit distance from TYPED to SOURCE.
        This function takes in a string TYPED, a string SOURCE, and a number LIMIT.
        Arguments:
            typed: a typed word
            source: a source word
            limit: a number representing an upper bound on the number of edits
        >>> big_limit = 10
        >>> minimum_mewtations("cats", "scat", big_limit)       # cats -> scats -> scat
        2
        >>> minimum_mewtations("purng", "purring", big_limit)   # purng -> purrng -> purring
        2
        >>> minimum_mewtations("ckiteus", "kittens", big_limit) # ckiteus -> kiteus -> kitteus -> kittens
        3
        """
        # BEGIN PROBLEM 7
        if limit < 0:
            return float('inf')
        if typed == '' or source == '':  # Fill in the condition
            # BEGIN
            return max(len(typed), len(source))
            # END
        if typed[0] == source[0]:
            return minimum_mewtations(typed[1:], source[1:], limit)
    
        else:
            add = 1 + minimum_mewtations(typed, source[1:], limit - 1) # Fill in these lines
            remove = 1 + minimum_mewtations(typed[1:], source, limit - 1)
            substitute = 1 + minimum_mewtations(typed[1:], source[1:], limit - 1) 
            # BEGIN
            return min(add, remove, substitute)
            # END
        # END PROBLEM 7
    

Nonlocal

  • The line nonlocal balance tells Python that balance will not be local to this frame, so it will look for it in parent frames.

    举个例子:

    def make_withdraw(balance):
        """Returns a function which can withdraw
        some amount from balance
    
        >>> withdraw = make_withdraw(50)
        >>> withdraw(25)
        25
        >>> withdraw(25)
        0
        """
        def withdraw(amount):
            nonlocal balance
            if amount > balance:
                return "Insufficient funds"
            balance = balance - amount
            return balance
        return withdraw
    
    Now we can update balance without running into problems.

Iterator

  • 例子:

    >>> lst = [1, 2, 3, 4]
    >>> next(lst)             # Calling next on an iterable
    TypeError: 'list' object is not an iterator
    >>> list_iter = iter(lst) # Creates an iterator for the list
    >>> list_iter
    <list_iterator object ...>
    >>> next(list_iter)       # Calling next on an iterator
    1
    >>> next(list_iter)       # Calling next on the same iterator
    2
    >>> next(iter(list_iter)) # Calling iter on an iterator returns itself
    3
    >>> list_iter2 = iter(lst)
    >>> next(list_iter2)      # Second iterator has new state
    1
    >>> next(list_iter)       # First iterator is unaffected by second iterator
    4
    >>> next(list_iter)       # No elements left!
    StopIteration
    >>> lst                   # Original iterable is unaffected
    [1, 2, 3, 4]
    

  • iterator是会终止的,所以到尽头之后认为没有任何元素在了.

    >>> list_iter = iter([4, 3, 2, 1])
    >>> for e in list_iter:
    ...     print(e)
    4
    3
    2
    1
    >>> for e in list_iter:
    ...     print(e)
    

  • lab05:

    # What would Python display? If you get stuck, try it out in the Python interpreter!
    
    >>> r = range(6)
    >>> r_iter = iter(r)
    >>> next(r_iter)
    ? 0
    -- OK! --
    
    >>> [x + 1 for x in r]
    ? [1, 2, 3, 4, 5, 6]
    -- OK! --
    
    >>> [x + 1 for x in r_iter]
    ? [2, 3, 4, 5, 6]
    -- OK! --
    
    >>> next(r_iter)
    ? StopIteration
    -- OK! --
    

    解释:

    1. r = range(6)r_iter = iter(r) - r 是一个 range 对象,代表序列 [0, 1, 2, 3, 4, 5] - r_iter 是从 r 创建的迭代器对象

    2. next(r_iter) → 0 - 第一次调用 next(),迭代器返回第一个元素 0 - 此时迭代器的内部指针移动到位置 1

    3. [x + 1 for x in r] → [1, 2, 3, 4, 5, 6] - 这里直接使用 r(range 对象),不是迭代器 - 列表推导式会重新从头开始遍历整个 range - 所以得到完整的结果:[0+1, 1+1, 2+1, 3+1, 4+1, 5+1]

    4. [x + 1 for x in r_iter] → [2, 3, 4, 5, 6] - 这里使用的是 r_iter(迭代器对象) - 关键点:迭代器已经消耗了第一个元素(0),当前指向位置 1 - 所以列表推导式从剩余元素开始:[1+1, 2+1, 3+1, 4+1, 5+1] - 这个过程会把迭代器完全耗尽

    5. next(r_iter) → StopIteration - 迭代器已经被完全消耗,没有更多元素 - 抛出 StopIteration 异常,表示迭代结束

    核心概念

    迭代器的特点: - 迭代器是有状态的,记住当前位置 - 迭代器只能单向前进,不能重置或后退 - 迭代器是一次性的,耗尽后不能复用

    Range 对象 vs 迭代器: - range 对象可以多次迭代(每次都从头开始) - 迭代器只能遍历一次

  • generator中的filter函数:

    def sieve(t):
    """Suppose the smallest number from t is p, sieves out all the
    numbers that can be divided by p (except p itself) and recursively
    sieves out all the multiples of the next smallest number from the
    reset of of the sequence.
    
    >>> list(sieve(iter([3, 4, 5, 6, 7, 8, 9])))
    [3, 4, 5, 7]
    >>> list(sieve(iter([2, 3, 4, 5, 6, 7, 8])))
    [2, 3, 5, 7]
    >>> list(sieve(iter([1, 2, 3, 4, 5])))
    [1]
    """
    try:
        s = next(t)
    except StopIteration:
        return
    yield s
    
    fil = filter(lambda x: x % s != 0, t)
    yield from sieve(fil) # 把递归返回的所有值再yield出来