mirror of
https://github.com/Estom/notes.git
synced 2026-02-06 12:04:05 +08:00
373 lines
11 KiB
Markdown
373 lines
11 KiB
Markdown
# 基本单位
|
|
|
|
```python
|
|
import math
|
|
|
|
import numpy as np
|
|
|
|
import matplotlib.units as units
|
|
import matplotlib.ticker as ticker
|
|
from matplotlib.cbook import iterable
|
|
|
|
|
|
class ProxyDelegate(object):
|
|
def __init__(self, fn_name, proxy_type):
|
|
self.proxy_type = proxy_type
|
|
self.fn_name = fn_name
|
|
|
|
def __get__(self, obj, objtype=None):
|
|
return self.proxy_type(self.fn_name, obj)
|
|
|
|
|
|
class TaggedValueMeta(type):
|
|
def __init__(self, name, bases, dict):
|
|
for fn_name in self._proxies:
|
|
try:
|
|
dummy = getattr(self, fn_name)
|
|
except AttributeError:
|
|
setattr(self, fn_name,
|
|
ProxyDelegate(fn_name, self._proxies[fn_name]))
|
|
|
|
|
|
class PassThroughProxy(object):
|
|
def __init__(self, fn_name, obj):
|
|
self.fn_name = fn_name
|
|
self.target = obj.proxy_target
|
|
|
|
def __call__(self, *args):
|
|
fn = getattr(self.target, self.fn_name)
|
|
ret = fn(*args)
|
|
return ret
|
|
|
|
|
|
class ConvertArgsProxy(PassThroughProxy):
|
|
def __init__(self, fn_name, obj):
|
|
PassThroughProxy.__init__(self, fn_name, obj)
|
|
self.unit = obj.unit
|
|
|
|
def __call__(self, *args):
|
|
converted_args = []
|
|
for a in args:
|
|
try:
|
|
converted_args.append(a.convert_to(self.unit))
|
|
except AttributeError:
|
|
converted_args.append(TaggedValue(a, self.unit))
|
|
converted_args = tuple([c.get_value() for c in converted_args])
|
|
return PassThroughProxy.__call__(self, *converted_args)
|
|
|
|
|
|
class ConvertReturnProxy(PassThroughProxy):
|
|
def __init__(self, fn_name, obj):
|
|
PassThroughProxy.__init__(self, fn_name, obj)
|
|
self.unit = obj.unit
|
|
|
|
def __call__(self, *args):
|
|
ret = PassThroughProxy.__call__(self, *args)
|
|
return (NotImplemented if ret is NotImplemented
|
|
else TaggedValue(ret, self.unit))
|
|
|
|
|
|
class ConvertAllProxy(PassThroughProxy):
|
|
def __init__(self, fn_name, obj):
|
|
PassThroughProxy.__init__(self, fn_name, obj)
|
|
self.unit = obj.unit
|
|
|
|
def __call__(self, *args):
|
|
converted_args = []
|
|
arg_units = [self.unit]
|
|
for a in args:
|
|
if hasattr(a, 'get_unit') and not hasattr(a, 'convert_to'):
|
|
# if this arg has a unit type but no conversion ability,
|
|
# this operation is prohibited
|
|
return NotImplemented
|
|
|
|
if hasattr(a, 'convert_to'):
|
|
try:
|
|
a = a.convert_to(self.unit)
|
|
except:
|
|
pass
|
|
arg_units.append(a.get_unit())
|
|
converted_args.append(a.get_value())
|
|
else:
|
|
converted_args.append(a)
|
|
if hasattr(a, 'get_unit'):
|
|
arg_units.append(a.get_unit())
|
|
else:
|
|
arg_units.append(None)
|
|
converted_args = tuple(converted_args)
|
|
ret = PassThroughProxy.__call__(self, *converted_args)
|
|
if ret is NotImplemented:
|
|
return NotImplemented
|
|
ret_unit = unit_resolver(self.fn_name, arg_units)
|
|
if ret_unit is NotImplemented:
|
|
return NotImplemented
|
|
return TaggedValue(ret, ret_unit)
|
|
|
|
|
|
class TaggedValue(metaclass=TaggedValueMeta):
|
|
|
|
_proxies = {'__add__': ConvertAllProxy,
|
|
'__sub__': ConvertAllProxy,
|
|
'__mul__': ConvertAllProxy,
|
|
'__rmul__': ConvertAllProxy,
|
|
'__cmp__': ConvertAllProxy,
|
|
'__lt__': ConvertAllProxy,
|
|
'__gt__': ConvertAllProxy,
|
|
'__len__': PassThroughProxy}
|
|
|
|
def __new__(cls, value, unit):
|
|
# generate a new subclass for value
|
|
value_class = type(value)
|
|
try:
|
|
subcls = type('TaggedValue_of_%s' % (value_class.__name__),
|
|
tuple([cls, value_class]),
|
|
{})
|
|
if subcls not in units.registry:
|
|
units.registry[subcls] = basicConverter
|
|
return object.__new__(subcls)
|
|
except TypeError:
|
|
if cls not in units.registry:
|
|
units.registry[cls] = basicConverter
|
|
return object.__new__(cls)
|
|
|
|
def __init__(self, value, unit):
|
|
self.value = value
|
|
self.unit = unit
|
|
self.proxy_target = self.value
|
|
|
|
def __getattribute__(self, name):
|
|
if name.startswith('__'):
|
|
return object.__getattribute__(self, name)
|
|
variable = object.__getattribute__(self, 'value')
|
|
if hasattr(variable, name) and name not in self.__class__.__dict__:
|
|
return getattr(variable, name)
|
|
return object.__getattribute__(self, name)
|
|
|
|
def __array__(self, dtype=object):
|
|
return np.asarray(self.value).astype(dtype)
|
|
|
|
def __array_wrap__(self, array, context):
|
|
return TaggedValue(array, self.unit)
|
|
|
|
def __repr__(self):
|
|
return 'TaggedValue({!r}, {!r})'.format(self.value, self.unit)
|
|
|
|
def __str__(self):
|
|
return str(self.value) + ' in ' + str(self.unit)
|
|
|
|
def __len__(self):
|
|
return len(self.value)
|
|
|
|
def __iter__(self):
|
|
# Return a generator expression rather than use `yield`, so that
|
|
# TypeError is raised by iter(self) if appropriate when checking for
|
|
# iterability.
|
|
return (TaggedValue(inner, self.unit) for inner in self.value)
|
|
|
|
def get_compressed_copy(self, mask):
|
|
new_value = np.ma.masked_array(self.value, mask=mask).compressed()
|
|
return TaggedValue(new_value, self.unit)
|
|
|
|
def convert_to(self, unit):
|
|
if unit == self.unit or not unit:
|
|
return self
|
|
new_value = self.unit.convert_value_to(self.value, unit)
|
|
return TaggedValue(new_value, unit)
|
|
|
|
def get_value(self):
|
|
return self.value
|
|
|
|
def get_unit(self):
|
|
return self.unit
|
|
|
|
|
|
class BasicUnit(object):
|
|
def __init__(self, name, fullname=None):
|
|
self.name = name
|
|
if fullname is None:
|
|
fullname = name
|
|
self.fullname = fullname
|
|
self.conversions = dict()
|
|
|
|
def __repr__(self):
|
|
return 'BasicUnit(%s)' % self.name
|
|
|
|
def __str__(self):
|
|
return self.fullname
|
|
|
|
def __call__(self, value):
|
|
return TaggedValue(value, self)
|
|
|
|
def __mul__(self, rhs):
|
|
value = rhs
|
|
unit = self
|
|
if hasattr(rhs, 'get_unit'):
|
|
value = rhs.get_value()
|
|
unit = rhs.get_unit()
|
|
unit = unit_resolver('__mul__', (self, unit))
|
|
if unit is NotImplemented:
|
|
return NotImplemented
|
|
return TaggedValue(value, unit)
|
|
|
|
def __rmul__(self, lhs):
|
|
return self*lhs
|
|
|
|
def __array_wrap__(self, array, context):
|
|
return TaggedValue(array, self)
|
|
|
|
def __array__(self, t=None, context=None):
|
|
ret = np.array([1])
|
|
if t is not None:
|
|
return ret.astype(t)
|
|
else:
|
|
return ret
|
|
|
|
def add_conversion_factor(self, unit, factor):
|
|
def convert(x):
|
|
return x*factor
|
|
self.conversions[unit] = convert
|
|
|
|
def add_conversion_fn(self, unit, fn):
|
|
self.conversions[unit] = fn
|
|
|
|
def get_conversion_fn(self, unit):
|
|
return self.conversions[unit]
|
|
|
|
def convert_value_to(self, value, unit):
|
|
conversion_fn = self.conversions[unit]
|
|
ret = conversion_fn(value)
|
|
return ret
|
|
|
|
def get_unit(self):
|
|
return self
|
|
|
|
|
|
class UnitResolver(object):
|
|
def addition_rule(self, units):
|
|
for unit_1, unit_2 in zip(units[:-1], units[1:]):
|
|
if unit_1 != unit_2:
|
|
return NotImplemented
|
|
return units[0]
|
|
|
|
def multiplication_rule(self, units):
|
|
non_null = [u for u in units if u]
|
|
if len(non_null) > 1:
|
|
return NotImplemented
|
|
return non_null[0]
|
|
|
|
op_dict = {
|
|
'__mul__': multiplication_rule,
|
|
'__rmul__': multiplication_rule,
|
|
'__add__': addition_rule,
|
|
'__radd__': addition_rule,
|
|
'__sub__': addition_rule,
|
|
'__rsub__': addition_rule}
|
|
|
|
def __call__(self, operation, units):
|
|
if operation not in self.op_dict:
|
|
return NotImplemented
|
|
|
|
return self.op_dict[operation](self, units)
|
|
|
|
|
|
unit_resolver = UnitResolver()
|
|
|
|
cm = BasicUnit('cm', 'centimeters')
|
|
inch = BasicUnit('inch', 'inches')
|
|
inch.add_conversion_factor(cm, 2.54)
|
|
cm.add_conversion_factor(inch, 1/2.54)
|
|
|
|
radians = BasicUnit('rad', 'radians')
|
|
degrees = BasicUnit('deg', 'degrees')
|
|
radians.add_conversion_factor(degrees, 180.0/np.pi)
|
|
degrees.add_conversion_factor(radians, np.pi/180.0)
|
|
|
|
secs = BasicUnit('s', 'seconds')
|
|
hertz = BasicUnit('Hz', 'Hertz')
|
|
minutes = BasicUnit('min', 'minutes')
|
|
|
|
secs.add_conversion_fn(hertz, lambda x: 1./x)
|
|
secs.add_conversion_factor(minutes, 1/60.0)
|
|
|
|
|
|
# radians formatting
|
|
def rad_fn(x, pos=None):
|
|
if x >= 0:
|
|
n = int((x / np.pi) * 2.0 + 0.25)
|
|
else:
|
|
n = int((x / np.pi) * 2.0 - 0.25)
|
|
|
|
if n == 0:
|
|
return '0'
|
|
elif n == 1:
|
|
return r'$\pi/2$'
|
|
elif n == 2:
|
|
return r'$\pi$'
|
|
elif n == -1:
|
|
return r'$-\pi/2$'
|
|
elif n == -2:
|
|
return r'$-\pi$'
|
|
elif n % 2 == 0:
|
|
return r'$%s\pi$' % (n//2,)
|
|
else:
|
|
return r'$%s\pi/2$' % (n,)
|
|
|
|
|
|
class BasicUnitConverter(units.ConversionInterface):
|
|
@staticmethod
|
|
def axisinfo(unit, axis):
|
|
'return AxisInfo instance for x and unit'
|
|
|
|
if unit == radians:
|
|
return units.AxisInfo(
|
|
majloc=ticker.MultipleLocator(base=np.pi/2),
|
|
majfmt=ticker.FuncFormatter(rad_fn),
|
|
label=unit.fullname,
|
|
)
|
|
elif unit == degrees:
|
|
return units.AxisInfo(
|
|
majloc=ticker.AutoLocator(),
|
|
majfmt=ticker.FormatStrFormatter(r'$%i^\circ$'),
|
|
label=unit.fullname,
|
|
)
|
|
elif unit is not None:
|
|
if hasattr(unit, 'fullname'):
|
|
return units.AxisInfo(label=unit.fullname)
|
|
elif hasattr(unit, 'unit'):
|
|
return units.AxisInfo(label=unit.unit.fullname)
|
|
return None
|
|
|
|
@staticmethod
|
|
def convert(val, unit, axis):
|
|
if units.ConversionInterface.is_numlike(val):
|
|
return val
|
|
if iterable(val):
|
|
return [thisval.convert_to(unit).get_value() for thisval in val]
|
|
else:
|
|
return val.convert_to(unit).get_value()
|
|
|
|
@staticmethod
|
|
def default_units(x, axis):
|
|
'return the default unit for x or None'
|
|
if iterable(x):
|
|
for thisx in x:
|
|
return thisx.unit
|
|
return x.unit
|
|
|
|
|
|
def cos(x):
|
|
if iterable(x):
|
|
return [math.cos(val.convert_to(radians).get_value()) for val in x]
|
|
else:
|
|
return math.cos(x.convert_to(radians).get_value())
|
|
|
|
|
|
basicConverter = BasicUnitConverter()
|
|
units.registry[BasicUnit] = basicConverter
|
|
units.registry[TaggedValue] = basicConverter
|
|
```
|
|
|
|
## 下载这个示例
|
|
|
|
- [下载python源码: basic_units.py](https://matplotlib.org/_downloads/basic_units.py)
|
|
- [下载Jupyter notebook: basic_units.ipynb](https://matplotlib.org/_downloads/basic_units.ipynb) |