# Copyright Kevin Deldycke <kevin@deldycke.com> and contributors.
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
"""Group a collection of traits. Also referred as families."""
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass, field, replace
from functools import cached_property
from types import MappingProxyType
from typing import cast
from .trait import Trait
TYPE_CHECKING = False
if TYPE_CHECKING:
from collections.abc import Iterator
from ._types import _TNestedReferences
_MembersMapping = MappingProxyType[str, Trait]
def _flatten(items: Iterable) -> Iterator:
"""Recursively flatten nested iterables (except strings).
Yields items from nested iterables one at a time, preserving order.
Strings are treated as atomic values, not iterable containers.
"""
for item in items:
if isinstance(item, Iterable) and not isinstance(item, (str, bytes)):
yield from _flatten(item)
else:
yield item
[docs]
@dataclass(frozen=True)
class Group:
"""A ``Group`` identifies a collection of ``Trait`` members.
Supports `set`-like operations (union, intersection, difference, etc.).
"""
id: str
"""Unique ID of the group."""
name: str
"""User-friendly description of a group."""
icon: str = field(repr=False, default="β")
"""Icon of the group."""
members: Iterable[Trait] = field(repr=False, default_factory=tuple)
"""Traits in this group.
Normalized to ``MappingProxyType[str, Trait]`` at init, providing O(1) lookup by ID.
"""
@property
def _members(self) -> _MembersMapping:
"""Typed access to members as ``MappingProxyType[str, Trait]``.
.. warning::
The ``members`` field is typed as ``Iterable[Trait]`` to accept any
iterable at construction time. After ``__post_init__``, it is always a
``MappingProxyType[str, Trait]``. This property provides a ``cast()`` to
that type, avoiding ``# type: ignore`` comments throughout the class.
"""
return cast(_MembersMapping, self.members)
def __post_init__(self):
"""Validate fields and normalize members to a sorted, deduplicated mapping."""
assert self.id, "Group ID cannot be empty."
assert self.name, "Group name cannot be empty."
assert self.icon, "Group icon cannot be empty."
# Accept either a MappingProxyType, dict, or iterable of Traits.
if isinstance(self.members, MappingProxyType):
traits = self.members.values()
elif isinstance(self.members, dict):
traits = self.members.values()
else:
traits = self.members
# Deduplicate and sort by ID, then build the immutable mapping.
sorted_traits = sorted(set(traits), key=lambda t: t.id)
object.__setattr__(
self,
"members",
MappingProxyType({t.id: t for t in sorted_traits}),
)
@property
def member_ids(self) -> frozenset[str]:
"""Set of member IDs that belong to this group."""
return frozenset(self._members.keys())
def __hash__(self) -> int:
"""Hash based on group ID and member IDs."""
return hash((self.id, self.member_ids))
@cached_property
def short_desc(self) -> str:
"""Returns the group name with its first letter in lowercase to be used as a
short description.
Mainly used to produce docstrings for function dynamically generated for each
group.
"""
return self.name[0].lower() + self.name[1:]
def __iter__(self) -> Iterator[Trait]:
"""Iterate over the members of the group."""
yield from self._members.values()
def __len__(self) -> int:
"""Return the number of members in the group."""
return len(self._members)
def __bool__(self) -> bool:
"""Return `True` if the group has members, `False` otherwise."""
return len(self._members) > 0
def __contains__(self, item: Trait | str) -> bool:
"""Test if ``Trait`` object or its ID is part of the group."""
if isinstance(item, str):
return item in self._members
return item.id in self._members and self._members[item.id] == item
def __getitem__(self, member_id: str) -> Trait:
"""Return the trait whose ID is ``member_id``."""
try:
return self._members[member_id]
except KeyError:
raise KeyError(f"No trait found whose ID is {member_id}") from None
[docs]
def items(self) -> Iterator[tuple[str, Trait]]:
"""Iterate over the traits of the group as key-value pairs."""
yield from self._members.items()
@staticmethod
def _extract_members(*other: _TNestedReferences) -> Iterator[Trait]:
"""Returns all traits found in ``other``.
``other`` can be an arbitrarily nested ``Iterable`` of ``Group``, ``Trait``, or
their IDs. ``None`` values and empty iterables are silently ignored.
.. caution::
Can returns duplicates.
"""
for item in _flatten(other):
match item:
case None:
continue
case Trait():
yield item
case Group():
yield from item._members.values()
case str():
# Prevent circular import.
from .operations import traits_from_ids
yield from traits_from_ids(item)
case _:
raise TypeError(f"Unsupported type: {type(item)}")
[docs]
def isdisjoint(self, other: _TNestedReferences) -> bool:
"""Return `True` if the group has no members in common with ``other``.
Groups are disjoint if and only if their intersection is an empty set.
``other`` can be an arbitrarily nested ``Iterable`` of ``Group`` and ``Trait``.
"""
return set(self._members.values()).isdisjoint(self._extract_members(other))
[docs]
def fullyintersects(self, other: _TNestedReferences) -> bool:
"""Return `True` if the group has all members in common with ``other``."""
return set(self._members.values()) == set(self._extract_members(other))
[docs]
def issubset(self, other: _TNestedReferences) -> bool:
"""Test whether every member in the group is in other."""
return set(self._members.values()).issubset(self._extract_members(other))
__le__ = issubset
def __lt__(self, other: _TNestedReferences) -> bool:
"""Test whether every member in the group is in other, but not all."""
return self <= other and set(self._members.values()) != set(
self._extract_members(other)
)
[docs]
def issuperset(self, other: _TNestedReferences) -> bool:
"""Test whether every member in other is in the group."""
return set(self._members.values()).issuperset(self._extract_members(other))
__ge__ = issuperset
def __gt__(self, other: _TNestedReferences) -> bool:
"""Test whether every member in other is in the group, but not all."""
return self >= other and set(self._members.values()) != set(
self._extract_members(other)
)
[docs]
def union(self, *others: _TNestedReferences) -> Group:
"""Return a new ``Group`` with members from the group and all others.
.. caution::
The new ``Group`` will inherits the metadata of the first one. All other
groups' metadata will be ignored.
"""
return Group(
self.id,
self.name,
self.icon,
tuple(
set(self._members.values()).union(
*(self._extract_members(other) for other in others)
)
),
)
__or__ = union
__ior__ = union
[docs]
def intersection(self, *others: _TNestedReferences) -> Group:
"""Return a new ``Group`` with members common to the group and all others.
.. caution::
The new ``Group`` will inherits the metadata of the first one. All other
groups' metadata will be ignored.
"""
return Group(
self.id,
self.name,
self.icon,
tuple(
set(self._members.values()).intersection(
*(self._extract_members(other) for other in others)
)
),
)
__and__ = intersection
__iand__ = intersection
[docs]
def difference(self, *others: _TNestedReferences) -> Group:
"""Return a new ``Group`` with members in the group that are not in the others.
.. caution::
The new ``Group`` will inherits the metadata of the first one. All other
groups' metadata will be ignored.
"""
return Group(
self.id,
self.name,
self.icon,
tuple(
set(self._members.values()).difference(
*(self._extract_members(other) for other in others)
)
),
)
__sub__ = difference
__isub__ = difference
[docs]
def symmetric_difference(self, other: _TNestedReferences) -> Group:
"""Return a new ``Group`` with members in either the group or other but not both.
.. caution::
The new ``Group`` will inherits the metadata of the first one. All other
groups' metadata will be ignored.
"""
return Group(
self.id,
self.name,
self.icon,
tuple(
set(self._members.values()).symmetric_difference(
self._extract_members(other)
)
),
)
__xor__ = symmetric_difference
__ixor__ = symmetric_difference
[docs]
def copy(
self,
id: str | None = None,
name: str | None = None,
icon: str | None = None,
members: Iterable[Trait] | None = None,
) -> Group:
"""Return a shallow copy of the group.
Fields can be overridden by passing new values as arguments.
"""
kwargs = {k: v for k, v in locals().items() if k != "self" and v is not None}
return replace(self, **kwargs)
[docs]
def add(self, member: Trait | str) -> Group:
"""Return a new ``Group`` with the specified trait added.
If the trait is already in the group, returns a copy unchanged.
Args:
member: A ``Trait`` object or trait ID string to add.
Returns:
A new ``Group`` instance with the trait added.
Raises:
ValueError: If the trait ID is not recognized.
"""
if isinstance(member, str):
# Prevent circular import.
from .operations import traits_from_ids
traits = traits_from_ids(member)
member = traits[0]
if member in self:
return self.copy()
return Group(
self.id,
self.name,
self.icon,
tuple(set(self._members.values()) | {member}),
)
[docs]
def remove(self, member: Trait | str) -> Group:
"""Return a new ``Group`` with the specified trait removed.
Raises ``KeyError`` if the trait is not in the group.
Args:
member: A ``Trait`` object or trait ID string to remove.
Returns:
A new ``Group`` instance with the trait removed.
Raises:
KeyError: If the trait is not in the group.
"""
member_id = member.id if isinstance(member, Trait) else member
if member_id not in self._members:
raise KeyError(f"Trait '{member_id}' is not in the group")
new_members = {
tid: trait for tid, trait in self._members.items() if tid != member_id
}
return Group(
self.id,
self.name,
self.icon,
tuple(new_members.values()),
)
[docs]
def discard(self, member: Trait | str) -> Group:
"""Return a new ``Group`` with the specified trait removed if present.
Unlike ``remove()``, this does not raise an error if the trait is not found.
Args:
member: A ``Trait`` object or trait ID string to remove.
Returns:
A new ``Group`` instance with the trait removed, or a copy if not present.
"""
member_id = member.id if isinstance(member, Trait) else member
if member_id not in self._members:
return self.copy()
new_members = {
tid: trait for tid, trait in self._members.items() if tid != member_id
}
return Group(
self.id,
self.name,
self.icon,
tuple(new_members.values()),
)
[docs]
def pop(self, member_id: str | None = None) -> tuple[Trait, Group]:
"""Remove and return a trait from the group.
Args:
member_id: Optional trait ID to remove. If not provided, removes an arbitrary
trait (specifically, the first one in iteration order).
Returns:
A tuple of (removed_trait, new_group).
Raises:
KeyError: If ``member_id`` is provided but not found in the group.
KeyError: If the group is empty.
"""
if not self._members:
raise KeyError("pop from an empty group")
if member_id is None:
# Pop arbitrary (first) member.
member_id = next(iter(self._members))
if member_id not in self._members:
raise KeyError(f"Trait '{member_id}' is not in the group")
popped_trait = self._members[member_id]
new_members = {
tid: trait for tid, trait in self._members.items() if tid != member_id
}
new_group = Group(
self.id,
self.name,
self.icon,
tuple(new_members.values()),
)
return popped_trait, new_group
[docs]
def clear(self) -> Group:
"""Return a new empty ``Group`` with the same metadata.
Returns:
A new ``Group`` instance with no members but same id, name, and icon.
"""
return Group(
self.id,
self.name,
self.icon,
tuple(),
)