# Part of Odoo. See LICENSE file for full copyright and licensing details. import itertools from collections.abc import Iterable, Iterator from .sql import SQL, make_identifier def _sql_from_table(alias: str, table: SQL) -> SQL: """ Return a FROM clause element from ``alias`` and ``table``. """ if (alias_identifier := SQL.identifier(alias)) == table: return table return SQL("%s AS %s", table, alias_identifier) def _sql_from_join(kind: SQL, alias: str, table: SQL, condition: SQL) -> SQL: """ Return a FROM clause element for a JOIN. """ return SQL("%s %s ON (%s)", kind, _sql_from_table(alias, table), condition) _SQL_JOINS = { "JOIN": SQL("JOIN"), "LEFT JOIN": SQL("LEFT JOIN"), } def _generate_table_alias(src_table_alias: str, link: str) -> str: """ Generate a standard table alias name. An alias is generated as following: - the base is the source table name (that can already be an alias) - then, the joined table is added in the alias using a 'link field name' that is used to render unique aliases for a given path - the name is shortcut if it goes beyond PostgreSQL's identifier limits .. code-block:: pycon >>> _generate_table_alias('res_users', link='parent_id') 'res_users__parent_id' :param str src_table_alias: alias of the source table :param str link: field name :return str: alias """ return make_identifier(f"{src_table_alias}__{link}") class Query: """ Simple implementation of a query object, managing tables with aliases, join clauses (with aliases, condition and parameters), where clauses (with parameters), order, limit and offset. :param env: model environment (for lazy evaluation) :param alias: name or alias of the table :param table: a table expression (``str`` or ``SQL`` object), optional """ def __init__(self, env, alias: str, table: (SQL | None) = None): # database cursor self._env = env self._tables: dict[str, SQL] = { alias: table if table is not None else SQL.identifier(alias), } # joins {alias: (kind(SQL), table(SQL), condition(SQL))} self._joins: dict[str, tuple[SQL, SQL, SQL]] = {} # holds the list of WHERE conditions (to be joined with 'AND') self._where_clauses: list[SQL] = [] # groupby, having, order, limit, offset self.groupby: SQL | None = None self.having: SQL | None = None self._order: SQL | None = None self.limit: int | None = None self.offset: int | None = None # memoized result self._ids: tuple[int, ...] | None = None def make_alias(self, alias: str, link: str) -> str: """ Return an alias based on ``alias`` and ``link``. """ return _generate_table_alias(alias, link) def add_table(self, alias: str, table: (SQL | None) = None): """ Add a table with a given alias to the from clause. """ assert alias not in self._tables and alias not in self._joins, f"Alias {alias!r} already in {self}" self._tables[alias] = table if table is not None else SQL.identifier(alias) self._ids = None def add_join(self, kind: str, alias: str, table: str | SQL | None, condition: SQL): """ Add a join clause with the given alias, table and condition. """ sql_kind = _SQL_JOINS.get(kind.upper()) assert sql_kind is not None, f"Invalid JOIN type {kind!r}" assert alias not in self._tables, f"Alias {alias!r} already used" table = table or alias if isinstance(table, str): table = SQL.identifier(table) if alias in self._joins: assert self._joins[alias] == (sql_kind, table, condition) else: self._joins[alias] = (sql_kind, table, condition) self._ids = None def add_where(self, where_clause: str | SQL, where_params=()): """ Add a condition to the where clause. """ self._where_clauses.append(SQL(where_clause, *where_params)) # pylint: disable = sql-injection self._ids = None def join(self, lhs_alias: str, lhs_column: str, rhs_table: str | SQL, rhs_column: str, link: str) -> str: """ Perform a join between a table already present in the current Query object and another table. This method is essentially a shortcut for methods :meth:`~.make_alias` and :meth:`~.add_join`. :param str lhs_alias: alias of a table already defined in the current Query object. :param str lhs_column: column of `lhs_alias` to be used for the join's ON condition. :param str rhs_table: name of the table to join to `lhs_alias`. :param str rhs_column: column of `rhs_alias` to be used for the join's ON condition. :param str link: used to generate the alias for the joined table, this string should represent the relationship (the link) between both tables. """ assert lhs_alias in self._tables or lhs_alias in self._joins, "Alias %r not in %s" % (lhs_alias, str(self)) rhs_alias = self.make_alias(lhs_alias, link) condition = SQL("%s = %s", SQL.identifier(lhs_alias, lhs_column), SQL.identifier(rhs_alias, rhs_column)) self.add_join('JOIN', rhs_alias, rhs_table, condition) return rhs_alias def left_join(self, lhs_alias: str, lhs_column: str, rhs_table: str, rhs_column: str, link: str) -> str: """ Add a LEFT JOIN to the current table (if necessary), and return the alias corresponding to ``rhs_table``. See the documentation of :meth:`join` for a better overview of the arguments and what they do. """ assert lhs_alias in self._tables or lhs_alias in self._joins, "Alias %r not in %s" % (lhs_alias, str(self)) rhs_alias = self.make_alias(lhs_alias, link) condition = SQL("%s = %s", SQL.identifier(lhs_alias, lhs_column), SQL.identifier(rhs_alias, rhs_column)) self.add_join('LEFT JOIN', rhs_alias, rhs_table, condition) return rhs_alias @property def order(self) -> SQL | None: return self._order @order.setter def order(self, value: SQL | str | None): self._order = SQL(value) if value is not None else None # pylint: disable = sql-injection @property def table(self) -> str: """ Return the query's main table, i.e., the first one in the FROM clause. """ return next(iter(self._tables)) @property def from_clause(self) -> SQL: """ Return the FROM clause of ``self``, without the FROM keyword. """ tables = SQL(", ").join(itertools.starmap(_sql_from_table, self._tables.items())) if not self._joins: return tables items = ( tables, *( _sql_from_join(kind, alias, table, condition) for alias, (kind, table, condition) in self._joins.items() ), ) return SQL(" ").join(items) @property def where_clause(self) -> SQL: """ Return the WHERE condition of ``self``, without the WHERE keyword. """ return SQL(" AND ").join(self._where_clauses) def is_empty(self) -> bool: """ Return whether the query is known to return nothing. """ return self._ids == () def select(self, *args: str | SQL) -> SQL: """ Return the SELECT query as an ``SQL`` object. """ sql_args = map(SQL, args) if args else [SQL.identifier(self.table, 'id')] return SQL( "%s%s%s%s%s%s%s%s", SQL("SELECT %s", SQL(", ").join(sql_args)), SQL(" FROM %s", self.from_clause), SQL(" WHERE %s", self.where_clause) if self._where_clauses else SQL(), SQL(" GROUP BY %s", self.groupby) if self.groupby else SQL(), SQL(" HAVING %s", self.having) if self.having else SQL(), SQL(" ORDER BY %s", self._order) if self._order else SQL(), SQL(" LIMIT %s", self.limit) if self.limit else SQL(), SQL(" OFFSET %s", self.offset) if self.offset else SQL(), ) def subselect(self, *args: str | SQL) -> SQL: """ Similar to :meth:`.select`, but for sub-queries. This one avoids the ORDER BY clause when possible, and includes parentheses around the subquery. """ if self._ids is not None and not args: # inject the known result instead of the subquery if not self._ids: # in case we have nothing, we want to use a sub_query with no records # because an empty tuple leads to a syntax error # and a tuple containing just None creates issues for `NOT IN` return SQL("(SELECT 1 WHERE FALSE)") return SQL("%s", self._ids) if self.limit or self.offset: # in this case, the ORDER BY clause is necessary return SQL("(%s)", self.select(*args)) sql_args = map(SQL, args) if args else [SQL.identifier(self.table, 'id')] return SQL( "(%s%s%s)", SQL("SELECT %s", SQL(", ").join(sql_args)), SQL(" FROM %s", self.from_clause), SQL(" WHERE %s", self.where_clause) if self._where_clauses else SQL(), ) def get_result_ids(self) -> tuple[int, ...]: """ Return the result of ``self.select()`` as a tuple of ids. The result is memoized for future use, which avoids making the same query twice. """ if self._ids is None: self._ids = tuple(id_ for id_, in self._env.execute_query(self.select())) return self._ids def set_result_ids(self, ids: Iterable[int], ordered: bool = True) -> None: """ Set up the query to return the lines given by ``ids``. The parameter ``ordered`` tells whether the query must be ordered to match exactly the sequence ``ids``. """ assert not (self._joins or self._where_clauses or self.limit or self.offset), \ "Method set_result_ids() can only be called on a virgin Query" ids = tuple(ids) if not ids: self.add_where("FALSE") elif ordered: # This guarantees that self.select() returns the results in the # expected order of ids: # SELECT "stuff".id # FROM "stuff" # JOIN (SELECT * FROM unnest(%s) WITH ORDINALITY) AS "stuff__ids" # ON ("stuff"."id" = "stuff__ids"."unnest") # ORDER BY "stuff__ids"."ordinality" alias = self.join( self.table, 'id', SQL('(SELECT * FROM unnest(%s) WITH ORDINALITY)', list(ids)), 'unnest', 'ids', ) self.order = SQL.identifier(alias, 'ordinality') else: self.add_where(SQL("%s IN %s", SQL.identifier(self.table, 'id'), ids)) self._ids = ids def __str__(self) -> str: sql = self.select() return f"" def __bool__(self): return bool(self.get_result_ids()) def __len__(self) -> int: if self._ids is None: if self.limit or self.offset: # optimization: generate a SELECT FROM, and then count the rows sql = SQL("SELECT COUNT(*) FROM (%s) t", self.select("")) else: sql = self.select('COUNT(*)') return self._env.execute_query(sql)[0][0] return len(self.get_result_ids()) def __iter__(self) -> Iterator[int]: return iter(self.get_result_ids())