# Copyright Kevin Deldycke <kevin@deldycke.com> and contributors.
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
"""Trait base class for architectures, platforms, CI systems, and more.
A trait represents a distinguishing characteristic of a runtime environment.
Each trait has a unique ID, a human-readable name, an icon, and the ability
to detect if it matches the current environment.
Data associated with traits can be aggressively cached and frozen, as they're only
computed based on environment-dependent values.
"""
from __future__ import annotations
import platform
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import cached_property
import distro
import extra_platforms
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Any
def _recursive_update(
a: dict[str, Any], b: dict[str, Any], strict: bool = False
) -> dict[str, Any]:
"""Like standard ``dict.update()``, but recursive so sub-dict gets updated.
Ignore elements present in ``b`` but not in ``a``. Unless ``strict`` is set to
``True``, in which case a ``ValueError`` exception will be raised.
"""
for k, v in b.items():
if isinstance(v, dict) and isinstance(a.get(k), dict):
a[k] = _recursive_update(a[k], v, strict=strict)
# Ignore elements unregistered in the template structure.
elif k in a:
a[k] = b[k]
elif strict:
raise ValueError(f"Parameter {k!r} found in second dict but not in first.")
return a
def _remove_blanks(
tree: dict,
remove_none: bool = True,
remove_dicts: bool = True,
remove_str: bool = True,
) -> dict:
"""Returns a copy of a dict without items whose values blanks.
Are considered blanks:
- ``None`` values
- empty strings
- empty ``dict``
The removal of each of these class can be skipped by setting ``remove_*``
parameters.
Dictionarries are inspected recursively and their own blank values are removed.
"""
result = {}
for key, value in tree.items():
# Skip None values if configured.
if remove_none and value is None:
continue
# Recursively process nested dicts.
if isinstance(value, dict):
cleaned = _remove_blanks(value, remove_none, remove_dicts, remove_str)
# Skip empty dicts if configured.
if remove_dicts and not cleaned:
continue
result[key] = cleaned
# Skip empty strings if configured.
elif remove_str and isinstance(value, str) and not value:
continue
else:
result[key] = value
return result
[docs]
@dataclass(frozen=True)
class Trait(ABC):
"""Base class for system traits like platforms and architectures.
A trait is a distinguishing characteristic of the runtime environment that can
be detected and identified. Examples include:
- Operating systems (macOS, Linux, Windows)
- CPU architectures (x86_64, ARM64)
- CI/CD environments (GitHub Actions, GitLab CI)
Each trait has:
- A unique ID for programmatic use
- A human-readable name
- An icon for visual representation
- A URL to official documentation
- A ``current`` property that detects if this trait matches the runtime
"""
id: str
"""Unique ID of the trait."""
name: str
"""User-friendly name of the trait."""
icon: str = field(repr=False, default="β")
"""Icon of the trait."""
url: str = field(repr=False, default="")
"""URL to the trait's official website or documentation."""
detection_func_id: str = field(repr=False, init=False)
"""ID of the detection function for this trait.
The detection function is expected to be named ``is_<id>()`` and located at the root
of the ``extra_platforms`` module.
"""
def __post_init__(self) -> None:
"""Validate and normalize trait fields.
- Ensure the trait ID, name, icon and URL are not empty.
- Ensure the URL starts with ``https://``.
- Set the detection function ID based on the trait ID.
- Populate the docstring.
"""
assert self.id, f"{self.__class__.__name__} ID cannot be empty."
assert self.name, f"{self.__class__.__name__} name cannot be empty."
assert self.icon, f"{self.__class__.__name__} icon cannot be empty."
assert self.url, f"{self.__class__.__name__} URL cannot be empty."
assert self.url.startswith("https://"), "URL must start with https://."
object.__setattr__(self, "detection_func_id", f"is_{self.id}")
object.__setattr__(self, "__doc__", f"Identify {self.name}.")
@cached_property
def short_desc(self) -> str:
"""Returns a short description of the trait.
Mainly used to produce docstrings for functions dynamically generated for each
group.
"""
return self.name
@cached_property
def current(self) -> bool:
"""Returns whether the current environment matches this trait.
The detection function is dynamically looked up based on the trait ID, and is
expected to be found at the root of the ``extra_platforms`` module.
Raises ``NotImplementedError`` if a detection function cannot be found.
.. hint::
This is a property to avoid calling all detection heuristics on
``Trait`` objects creation, which happens at module import time.
"""
func = getattr(extra_platforms, self.detection_func_id, None)
if not func:
raise NotImplementedError(
f"Detection function {self.detection_func_id}() is not implemented."
)
return func() # type: ignore[no-any-return]
[docs]
@abstractmethod
def info(self) -> dict:
"""Returns all trait attributes that can be gathered.
Subclasses should override this to include trait-specific information.
"""
...
def _base_info(self) -> dict[str, str | bool | None]:
"""Returns the base info dictionary common to all traits."""
return {
"id": self.id,
"name": self.name,
"icon": self.icon,
"url": self.url,
"current": self.current,
}
[docs]
@dataclass(frozen=True)
class Architecture(Trait):
"""A CPU architecture identifies a `processor instruction set
<https://en.wikipedia.org/wiki/Instruction_set_architecture>`_.
It has a unique ID, a human-readable name, and boolean to flag current architecture.
"""
icon: str = field(repr=False, default="β£")
"""Icon of the architecture."""
def __post_init__(self) -> None:
"""Validate and normalize architecture fields."""
super().__post_init__()
# Customize the docstring for architectures.
object.__setattr__(self, "__doc__", f"Identify {self.name} architecture.")
[docs]
def info(self) -> dict[str, str | bool | None]:
"""Returns all architecture attributes we can gather."""
info: dict[str, str | bool | None] = {
**self._base_info(),
"machine": None,
"processor": None,
}
if self.current:
info["machine"] = platform.machine() or None
info["processor"] = platform.processor() or None
return info
[docs]
@dataclass(frozen=True)
class CI(Trait):
"""A CI/CD environment identifies a continuous integration platform.
It has a unique ID, a human-readable name, and boolean to flag current CI.
"""
icon: str = field(repr=False, default="β²")
"""Icon of the CI environment."""
def __post_init__(self) -> None:
"""Validate and normalize CI fields."""
super().__post_init__()
# Customize the docstring for CI environments.
object.__setattr__(self, "__doc__", f"Identify {self.name} CI environment.")
[docs]
def info(self) -> dict[str, str | bool | None]:
"""Returns all CI attributes we can gather."""
return {
**self._base_info(),
}