np.expand_dims: AxisError: axis 4 is out of bounds for array of dimension 4

发布时间 2023-10-09 22:41:42作者: emanlee

np.expand_dims

 axis = 0时,[]加在最外面

axis = 1时,给每一行都加[]

axis = 2时,给每一个元素都加[]

 

 

x_train = np.expand_dims(X, axis=4)
---------------------------------------------------------------------------
AxisError                                 Traceback (most recent call last)
Cell In[5], line 10
      8 #X[:, [0, 2], :] = X[:, [2, 0], :]
      9 X, Y = shuffle(X, Y, random_state=0)
---> 10 x_train = np.expand_dims(X, axis=4)
     11 y_train = Y
     13 #calculate class weights

File <__array_function__ internals>:180, in expand_dims(*args, **kwargs)

File /home/software/anaconda3/envs/mydlenv/lib/python3.8/site-packages/numpy/lib/shape_base.py:597, in expand_dims(a, axis)
    594     axis = (axis,)
    596 out_ndim = len(axis) + a.ndim
--> 597 axis = normalize_axis_tuple(axis, out_ndim)
    599 shape_it = iter(a.shape)
    600 shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]

File /home/software/anaconda3/envs/mydlenv/lib/python3.8/site-packages/numpy/core/numeric.py:1397, in normalize_axis_tuple(axis, ndim, argname, allow_duplicate)
   1395         pass
   1396 # Going via an iterator directly is slower than via list comprehension.
-> 1397 axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
   1398 if not allow_duplicate and len(set(axis)) != len(axis):
   1399     if argname:

File /home/software/anaconda3/envs/mydlenv/lib/python3.8/site-packages/numpy/core/numeric.py:1397, in <listcomp>(.0)
   1395         pass
   1396 # Going via an iterator directly is slower than via list comprehension.
-> 1397 axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
   1398 if not allow_duplicate and len(set(axis)) != len(axis):
   1399     if argname:

AxisError: axis 4 is out of bounds for array of dimension 4

 

 

 

 

 

 http://www.xavierdupre.fr/app/mlprodict/helpsphinx/onnxops/onnx__Unsqueeze.html

 

 

x = np.random.randn(3, 4, 5).astype(np.float32)

for i in range(x.ndim):  # 0,1,2
    axes = np.array([i]).astype(np.int64)
    
    y = np.expand_dims(x, axis=i)
    print(i,y)
    

 

x = np.random.randn(3, 4, 5).astype(np.float32)

print(x.ndim)
y = np.expand_dims(y, axis=4)
y

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

REF

https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html