A style guide is about consistency. Consistency with this style guide is important. Consistency within a project is more important. Consistency within one module or function is the most important.
PEP 8 -- Style Guide for Python Code
一个需要注意的地方是,PEP 8 的代码规范并不是绝对的,项目内的一致性要优先于 PEP 8 的规范。OpenMMLab 各个项目都在 setup.cfg 设定了一些代码规范的设置,请遵照这些设置。一个例子是在 PEP 8 中有如下一个例子:
# Correct:
hypot2 = x*x + y*y
# Wrong:
hypot2 = x * x + y * y
"""A one line summary of the module or program, terminated by a period.
Leave one blank line. The rest of this docstring should contain an
overall description of the module or program. Optionally, it may also
contain a brief description of exported classes and functions and/or usage
examples.
Typical usage example:
foo = ClassFoo()
bar = foo.FunctionBar()
"""
class BaseRunner(metaclass=ABCMeta):
"""The base class of Runner, a training helper for PyTorch.
All subclasses should implement the following APIs:
- ``run()``
- ``train()``
- ``val()``
- ``save_checkpoint()``
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable, optional): A callable method that process
a data batch. The interface of this method should be
``batch_processor(model, data, train_mode) -> dict``.
Defaults to None.
optimizer (dict | :obj:`torch.optim.Optimizer` | None): It can be
either an optimizer (in most cases) or a dict of optimizers
(in models that requires more than one optimizer, e.g., GAN).
Defaults to None.
work_dir (str, optional): The working directory to save checkpoints
and logs. Defaults to None.
logger (:obj:`logging.Logger`): Logger used during training.
Defaults to None. (The default value is just for backward
compatibility)
meta (dict, optional): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None.
max_epochs (int, optional): Total training epochs. Defaults to None.
max_iters (int, optional): Total training iterations. Defaults to None.
"""
def __init__(self,
model,
batch_processor=None,
optimizer=None,
work_dir=None,
logger=None,
meta=None,
max_iters=None,
max_epochs=None):
...
# 参考实现
# This func is modified from `detectron2
# <https://github.com/facebookresearch/detectron2/blob/ffff8acc35ea88ad1cb1806ab0f00b4c1c5dbfd9/detectron2/structures/masks.py#L387>`_.
# 复制代码
# This code was copied from the `ubelt
# library<https://github.com/Erotemic/ubelt>`_.
# 引用论文 & 添加公式
class LabelSmoothLoss(nn.Module):
r"""Intializer for the label smoothed cross entropy loss.
Refers to `Rethinking the Inception Architecture for Computer Vision
<https://arxiv.org/abs/1512.00567>`_.
This decreases gap between output scores and encourages generalization.
Labels provided to forward can be one-hot like vectors (NxC) or class
indices (Nx1).
And this accepts linear combination of one-hot like labels from mixup or
cutmix except multi-label task.
Args:
label_smooth_val (float): The degree of label smoothing.
num_classes (int, optional): Number of classes. Defaults to None.
mode (str): Refers to notes, Options are "original", "classy_vision",
"multi_label". Defaults to "classy_vision".
reduction (str): The method used to reduce the loss.
Options are "none", "mean" and "sum". Defaults to 'mean'.
loss_weight (float): Weight of the loss. Defaults to 1.0.
Note:
if the ``mode`` is "original", this will use the same label smooth
method as the original paper as:
.. math::
(1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K}
where :math:`\epsilon` is the ``label_smooth_val``, :math:`K` is
the ``num_classes`` and :math:`\delta_{k,y}` is Dirac delta,
which equals 1 for k=y and 0 otherwise.
if the ``mode`` is "classy_vision", this will use the same label
smooth method as the `facebookresearch/ClassyVision
<https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/losses/label_smoothing_loss.py>`_ repo as:
.. math::
\frac{\delta_{k, y} + \epsilon/K}{1+\epsilon}
if the ``mode`` is "multi_label", this will accept labels from
multi-label task and smoothing them as:
.. math::
(1-2\epsilon)\delta_{k, y} + \epsilon
def import_modules_from_strings(imports, allow_failed_imports=False):
"""Import modules from the given list of strings.
Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Defaults to False.
Returns:
List[module] | module | None: The imported modules.
All these three lines in docstring will be compiled into the same
line in readthedocs.
Examples:
>>> osp, sys = import_modules_from_strings(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
...
class CheckpointHook(Hook):
"""Save checkpoints periodically.
Args:
. ...
out_dir (str, optional): The root directory to save checkpoints. If
not specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of
``out_dir`` and the last level directory of ``runner.work_dir``.
Defaults to None. `Changed in version 1.3.15.`
...
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Defaults to None. `New in version 1.3.15.`
Warning:
Before v1.3.15, the ``out_dir`` argument indicates the path where the
checkpoint is stored. However, in v1.3.15 and later, ``out_dir``
indicates the root directory and the final path to save checkpoint is
the concatenation of out_dir and the last level directory of
``runner.work_dir``. Suppose the value of ``out_dir`` is
"/path/of/A" and the value of ``runner.work_dir`` is "/path/of/B",
then the final path will be "/path/of/A/B".
如果参数或返回值里带有需要展开描述字段的 dict,则应该采用如下格式:
def func(x):
r"""
Args:
x (None): A dict with 2 keys, ``padded_targets``, and ``targets``.
- | ``targets`` (list[Tensor]): A list of tensors.
Each tensor has the shape of :math:`(T_i)`. Each
element is the index of a character.
- | ``padded_targets`` (Tensor): A tensor of shape :math:`(N)`.
Each item is the length of a word.
Returns:
dict: A dict with 2 keys, ``padded_targets``, and ``targets``.
- | ``targets`` (list[Tensor]): A list of tensors.
Each tensor has the shape of :math:`(T_i)`. Each
element is the index of a character.
- | ``padded_targets`` (Tensor): A tensor of shape :math:`(N)`.
Each item is the length of a word.
"""
return x
# We use a weighted dictionary search to find out where i is in
# the array. We extrapolate position based on the largest num
# in the array and the array size and then do binary search to
# get the exact number.
if i & (i-1) == 0: # True if i is 0 or a power of 2.
# Wrong:
# Now go through the b array and make sure whenever i occurs
# the next element is i+1
# Wrong:
if i & (i-1) == 0: # True if i bitwise and i-1 is 0.
# `_reversed_padding_repeated_twice` is the padding to be passed to
# `F.pad` if needed (e.g., for non-zero padding types that are
# implemented as two ops: padding + conv). `F.pad` accepts paddings in
# reverse order than the dimension.
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
def _save_ckpt(checkpoint, file):
# The 1.6 release of PyTorch switched torch.save to use a new
# zipfile-based file format. It will cause RuntimeError when a
# checkpoint was saved in high version (PyTorch version>=1.6.0) but
# loaded in low version (PyTorch version<1.6.0). More details at
# https://github.com/open-mmlab/mmpose/issues/904
if digit_version(TORCH_VERSION) >= digit_version('1.6.0'):
torch.save(checkpoint, file, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, file)
from typing import TypeVar, List
T = TypeVar('T') # Can be anything
A = TypeVar('A', str, bytes) # Must be str or bytes
def repeat(x: T, n: int) -> List[T]:
"""Return a list containing n references to x."""
return [x]*n
def longest(x: A, y: A) -> A:
"""Return the longest of two strings."""
return x if len(x) >= len(y) else y
test.py:4: error: Incompatible types in assignment (expression has type "float", variable has type "int")
test.py:4: error: Argument 1 to "foo" has incompatible type "str"; expected "int"
Found 2 errors in 1 file (checked 1 source file)