Odoo18-Base/odoo/tools/set_expression.py
2025-01-06 10:57:38 +07:00

560 lines
20 KiB
Python

from __future__ import annotations
import ast
from abc import ABC, abstractmethod
import typing
if typing.TYPE_CHECKING:
from collections.abc import Collection, Iterable
class SetDefinitions:
""" A collection of set definitions, where each set is defined by an id, a
name, its supersets, and the sets that are disjoint with it. This object
is used as a factory to create set expressions, which are combinations of
named sets with union, intersection and complement.
"""
__slots__ = ('__leaves',)
def __init__(self, definitions: dict[int, dict]):
""" Initialize the object with ``definitions``, a dict which maps each
set id to a dict with optional keys ``"ref"`` (value is the set's name),
``"supersets"`` (value is a collection of set ids), and ``"disjoints"``
(value is a collection of set ids).
Here is an example of set definitions, with natural numbers (N), integer
numbers (Z), rational numbers (Q), real numbers (R), imaginary numbers
(I) and complex numbers (C)::
{
1: {"ref": "N", "supersets": [2]},
2: {"ref": "Z", "supersets": [3]},
3: {"ref": "Q", "supersets": [4]},
4: {"ref": "R", "supersets": [6]},
5: {"ref": "I", "supersets": [6], "disjoints": [4]},
6: {"ref": "C"},
}
"""
self.__leaves: dict[int | str, Leaf] = {}
for leaf_id, info in definitions.items():
ref = info['ref']
assert ref != '*', "The set reference '*' is reserved for the universal set."
leaf = Leaf(leaf_id, ref)
self.__leaves[leaf_id] = leaf
self.__leaves[ref] = leaf
# compute transitive closure of subsets and supersets
subsets = {leaf.id: leaf.subsets for leaf in self.__leaves.values()}
supersets = {leaf.id: leaf.supersets for leaf in self.__leaves.values()}
for leaf_id, info in definitions.items():
for greater_id in info.get('supersets', ()):
# transitive closure: smaller_ids <= leaf_id <= greater_id <= greater_ids
smaller_ids = subsets[leaf_id]
greater_ids = supersets[greater_id]
for smaller_id in smaller_ids:
supersets[smaller_id].update(greater_ids)
for greater_id in greater_ids:
subsets[greater_id].update(smaller_ids)
# compute transitive closure of disjoint relation
disjoints = {leaf.id: leaf.disjoints for leaf in self.__leaves.values()}
for leaf_id, info in definitions.items():
for distinct_id in info.get('disjoints', set()):
# all subsets[leaf_id] are disjoint from all subsets[distinct_id]
left_ids = subsets[leaf_id]
right_ids = subsets[distinct_id]
for left_id in left_ids:
disjoints[left_id].update(right_ids)
for right_id in right_ids:
disjoints[right_id].update(left_ids)
@property
def empty(self) -> SetExpression:
return EMPTY_UNION
@property
def universe(self) -> SetExpression:
return UNIVERSAL_UNION
def parse(self, refs: str, raise_if_not_found: bool = True) -> SetExpression:
""" Return the set expression corresponding to ``refs``
:param str refs: comma-separated list of set references
optionally preceded by ``!`` (negative item). The result is
an union between positive item who intersect every negative
group.
(e.g. ``base.group_user,base.group_portal,!base.group_system``)
"""
positives: list[Leaf] = []
negatives: list[Leaf] = []
for xmlid in refs.split(','):
if xmlid.startswith('!'):
negatives.append(~self.__get_leaf(xmlid[1:], raise_if_not_found))
else:
positives.append(self.__get_leaf(xmlid, raise_if_not_found))
if positives:
return Union(Inter([leaf] + negatives) for leaf in positives)
else:
return Union([Inter(negatives)])
def from_ids(self, ids: Iterable[int], keep_subsets: bool = False) -> SetExpression:
""" Return the set expression corresponding to given set ids. """
if keep_subsets:
ids = set(ids)
ids = [leaf_id for leaf_id in ids if not any((self.__leaves[leaf_id].subsets - {leaf_id}) & ids)]
return Union(Inter([self.__leaves[leaf_id]]) for leaf_id in ids)
def from_key(self, key: str) -> SetExpression:
""" Return the set expression corresponding to the given key. """
# union_tuple = tuple(tuple(tuple(leaf_id, negative), ...), ...)
union_tuple = ast.literal_eval(key)
return Union([
Inter([
~leaf if negative else leaf
for leaf_id, negative in inter_tuple
for leaf in [self.__get_leaf(leaf_id, raise_if_not_found=False)]
], optimal=True)
for inter_tuple in union_tuple
], optimal=True)
def get_id(self, ref: LeafIdType) -> LeafIdType | None:
""" Return a set id from its reference, or ``None`` if it does not exist. """
if ref == '*':
return UNIVERSAL_LEAF.id
leaf = self.__leaves.get(ref)
return None if leaf is None else leaf.id
def __get_leaf(self, ref: str | int, raise_if_not_found: bool = True) -> Leaf:
""" Return the group object from the string.
:param str ref: the ref of a leaf
"""
if ref == '*':
return UNIVERSAL_LEAF
if not raise_if_not_found and ref not in self.__leaves:
return Leaf(UnknownId(ref), ref)
return self.__leaves[ref]
class SetExpression(ABC):
""" An object that represents a combination of named sets with union,
intersection and complement.
"""
@abstractmethod
def is_empty(self) -> bool:
""" Returns whether ``self`` is the empty set, that contains nothing. """
raise NotImplementedError()
@abstractmethod
def is_universal(self) -> bool:
""" Returns whether ``self`` is the universal set, that contains all possible elements. """
raise NotImplementedError()
@abstractmethod
def invert_intersect(self, factor: SetExpression) -> SetExpression | None:
""" Performs the inverse operation of intersection (a sort of factorization)
such that: ``self == result & factor``.
"""
raise NotImplementedError()
@abstractmethod
def matches(self, user_group_ids: Iterable[int]) -> bool:
""" Return whether the given group ids are included to ``self``. """
raise NotImplementedError()
@property
@abstractmethod
def key(self) -> str:
""" Return a unique identifier for the expression. """
raise NotImplementedError()
@abstractmethod
def __and__(self, other: SetExpression) -> SetExpression:
raise NotImplementedError()
@abstractmethod
def __or__(self, other: SetExpression) -> SetExpression:
raise NotImplementedError()
@abstractmethod
def __invert__(self) -> SetExpression:
raise NotImplementedError()
@abstractmethod
def __eq__(self, other) -> bool:
raise NotImplementedError()
@abstractmethod
def __le__(self, other: SetExpression) -> bool:
raise NotImplementedError()
@abstractmethod
def __lt__(self, other: SetExpression) -> bool:
raise NotImplementedError()
@abstractmethod
def __hash__(self):
raise NotImplementedError()
class Union(SetExpression):
""" Implementation of a set expression, that represents it as a union of
intersections of named sets or their complement.
"""
def __init__(self, inters: Iterable[Inter] = (), optimal=False):
if inters and not optimal:
inters = self.__combine((), inters)
self.__inters = sorted(inters, key=lambda inter: inter.key)
self.__key = str(tuple(inter.key for inter in self.__inters))
self.__hash = hash(self.__key)
@property
def key(self) -> str:
return self.__key
@staticmethod
def __combine(inters: Iterable[Inter], inters_to_add: Iterable[Inter]) -> list[Inter]:
""" Combine some existing union of intersections with extra intersections. """
result = list(inters)
todo = list(inters_to_add)
while todo:
inter_to_add = todo.pop()
if inter_to_add.is_universal():
return [UNIVERSAL_INTER]
if inter_to_add.is_empty():
continue
for index, inter in enumerate(result):
merged = inter._union_merge(inter_to_add)
if merged is not None:
result.pop(index)
todo.append(merged)
break
else:
result.append(inter_to_add)
return result
def is_empty(self) -> bool:
""" Returns whether ``self`` is the empty set, that contains nothing. """
return not self.__inters
def is_universal(self) -> bool:
""" Returns whether ``self`` is the universal set, that contains all possible elements. """
return any(item.is_universal() for item in self.__inters)
def invert_intersect(self, factor: SetExpression) -> Union | None:
""" Performs the inverse operation of intersection (a sort of factorization)
such that: ``self == result & factor``.
"""
if factor == self:
return UNIVERSAL_UNION
rfactor = ~factor
if rfactor.is_empty() or rfactor.is_universal():
return None
rself = ~self
assert isinstance(rfactor, Union)
inters = [inter for inter in rself.__inters if inter not in rfactor.__inters]
if len(rself.__inters) - len(inters) != len(rfactor.__inters):
# not possible to invert the intersection
return None
rself_value = Union(inters)
return ~rself_value
def __and__(self, other: SetExpression) -> Union:
assert isinstance(other, Union)
if self.is_universal():
return other
if other.is_universal():
return self
if self.is_empty() or other.is_empty():
return EMPTY_UNION
if self == other:
return self
return Union(
self_inter & other_inter
for self_inter in self.__inters
for other_inter in other.__inters
)
def __or__(self, other: SetExpression) -> Union:
assert isinstance(other, Union)
if self.is_empty():
return other
if other.is_empty():
return self
if self.is_universal() or other.is_universal():
return UNIVERSAL_UNION
if self == other:
return self
inters = self.__combine(self.__inters, other.__inters)
return Union(inters, optimal=True)
def __invert__(self) -> Union:
if self.is_empty():
return UNIVERSAL_UNION
if self.is_universal():
return EMPTY_UNION
# apply De Morgan's laws
inverses_of_inters = [
# ~(A & B) = ~A | ~B
Union(Inter([~leaf]) for leaf in inter.leaves)
for inter in self.__inters
]
result = inverses_of_inters[0]
# ~(A | B) = ~A & ~B
for inverse in inverses_of_inters[1:]:
result = result & inverse
return result
def matches(self, user_group_ids) -> bool:
if self.is_empty() or not user_group_ids:
return False
if self.is_universal():
return True
user_group_ids = set(user_group_ids)
return any(inter.matches(user_group_ids) for inter in self.__inters)
def __bool__(self):
raise NotImplementedError()
def __eq__(self, other) -> bool:
return isinstance(other, Union) and self.__key == other.__key
def __le__(self, other: SetExpression) -> bool:
if not isinstance(other, Union):
return False
if self.__key == other.__key:
return True
if self.is_universal() or other.is_empty():
return False
if other.is_universal() or self.is_empty():
return True
return all(
any(self_inter <= other_inter for other_inter in other.__inters)
for self_inter in self.__inters
)
def __lt__(self, other: SetExpression) -> bool:
return self != other and self.__le__(other)
def __str__(self):
""" Returns an intersection union representation of groups using user-readable references.
e.g. ('base.group_user' & 'base.group_multi_company') | ('base.group_portal' & ~'base.group_multi_company') | 'base.group_public'
"""
if self.is_empty():
return "~*"
def leaf_to_str(leaf):
return f"{'~' if leaf.negative else ''}{leaf.ref!r}"
def inter_to_str(inter, wrapped=False):
result = " & ".join(leaf_to_str(leaf) for leaf in inter.leaves) or "*"
return f"({result})" if wrapped and len(inter.leaves) > 1 else result
wrapped = len(self.__inters) > 1
return " | ".join(inter_to_str(inter, wrapped) for inter in self.__inters)
def __repr__(self):
return repr(self.__str__())
def __hash__(self):
return self.__hash
class Inter:
""" Part of the implementation of a set expression, that represents an
intersection of named sets or their complement.
"""
__slots__ = ('key', 'leaves')
def __init__(self, leaves: Iterable[Leaf] = (), optimal=False):
if leaves and not optimal:
leaves = self.__combine((), leaves)
self.leaves: list[Leaf] = sorted(leaves, key=lambda leaf: leaf.key)
self.key: tuple[tuple[LeafIdType, bool], ...] = tuple(leaf.key for leaf in self.leaves)
@staticmethod
def __combine(leaves: Iterable[Leaf], leaves_to_add: Iterable[Leaf]) -> list[Leaf]:
""" Combine some existing intersection of leaves with extra leaves. """
result = list(leaves)
for leaf_to_add in leaves_to_add:
for index, leaf in enumerate(result):
if leaf.isdisjoint(leaf_to_add): # leaf & leaf_to_add = empty
return [EMPTY_LEAF]
if leaf <= leaf_to_add: # leaf & leaf_to_add = leaf
break
if leaf_to_add <= leaf: # leaf & leaf_to_add = leaf_to_add
result[index] = leaf_to_add
break
else:
if not leaf_to_add.is_universal():
result.append(leaf_to_add)
return result
def is_empty(self) -> bool:
return any(item.is_empty() for item in self.leaves)
def is_universal(self) -> bool:
""" Returns whether ``self`` is the universal set, that contains all possible elements. """
return not self.leaves
def matches(self, user_group_ids) -> bool:
return all(leaf.matches(user_group_ids) for leaf in self.leaves)
def _union_merge(self, other: Inter) -> Inter | None:
""" Return the union of ``self`` with another intersection, if it can be
represented as an intersection. Otherwise return ``None``.
"""
# the following covers cases like (A & B) | A -> A
if self.is_universal() or other <= self:
return self
if self <= other:
return other
# combine complementary parts: (A & ~B) | (A & B) -> A
if len(self.leaves) == len(other.leaves):
opposite_index = None
# we use the property that __leaves are ordered
for index, self_leaf, other_leaf in zip(range(len(self.leaves)), self.leaves, other.leaves):
if self_leaf.id != other_leaf.id:
return None
if self_leaf.negative != other_leaf.negative:
if opposite_index is not None:
return None # we already have two opposite leaves
opposite_index = index
if opposite_index is not None:
leaves = list(self.leaves)
leaves.pop(opposite_index)
return Inter(leaves, optimal=True)
return None
def __and__(self, other: Inter) -> Inter:
if self.is_empty() or other.is_empty():
return EMPTY_INTER
if self.is_universal():
return other
if other.is_universal():
return self
leaves = self.__combine(self.leaves, other.leaves)
return Inter(leaves, optimal=True)
def __eq__(self, other) -> bool:
return isinstance(other, Inter) and self.key == other.key
def __le__(self, other: Inter) -> bool:
return self.key == other.key or all(
any(self_leaf <= other_leaf for self_leaf in self.leaves)
for other_leaf in other.leaves
)
def __lt__(self, other: Inter) -> bool:
return self != other and self <= other
def __hash__(self):
return hash(self.key)
class Leaf:
""" Part of the implementation of a set expression, that represents a named
set or its complement.
"""
__slots__ = ('disjoints', 'id', 'inverse', 'key', 'negative', 'ref', 'subsets', 'supersets')
def __init__(self, leaf_id: LeafIdType, ref: str | int | None = None, negative: bool = False):
self.id = leaf_id
self.ref = ref or str(leaf_id)
self.negative = bool(negative)
self.key: tuple[LeafIdType, bool] = (leaf_id, self.negative)
self.subsets: set[LeafIdType] = {leaf_id} # all the leaf ids that are <= self
self.supersets: set[LeafIdType] = {leaf_id} # all the leaf ids that are >= self
self.disjoints: set[LeafIdType] = set() # all the leaf ids disjoint from self
self.inverse: Leaf | None = None
def __invert__(self) -> Leaf:
if self.inverse is None:
self.inverse = Leaf(self.id, self.ref, negative=not self.negative)
self.inverse.inverse = self
self.inverse.subsets = self.subsets
self.inverse.supersets = self.supersets
self.inverse.disjoints = self.disjoints
return self.inverse
def is_empty(self) -> bool:
return self.ref == '*' and self.negative
def is_universal(self) -> bool:
return self.ref == '*' and not self.negative
def isdisjoint(self, other: Leaf) -> bool:
if self.negative:
return other <= ~self
elif other.negative:
return self <= ~other
else:
return self.id in other.disjoints
def matches(self, user_group_ids: Collection[int]) -> bool:
return (self.id not in user_group_ids) if self.negative else (self.id in user_group_ids)
def __eq__(self, other) -> bool:
return isinstance(other, Leaf) and self.key == other.key
def __le__(self, other: Leaf) -> bool:
if self.is_empty() or other.is_universal():
return True
elif self.is_universal() or other.is_empty():
return False
elif self.negative:
return other.negative and ~other <= ~self
elif other.negative:
return self.id in other.disjoints
else:
return self.id in other.subsets
def __lt__(self, other: Leaf) -> bool:
return self != other and self <= other
def __hash__(self):
return hash(self.key)
class UnknownId(str):
""" Special id object for unknown leaves. It behaves as being strictly
greater than any other kind of id.
"""
__slots__ = ()
def __lt__(self, other) -> bool:
if isinstance(other, UnknownId):
return super().__lt__(other)
return False
def __gt__(self, other) -> bool:
if isinstance(other, UnknownId):
return super().__gt__(other)
return True
LeafIdType = int | typing.Literal["*"] | UnknownId
# constants
UNIVERSAL_LEAF = Leaf('*')
EMPTY_LEAF = ~UNIVERSAL_LEAF
EMPTY_INTER = Inter([EMPTY_LEAF])
UNIVERSAL_INTER = Inter()
EMPTY_UNION = Union()
UNIVERSAL_UNION = Union([UNIVERSAL_INTER])