有很多方法可以展平嵌套列表。我在这里复制一个解决方案仅供引用:
def flatten(x):
result = []
for el in x:
if hasattr(el, "__iter__") and not isinstance(el, basestring):
result.extend(flatten(el))
else:
result.append(el)
return result
我感兴趣的是逆运算,它将列表重建为其原始格式。例如:
L = [[array([[ 24, -134],[ -67, -207]])],
[array([[ 204, -45],[ 99, -118]])],
[array([[ 43, -154],[-122, 168]]), array([[ 33, -110],[ 147, -26],[ -49, -122]])]]
# flattened version
L_flat = [24, -134, -67, -207, 204, -45, 99, -118, 43, -154, -122, 168, 33, -110, 147, -26, -49, -122]
是否有一种有效的方法来展平、保存索引并重建其原始格式?
请注意,列表的深度可以是任意的,形状可能不规则,并且将包含不同维度的数组。
当然,flattening 函数也应该更改为存储列表的结构和 numpy
数组的形状。
最佳答案
我一直在寻找一种解决方案来展平和展平 numpy 数组的嵌套列表,但只发现了这个未回答的问题,所以我想到了这个:
def _flatten(values):
if isinstance(values, np.ndarray):
yield values.flatten()
else:
for value in values:
yield from _flatten(value)
def flatten(values):
# flatten nested lists of np.ndarray to np.ndarray
return np.concatenate(list(_flatten(values)))
def _unflatten(flat_values, prototype, offset):
if isinstance(prototype, np.ndarray):
shape = prototype.shape
new_offset = offset + np.product(shape)
value = flat_values[offset:new_offset].reshape(shape)
return value, new_offset
else:
result = []
for value in prototype:
value, offset = _unflatten(flat_values, value, offset)
result.append(value)
return result, offset
def unflatten(flat_values, prototype):
# unflatten np.ndarray to nested lists with structure of prototype
result, offset = _unflatten(flat_values, prototype, 0)
assert(offset == len(flat_values))
return result
例子:
a = [
np.random.rand(1),
[
np.random.rand(2, 1),
np.random.rand(1, 2, 1),
],
[[]],
]
b = flatten(a)
# 'c' will have values of 'b' and structure of 'a'
c = unflatten(b, a)
输出:
a:
[array([ 0.26453544]), [array([[ 0.88273824],
[ 0.63458643]]), array([[[ 0.84252894],
[ 0.91414218]]])], [[]]]
b:
[ 0.26453544 0.88273824 0.63458643 0.84252894 0.91414218]
c:
[array([ 0.26453544]), [array([[ 0.88273824],
[ 0.63458643]]), array([[[ 0.84252894],
[ 0.91414218]]])], [[]]]
许可证:WTFPL
关于python - 展平和展平 numpy 数组的嵌套列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/27982432/