# Part of Odoo. See LICENSE file for full copyright and licensing details.
import itertools
from collections.abc import Iterable, Iterator

from .sql import SQL, make_identifier


def _sql_from_table(alias: str, table: SQL) -> SQL:
    """ Return a FROM clause element from ``alias`` and ``table``. """
    if (alias_identifier := SQL.identifier(alias)) == table:
        return table
    return SQL("%s AS %s", table, alias_identifier)


def _sql_from_join(kind: SQL, alias: str, table: SQL, condition: SQL) -> SQL:
    """ Return a FROM clause element for a JOIN. """
    return SQL("%s %s ON (%s)", kind, _sql_from_table(alias, table), condition)


_SQL_JOINS = {
    "JOIN": SQL("JOIN"),
    "LEFT JOIN": SQL("LEFT JOIN"),
}


def _generate_table_alias(src_table_alias: str, link: str) -> str:
    """ Generate a standard table alias name. An alias is generated as following:

        - the base is the source table name (that can already be an alias)
        - then, the joined table is added in the alias using a 'link field name'
          that is used to render unique aliases for a given path
        - the name is shortcut if it goes beyond PostgreSQL's identifier limits

        .. code-block:: pycon

            >>> _generate_table_alias('res_users', link='parent_id')
            'res_users__parent_id'

        :param str src_table_alias: alias of the source table
        :param str link: field name
        :return str: alias
    """
    return make_identifier(f"{src_table_alias}__{link}")


class Query:
    """ Simple implementation of a query object, managing tables with aliases,
    join clauses (with aliases, condition and parameters), where clauses (with
    parameters), order, limit and offset.

    :param env: model environment (for lazy evaluation)
    :param alias: name or alias of the table
    :param table: a table expression (``str`` or ``SQL`` object), optional
    """

    def __init__(self, env, alias: str, table: (SQL | None) = None):
        # database cursor
        self._env = env

        self._tables: dict[str, SQL] = {
            alias: table if table is not None else SQL.identifier(alias),
        }

        # joins {alias: (kind(SQL), table(SQL), condition(SQL))}
        self._joins: dict[str, tuple[SQL, SQL, SQL]] = {}

        # holds the list of WHERE conditions (to be joined with 'AND')
        self._where_clauses: list[SQL] = []

        # groupby, having, order, limit, offset
        self.groupby: SQL | None = None
        self.having: SQL | None = None
        self._order: SQL | None = None
        self.limit: int | None = None
        self.offset: int | None = None

        # memoized result
        self._ids: tuple[int, ...] | None = None

    def make_alias(self, alias: str, link: str) -> str:
        """ Return an alias based on ``alias`` and ``link``. """
        return _generate_table_alias(alias, link)

    def add_table(self, alias: str, table: (SQL | None) = None):
        """ Add a table with a given alias to the from clause. """
        assert alias not in self._tables and alias not in self._joins, f"Alias {alias!r} already in {self}"
        self._tables[alias] = table if table is not None else SQL.identifier(alias)
        self._ids = None

    def add_join(self, kind: str, alias: str, table: str | SQL | None, condition: SQL):
        """ Add a join clause with the given alias, table and condition. """
        sql_kind = _SQL_JOINS.get(kind.upper())
        assert sql_kind is not None, f"Invalid JOIN type {kind!r}"
        assert alias not in self._tables, f"Alias {alias!r} already used"
        table = table or alias
        if isinstance(table, str):
            table = SQL.identifier(table)

        if alias in self._joins:
            assert self._joins[alias] == (sql_kind, table, condition)
        else:
            self._joins[alias] = (sql_kind, table, condition)
            self._ids = None

    def add_where(self, where_clause: str | SQL, where_params=()):
        """ Add a condition to the where clause. """
        self._where_clauses.append(SQL(where_clause, *where_params)) # pylint: disable = sql-injection
        self._ids = None

    def join(self, lhs_alias: str, lhs_column: str, rhs_table: str | SQL, rhs_column: str, link: str) -> str:
        """
        Perform a join between a table already present in the current Query object and
        another table.  This method is essentially a shortcut for methods :meth:`~.make_alias`
        and :meth:`~.add_join`.

        :param str lhs_alias: alias of a table already defined in the current Query object.
        :param str lhs_column: column of `lhs_alias` to be used for the join's ON condition.
        :param str rhs_table: name of the table to join to `lhs_alias`.
        :param str rhs_column: column of `rhs_alias` to be used for the join's ON condition.
        :param str link: used to generate the alias for the joined table, this string should
            represent the relationship (the link) between both tables.
        """
        assert lhs_alias in self._tables or lhs_alias in self._joins, "Alias %r not in %s" % (lhs_alias, str(self))
        rhs_alias = self.make_alias(lhs_alias, link)
        condition = SQL("%s = %s", SQL.identifier(lhs_alias, lhs_column), SQL.identifier(rhs_alias, rhs_column))
        self.add_join('JOIN', rhs_alias, rhs_table, condition)
        return rhs_alias

    def left_join(self, lhs_alias: str, lhs_column: str, rhs_table: str, rhs_column: str, link: str) -> str:
        """ Add a LEFT JOIN to the current table (if necessary), and return the
        alias corresponding to ``rhs_table``.

        See the documentation of :meth:`join` for a better overview of the
        arguments and what they do.
        """
        assert lhs_alias in self._tables or lhs_alias in self._joins, "Alias %r not in %s" % (lhs_alias, str(self))
        rhs_alias = self.make_alias(lhs_alias, link)
        condition = SQL("%s = %s", SQL.identifier(lhs_alias, lhs_column), SQL.identifier(rhs_alias, rhs_column))
        self.add_join('LEFT JOIN', rhs_alias, rhs_table, condition)
        return rhs_alias

    @property
    def order(self) -> SQL | None:
        return self._order

    @order.setter
    def order(self, value: SQL | str | None):
        self._order = SQL(value) if value is not None else None # pylint: disable = sql-injection

    @property
    def table(self) -> str:
        """ Return the query's main table, i.e., the first one in the FROM clause. """
        return next(iter(self._tables))

    @property
    def from_clause(self) -> SQL:
        """ Return the FROM clause of ``self``, without the FROM keyword. """
        tables = SQL(", ").join(itertools.starmap(_sql_from_table, self._tables.items()))
        if not self._joins:
            return tables
        items = (
            tables,
            *(
                _sql_from_join(kind, alias, table, condition)
                for alias, (kind, table, condition) in self._joins.items()
            ),
        )
        return SQL(" ").join(items)

    @property
    def where_clause(self) -> SQL:
        """ Return the WHERE condition of ``self``, without the WHERE keyword. """
        return SQL(" AND ").join(self._where_clauses)

    def is_empty(self) -> bool:
        """ Return whether the query is known to return nothing. """
        return self._ids == ()

    def select(self, *args: str | SQL) -> SQL:
        """ Return the SELECT query as an ``SQL`` object. """
        sql_args = map(SQL, args) if args else [SQL.identifier(self.table, 'id')]
        return SQL(
            "%s%s%s%s%s%s%s%s",
            SQL("SELECT %s", SQL(", ").join(sql_args)),
            SQL(" FROM %s", self.from_clause),
            SQL(" WHERE %s", self.where_clause) if self._where_clauses else SQL(),
            SQL(" GROUP BY %s", self.groupby) if self.groupby else SQL(),
            SQL(" HAVING %s", self.having) if self.having else SQL(),
            SQL(" ORDER BY %s", self._order) if self._order else SQL(),
            SQL(" LIMIT %s", self.limit) if self.limit else SQL(),
            SQL(" OFFSET %s", self.offset) if self.offset else SQL(),
        )

    def subselect(self, *args: str | SQL) -> SQL:
        """ Similar to :meth:`.select`, but for sub-queries.
            This one avoids the ORDER BY clause when possible,
            and includes parentheses around the subquery.
        """
        if self._ids is not None and not args:
            # inject the known result instead of the subquery
            if not self._ids:
                # in case we have nothing, we want to use a sub_query with no records
                # because an empty tuple leads to a syntax error
                # and a tuple containing just None creates issues for `NOT IN`
                return SQL("(SELECT 1 WHERE FALSE)")
            return SQL("%s", self._ids)

        if self.limit or self.offset:
            # in this case, the ORDER BY clause is necessary
            return SQL("(%s)", self.select(*args))

        sql_args = map(SQL, args) if args else [SQL.identifier(self.table, 'id')]
        return SQL(
            "(%s%s%s)",
            SQL("SELECT %s", SQL(", ").join(sql_args)),
            SQL(" FROM %s", self.from_clause),
            SQL(" WHERE %s", self.where_clause) if self._where_clauses else SQL(),
        )

    def get_result_ids(self) -> tuple[int, ...]:
        """ Return the result of ``self.select()`` as a tuple of ids. The result
        is memoized for future use, which avoids making the same query twice.
        """
        if self._ids is None:
            self._ids = tuple(id_ for id_, in self._env.execute_query(self.select()))
        return self._ids

    def set_result_ids(self, ids: Iterable[int], ordered: bool = True) -> None:
        """ Set up the query to return the lines given by ``ids``. The parameter
        ``ordered`` tells whether the query must be ordered to match exactly the
        sequence ``ids``.
        """
        assert not (self._joins or self._where_clauses or self.limit or self.offset), \
            "Method set_result_ids() can only be called on a virgin Query"
        ids = tuple(ids)
        if not ids:
            self.add_where("FALSE")
        elif ordered:
            # This guarantees that self.select() returns the results in the
            # expected order of ids:
            #   SELECT "stuff".id
            #   FROM "stuff"
            #   JOIN (SELECT * FROM unnest(%s) WITH ORDINALITY) AS "stuff__ids"
            #       ON ("stuff"."id" = "stuff__ids"."unnest")
            #   ORDER BY "stuff__ids"."ordinality"
            alias = self.join(
                self.table, 'id',
                SQL('(SELECT * FROM unnest(%s) WITH ORDINALITY)', list(ids)), 'unnest',
                'ids',
            )
            self.order = SQL.identifier(alias, 'ordinality')
        else:
            self.add_where(SQL("%s IN %s", SQL.identifier(self.table, 'id'), ids))
        self._ids = ids

    def __str__(self) -> str:
        sql = self.select()
        return f"<Query: {sql.code!r} with params: {sql.params!r}>"

    def __bool__(self):
        return bool(self.get_result_ids())

    def __len__(self) -> int:
        if self._ids is None:
            if self.limit or self.offset:
                # optimization: generate a SELECT FROM, and then count the rows
                sql = SQL("SELECT COUNT(*) FROM (%s) t", self.select(""))
            else:
                sql = self.select('COUNT(*)')
            return self._env.execute_query(sql)[0][0]
        return len(self.get_result_ids())

    def __iter__(self) -> Iterator[int]:
        return iter(self.get_result_ids())