Chireiden

地霊殿

地霊殿,充满幻想与希望的殿堂


夢も希望も無い、毎日がそんな生活だった。

Paddle开发入门

过年期间看到 Paddle 框架那边发了新的 issue,所以顺手修了几个,这里取几个例子来简单介绍其修复过程,

paddle.flip

事实上的首个动手目标,它的出错输入是这样的:

import paddle
import numpy as np
from paddle import flip
x = paddle.to_tensor(np.random.uniform(-10, 10, [1, 2, 3]).astype(np.int64)),
axis = paddle.to_tensor(np.random.uniform(-2147483648, 2147483647, [3, 3]).astype(np.int32))
print(x)
print(axis)
flip(x, axis)

翻到 Paddle 源码看看

def flip(x, axis, name=None):
    """
    Reverse the order of a n-D tensor along given axis in axis.

    Args:
        x (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor x
            should be float32, float64, int32, int64, bool.
        axis (list|tuple|int): The axis(axes) to flip on. Negative indices for indexing from the end are accepted.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor, Tensor or LoDTensor calculated by flip layer. The data type is same with input x.

    Examples:
        .. code-block:: python

          import paddle

          image_shape=(3, 2, 2)
          img = paddle.arange(image_shape[0] * image_shape[1] * image_shape[2]).reshape(image_shape)
          tmp = paddle.flip(img, [0,1])
          print(tmp) # [[[10,11],[8, 9]], [[6, 7],[4, 5]], [[2, 3],[0, 1]]]

          out = paddle.flip(tmp,-1)
          print(out) # [[[11,10],[9, 8]], [[7, 6],[5, 4]], [[3, 2],[1, 0]]]
    """
    if isinstance(axis, int):
        axis = [axis]

    if in_dygraph_mode():
        return _C_ops.flip(x, axis)

思路很明显,flip 期望输入一个一维的数组,但是错误的案例里输入了二维的。因此解决方案就更明显了,不妨加一个维度检查。

加在哪呢?很直观的,既然上面的 python 代码里处理了输入单个整数时将其转换为整数,那就给它加上一个维度检查,形如

    if paddle.to_tensor(axis).ndim != 1:
        raise ValueError('维度不对,麻烦再检查检查,只要一维的')

非常合理,而且简单易懂,维度不为 1 就报错,于是事情就这样成了(并没有)。

因为 Good First Issue Lists 里面还有这样的对话。

您好,想问下是要在 Python 端加输入检查,还是要在C++端用 PADDLE_ENFORCE_EQ 做检查呢 ?

根据情况而定。能够在C++端加的,一定要在C++端加。如果C++端不具备条件的,则在Python端加。

也就是说,应该尽可能在 C++ 端加对应检查。

C++ 端加检查难吗?不难,顺着 flip 查找内核实现,就能看到 flip_kernel.cc 里面对应的代码:

template <typename T, typename Context>
void FlipKernel(const Context& dev_ctx,
                const DenseTensor& x,
                const std::vector<int>& axis,
                DenseTensor* out);

PD_REGISTER_KERNEL(flip,
                   CPU,
                   ALL_LAYOUT,
                   phi::FlipKernel,
                   float,
                   double,
                   int32_t,
                   int64_t,
                   bool,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}

也就是说,只要看 FlipKernel 就行了。

逻辑很合理,但是失败了,gdb 调试跑了一遍(gdb 7 支持了 python 调试),看到了和 #49922 相同的调用栈。

AddressSanitizer:DEADLYSIGNAL
=================================================================
==92083==ERROR: AddressSanitizer: SEGV on unknown address 0x000000000008 (pc 0x7faceb79eab3 bp 0x7ffe95a59990 sp 0x7ffe95a598e0 T0)
==92083==The signal is caused by a READ memory access.
==92083==Hint: address points to the zero page.
    #0 0x7faceb79eab3 in paddle::pybind::PyObject_CheckLongOrToLong(_object**) /home/work/yakun/paddle-2.4.0/Paddle/paddle/fluid/pybind/op_function_common.cc:69:8
    #1 0x7faceb7a4dd5 in paddle::pybind::CastPyArg2Ints(_object*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, long) /home/work/yakun/paddle-2.4.0/Paddle/paddle/fluid/pybind/op_function_common.cc:357:11
    #2 0x7facea7c4caf in paddle::pybind::eager_api_flip(_object*, _object*, _object*) /home/work/yakun/paddle-2.4.0/Paddle/paddle/fluid/pybind/eager_op_function.cc:1216:29
    #3 0x5025f3 in PyCFunction_Call (/usr/bin/python3.8+0x5025f3) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #4 0x500db4 in _PyObject_MakeTpCall (/usr/bin/python3.8+0x500db4) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #5 0x566223 in _PyEval_EvalFrameDefault (/usr/bin/python3.8+0x566223) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #6 0x55f470 in _PyEval_EvalCodeWithName (/usr/bin/python3.8+0x55f470) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #7 0x5016c5 in _PyFunction_Vectorcall (/usr/bin/python3.8+0x5016c5) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #8 0x560136 in _PyEval_EvalFrameDefault (/usr/bin/python3.8+0x560136) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #9 0x55f470 in _PyEval_EvalCodeWithName (/usr/bin/python3.8+0x55f470) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #10 0x55f102 in PyEval_EvalCode (/usr/bin/python3.8+0x55f102) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #11 0x62a1ef  (/usr/bin/python3.8+0x62a1ef) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #12 0x62a179  (/usr/bin/python3.8+0x62a179) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #13 0x47a7f2  (/usr/bin/python3.8+0x47a7f2) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #14 0x47a5cb in PyRun_SimpleFileExFlags (/usr/bin/python3.8+0x47a5cb) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #15 0x4247dc in _init (/usr/bin/python3.8+0x4247dc) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #16 0x5fb9b8 in Py_BytesMain (/usr/bin/python3.8+0x5fb9b8) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
    #17 0x7fad0bf9783f in __libc_start_main /build/glibc-S7Ft5T/glibc-2.23/csu/../csu/libc-start.c:291
    #18 0x5fb8b8 in _start (/usr/bin/python3.8+0x5fb8b8) (BuildId: 69b06f9a4b2e8428d7e32aa682c34a91dc0b961e)
AddressSanitizer can not provide additional info.
SUMMARY: AddressSanitizer: SEGV /home/work/yakun/paddle-2.4.0/Paddle/paddle/fluid/pybind/op_function_common.cc:69:8 in paddle::pybind::PyObject_CheckLongOrToLong(_object**)
==92083==ABORTING

可以看到,调用栈完全没进入到 FlipKernel 中,而是停在了 eager_api_flip,这个自动生成的文件里,并且更具体的,是 CastPyArg2Ints 这个函数。

这个函数简单来说,是将类型为 PyObjectpython 变量转换成 std::vector<int>cpp 值,首先判断输入变量是否可迭代,如果可迭代,就将每个元素转换为 intlong)。

而问题就出在这个转换过程:先检查是否能转换为 long,如果能,则进行转换,否则抛出错误。

bool PyObject_CheckLongOrToLong(PyObject** obj) {
  if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) ||
      PyObject_IsInstance(*obj, (PyObject*)g_vartype_pytype) ||  // NOLINT
      PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) ||  // NOLINT
      PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) {     // NOLINT
    return true;
  }

如果 obj 是(能转换到) long 并且 obj 不是 bool,则返回 true或者,如果 objVariableTensor,返回 true

  • itemlong 的时候
    • 检查能否转换 PyObject_CheckLongOrToLong 结果为 true, 检查通过
    • 进行 PyLong_AsLong得到一个 long 变量的指针
  • item 为 tensor 的时候
    • 检查能否转换 PyObject_CheckLongOrToLong 结果为 true(不确定内部逻辑通过原因), 检查通过
    • 进行 PyLong_AsLong得到一个空指针

因而出错。

简单来说,PyObject_CheckLongOrToLongVariableTensor 这种复杂类型的检查,不够详尽。

最后进行的修复策略也同样简单。

  if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) ||
      PyObject_IsInstance(*obj, (PyObject*)g_vartype_pytype) ||  // NOLINT
      PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) ||  // NOLINT
      PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) {     // NOLINT
      (PyObject_IsInstance(*obj, (PyObject*)p_tensor_type) &&    // NOLINT
       (((TensorObject*)(*obj))->tensor.numel() == 1))) {        // NOLINT
    return true;
  }

即检查是否是仅存在单个元素,如果是,则表明其支持转换。

通过添加该检查,能够解决该问题。

但是通过敏锐的观察力,不难发现,对于其他数据类型,如 float 等,同样存在该问题。

同时,也不难发现,对于 Variable (static),同样存在该问题。

因此短期小目标是对这些检查代码进行修改,以提高其稳健性。

一步式操作:

git config --global url."https://ghproxy.com/https://github.com/".insteadOf "https://github.com/"
git config --global protocol.https.allow always

竹林里站着一个孩子,十来岁的孩子。

他什么也没做,只是站着,盯着竹子,满脸的愤懑。

他所在的竹林在陈家后山上,外人不得入内,所以,他自然也是陈家的人。

可惜,他的七品资质,在寻常人家里,能让邻里人前人后议论数月不止的七品资质,却让他快要做不了陈家的人了。

Continue
编程通解编程思想小说 Continue