python - 非常大的集合的 SQLAlchemy 集合成员资格

标签 python postgresql sqlalchemy

我的 SQL 查询可以非常简单地写成:

result = session.query(Table).filter(Table.my_key._in(key_set))

my_key 整数列被索引(主键),但 key_set 可能确实非常大,有数千万个值。

对于如此庞大的集合进行过滤,推荐的 SQLAlchemy 模式是什么?

有没有比行人更高效的内置东西:

result = [session.query(Table).get(key) for key in key_set]

最佳答案

在这种极端情况下,您最好首先考虑推荐的 SQL 解决方案是什么,然后在 SQLAlchemy 中实现它——如果需要,甚至可以使用原始 SQL。一种这样的解决方案是为 key_set 数据创建一个临时表并填充它。

为了测试类似您的设置的东西,我创建了以下模型

class Table(Base):
    __tablename__ = 'mytable'
    my_key = Column(Integer, primary_key=True)

并用 20,000,000 行填充它:

In [1]: engine.execute("""
   ...:     insert into mytable
   ...:     select generate_series(1, 20000001)
   ...:     """)

我还创建了一些帮助程序来测试临时表、填充和查询的不同组合。请注意,查询使用核心表,以绕过 ORM 及其机制——无论如何,对计时的贡献将是恒定的:

# testdb is just your usual SQLAlchemy imports, and some
# preconfigured engine options.
from testdb import *
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Executable, ClauseElement
from io import StringIO
from itertools import product

class Table(Base):
    __tablename__ = "mytable"
    my_key = Column(Integer, primary_key=True)

def with_session(f):
    def wrapper(*a, **kw):
        session = Session(bind=engine)
        try:
            return f(session, *a, **kw)

        finally:
            session.close()
    return wrapper

def all(_, query):
    return query.all()

def explain(analyze=False):
    def cont(session, query):
        results = session.execute(Explain(query.statement, analyze))
        return [l for l, in results]

    return cont

class Explain(Executable, ClauseElement):
    def __init__(self, stmt, analyze=False):
        self.stmt = stmt
        self.analyze = analyze

@compiles(Explain)
def visit_explain(element, compiler, **kw):
    stmt = "EXPLAIN "

    if element.analyze:
        stmt += "ANALYZE "

    stmt += compiler.process(element.stmt, **kw)
    return stmt

def create_tmp_tbl_w_insert(session, key_set, unique=False):
    session.execute("CREATE TEMPORARY TABLE x (k INTEGER NOT NULL)")
    x = table("x", column("k"))
    session.execute(x.insert().values([(k,) for k in key_set]))

    if unique:
        session.execute("CREATE UNIQUE INDEX ON x (k)")

    session.execute("ANALYZE x")
    return x

def create_tmp_tbl_w_copy(session, key_set, unique=False):
    session.execute("CREATE TEMPORARY TABLE x (k INTEGER NOT NULL)")
    # This assumes that the string representation of the Python values
    # is a valid representation for Postgresql as well. If this is not
    # the case, `cur.mogrify()` should be used.
    file = StringIO("".join([f"{k}\n" for k in key_set]))
    # HACK ALERT, get the DB-API connection object
    with session.connection().connection.connection.cursor() as cur:
        cur.copy_from(file, "x")

    if unique:
        session.execute("CREATE UNIQUE INDEX ON x (k)")

    session.execute("ANALYZE x")
    return table("x", column("k"))

tmp_tbl_factories = {
    "insert": create_tmp_tbl_w_insert,
    "insert (uniq)": lambda session, key_set: create_tmp_tbl_w_insert(session, key_set, unique=True),
    "copy": create_tmp_tbl_w_copy,
    "copy (uniq)": lambda session, key_set: create_tmp_tbl_w_copy(session, key_set, unique=True),
}

query_factories = {
    "in": lambda session, _, x: session.query(Table.__table__).
        filter(Table.my_key.in_(x.select().as_scalar())),
    "exists": lambda session, _, x: session.query(Table.__table__).
        filter(exists().where(x.c.k == Table.my_key)),
    "join": lambda session, _, x: session.query(Table.__table__).
        join(x, x.c.k == Table.my_key)
}

tests = {
    "test in": (
        lambda _s, _ks: None,
        lambda session, key_set, _: session.query(Table.__table__).
            filter(Table.my_key.in_(key_set))
    ),
    "test in expanding": (
        lambda _s, _kw: None,
        lambda session, key_set, _: session.query(Table.__table__).
            filter(Table.my_key.in_(bindparam('key_set', key_set, expanding=True)))
    ),
    **{
        f"test {ql} w/ {tl}": (tf, qf)
        for (tl, tf), (ql, qf)
        in product(tmp_tbl_factories.items(), query_factories.items())
    }
}

@with_session
def run_test(session, key_set, tmp_tbl_factory, query_factory, *, cont=all):
    x = tmp_tbl_factory(session, key_set)
    return cont(session, query_factory(session, key_set, x))

对于小键集,您拥有的简单 IN 查询与其他查询一样快,但是使用 100,000 的 key_set 时,涉及更多的解决方案开始获胜:

In [10]: for test, steps in tests.items():
    ...:     print(f"{test:<28}", end=" ")
    ...:     %timeit -r2 -n2 run_test(range(100000), *steps)
    ...:     
test in                      2.21 s ± 7.31 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
test in expanding            630 ms ± 929 µs per loop (mean ± std. dev. of 2 runs, 2 loops each)
test in w/ insert            1.83 s ± 3.73 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
test exists w/ insert        1.83 s ± 3.99 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
test join w/ insert          1.86 s ± 3.76 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
test in w/ insert (uniq)     1.87 s ± 6.67 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
test exists w/ insert (uniq) 1.84 s ± 125 µs per loop (mean ± std. dev. of 2 runs, 2 loops each)
test join w/ insert (uniq)   1.85 s ± 2.8 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
test in w/ copy              246 ms ± 1.18 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
test exists w/ copy          243 ms ± 2.31 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
test join w/ copy            258 ms ± 3.05 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
test in w/ copy (uniq)       261 ms ± 1.39 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
test exists w/ copy (uniq)   267 ms ± 8.24 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
test join w/ copy (uniq)     264 ms ± 1.16 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)

key_set 提高到 1,000,000:

In [11]: for test, steps in tests.items():
    ...:     print(f"{test:<28}", end=" ")
    ...:     %timeit -r2 -n1 run_test(range(1000000), *steps)
    ...:     
test in                      23.8 s ± 158 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test in expanding            6.96 s ± 3.02 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test in w/ insert            19.6 s ± 79.3 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test exists w/ insert        20.1 s ± 114 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test join w/ insert          19.5 s ± 7.93 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test in w/ insert (uniq)     19.5 s ± 45.4 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test exists w/ insert (uniq) 19.6 s ± 73.6 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test join w/ insert (uniq)   20 s ± 57.5 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test in w/ copy              2.53 s ± 49.9 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test exists w/ copy          2.56 s ± 1.96 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test join w/ copy            2.61 s ± 26.8 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test in w/ copy (uniq)       2.63 s ± 3.79 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
test exists w/ copy (uniq)   2.61 s ± 916 µs per loop (mean ± std. dev. of 2 runs, 1 loop each)
test join w/ copy (uniq)     2.6 s ± 5.31 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)

10,000,000 个 key 集,仅COPY 解决方案,因为其他解决方案占用了我所有的 RAM 并且在被杀死之前正在经历交换,暗示他们永远不会在这台机器上完成:

In [12]: for test, steps in tests.items():
    ...:     if "copy" in test:
    ...:         print(f"{test:<28}", end=" ")
    ...:         %timeit -r1 -n1 run_test(range(10000000), *steps)
    ...:     
test in w/ copy              28.9 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
test exists w/ copy          29.3 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
test join w/ copy            29.7 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
test in w/ copy (uniq)       28.3 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
test exists w/ copy (uniq)   27.5 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
test join w/ copy (uniq)     28.4 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

因此,对于小型 key 集(~100,000 或更少),使用什么并不重要,尽管与易用性相比,使用扩展 bindparam 在时间上明显是赢家,但是对于更大的集合,您可能需要考虑使用临时表和COPY

值得注意的是,对于大型集合,如果使用唯一索引,查询计划是相同的:

In [13]: print(*run_test(range(10000000),
    ...:                 tmp_tbl_factories["copy (uniq)"],
    ...:                 query_factories["in"],
    ...:                 cont=explain()), sep="\n")
Merge Join  (cost=45.44..760102.11 rows=9999977 width=4)
  Merge Cond: (mytable.my_key = x.k)
  ->  Index Only Scan using mytable_pkey on mytable  (cost=0.44..607856.88 rows=20000096 width=4)
  ->  Index Only Scan using x_k_idx on x  (cost=0.43..303939.09 rows=9999977 width=4)

In [14]: print(*run_test(range(10000000),
    ...:                 tmp_tbl_factories["copy (uniq)"],
    ...:                 query_factories["exists"],
    ...:                 cont=explain()), sep="\n")
Merge Join  (cost=44.29..760123.36 rows=9999977 width=4)
  Merge Cond: (mytable.my_key = x.k)
  ->  Index Only Scan using mytable_pkey on mytable  (cost=0.44..607856.88 rows=20000096 width=4)
  ->  Index Only Scan using x_k_idx on x  (cost=0.43..303939.09 rows=9999977 width=4)

In [15]: print(*run_test(range(10000000),
    ...:                 tmp_tbl_factories["copy (uniq)"],
    ...:                 query_factories["join"],
    ...:                 cont=explain()), sep="\n")
Merge Join  (cost=39.06..760113.29 rows=9999977 width=4)
  Merge Cond: (mytable.my_key = x.k)
  ->  Index Only Scan using mytable_pkey on mytable  (cost=0.44..607856.88 rows=20000096 width=4)
  ->  Index Only Scan using x_k_idx on x  (cost=0.43..303939.09 rows=9999977 width=4)

由于测试表是人工的,因此可以使用仅索引扫描。


最后,这里是“行人”方式的计时,粗略对比一下:

In [3]: for ksl in [100000, 1000000]:
   ...:     %time [session.query(Table).get(k) for k in range(ksl)]
   ...:     session.rollback()
   ...:     
CPU times: user 1min, sys: 1.76 s, total: 1min 1s
Wall time: 1min 13s
CPU times: user 9min 48s, sys: 17.3 s, total: 10min 5s
Wall time: 12min 1s

问题是使用 Query.get() 必然包括 ORM,而原来的比较没有。尽管如此,即使在使用本地数据库时,单独往返数据库的成本还是很明显的。

关于python - 非常大的集合的 SQLAlchemy 集合成员资格,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56761442/

相关文章:

python - numpy.loadtxt 比 open.....readlines() 慢得多

python - Django 用户登录表单用户为无

postgresql - 如何解决错误重复的sources.list?

spring - 如何使用 SSL 和 Spring Boot 连接到 PostgresQL?

python - Flask:如何检查一对多反向引用是否存在

python - SQLAlchemy 级联删除自动映射的通用数据模型架构

python - 如何关闭我在线程中使用的 Flask-SQLAlchemy 连接?

python /Django : RelatedObjectDoesNotExist: Cart has no user

python - 在 Python 中,如何获取列表列表的索引?

postgresql - CSRF token 在 docker pgadmin 中丢失错误