使用3D NumPy数组的掩码->返回2D数组数组、掩码、NumPy、amp

2023-09-03 11:53:14 作者:╭糖豆⌒

我所拥有的:

import numpy as np
np.random.seed(42)
dlen = 250000
data = np.random.rand(dlen, 3, 3)
mask = np.random.choice([0, 1, 2], dlen)

我想要的:

[[0.37454012 0.95071431 0.73199394], 
 [0.83244264 0.21233911 0.18182497], 
 [0.13949386 0.29214465 0.36636184], 
 [0.94888554 0.96563203 0.80839735], 
 [0.44015249 0.12203823 0.49517691],
 ....
(250000, 3)
numpy数组与布尔数组 numpy数组的基础

我尝试使用的内容:

data[:,mask,:]

{MemoryError}Unable to allocate 1.36 TiB for an array with shape (250000, 250000, 3) and data type float64

给出正确结果但看起来奇怪的是什么:

data[np.arange(data.shape[0]), mask, :]

那么使用此口罩的正确方式是什么?

更新。: 掩码应选择具有指定索引的列。形状为[2,3,3]的数组示例:

array = [[[5 6 7], [7 8 9], [2 3 4]],
         [[2 1 0], [7 6 5], [7 6 5]]]
mask = [1 0]
result = [[7 8 9], 
          [2 1 0]]

推荐答案

data[np.arange(data.shape[0]), mask, :]

这是可行的,因为它是multi-dimensional index array

当我在这里使用术语掩码时,我想到了布尔索引。您的整数掩码可以转换为布尔掩码,以便以您想要的方式使用它。

>>> data.shape                 
(250000, 3, 3)
>>> mask.shape
(250000,)
>>> q = mask[:,None] == [0,1,2]
>>> q.shape
(250000, 3)
>>> q[:5]        
array([[ True, False, False],
       [False,  True, False],
       [False,  True, False],
       [False, False,  True],
       [False,  True, False]])
>>> r = data[q]
>>> r.shape
(250000, 3)
>>> r[:10]
array([[0.37454012, 0.95071431, 0.73199394],
       [0.83244264, 0.21233911, 0.18182497],
       [0.13949386, 0.29214465, 0.36636184],
       [0.94888554, 0.96563203, 0.80839735],
       [0.44015249, 0.12203823, 0.49517691],
       [0.66252228, 0.31171108, 0.52006802],
       [0.59789998, 0.92187424, 0.0884925 ],
       [0.14092422, 0.80219698, 0.07455064],
       [0.00552212, 0.81546143, 0.70685734],
       [0.31098232, 0.32518332, 0.72960618]])
>>>

您可以使用第二个维度长度来使其更通用一些:

q = mask[:,None] == np.arange(data.shape[1])
>>> q[:5]                                        
array([[ True, False, False], 
       [False,  True, False], 
       [False,  True, False], 
       [False, False,  True], 
       [False,  True, False]])

如果您控制掩码的构造,您可能希望将其构造为布尔数组。

如果这是新代码,您可能需要升级到兼容版本的Numpy并使用新的random generator。

 
精彩推荐
图片推荐