单分派泛函数

Posted May 31, 2021 by clannadzsy  ‐  2 min read

singledispatch 是 PEP443 中引入的

  1. 单分派:根据一个参数的类型,以不同方式执行相同的操作的行为
  2. 多分派:可根据多个参数的类型选择专门的函数的行为
  3. 泛函数:多个函数绑在一起组合成一个泛函数

初步认识

from decimal import Decimal
from functools import singledispatch


@singledispatch
def fun(arg):
    print(arg)


@fun.register(str)
def _(arg):
    print(type(arg), arg)


@fun.register(int)
def _(arg):
    print(type(arg), arg)


@fun.register(list)
def _(arg):
    print(type(arg), arg)


@fun.register(type(None))
def _(arg):
    print(type(arg), arg)


@fun.register(float)
@fun.register(Decimal)
def fun_num(arg):
    print(type(arg), arg)


if __name__ == '__main__':
    fun("Hello, world.")
    fun(123)
    fun(['spam', 'spam', 'eggs', 'spam'])
    fun(None)
    fun(1.23)

    # False
    print(fun_num is fun)

    # To access all registered implementations, use the read-only registry attribute
    print(fun.registry.keys())
    print(fun.registry[float])
    print(fun.registry[object])

    # 输出结果
    # <class 'str'> Hello, world.
    # <class 'int'> 123
    # <class 'list'> ['spam', 'spam', 'eggs', 'spam']
    # <class 'NoneType'> None
    # <class 'float'> 1.23
    # False
    # dict_keys([<class 'object'>, <class 'str'>, <class 'int'>, <class 'list'>, <class 'NoneType'>, <class 'decimal.Decimal'>, <class 'float'>])
    # <function fun_num at 0x00000155A9E9B790>
    # <function fun at 0x00000155A972F0D0>    

简单应用

from functools import singledispatch


def check_type(func):
    def wrapper(*args):
        arg1, arg2 = args[:2]
        if type(arg1) != type(arg2):
            return 'Error: 参数类型不同'
        return func(*args)

    return wrapper


@singledispatch
def add(obj, new_obj):
    raise TypeError


@add.register(str)
@check_type
def _(obj, new_obj):
    obj += new_obj
    return obj


@add.register(list)
@check_type
def _(obj, new_obj):
    obj.extend(new_obj)
    return obj


@add.register(dict)
@check_type
def _(obj, new_obj):
    obj.update(new_obj)
    return obj


@add.register(tuple)
@check_type
def _(obj, new_obj):
    return *obj, *new_obj


if __name__ == '__main__':
    print(add('hello', ', world'))
    print(add([1, 2, 3], [4, 5, 6]))
    print(add({'name': 'xxx'}, {'age': 111}))
    print(add(('1', '2',), ('3', '4',)))

    # list 和 字符串 无法拼接
    print(add([1, 2, 3], '4,5,6'))

    # 输出结果
    # hello, world
    # [1, 2, 3, 4, 5, 6]
    # {'name': 'xxx', 'age': 111}
    # ('1', '2', '3', '4')
    # Error: 参数类型不同