Source code for dgenerate.filecache

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import datetime
import json
import os
import pathlib
import sqlite3
import typing
import uuid
import re

import fake_useragent
import filelock
import pyrfc6266
import requests
import tqdm

import dgenerate.memory as _memory
import dgenerate.messages as _messages

__doc__ = """
On disk file cache implementation and primitives.
"""


[docs] class WebFileCacheOfflineModeException(Exception): """ Exception raised when the web cache is in offline mode and a file is not found in the cache. """ pass
[docs] class KeyValueStore: """ A key-value store using SQLite3 for storage. """
[docs] def __init__(self, db_path: str): """ Initialize the key-value store. :param db_path: The path to the SQLite3 database file. """ db_dir = pathlib.Path(db_path).parent if not db_dir.exists(): db_dir.mkdir(parents=True, exist_ok=True) self.db_path = db_path self.connection = None self.cursor = None self.file_lock = filelock.FileLock(db_path + ".lock") self._lock_counter = 0
def __enter__(self): """ Enter a context managed by this key-value store. :return: This key-value store. """ if self._lock_counter == 0: self.file_lock.acquire() try: self.connection = sqlite3.connect(self.db_path) self.cursor = self.connection.cursor() self.cursor.execute( "CREATE TABLE IF NOT EXISTS store (key TEXT PRIMARY KEY, value TEXT, creation_date TIMESTAMP)") except Exception: if self.connection is not None: self.connection.close() self.connection = None self.cursor = None self.file_lock.release() raise self._lock_counter += 1 return self def __exit__(self, exc_type, exc_val, exc_tb): """ Exit a context managed by this key-value store. """ self._lock_counter -= 1 if self._lock_counter == 0: self.connection.commit() self.connection.close() self.file_lock.release()
[docs] def get(self, key: str, default=None): """ Get the value associated with a key. :param key: The key to get the value for. :param default: The default value to return if the key is not found. :return: The value associated with the key, or the default value if the key is not found. """ with self: self.cursor.execute("SELECT value FROM store WHERE key=?", (key,)) result = self.cursor.fetchone() if result is None: return default return result[0]
def __getitem__(self, key: str) -> str: """ Get the value associated with a key. :param key: The key to get the value for. :return: The value associated with the key. :raises KeyError: If the key is not found. """ with self: self.cursor.execute("SELECT value FROM store WHERE key=?", (key,)) result = self.cursor.fetchone() if result is None: raise KeyError(key) return result[0] def __setitem__(self, key: str, value: str): """ Set the value for a key. :param key: The key to set the value for. :param value: The value to set. """ with self: creation_date = datetime.datetime.now() self.cursor.execute("REPLACE INTO store (key, value, creation_date) VALUES (?, ?, ?)", (key, value, creation_date)) def __delitem__(self, key: str): """ Delete a key and its associated value. :param key: The key to delete. :raises KeyError: If the key is not found. """ with self: if key not in self: raise KeyError(key) self.cursor.execute("DELETE FROM store WHERE key=?", (key,)) def __contains__(self, key: str) -> bool: """ Check if a key is in the store. :param key: The key to check. :return: ``True`` if the key is in the store, ``False`` otherwise. """ with self: self.cursor.execute("SELECT 1 FROM store WHERE key=?", (key,)) return self.cursor.fetchone() is not None def __iter__(self) -> typing.Iterator[str]: """ Iterate over the keys and values in the store. :return: An iterator over the keys and values in the store. """ with self: self.cursor.execute("SELECT key, value FROM store") for row in self.cursor: yield row
[docs] def keys(self) -> typing.Iterator[str]: """ Get all keys in the store. :return: An iterator over the keys in the store. """ with self: self.cursor.execute("SELECT key FROM store") for row in self.cursor: yield row[0]
[docs] def items(self) -> typing.Iterator[str]: """ Get all values in the store. :return: An iterator over the values in the store. """ with self: self.cursor.execute("SELECT value FROM store") for row in self.cursor: yield row[0]
[docs] def delete_older_than(self, timedelta: datetime.timedelta) -> list[tuple[str, str]]: """ Delete all keys and their associated values that were created more than a certain time ago. :param timedelta: The age of the keys to delete. :return: The keys and values that were deleted. """ with self: try: cutoff_date = datetime.datetime.now() - timedelta except OverflowError: cutoff_date = datetime.datetime.min self.cursor.execute("SELECT key, value FROM store WHERE creation_date < ?", (cutoff_date,)) deleted_rows = self.cursor.fetchall() self.cursor.execute("DELETE FROM store WHERE creation_date < ?", (cutoff_date,)) return deleted_rows
[docs] class CachedFile: """Represents the path of a file in a :py:class:`.FileCache`""" path: str """ The path to the file on disk. """ metadata: dict[str, str] """ Optional metadata for the file stored in the database. """
[docs] def __init__(self, data_dict): """ :param data_dict: file data dict parsed from the cache database. """ self.path = data_dict['path'] self.metadata = data_dict['metadata']
[docs] class FileCache: """ A cache system that stores files and their metadata. """
[docs] def __init__(self, db_path: str, cache_dir: str): """ Initializes the :py:class:`.FileCache` object with a key-value store located at ``db_path`` and a cache directory at ``cache_dir``. If the cache directory doesn't exist, it creates it. :param db_path: The path to the key-value store database. :param cache_dir: The directory where the cache files are stored. """ self.kv_store = KeyValueStore(db_path) self.cache_dir = cache_dir if not os.path.exists(cache_dir): os.makedirs(cache_dir, exist_ok=True)
def __iter__(self) -> typing.Iterator[CachedFile]: """ Allows iteration over the key-value pairs in the key-value store, yielding each key and its corresponding :py:class:`.CachedFile` object. """ with self.kv_store: for k, v in self.kv_store: yield k, CachedFile(json.loads(v)) def __delitem__(self, key): """ Deletes the item with the specified key from the key-value store. This also deletes the associated file in the cache. """ with self.kv_store: if key not in self.kv_store: raise KeyError(key) file = CachedFile(json.loads(self.kv_store[key])) try: os.unlink(file.path) except OSError: pass del self.kv_store[key]
[docs] def items(self) -> typing.Iterator[CachedFile]: """ Yields all items in the key-value store as :py:class:`.CachedFile` objects. """ with self.kv_store: for k, v in self.kv_store: yield CachedFile(json.loads(v))
[docs] def keys(self) -> typing.Iterator[str]: """ Yields all keys in the key-value store. """ with self.kv_store: for k in self.kv_store.keys(): yield k
def _generate_unique_filename(self, ext): """ Generates a unique filename with the specified extension in the cache directory. """ if ext is None or not ext.strip(): ext = '' else: ext = '.' + ext.lstrip('.') while True: file_path = os.path.join(self.cache_dir, str(uuid.uuid4())) + ext if not os.path.exists(file_path): break return file_path def __enter__(self): """ Allows the :py:class:`.FileCache` object to be used in a with statement, ensuring that the key-value store is properly opened. """ self.kv_store.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): """ Ensures that the key-value store is properly closed after being used in a with statement. """ self.kv_store.__exit__(exc_type, exc_val, exc_tb)
[docs] def delete_older_than(self, timedelta: datetime.timedelta) -> typing.Iterator[CachedFile]: """ Deletes items from the key-value store that are older than the specified timedelta, yielding each key and its corresponding :py:class:`.CachedFile` object. """ for key, value in self.kv_store.delete_older_than(timedelta): yield key, CachedFile(json.loads(value))
[docs] def add(self, key: str, file_data: bytes | typing.Iterable[bytes], metadata: typing.Dict[str, str] = None, ext: str | None = None) \ -> CachedFile: """ Adds a file to the cache. If a file with the same key already exists, it overwrites the existing file. Otherwise, it creates a new file with a unique filename. :param key: The key associated with the file. :param file_data: The data of the file in bytes, or an iterable of binary chunks. :param metadata: The metadata of the file. :param ext: The extension of the file. :return: A :py:class:`.CachedFile` object representing the added file. """ with self.kv_store as kv: if key in kv: file_path = json.loads(kv.get(key))['path'] else: file_path = self._generate_unique_filename(ext) if isinstance(file_data, bytes): with open(file_path, 'wb') as f: f.write(file_data) f.flush() else: with open(file_path, 'wb') as f: iterable = iter(file_data) for chunk in iterable: f.write(chunk) f.flush() with self.kv_store as kv: entry_data = {'path': file_path, 'metadata': metadata} kv[key] = json.dumps(entry_data) return CachedFile(entry_data)
[docs] def get(self, key) -> CachedFile | None: """ Retrieves the :py:class:`.CachedFile` object for the specified key from the key-value store, or returns None if the key does not exist. :param key: The key associated with the file. :return: A :py:class:`.CachedFile` object representing the file, or ``None`` if the key does not exist. """ with self.kv_store as kv: if key in kv: return CachedFile(json.loads(kv[key])) else: return None
[docs] class WebFileCache(FileCache): """ A cache system that stores files and their metadata downloaded from the web. """
[docs] def __init__(self, db_path: str, cache_dir: str, expiry_delta: datetime.timedelta = datetime.timedelta(hours=12)): """ Initializes the :py:class:`.WebFileCache` object with a key-value store located at ``db_path``, a cache directory at ``cache_dir``, and an expiry delta. If the cache directory doesn't exist, it creates it. It also attempts to clear old files. :param db_path: The path to the key-value store database. :param cache_dir: The directory where the cache files are stored. :param expiry_delta: The time delta for file expiry. """ super().__init__(db_path, cache_dir) self.expiry_delta = expiry_delta self._local_files_only = False try: self._clear_old_files() except sqlite3.Error: self._remove_cache_files_except_locks()
@property def local_files_only(self) -> bool: """ Get the local_files_only mode status. :return: ``True`` if local_files_only mode is enabled, ``False`` otherwise. """ return self._local_files_only @local_files_only.setter def local_files_only(self, value: bool): """ Set the local_files_only mode status. :param value: ``True`` to enable local_files_only mode, ``False`` to disable it. """ self._local_files_only = value def _remove_cache_files_except_locks(self): """ Removes all cache files except for lock files. """ with self.kv_store.file_lock: os.unlink(self.kv_store.db_path) stack = [self.cache_dir] while stack: base = stack.pop() for entry in os.scandir(base): if entry.is_file(): if not entry.name.endswith('.lock'): os.remove(entry.path) elif entry.is_dir(): stack.append(entry.path) if not os.listdir(base): os.rmdir(base) def _clear_old_files(self): """ Clears files that are older than the expiry delta. """ for key, cached_file in self.delete_older_than(self.expiry_delta): try: os.unlink(cached_file.path) except FileNotFoundError: pass
[docs] def request_mimetype(self, url, local_files_only: bool = False) -> str: """ Requests the mimetype of a file at a URL. If the file exists in the cache, a known mimetype is returned without connecting to the internet. Otherwise, it connects to the internet to retrieve the mimetype. This action does not update the cache. :raise HTTPError: On http status errors. :raise WebFileCacheOfflineModeException: If local_files_only mode is enabled and the file is not found in the cache. :param url: The URL of the file. :param local_files_only: If ``True``, do not make a request, only check the cache. :return: The mimetype of the file. """ with self: exists = self.get(url) if exists is not None: return exists.metadata['mime-type'] # Check if we're in offline mode (either from property or parameter) if self._local_files_only or local_files_only: raise WebFileCacheOfflineModeException( f'Web cache is in offline mode, and the ' f'file for "{url}" was not found in the local cache.' ) headers = {'User-Agent': fake_useragent.UserAgent().chrome} with requests.get(url, headers=headers, stream=True, timeout=5) as req: req.raise_for_status() mime_type = req.headers['content-type'] return mime_type
[docs] @staticmethod def is_downloadable_url(string) -> bool: """ Does a string represent a URL that can be downloaded by this web cache implementation? :param string: the string :return: ``True`` or ``False`` """ return string.startswith('http://') or string.startswith('https://')
[docs] def download(self, url, mime_acceptable_desc: str | None = None, mimetype_is_supported: typing.Callable[[str], bool] | None = None, unknown_mimetype_exception=ValueError, overwrite: bool = False, tqdm_pbar=tqdm.tqdm, local_files_only: bool = False) -> CachedFile: """ Downloads a file and/or returns a file path from the cache. If the mimetype of the file is not supported, it raises an exception. :raise requests.RequestException: Can raise any exception raised by ``requests.get`` for request related errors. :raise WebFileCacheOfflineModeException: If local_files_only mode is enabled and the file is not found in the cache. :param url: The URL of the file. :param mime_acceptable_desc: A description of acceptable mimetypes for use in exceptions. :param mimetype_is_supported: A function that determines if a mimetype is supported for downloading. :param unknown_mimetype_exception: The exception type to raise when an unknown mimetype is encountered. :param overwrite: Always overwrite any previously cached file? :param tqdm_pbar: tqdm progress bar type, if set to `None` no progress bar will be used. Defaults to `tqdm.tqdm` :param local_files_only: If ``True``, do not attempt to download files, only check cache. :return: The path to the downloaded file. """ self._clear_old_files() def _mimetype_is_supported(mimetype): if mimetype_is_supported is not None: return mimetype_is_supported(mimetype) return True # Check if we're in offline mode (either from property or parameter) is_offline = self._local_files_only or local_files_only if not overwrite: with self: cached_file = self.get(url) if cached_file is not None and os.path.exists(cached_file.path): return cached_file # If we're in offline mode and the file wasn't found in cache, raise an exception if is_offline: raise WebFileCacheOfflineModeException( f'Web cache is in offline mode, and the ' f'file for "{url}" was not found in the local cache.' ) with requests.get(url, headers={'User-Agent': fake_useragent.UserAgent().chrome}, stream=True, timeout=5) as response: response.raise_for_status() mime_type = response.headers.get('content-type', 'unknown') if not _mimetype_is_supported(mime_type): raise unknown_mimetype_exception( f'Unknown mimetype "{mime_type}" from URL "{url}". ' f'Expected: {mime_acceptable_desc}') metadata = {'mime-type': mime_type} filename = pyrfc6266.requests_response_to_filename(response) _, ext = os.path.splitext(filename) total_size = int(response.headers.get('content-length', 0)) chunk_size = _memory.calculate_chunk_size(total_size) if tqdm_pbar is not None: _messages.log(f'Downloading: "{url}"', underline=True) if chunk_size != total_size: if tqdm_pbar is None: def file_data_generator(): for chunk in response.iter_content( chunk_size=chunk_size): yield chunk else: def file_data_generator(): with tqdm_pbar(total=total_size if total_size != 0 else None, unit='iB', unit_scale=True) as progress_bar: for chunk in response.iter_content( chunk_size=chunk_size): progress_bar.update(len(chunk)) yield chunk else: def file_data_generator(): yield response.content # Add the downloaded file to the cache return self.add(url, file_data_generator(), metadata, ext)