From fce4fb9b8840b84f0f3d966e54fbef73ec602aea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Andr=C3=A9=20Lureau?= Date: Tue, 2 Sep 2025 12:35:50 +0400 Subject: inventory: add type annotations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marc-André Lureau --- lcitool/inventory.py | 70 +++++++++++++++++++++++++++++++++++----------------- pyproject.toml | 1 + 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/lcitool/inventory.py b/lcitool/inventory.py index ea65bdd..379456a 100644 --- a/lcitool/inventory.py +++ b/lcitool/inventory.py @@ -10,6 +10,10 @@ from pathlib import Path from lcitool import util, LcitoolError from lcitool.packages import package_names_by_type +from lcitool.config import Config +from lcitool.projects import Projects +from lcitool.targets import BuildTarget, Targets +from typing import Any, Dict, List, Optional, Union log = logging.getLogger(__name__) @@ -17,39 +21,53 @@ log = logging.getLogger(__name__) class InventoryError(LcitoolError): """Global exception type for the inventory module.""" - def __init__(self, message): + def __init__(self, message: str): super().__init__(message, "Inventory") class Inventory: + def __init__( + self, + targets: Targets, + config: Config, + inventory_path: Optional[Path] = None, + ): + self._targets = targets + self._config = config + self._host_facts: Optional[Dict[str, Dict[str, Any]]] = None + self._ansible_inventory: Optional[Dict[str, Dict[str, Any]]] = None + self._inventory_path = inventory_path + @property - def ansible_inventory(self): + def ansible_inventory( + self, + ) -> Dict[ + str, + Dict[str, Any], + ]: if self._ansible_inventory is None: self._ansible_inventory = self._get_ansible_inventory() + assert isinstance(self._ansible_inventory, dict) return self._ansible_inventory @property - def host_facts(self): + def host_facts(self) -> Dict[str, Dict[str, Any]]: if self._host_facts is None: self._host_facts = self._load_host_facts() + assert isinstance(self._host_facts, dict) return self._host_facts @property - def hosts(self): + def hosts(self) -> List[str]: return list(self.host_facts.keys()) - def __init__(self, targets, config, inventory_path=None): - self._targets = targets - self._config = config - self._host_facts = None - self._ansible_inventory = None - self._inventory_path = inventory_path - - def _get_ansible_inventory(self): + def _get_ansible_inventory( + self, + ) -> Any: from lcitool.ansible_wrapper import AnsibleWrapper, AnsibleWrapperError - inventory_sources = [] + inventory_sources: List[Union[Path, Dict[str, Any]]] = [] if self._inventory_path is None: self._inventory_path = Path(util.get_config_dir(), "inventory") @@ -76,10 +94,10 @@ class Inventory: return inventory - def _get_libvirt_inventory(self): + def _get_libvirt_inventory(self) -> Dict[str, Any]: from lcitool.libvirt_wrapper import LibvirtWrapper - inventory = {"all": {"children": {}}} + inventory: Dict[str, Any] = {"all": {"children": {}}} children = inventory["all"]["children"] for host, target in LibvirtWrapper().hosts.items(): @@ -89,11 +107,13 @@ class Inventory: return inventory - def _load_host_facts(self): - facts = {} - groups = {} + def _load_host_facts( + self, + ) -> Dict[str, Dict[str, Any]]: + facts: Dict[str, Dict[str, Any]] = {} + groups: Dict[str, Any] = {} - def _rec(inventory, group_name): + def _rec(inventory: Dict[str, Any], group_name: str) -> None: for key, subinventory in inventory.items(): if key == "hosts": if ( @@ -150,7 +170,7 @@ class Inventory: return facts - def expand_hosts(self, pattern): + def expand_hosts(self, pattern: str) -> List[str]: try: return util.expand_pattern(pattern, self.hosts, "hosts") except InventoryError as ex: @@ -159,10 +179,14 @@ class Inventory: log.debug(f"Failed to load expand '{pattern}'") raise InventoryError(f"Failed to expand '{pattern}': {ex}") - def get_host_target_name(self, host): - return self.host_facts[host]["target"] + def get_host_target_name(self, host: str) -> str: + target = self.host_facts[host]["target"] + assert isinstance(target, str) + return target - def get_group_vars(self, target, projects, projects_expanded): + def get_group_vars( + self, target: BuildTarget, projects: Projects, projects_expanded: List[str] + ) -> Dict[str, Union[Dict[str, str], str, List[str]]]: # resolve the package mappings to actual package names internal_wanted_projects = ["base", "developer", "vm"] if self._config.values["install"]["cloud_init"]: diff --git a/pyproject.toml b/pyproject.toml index 82bb19d..8be152e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ files = [ "lcitool/containers/*.py", "lcitool/install/*.py", "lcitool/gitlab.py", + "lcitool/inventory.py", "lcitool/libvirt_wrapper.py", "lcitool/logger.py", "lcitool/projects.py", -- cgit v1.1