python - 如何有效地在 Python 中对参数组合进行网格搜索?

标签 python algorithm performance combinatorics

问题

对于计算工程模型,我想对所有可行的参数组合进行网格搜索。每个参数都有一定的可能范围,例如(0 … 100) 并且参数组合必须满足条件 a+b+c=100。一个例子:

ranges = {
    'a': (95, 99), 
    'b': (1, 4), 
    'c': (1, 2)}
increment = 1.0
target = 100.0

因此满足条件 a+b+c=100 的组合是:

[(95, 4, 1), (95, 3, 2), (96, 2, 2), (96, 3, 1), (97, 1, 2), (97, 2, 1), (98, 1, 1)]  

该算法应该使用任意数量的参数、范围长度和增量运行。

我的解决方案(到目前为止)

我想出的解决方案都是暴力破解问题。这意味着计算所有组合,然后丢弃不满足给定条件的组合:

def solution1(ranges, increment, target):
    combinations = []
    for parameter in ranges:
        combinations.append(list(np.arange(ranges[parameter][0], ranges[parameter][1], increment)))
        # np.arange() is exclusive of the upper bound, let's fix that
        if combinations[-1][-1] != ranges[parameter][1]:
            combinations[-1].append(ranges[parameter][1])
    combinations = list(itertools.product(*combinations))
    df = pd.DataFrame(combinations, columns=ranges.keys())
    # using np.isclose() so that the algorithm works for floats
    return df[np.isclose(df.sum(axis=1), target)]

由于我在使用 solution1() 时遇到了 RAM 问题,所以我使用了 itertools.product 作为迭代器。

def solution2(ranges, increment, target):
    combinations = []
    for parameter in ranges:
        combinations.append(list(np.arange(ranges[parameter][0], ranges[parameter][1], increment)))
        # np.arange() is exclusive of the upper bound, let's fix that
        if combinations[-1][-1] != ranges[parameter][1]:
            combinations[-1].append(ranges[parameter][1])
    result = []
    for combination in itertools.product(*combinations):
        # using np.isclose() so that the algorithm works for floats
        if np.isclose(sum(combination), target):
            result.append(combination)
    df = pd.DataFrame(result, columns=ranges.keys())
    return df

但是,这很快需要几天的时间来计算。因此,这两种解决方案都不适用于大量参数和范围。例如,我要解决的一组是(已经解压的 combinations 变量):

[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0], [22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], [0.0, 1.0, 2.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], [0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0], [0.0]]

这导致 solution1() 使用 >40 GB 的内存,solution2() 的计算时间 >400 小时。

问题

您是否看到了更快或更智能的解决方案,即不尝试暴力解决问题?

P.S.:我不能 100% 确定这个问题是否更适合其他 Stackexchange 网站之一。如果您认为应该移动,请在评论中提出建议,我将在此处删除。

最佳答案

这是一个递归的解决方案:

a = [95, 100]
b = [1, 4]
c = [1, 2]

Params = (a, b, c)

def GetValidParamValues(Params, constriantSum, prevVals):
    validParamValues = []
    if (len(Params) == 1):
        if (constriantSum >= Params[0][0] and constriantSum <= Params[0][1]):
            validParamValues.append(constriantSum)
        for v in validParamValues:
            print(prevVals + v)
        return
    sumOfLowParams = sum([Params[x][0] for x in range(1, len(Params))])
    sumOfHighParams = sum([Params[x][1] for x in range(1, len(Params))])
    lowEnd = max(Params[0][0], constriantSum - sumOfHighParams)
    highEnd = min(Params[0][1], constriantSum - sumOfLowParams) + 1
    if (len(Params) == 2):
        for av in range(lowEnd, highEnd):
            bv  = constriantSum - av
            if (bv <= Params[1][1]):
                validParamValues.append([av, bv])
        for v in validParamValues:
            print(prevVals + v)
        return
    for av in range(lowEnd, highEnd):
        nexPrevVals = prevVals + [av]
        subSeParams = Params[1:]
        GetValidParamValues(subSeParams, constriantSum - av, nexPrevVals)


GetValidParamValues(Params, 100)

想法是,如果有 2 个参数,ab,我们可以通过传递 a 的值来列出所有有效对code>,然后获取 (ai, S - ai) 并检查 S-ai 是否为 b 的有效值。

这是改进,因为我们可以提前计算 ai 的哪些值将使 S-ai 成为 b 的有效值,所以我们从不检查不起作用的值。

当params个数大于2时,我们可以再次查看ai的每一个有效值,我们知道其他数的和一定是S - ai。所以我们唯一需要的是其他数字添加到 S - ai 的所有可能的方法,这是同样的问题,少了一个参数。因此,通过使用递归,我们可以让它一直下降到大小 2 并解决它。

关于python - 如何有效地在 Python 中对参数组合进行网格搜索?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51426623/

相关文章:

python - 用于内化和翻译的通用 Python 库

mongodb - 如何在mongodb中查找几乎相似的记录?

c++ - 如何编写自定义词典比较器 C++

performance - 抓取站点并为每个 URL 编译有效负载统计信息的工具?

python - Python中查找字典键和值的标签分配

Python 无缓冲模式在 Windows 中导致问题

python - 将多个元组添加到单个字典键而不合并元组?

algorithm - 最小-最大算法的复杂度

函数内定义的 Python 编译器和常量

python - 无法拆除 Pandas read_csv 使用的临时文件