Source code for syllabus.utils

from itertools import product, groupby
from typing import Union

import numpy as np


[docs] def decorate_all_functions(function_decorator): def decorator(cls): for base_cls in cls.__bases__: for name, obj in vars(base_cls).items(): parent_func = getattr(base_cls, name) child_func = getattr(cls, name) # Only apply decorator to functions not overridden by subclass. if callable(obj) and child_func == parent_func: setattr(cls, name, function_decorator(obj)) return cls return decorator
[docs] class UsageError(Exception): pass
[docs] def enumerate_axes(list_or_size: Union[np.ndarray, int]): if isinstance(list_or_size, int) or isinstance(list_or_size, np.int64): return tuple(range(list_or_size)) elif isinstance(list_or_size, list) or isinstance(list_or_size, np.ndarray): return tuple(product(*[enumerate_axes(x) for x in list_or_size])) else: raise NotImplementedError(f"{type(list_or_size)}")
[docs] def compress_ranges(nums): nums = sorted(set(nums)) ranges = [] for _, group in groupby(enumerate(nums), lambda x: x[1] - x[0]): group = list(group) start, end = group[0][1], group[-1][1] ranges.append(f"{start}" if start == end else f"{start}-{end}") return ", ".join(ranges)