Source code for tlo.dependencies

"""Functions for getting, checking and sorting dependencies of ``Module`` subclasses."""

import importlib
import inspect
import os
import pkgutil
from typing import Any, Callable, Generator, Iterable, Mapping, Optional, Set, Type, Union

import tlo.methods
from tlo import Module


[docs] class ModuleDependencyError(Exception): """Raised when a module dependency is missing or there are circular dependencies."""
[docs] class MultipleModuleInstanceError(Exception): """Raised when multiple instances of the same module are registered."""
DependencyGetter = Callable[[Union[Module, Type[Module]], Set[str]], Set[str]]
[docs] def get_init_dependencies( module: Union[Module, Type[Module]], module_names_present: Set[str] ) -> Set[str]: """Get the initialisation dependencies for a ``Module`` subclass. :param module: ``Module`` subclass to get dependencies for. :param module_names_present: Set of names of ``Module`` subclasses that will be present in simulation to use to select optional initialisation dependencies. :return: Set of ``Module`` subclass names corresponding to initialisation dependencies of ``module``, including any optional dependencies present. """ return ( module.INIT_DEPENDENCIES | (module.OPTIONAL_INIT_DEPENDENCIES & module_names_present) )
[docs] def get_all_dependencies( module: Union[Module, Type[Module]], module_names_present: Set[str] ) -> Set[str]: """Get all dependencies for a ``Module`` subclass. :param module: ``Module`` subclass to get dependencies for. :param module_names_present: Set of names of ``Module`` subclasses that will be present in simulation to use to select optional initialisation dependencies. :return: Set of ``Module`` subclass names corresponding to dependencies of ``module``, including any optional dependencies present. """ return ( get_init_dependencies(module, module_names_present) | module.ADDITIONAL_DEPENDENCIES )
[docs] def get_all_required_dependencies( module: Union[Module, Type[Module]], module_names_present: Optional[Set[str]] = None ) -> Set[str]: """Get all non-optional dependencies for a ``Module`` subclass. :param module: ``Module`` subclass to get dependencies for. :param module_names_present: Set of names of ``Module`` subclasses that will be present in simulation to use to select optional initialisation dependencies. Unused by this function, but kept as an argument to ensure a consistent interface with the other dependency-getter functions. :return: Set of ``Module`` subclass names corresponding to non-optional dependencies of ``module``. """ return module.INIT_DEPENDENCIES | module.ADDITIONAL_DEPENDENCIES
[docs] def topologically_sort_modules( module_instances: Iterable[Module], get_dependencies: DependencyGetter = get_init_dependencies, ) -> Generator[Module, None, None]: """Generator which yields topological sort of modules based on their dependencies. A topological sort of a dependency graph is ordered such that any dependencies of a node in the graph are guaranteed to be yielded before the node itself. This implementation uses a depth-first search algorithm (https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search). :param module_instances: The set of module instances to topologically sort. The yielded module instances will consist of all nodes in this set which must include instances of their (recursive) dependencies. :param get_dependencies: Function which given a module gets the set of module dependencies. Defaults to returing the ``Module.INIT_DEPENDENCIES`` class attribute. :raises ModuleDependencyError: Raised when a module dependency is missing from ``module_instances`` or a module has circular dependencies. :raises MultipleModuleInstanceError: Raised when multiple instances of the same module are passed in ``module_instances``. :return: Generator which yields module instances in topologically sorted order. """ module_instances = list(module_instances) module_instance_map = {type(module).__name__: module for module in module_instances} if len(module_instance_map) != len(module_instances): raise MultipleModuleInstanceError( 'Multiple instances of one or more `Module` subclasses were passed to the ' 'Simulation.register method. If you are sure this is correct, you can ' 'disable this check (and the automatic dependency sorting) by setting ' 'sort_modules=False in Simulation.register.' ) visited, currently_processing = set(), set() def depth_first_search(module): if module not in visited: if module in currently_processing: raise ModuleDependencyError( f'Module {module} has circular dependencies.' ) currently_processing.add(module) dependencies = get_dependencies( module_instance_map[module], module_instance_map.keys() ) for dependency in sorted(dependencies): if dependency not in module_instance_map: alternatives_with_instances = [ name for name, instance in module_instance_map.items() if dependency in instance.ALTERNATIVE_TO ] if len(alternatives_with_instances) != 1: message = ( f'Module {module} depends on {dependency} which is ' 'missing from modules to register' ) if len(alternatives_with_instances) == 0: message += f' as are any alternatives to {dependency}.' else: message += ( ' and there are multiple alternatives ' f'({alternatives_with_instances}) so which ' 'to use to resolve dependency is ambiguous.' ) raise ModuleDependencyError(message) else: yield from depth_first_search(alternatives_with_instances[0]) else: yield from depth_first_search(dependency) currently_processing.remove(module) visited.add(module) yield module_instance_map[module] for module_instance in module_instances: yield from depth_first_search(type(module_instance).__name__)
[docs] def is_valid_tlo_module_subclass(obj: Any, excluded_modules: Set[str]) -> bool: """Determine whether object is a ``Module`` subclass and not in an excluded set. :param obj: Object to check if ``Module`` subclass. :param excluded_modules: Set of names of ``Module`` subclasses to force check to return ``False`` for. :return: ``True`` is ``obj`` is a _strict_ subclass of ``Module`` and not in the ``excluded_modules`` set. """ return ( inspect.isclass(obj) and issubclass(obj, Module) and obj is not Module and obj.__name__ not in excluded_modules )
[docs] def get_module_class_map(excluded_modules: Set[str]) -> Mapping[str, Type[Module]]: """Constructs a map from ``Module`` subclass names to class objects. :param excluded_modules: Set of ``Module`` subclass names to exclude from map. :return: A mapping from unqualified ``Module`` subclass to names to the correponding class objects. This adds an implicit requirement that the names of all the ``Module`` subclasses are unique. :raises RuntimError: Raised if multiple ``Module`` subclasses with the same name are defined (and not included in the ``exclude_modules`` set). """ methods_package_path = os.path.dirname(inspect.getfile(tlo.methods)) module_classes = {} for _, methods_module_name, _ in pkgutil.iter_modules([methods_package_path]): methods_module = importlib.import_module(f'tlo.methods.{methods_module_name}') for _, obj in inspect.getmembers(methods_module): if is_valid_tlo_module_subclass(obj, excluded_modules): if module_classes.get(obj.__name__) not in {None, obj}: raise RuntimeError( f'Multiple modules with name {obj.__name__} are defined' ) else: module_classes[obj.__name__] = obj return module_classes
[docs] def get_dependencies_and_initialise( *module_classes: Type[Module], module_class_map: Mapping[str, Type[Module]], excluded_module_classes: Optional[Set[Module]] = None, get_dependencies: DependencyGetter = get_init_dependencies, **module_class_kwargs ) -> Generator[Module, None, None]: """Generate a sequence of ``Module`` instances including all dependencies. The generated sequence of initialised ``Module`` subclass instances will correspond to all the (recursive) dependencies of the seed ``Module`` subclasses in ``module_classes``. :param module_classes: ``Module`` subclass(es) to seed dependency search with. :param module_class_map: Mapping from ``Module`` subclass names to classes. :param excluded_module_classes: Any ``Module`` subclasses to not yield instances of in the returned generator. :param get_dependencies: Function which given a module gets the set of module dependencies. Defaults to returing the ``Module.INIT_DEPENDENCIES`` class attribute. :param module_class_kwargs: Any keyword arguments to pass to initialisers for ``Module`` subclasses if present in their ``__init__`` method signature. :return: Sequence of initialised ``Module`` subclass instances corresponding to all of the ``Module`` subclasses and their the (recursive) dependencies in the seed ``module_classes``. """ if excluded_module_classes is None: excluded_module_classes = set() visited = set() def initialise_module(module_class): signature = inspect.signature(module_class) relevant_kwargs = { key: value for key, value in module_class_kwargs.items() if key in signature.parameters } bound_args = signature.bind(**relevant_kwargs) return module_class(*bound_args.args, **bound_args.kwargs) def depth_first_search(module_class): if module_class not in (visited | excluded_module_classes): visited.add(module_class) yield initialise_module(module_class) dependencies = get_dependencies(module_class, module_class_map.keys()) for dependency_name in sorted(dependencies): yield from depth_first_search(module_class_map[dependency_name]) for module_class in module_classes: yield from depth_first_search(module_class)
[docs] def check_dependencies_present( module_instances: Iterable[Module], get_dependencies: DependencyGetter = get_all_dependencies, ): """Check whether an iterable of modules contains the required dependencies. :param module_instances: Iterable of ``Module`` subclass instances to check. :param get_dependencies: Callable which extracts the set of dependencies to check for from a module instance. Defaults to extracting all dependencies. :raises ModuleDependencyError: Raised if any dependencies are missing. """ module_instances = list(module_instances) modules_present = {type(module).__name__ for module in module_instances} modules_present_are_alternatives_to = set.union( # Force conversion to set to avoid errors when using set.union with frozenset *(set(module.ALTERNATIVE_TO) for module in module_instances) ) modules_required = set.union( *(set(get_dependencies(module, modules_present)) for module in module_instances) ) missing_dependencies = modules_required - modules_present missing_dependencies_without_alternatives_present = ( missing_dependencies - modules_present_are_alternatives_to ) if not missing_dependencies_without_alternatives_present == set(): raise ModuleDependencyError( 'One or more required dependency is missing from the module list and no ' 'alternative to this / these modules are available either: ' f'{missing_dependencies_without_alternatives_present}' )