Python正确重载运算符的方法示例详解

前言

说到运算符重载相信大家都不陌生,运算符重载的作用是让用户定义的对象使用中缀运算符(如 + 和 |)或一元运算符(如 - 和 ~)。说得宽泛一些,在 Python 中,函数调用(())、属性访问(.)和元素访问 / 切片([])也是运算符。

我们为 Vector 类简略实现了几个运算符。__add__ 和 __mul__ 方法是为了展示如何使用特殊方法重载运算符,不过有些小问题被我们忽视了。此外,我们定义的Vector2d.__eq__ 方法认为 Vector(3, 4) == [3, 4] 是真的(True),这可能并不合理。下面来一起看看详细的介绍吧。

运算符重载基础

在某些圈子中,运算符重载的名声并不好。这个语言特性可能(已经)被滥用,让程序员困惑,导致缺陷和意料之外的性能瓶颈。但是,如果使用得当,API 会变得好用,代码会变得易于阅读。Python 施加了一些限制,做好了灵活性、可用性和安全性方面的平衡:

  • 不能重载内置类型的运算符
  • 不能新建运算符,只能重载现有的
  • 某些运算符不能重载——is、and、or 和 not(不过位运算符
  • &、| 和 ~ 可以)

前面的博文已经为 Vector 定义了一个中缀运算符,即 ==,这个运算符由__eq__ 方法支持。我们将改进 __eq__ 方法的实现,更好地处理不是Vector 实例的操作数。然而,在运算符重载方面,众多比较运算符(==、!=、>、<、>=、<=)是特例,因此我们首先将在 Vector 中重载四个算术运算符:一元运算符 - 和 +,以及中缀运算符 + 和 *。

一元运算符

  -(__neg__)

    一元取负算术运算符。如果 x 是 -2,那么 -x == 2。

  +(__pos__)

    一元取正算术运算符。通常,x == +x,但也有一些例外。如果好奇,请阅读“x 和 +x 何时不相等”附注栏。

  ~(__invert__)

    对整数按位取反,定义为 ~x == -(x+1)。如果 x 是 2,那么 ~x== -3。

支持一元运算符很简单,只需实现相应的特殊方法。这些特殊方法只有一个参数,self。然后,使用符合所在类的逻辑实现。不过,要遵守运算符的一个基本规则:始终返回一个新对象。也就是说,不能修改self,要创建并返回合适类型的新实例。

对 - 和 + 来说,结果可能是与 self 同属一类的实例。多数时候,+ 最好返回 self 的副本。abs(...) 的结果应该是一个标量。但是对 ~ 来说,很难说什么结果是合理的,因为可能不是处理整数的位,例如在ORM 中,SQL WHERE 子句应该返回反集。

def __abs__(self):
  return math.sqrt(sum(x * x for x in self))

 def __neg__(self):
  return Vector(-x for x in self)   #为了计算 -v,构建一个新 Vector 实例,把 self 的每个分量都取反

 def __pos__(self):
  return Vector(self)      #为了计算 +v,构建一个新 Vector 实例,传入 self 的各个分量

x 和 +x 何时不相等

每个人都觉得 x == +x,而且在 Python 中,几乎所有情况下都是这样。但是,我在标准库中找到两例 x != +x 的情况。

第一例与 decimal.Decimal 类有关。如果 x 是 Decimal 实例,在算术运算的上下文中创建,然后在不同的上下文中计算 +x,那么 x!= +x。例如,x 所在的上下文使用某个精度,而计算 +x 时,精度变了,例如下面的 🌰

算术运算上下文的精度变化可能导致 x 不等于 +x

>>> import decimal
>>> ctx = decimal.getcontext()                  #获取当前全局算术运算符的上下文引用
>>> ctx.prec = 40                          #把算术运算上下文的精度设为40
>>> one_third = decimal.Decimal('1') / decimal.Decimal('3') #使用当前精度计算1/3
>>> one_third
Decimal('0.3333333333333333333333333333333333333333')     #查看结果,小数点后的40个数字
>>> one_third == +one_third                    #one_third = +one_thied返回TRUE
True
>>> ctx.prec = 28                          #把精度降为28
>>> one_third == +one_third                    #one_third = +one_thied返回FalseFalse >>> +one_third Decimal('0.3333333333333333333333333333')   #查看+one_third,小术后的28位数字

虽然每个 +one_third 表达式都会使用 one_third 的值创建一个新 Decimal 实例,但是会使用当前算术运算上下文的精度。

x != +x 的第二例在 collections.Counter 的文档中(https://docs.python.org/3/library/collections.html#collections.Counter)。类实现了几个算术运算符,例如中缀运算符 +,作用是把两个Counter 实例的计数器加在一起。然而,从实用角度出发,Counter 相加时,负值和零值计数会从结果中剔除。而一元运算符 + 等同于加上一个空 Counter,因此它产生一个新的Counter 且仅保留大于零的计数器。

🌰  一元运算符 + 得到一个新 Counter 实例,但是没有零值和负值计数器

>>> from collections import Counter
>>> ct = Counter('abracadabra')
>>> ct['r'] = -3
>>> ct['d'] = 0
>>> ct
Counter({'a': 5, 'r': -3, 'b': 2, 'c': 1, 'd': 0})
>>> +ct
Counter({'a': 5, 'b': 2, 'c': 1})

重载向量加法运算符+

两个欧几里得向量加在一起得到的是一个新向量,它的各个分量是两个向量中相应的分量之和。比如说:

>>> v1 = Vector([3, 4, 5])
>>> v2 = Vector([6, 7, 8])
>>> v1 + v2
Vector([9.0, 11.0, 13.0])
>>> v1 + v2 == Vector([3+6, 4+7, 5+8])
True

确定这些基本的要求之后,__add__ 方法的实现短小精悍,🌰 如下

 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
  return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和

还可以把Vector 加到元组或任何生成数字的可迭代对象上:

# 在Vector类中定义 

 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
  return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和

 def __radd__(self, other):            #会直接委托给__add__
  return self + other

__radd__ 通常就这么简单:直接调用适当的运算符,在这里就是委托__add__。任何可交换的运算符都能这么做。处理数字和向量时,+ 可以交换,但是拼接序列时不行。

重载标量乘法运算符*

Vector([1, 2, 3]) * x 是什么意思?如果 x 是数字,就是计算标量积(scalar product),结果是一个新 Vector 实例,各个分量都会乘以x——这也叫元素级乘法(elementwise multiplication)。

>>> v1 = Vector([1, 2, 3])
>>> v1 * 10
Vector([10.0, 20.0, 30.0])
>>> 11 * v1
Vector([11.0, 22.0, 33.0])

涉及 Vector 操作数的积还有一种,叫两个向量的点积(dotproduct);如果把一个向量看作 1×N 矩阵,把另一个向量看作 N×1 矩阵,那么就是矩阵乘法。NumPy 等库目前的做法是,不重载这两种意义的 *,只用 * 计算标量积。例如,在 NumPy 中,点积使用numpy.dot() 函数计算。

回到标量积的话题。我们依然先实现最简可用的 __mul__ 和 __rmul__方法:

def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented

 def __rmul__(self, scalar):
  return self * scalar

这两个方法确实可用,但是提供不兼容的操作数时会出问题。scalar参数的值要是数字,与浮点数相乘得到的积是另一个浮点数(因为Vector 类在内部使用浮点数数组)。因此,不能使用复数,但可以是int、bool(int 的子类),甚至 fractions.Fraction 实例等标量。

提供了点积所需的 @ 记号(例如,a @ b 是 a 和 b 的点积)。@ 运算符由特殊方法 __matmul__、__rmatmul__ 和__imatmul__ 提供支持,名称取自“matrix multiplication”(矩阵乘法)

>>> va = Vector([1, 2, 3])
>>> vz = Vector([5, 6, 7])
>>> va @ vz == 38.0 # 1*5 + 2*6 + 3*7
True
>>> [10, 20, 30] @ vz
380.0
>>> va @ 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for @: 'Vector' and 'int'

下面是相应特殊方法的代码:

>>> va = Vector([1, 2, 3])
>>> vz = Vector([5, 6, 7])
>>> va @ vz == 38.0 # 1*5 + 2*6 + 3*7
True
>>> [10, 20, 30] @ vz
380.0
>>> va @ 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for @: 'Vector' and 'int'

众多比较运算符

Python 解释器对众多比较运算符(==、!=、>、<、>=、<=)的处理与前文类似,不过在两个方面有重大区别。

  • 正向和反向调用使用的是同一系列方法。例如,对 == 来说,正向和反向调用都是 __eq__ 方法,只是把参数对调了;而正向的 __gt__ 方法调用的是反向的 __lt__方法,并把参数对调。
  • 对 == 和 != 来说,如果反向调用失败,Python 会比较对象的 ID,而不抛出 TypeError。

众多比较运算符:正向方法返回NotImplemented的话,调用反向方法


分组


中缀运算符


正向方法调用


反向方法调用


后备机制


相等性


a == b


a.__eq__(b)


b.__eq__(a)


返回 id(a) == id(b)


a != b


a.__ne__(b)


b.__ne__(a)


返回 not (a == b)


排序


a > b


a.__gt__(b)


b.__lt__(a)


抛出 TypeError


a < b


a.__lt__(b)


b.__gt__(a)


抛出 TypeError


a >= b


a.__ge__(b)


b.__le__(a)


抛出 TypeError


a <= b


a.__le__(b)


b.__ge__(a)


抛出T ypeError

看下面的🌰

from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools

class Vector:
 typecode = 'd'

 def __init__(self, components):
  self._components = array(self.typecode, components)

 def __iter__(self):
  return iter(self._components)

 def __repr__(self):
  components = reprlib.repr(self._components)
  components = components[components.find('['):-1]
  return 'Vector({})'.format(components)

 def __str__(self):
  return str(tuple(self))

 def __bytes__(self):
  return (bytes([ord(self.typecode)]) + bytes(self._components))

 def __eq__(self, other):
  return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))

 def __hash__(self):
  hashes = map(hash, self._components)
  return functools.reduce(operator.xor, hashes, 0)

 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
  return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和

 def __radd__(self, other):            #会直接委托给__add__
  return self + other

 def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented

 def __rmul__(self, scalar):
  return self * scalar

 def __matmul__(self, other):
  try:
   return sum(a * b for a, b in zip(self, other))
  except TypeError:
   return NotImplemented

 def __rmatmul__(self, other):
  return self @ other

 def __abs__(self):
  return math.sqrt(sum(x * x for x in self))

 def __neg__(self):
  return Vector(-x for x in self)   #为了计算 -v,构建一个新 Vector 实例,把 self 的每个分量都取反

 def __pos__(self):
  return Vector(self)       #为了计算 +v,构建一个新 Vector 实例,传入 self 的各个分量

 def __bool__(self):
  return bool(abs(self))

 def __len__(self):
  return len(self._components)

 def __getitem__(self, index):
  cls = type(self)

  if isinstance(index, slice):
   return cls(self._components[index])
  elif isinstance(index, numbers.Integral):
   return self._components[index]
  else:
   msg = '{.__name__} indices must be integers'
   raise TypeError(msg.format(cls))

 shorcut_names = 'xyzt'

 def __getattr__(self, name):
  cls = type(self)

  if len(name) == 1:
   pos = cls.shorcut_names.find(name)
   if 0 <= pos < len(self._components):
    return self._components[pos]
  msg = '{.__name__!r} object has no attribute {!r}'
  raise AttributeError(msg.format(cls, name))

 def angle(self, n):
  r = math.sqrt(sum(x * x for x in self[n:]))
  a = math.atan2(r, self[n-1])
  if (n == len(self) - 1 ) and (self[-1] < 0):
   return math.pi * 2 - a
  else:
   return a

 def angles(self):
  return (self.angle(n) for n in range(1, len(self)))

 def __format__(self, fmt_spec=''):
  if fmt_spec.endswith('h'):
   fmt_spec = fmt_spec[:-1]
   coords = itertools.chain([abs(self)], self.angles())
   outer_fmt = '<{}>'
  else:
   coords = self
   outer_fmt = '({})'
  components = (format(c, fmt_spec) for c in coords)
  return outer_fmt.format(', '.join(components))

 @classmethod
 def frombytes(cls, octets):
  typecode = chr(octets[0])
  memv = memoryview(octets[1:]).cast(typecode)
  return cls(memv)

va = Vector([1.0, 2.0, 3.0])
vb = Vector(range(1, 4))
print('va == vb:', va == vb)     #两个具有相同数值分量的 Vector 实例是相等的
t3 = (1, 2, 3)
print('va == t3:', va == t3)

print('[1, 2] == (1, 2):', [1, 2] == (1, 2))

上面代码执行返回的结果为:

va == vb: True
va == t3: True
[1, 2] == (1, 2): False

从 Python 自身来找线索,我们发现 [1,2] == (1, 2) 的结果是False。因此,我们要保守一点,做些类型检查。如果第二个操作数是Vector 实例(或者 Vector 子类的实例),那么就使用 __eq__ 方法的当前逻辑。否则,返回 NotImplemented,让 Python 处理。

🌰 vector_v8.py:改进 Vector 类的 __eq__ 方法

  def __eq__(self, other):
   if isinstance(other, Vector):          #判断对比的是否和Vector同属一个实例
    return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
   else:
    return NotImplemented           #否则,返回NotImplemented

改进以后的代码执行结果:

>>> va = Vector([1.0, 2.0, 3.0])
>>> vb = Vector(range(1, 4))
>>> va == vb
True
>>> t3 = (1, 2, 3)
>>> va == t3
False

增量赋值运算符

  Vector 类已经支持增量赋值运算符 += 和 *= 了,示例如下

🌰  增量赋值不会修改不可变目标,而是新建实例,然后重新绑定

>>> v1 = Vector([1, 2, 3])
>>> v1_alias = v1             # 复制一份,供后面审查Vector([1, 2, 3])对象
>>> id(v1)                 # 记住一开始绑定给v1的Vector实例的ID
>>> v1 += Vector([4, 5, 6])       # 增量加法运算
>>> v1                    # 结果与预期相符
Vector([5.0, 7.0, 9.0])
>>> id(v1)                 # 但是创建了新的Vector实例
>>> v1_alias                # 审查v1_alias,确认原来的Vector实例没被修改
Vector([1.0, 2.0, 3.0])
>>> v1 *= 11                # 增量乘法运算
>>> v1                   # 同样,结果与预期相符,但是创建了新的Vector实例
Vector([55.0, 77.0, 99.0])
>>> id(v1)

完整代码:

from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools

class Vector:
 typecode = 'd'

 def __init__(self, components):
  self._components = array(self.typecode, components)

 def __iter__(self):
  return iter(self._components)

 def __repr__(self):
  components = reprlib.repr(self._components)
  components = components[components.find('['):-1]
  return 'Vector({})'.format(components)

 def __str__(self):
  return str(tuple(self))

 def __bytes__(self):
  return (bytes([ord(self.typecode)]) + bytes(self._components))

 def __eq__(self, other):
  if isinstance(other, Vector):
   return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
  else:
   return NotImplemented          

 def __hash__(self):
  hashes = map(hash, self._components)
  return functools.reduce(operator.xor, hashes, 0)

 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)
  return Vector(a + b for a, b in pairs)        

 def __radd__(self, other):
  return self + other

 def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented

 def __rmul__(self, scalar):
  return self * scalar

 def __matmul__(self, other):
  try:
   return sum(a * b for a, b in zip(self, other))
  except TypeError:
   return NotImplemented

 def __rmatmul__(self, other):
  return self @ other

 def __abs__(self):
  return math.sqrt(sum(x * x for x in self))

 def __neg__(self):
  return Vector(-x for x in self)   

 def __pos__(self):
  return Vector(self)       

 def __bool__(self):
  return bool(abs(self))

 def __len__(self):
  return len(self._components)

 def __getitem__(self, index):
  cls = type(self)

  if isinstance(index, slice):
   return cls(self._components[index])
  elif isinstance(index, numbers.Integral):
   return self._components[index]
  else:
   msg = '{.__name__} indices must be integers'
   raise TypeError(msg.format(cls))

 shorcut_names = 'xyzt'

 def __getattr__(self, name):
  cls = type(self)

  if len(name) == 1:
   pos = cls.shorcut_names.find(name)
   if 0 <= pos < len(self._components):
    return self._components[pos]
  msg = '{.__name__!r} object has no attribute {!r}'
  raise AttributeError(msg.format(cls, name))

 def angle(self, n):
  r = math.sqrt(sum(x * x for x in self[n:]))
  a = math.atan2(r, self[n-1])
  if (n == len(self) - 1 ) and (self[-1] < 0):
   return math.pi * 2 - a
  else:
   return a

 def angles(self):
  return (self.angle(n) for n in range(1, len(self)))

 def __format__(self, fmt_spec=''):
  if fmt_spec.endswith('h'):
   fmt_spec = fmt_spec[:-1]
   coords = itertools.chain([abs(self)], self.angles())
   outer_fmt = '<{}>'
  else:
   coords = self
   outer_fmt = '({})'
  components = (format(c, fmt_spec) for c in coords)
  return outer_fmt.format(', '.join(components))

 @classmethod
 def frombytes(cls, octets):
  typecode = chr(octets[0])
  memv = memoryview(octets[1:]).cast(typecode)
  return cls(memv)

总结

以上就是这篇文章的全部内容了,希望本文的内容对大家的学习或者工作能带来一定的帮助,如果有疑问大家可以留言交流,谢谢大家对我们的支持。

时间: 2017-08-24

Python运算符重载详解及实例代码

Python运算符重载 Python语言提供了运算符重载功能,增强了语言的灵活性,这一点与C++有点类似又有些不同.鉴于它的特殊性,今天就来讨论一下Python运算符重载. Python语言本身提供了很多魔法方法,它的运算符重载就是通过重写这些Python内置魔法方法实现的.这些魔法方法都是以双下划线开头和结尾的,类似于__X__的形式,python通过这种特殊的命名方式来拦截操作符,以实现重载.当Python的内置操作运用于类对象时,Python会去搜索并调用对象中指定的方法完成操作. 类可以

Python运算符重载用法实例分析

本文实例讲述了Python运算符重载用法.分享给大家供大家参考.具体如下: 在Python语言中提供了类似于C++的运算符重在功能: 一下为Python运算符重在调用的方法如下: Method         Overloads         Call for __init__        构造函数         X=Class() __del__         析构函数         对象销毁 __add__         +                 X+Y,X+=Y __

Python运算符重载用法实例

本文实例讲述了Python运算符重载用法.分享给大家供大家参考.具体分析如下: python中,我们在定义类的时候,可以通过实现一些函数来实现重载运算符. 例子如下: # -*- coding:utf-8 -*- ''''' Created on 2013-3-21 @author: naughty ''' class Test(object): def __init__(self, value): self.value = value def __add__(self, x): return

python 运算符 供重载参考

二元运算符 特殊方法 + __add__,__radd__ - __sub__,__rsub__ * __mul__,__rmul__ / __div__,__rdiv__,__truediv__,__rtruediv__ // __floordiv__,__rfloordiv__ % __mod__,__rmod__ ** __pow__,__rpow__ << __lshift__,__rlshift__ >> __rshift__,__rrshift__ & __an

C++实践Time类中的运算符重载参考方法

[项目-Time类中的运算符重载] 实现Time类中的运算符重载. class CTime { private: unsigned short int hour; // 时 unsigned short int minute; // 分 unsigned short int second; // 秒 public: CTime(int h=0,int m=0,int s=0); void setTime(int h,int m,int s); void display(); //二目的比较运算符

Python中操作符重载用法分析

本文实例讲述了Python中操作符重载用法.分享给大家供大家参考,具体如下: 类可以重载python的操作符 操作符重载使我们的对象与内置的一样.__X__的名字的方法是特殊的挂钩(hook),python通过这种特殊的命名来拦截操作符,以实现重载. python在计算操作符时会自动调用这样的方法,例如: 如果对象继承了__add__方法,当它出现在+表达式中时会调用这个方法.通过重载,用户定义的对象就像内置的一样. 在类中重载操作符 1.操作符重载使得类能拦截标准的python操作. 2.类可

详解C++中的函数调用和下标以及成员访问运算符的重载

函数调用 使用括号调用的函数调用运算符是二元运算符. 语法 primary-expression ( expression-list ) 备注 在此上下文中,primary-expression 为第一个操作数,并且 expression-list(可能为参数的空列表)为第二个操作数.函数调用运算符用于需要大量参数的操作.这之所以有效,是因为 expression-list 是列表而非单一操作数.函数调用运算符必须是非静态成员函数. 函数调用运算符在重载时不会修改函数的调用方式:相反,它会在运算

老生常谈python中的重载

在一些静态语言中,大都存在有一个重载的概念.这是在OOP(面对对象编程)中一个必不可少的一个行为. 所谓重载,就是多个相同函数名的函数,根据传入的参数个数,参数类型而执行不同的功能.所以函数重载实质上是为了解决编程中参数可变不统一的问题. python 中的重载   在python中,具有重载的思想却没有重载的概念.所以有的人说python这么语言并不支持函数重载,有的人说python具有重载功能.实际上python编程中具有重载的目的缺无重载的行为,或者说是python并不需要重载!   py

C++ 流插入和流提取运算符的重载的实现

01 流插入<<运算符的重载 C++ 在输出内容时,最常用的方式: std::cout << 1 <<"hello"; 问题: 那这条语句为什么能成立呢? cout 是什么?"<<" 运算符能用在 cout 上呢? 原因: 实际上,cout 是在 iostream 头文件中定义的 ostream 类的对象. "<<" 能够用在 cout 上是因为,在 ostream 类对 "&

C++ 自增、自减运算符的重载和性能分析小结

01 ++.--运算符重载函数的格式 自增运算符和自减运算符是有前置和后置之分的,如: a++ // 后置自增运算符 ++a // 前置自增运算符 b-- // 后置自减运算符 --b // 前置自减运算符 为了区分所重载的是前置运算符还是后置运算符,C++规定: 前置运算符作为一元运算符重载,重载为成员函数的格式如下: T & operator++(); // 前置自增运算符的重载函数,函数参数是空 T & operator--(); // 前置自减运算符的重载函数,函数参数是空 后置运