Files
kleinanzeigen-bot/src/kleinanzeigen_bot/model/update_check_state.py

196 lines
7.9 KiB
Python

# SPDX-FileCopyrightText: © Jens Bergmann and contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
# SPDX-ArtifactOfProjectHomePage: https://github.com/Second-Hand-Friends/kleinanzeigen-bot/
from __future__ import annotations
import datetime
import json
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pathlib import Path
from kleinanzeigen_bot.utils import dicts, loggers, misc, xdg_paths
from kleinanzeigen_bot.utils.pydantics import ContextualModel
LOG = loggers.get_logger(__name__)
# Current version of the state file format
CURRENT_STATE_VERSION = 1
# Maximum allowed interval in days
MAX_INTERVAL_DAYS = 30
class UpdateCheckState(ContextualModel):
"""State for update checking functionality."""
version:int = CURRENT_STATE_VERSION
last_check:datetime.datetime | None = None
@classmethod
def _parse_timestamp(cls, timestamp_str:str) -> datetime.datetime | None:
"""Parse a timestamp string and ensure it's in UTC.
Args:
timestamp_str: The timestamp string to parse.
Returns:
The parsed timestamp in UTC, or None if parsing fails.
"""
try:
timestamp = datetime.datetime.fromisoformat(timestamp_str)
if timestamp.tzinfo is None:
# If no timezone info, assume UTC
timestamp = timestamp.replace(tzinfo = datetime.timezone.utc)
elif timestamp.tzinfo != datetime.timezone.utc:
# Convert to UTC if in a different timezone
timestamp = timestamp.astimezone(datetime.timezone.utc)
return timestamp
except ValueError as e:
LOG.warning("Invalid timestamp format in state file: %s", e)
return None
@classmethod
def load(cls, state_file:Path) -> UpdateCheckState:
"""Load the update check state from a file.
Args:
state_file: The path to the state file.
Returns:
The loaded state.
"""
if not state_file.exists():
return cls()
if state_file.stat().st_size == 0:
return cls()
try:
data = dicts.load_dict(str(state_file))
if not data:
return cls()
# Handle version migration
version = data.get("version", 0)
if version < CURRENT_STATE_VERSION:
LOG.info("Migrating update check state from version %d to %d", version, CURRENT_STATE_VERSION)
data = cls._migrate_state(data, version)
# Parse last_check timestamp
if "last_check" in data:
data["last_check"] = cls._parse_timestamp(data["last_check"])
return cls.model_validate(data)
except (json.JSONDecodeError, ValueError) as e:
LOG.warning("Failed to load update check state: %s", e)
return cls()
@classmethod
def _migrate_state(cls, data:dict[str, Any], from_version:int) -> dict[str, Any]:
"""Migrate state data from an older version to the current version.
Args:
data: The state data to migrate.
from_version: The version of the state data.
Returns:
The migrated state data.
"""
# Version 0 to 1: Add version field
if from_version == 0:
data["version"] = CURRENT_STATE_VERSION
LOG.debug("Migrated state from version 0 to 1: Added version field")
return data
def save(self, state_file:Path) -> None:
"""Save the update check state to a file.
Args:
state_file: The path to the state file.
"""
try:
data = self.model_dump()
if data["last_check"]:
# Ensure timestamp is in UTC before saving
if data["last_check"].tzinfo != datetime.timezone.utc:
data["last_check"] = data["last_check"].astimezone(datetime.timezone.utc)
data["last_check"] = data["last_check"].isoformat()
xdg_paths.ensure_directory(state_file.parent, "update check state directory")
dicts.save_dict(str(state_file), data)
except PermissionError:
LOG.warning("Permission denied when saving update check state to %s", state_file)
except Exception as e:
LOG.warning("Failed to save update check state: %s", e)
def update_last_check(self) -> None:
"""Update the last check time to now in UTC."""
self.last_check = datetime.datetime.now(datetime.timezone.utc)
def _validate_update_interval(self, interval:str) -> tuple[datetime.timedelta, bool, str]:
"""
Validate the update check interval string.
Returns (timedelta, is_valid, reason).
"""
td = misc.parse_duration(interval)
# Accept explicit zero (e.g. "0d", "0h", "0m", "0s", "0") as invalid, but distinguish from typos
if td.total_seconds() == 0:
if interval.strip() in {"0d", "0h", "0m", "0s", "0"}:
return td, False, "Interval is zero, which is not allowed."
return td, False, "Invalid interval format or unsupported unit."
if td.total_seconds() < 0:
return td, False, "Negative interval is not allowed."
return td, True, ""
def should_check(self, interval:str, channel:str = "latest") -> bool:
"""
Determine if an update check should be performed based on the provided interval.
Args:
interval: The interval string (e.g. '7d', '1d 12h', etc.)
channel: The update channel ('latest' or 'preview') for fallback default interval.
Returns:
bool: True if an update check should be performed, False otherwise.
Notes:
- If interval is invalid, negative, zero, or above max, falls back to default interval for the channel.
- Only returns True if more than the interval has passed since last_check.
- Always compares in UTC.
"""
fallback = False
td = None
reason = ""
td, is_valid, reason = self._validate_update_interval(interval)
total_days = td.total_seconds() / 86400 if td else 0
epsilon = 1e-6
if not is_valid:
if reason == "Interval is zero, which is not allowed.":
LOG.warning("Interval is zero: %s. Minimum interval is 1d. Using default interval for this run.", interval)
elif reason == "Invalid interval format or unsupported unit.":
LOG.warning("Invalid interval format or unsupported unit: %s. Using default interval for this run.", interval)
elif reason == "Negative interval is not allowed.":
LOG.warning("Negative interval: %s. Minimum interval is 1d. Using default interval for this run.", interval)
fallback = True
elif total_days > MAX_INTERVAL_DAYS + epsilon:
LOG.warning("Interval too long: %s. Maximum interval is 30d. Using default interval for this run.", interval)
fallback = True
elif total_days < 1 - epsilon:
LOG.warning("Interval too short: %s. Minimum interval is 1d. Using default interval for this run.", interval)
fallback = True
if fallback:
# Fallback to default interval based on channel
if channel == "preview":
td = misc.parse_duration("1d")
LOG.warning("Falling back to default interval: 1d (preview channel). Please fix your config to avoid this warning.")
else:
td = misc.parse_duration("7d")
LOG.warning("Falling back to default interval: 7d (latest channel). Please fix your config to avoid this warning.")
if not self.last_check:
return True
now = datetime.datetime.now(datetime.timezone.utc)
elapsed = now - self.last_check
# Compare using integer seconds to avoid microsecond-level flakiness
return int(elapsed.total_seconds()) > int(td.total_seconds())