""" Copyright (C) 2022 Sebastian Thomschke and contributors SPDX-License-Identifier: AGPL-3.0-or-later """ import copy, decimal, json, logging, os, re, secrets, sys, traceback, time from importlib.resources import read_text as get_resource_as_string from collections.abc import Callable, Sized from datetime import datetime from types import FrameType, ModuleType, TracebackType from typing import Any, Final, TypeVar import coloredlogs, inflect from ruamel.yaml import YAML from selenium.webdriver.chrome.webdriver import WebDriver LOG_ROOT:Final[logging.Logger] = logging.getLogger() LOG:Final[logging.Logger] = logging.getLogger("kleinanzeigen_bot.utils") # https://mypy.readthedocs.io/en/stable/generics.html#generic-functions T = TypeVar('T') def abspath(relative_path:str, relative_to:str | None = None) -> str: """ Makes a given relative path absolute based on another file/folder """ if os.path.isabs(relative_path): return relative_path if not relative_to: return os.path.abspath(relative_path) if os.path.isfile(relative_to): relative_to = os.path.dirname(relative_to) return os.path.normpath(os.path.join(relative_to, relative_path)) def ensure(condition:Any | bool | Callable[[], bool], error_message:str, timeout:float = 5, poll_requency:float = 0.5) -> None: """ :param timeout: timespan in seconds until when the condition must become `True`, default is 5 seconds :param poll_requency: sleep interval between calls in seconds, default is 0.5 seconds :raises AssertionError: if condition did not come `True` within given timespan """ if not isinstance(condition, Callable): # type: ignore[arg-type] # https://github.com/python/mypy/issues/6864 if condition: return raise AssertionError(error_message) if timeout < 0: raise AssertionError("[timeout] must be >= 0") if poll_requency < 0: raise AssertionError("[poll_requency] must be >= 0") start_at = time.time() while not condition(): # type: ignore[operator] elapsed = time.time() - start_at if elapsed >= timeout: raise AssertionError(error_message) time.sleep(poll_requency) def is_frozen() -> bool: """ >>> is_frozen() False """ return getattr(sys, "frozen", False) def apply_defaults( target:dict[Any, Any], defaults:dict[Any, Any], ignore:Callable[[Any, Any], bool] = lambda _k, _v: False, override:Callable[[Any, Any], bool] = lambda _k, _v: False ) -> dict[Any, Any]: """ >>> apply_defaults({}, {"foo": "bar"}) {'foo': 'bar'} >>> apply_defaults({"foo": "foo"}, {"foo": "bar"}) {'foo': 'foo'} >>> apply_defaults({"foo": ""}, {"foo": "bar"}) {'foo': ''} >>> apply_defaults({}, {"foo": "bar"}, ignore = lambda k, _: k == "foo") {} >>> apply_defaults({"foo": ""}, {"foo": "bar"}, override = lambda _, v: v == "") {'foo': 'bar'} >>> apply_defaults({"foo": None}, {"foo": "bar"}, override = lambda _, v: v == "") {'foo': None} """ for key, default_value in defaults.items(): if key in target: if isinstance(target[key], dict) and isinstance(default_value, dict): apply_defaults(target[key], default_value, ignore = ignore) elif override(key, target[key]): target[key] = copy.deepcopy(default_value) elif not ignore(key, default_value): target[key] = copy.deepcopy(default_value) return target def safe_get(a_map:dict[Any, Any], *keys:str) -> Any: """ >>> safe_get({"foo": {}}, "foo", "bar") is None True >>> safe_get({"foo": {"bar": "some_value"}}, "foo", "bar") 'some_value' """ if a_map: for key in keys: try: a_map = a_map[key] except (KeyError, TypeError): return None return a_map def configure_console_logging() -> None: stdout_log = logging.StreamHandler(sys.stderr) stdout_log.setLevel(logging.DEBUG) stdout_log.setFormatter(coloredlogs.ColoredFormatter("[%(levelname)s] %(message)s")) stdout_log.addFilter(type("", (logging.Filter,), { "filter": lambda rec: rec.levelno <= logging.INFO })) LOG_ROOT.addHandler(stdout_log) stderr_log = logging.StreamHandler(sys.stderr) stderr_log.setLevel(logging.WARNING) stderr_log.setFormatter(coloredlogs.ColoredFormatter("[%(levelname)s] %(message)s")) LOG_ROOT.addHandler(stderr_log) def on_exception(ex_type:type[BaseException], ex_value:Any, ex_traceback:TracebackType | None) -> None: if issubclass(ex_type, KeyboardInterrupt): sys.__excepthook__(ex_type, ex_value, ex_traceback) elif LOG.isEnabledFor(logging.DEBUG) or isinstance(ex_value, (AttributeError, ImportError, NameError, TypeError)): LOG.error("".join(traceback.format_exception(ex_type, ex_value, ex_traceback))) elif isinstance(ex_value, AssertionError): LOG.error(ex_value) else: LOG.error("%s: %s", ex_type.__name__, ex_value) def on_exit() -> None: for handler in LOG_ROOT.handlers: handler.flush() def on_sigint(_sig:int, _frame:FrameType | None) -> None: LOG.warning("Aborted on user request.") sys.exit(0) def pause(min_ms:int = 200, max_ms:int = 2000) -> None: if max_ms <= min_ms: duration = min_ms else: duration = secrets.randbelow(max_ms - min_ms) + min_ms LOG.log(logging.INFO if duration > 1500 else logging.DEBUG, " ... pausing for %d ms ...", duration) time.sleep(duration / 1000) def pluralize(word:str, count:int | Sized, prefix:bool = True) -> str: """ >>> pluralize("field", 1) '1 field' >>> pluralize("field", 2) '2 fields' >>> pluralize("field", 2, prefix = False) 'fields' """ if not hasattr(pluralize, "inflect"): pluralize.inflect = inflect.engine() if isinstance(count, Sized): count = len(count) plural:str = pluralize.inflect.plural_noun(word, count) if prefix: return f"{count} {plural}" return plural def load_dict(filepath:str, content_label:str = "") -> dict[str, Any]: """ :raises FileNotFoundError """ data = load_dict_if_exists(filepath, content_label) if data is None: raise FileNotFoundError(filepath) return data def load_dict_if_exists(filepath:str, content_label:str = "") -> dict[str, Any] | None: filepath = os.path.abspath(filepath) LOG.info("Loading %s[%s]...", content_label and content_label + " from " or "", filepath) _, file_ext = os.path.splitext(filepath) if file_ext not in [".json", ".yaml", ".yml"]: raise ValueError(f'Unsupported file type. The file name "{filepath}" must end with *.json, *.yaml, or *.yml') if not os.path.exists(filepath): return None with open(filepath, encoding = "utf-8") as file: return json.load(file) if filepath.endswith(".json") else YAML().load(file) def load_dict_from_module(module:ModuleType, filename:str, content_label:str = "") -> dict[str, Any]: """ :raises FileNotFoundError """ LOG.debug("Loading %s[%s.%s]...", content_label and content_label + " from " or "", module.__name__, filename) _, file_ext = os.path.splitext(filename) if file_ext not in (".json", ".yaml", ".yml"): raise ValueError(f'Unsupported file type. The file name "{filename}" must end with *.json, *.yaml, or *.yml') content = get_resource_as_string(module, filename) return json.loads(content) if filename.endswith(".json") else YAML().load(content) def save_dict(filepath:str, content:dict[str, Any]) -> None: filepath = os.path.abspath(filepath) LOG.info("Saving [%s]...", filepath) with open(filepath, "w", encoding = "utf-8") as file: if filepath.endswith(".json"): file.write(json.dumps(content, indent = 2, ensure_ascii = False)) else: yaml = YAML() yaml.indent(mapping = 2, sequence = 4, offset = 2) yaml.allow_duplicate_keys = False yaml.explicit_start = False yaml.dump(content, file) def parse_decimal(number:float | int | str) -> decimal.Decimal: """ >>> parse_decimal(5) Decimal('5') >>> parse_decimal(5.5) Decimal('5.5') >>> parse_decimal("5.5") Decimal('5.5') >>> parse_decimal("5,5") Decimal('5.5') >>> parse_decimal("1.005,5") Decimal('1005.5') >>> parse_decimal("1,005.5") Decimal('1005.5') """ try: return decimal.Decimal(number) except decimal.InvalidOperation as ex: parts = re.split("[.,]", str(number)) try: return decimal.Decimal("".join(parts[:-1]) + "." + parts[-1]) except decimal.InvalidOperation: raise decimal.DecimalException(f"Invalid number format: {number}") from ex def parse_datetime(date:datetime | str | None) -> datetime | None: """ >>> parse_datetime(datetime(2020, 1, 1, 0, 0)) datetime.datetime(2020, 1, 1, 0, 0) >>> parse_datetime("2020-01-01T00:00:00") datetime.datetime(2020, 1, 1, 0, 0) >>> parse_datetime(None) """ if date is None: return None if isinstance(date, datetime): return date return datetime.fromisoformat(date) def smooth_scroll_page(driver: WebDriver, scroll_length: int = 10, scroll_speed: int = 10000, scroll_back_top: bool = False): """ Scrolls the current page of a web driver session. :param driver: the web driver session :param scroll_length: the length of a single scroll iteration, determines smoothness of scrolling, lower is smoother :param scroll_speed: the speed of scrolling, higher is faster :param scroll_back_top: whether to scroll the page back to the top after scrolling to the bottom """ current_y_pos = 0 bottom_y_pos: int = driver.execute_script('return document.body.scrollHeight;') # get bottom position by JS while current_y_pos < bottom_y_pos: # scroll in steps until bottom reached current_y_pos += scroll_length driver.execute_script(f'window.scrollTo(0, {current_y_pos});') # scroll one step time.sleep(scroll_length / scroll_speed) if scroll_back_top: # scroll back to top in same style while current_y_pos > 0: current_y_pos -= scroll_length driver.execute_script(f'window.scrollTo(0, {current_y_pos});') time.sleep(scroll_length / scroll_speed / 2) # double speed def extract_ad_id_from_ad_link(url: str) -> int: """ Extracts the ID of an ad, given by its reference link. :param url: the URL to the ad page :return: the ad ID, a (ten-digit) integer number """ num_part = url.split('/')[-1] # suffix id_part = num_part.split('-')[0] try: return int(id_part) except ValueError: print('The ad ID could not be extracted from the given ad reference!') return -1