master
Dr. Sascha Woitschetzki 2024-03-12 09:24:03 +07:00
parent 406604d1fc
commit f18e636354
40 changed files with 2679 additions and 0 deletions

@ -0,0 +1,28 @@
Copyright 2018 Pallets
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -0,0 +1,64 @@
Metadata-Version: 2.1
Name: cachelib
Version: 0.12.0
Summary: A collection of cache libraries in the same API interface.
Home-page: https://github.com/pallets-eco/cachelib/
Maintainer: Pallets
Maintainer-email: contact@palletsprojects.com
License: BSD-3-Clause
Project-URL: Donate, https://palletsprojects.com/donate
Project-URL: Documentation, https://cachelib.readthedocs.io/
Project-URL: Changes, https://cachelib.readthedocs.io/changes/
Project-URL: Source Code, https://github.com/pallets-eco/cachelib/
Project-URL: Issue Tracker, https://github.com/pallets-eco/cachelib/issues/
Project-URL: Twitter, https://twitter.com/PalletsTeam
Project-URL: Chat, https://discord.gg/pallets
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Requires-Python: >=3.8
Description-Content-Type: text/x-rst
License-File: LICENSE.rst
CacheLib
========
A collection of cache libraries in the same API interface. Extracted
from Werkzeug.
Installing
----------
Install and update using `pip`_:
.. code-block:: text
$ pip install -U cachelib
.. _pip: https://pip.pypa.io/en/stable/getting-started/
Donate
------
The Pallets organization develops and supports Flask and the libraries
it uses. In order to grow the community of contributors and users, and
allow the maintainers to devote more time to the projects, `please
donate today`_.
.. _please donate today: https://palletsprojects.com/donate
Links
-----
- Documentation: https://cachelib.readthedocs.io/
- Changes: https://cachelib.readthedocs.io/changes/
- PyPI Releases: https://pypi.org/project/cachelib/
- Source Code: https://github.com/pallets/cachelib/
- Issue Tracker: https://github.com/pallets/cachelib/issues/
- Twitter: https://twitter.com/PalletsTeam
- Chat: https://discord.gg/pallets

@ -0,0 +1,27 @@
cachelib-0.12.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
cachelib-0.12.0.dist-info/LICENSE.rst,sha256=zUGBIIEtwmJiga4CfoG2SCKdFmtaynRyzs1RADjTbn0,1475
cachelib-0.12.0.dist-info/METADATA,sha256=5rWdhpMckpSSZve1XYviRLBz_oHi5lAGknNHmWJ5V8g,1960
cachelib-0.12.0.dist-info/RECORD,,
cachelib-0.12.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
cachelib-0.12.0.dist-info/top_level.txt,sha256=AYC4q8wgGd_hR_F2YcDkmtQm41gv9-5AThKuQtNPEXk,9
cachelib/__init__.py,sha256=LqvUrhckxpFJyQcJ1eWsDhYZjzqjpovzipSTfw1dvjE,575
cachelib/__pycache__/__init__.cpython-310.pyc,,
cachelib/__pycache__/base.cpython-310.pyc,,
cachelib/__pycache__/dynamodb.cpython-310.pyc,,
cachelib/__pycache__/file.cpython-310.pyc,,
cachelib/__pycache__/memcached.cpython-310.pyc,,
cachelib/__pycache__/mongodb.cpython-310.pyc,,
cachelib/__pycache__/redis.cpython-310.pyc,,
cachelib/__pycache__/serializers.cpython-310.pyc,,
cachelib/__pycache__/simple.cpython-310.pyc,,
cachelib/__pycache__/uwsgi.cpython-310.pyc,,
cachelib/base.py,sha256=3_B-cB1VEh_x-VzH9g3qvzdqCxDX2ywDzQ7a_aYFJlE,6731
cachelib/dynamodb.py,sha256=fSmp8G7V0yBcRC2scdIhz8d0D2-9OMZEwQ9AcBONyC8,8512
cachelib/file.py,sha256=V8uPVfgn5YK7PcCvQlignH5QCTdupEpJ9chucr3XVmM,11678
cachelib/memcached.py,sha256=KyUN4wblVPf2XNLYk15kwN9QTfkFK6jrpVGrj4NAoFA,7160
cachelib/mongodb.py,sha256=b9l8fTKMFm8hAXFn748GKertHUASSVDBgPgrtGaZ6cA,6901
cachelib/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
cachelib/redis.py,sha256=hSKV9fVD7gzk1X_B3Ac9XJYN3G3F6eYBoreqDo9pces,6295
cachelib/serializers.py,sha256=MXk1moN6ljOPUFQ0E0D129mZlrDDwuQ5DvInNkitvoI,3343
cachelib/simple.py,sha256=8UPp95_oc3bLeW_gzFUzdppIKV8hV_o_yQYoKb8gVMk,3422
cachelib/uwsgi.py,sha256=4DX3C9QGvB6mVcg1d7qpLIEkI6bccuq-8M6I_YbPicY,2563

@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.42.0)
Root-Is-Purelib: true
Tag: py3-none-any

@ -0,0 +1,22 @@
from cachelib.base import BaseCache
from cachelib.base import NullCache
from cachelib.dynamodb import DynamoDbCache
from cachelib.file import FileSystemCache
from cachelib.memcached import MemcachedCache
from cachelib.mongodb import MongoDbCache
from cachelib.redis import RedisCache
from cachelib.simple import SimpleCache
from cachelib.uwsgi import UWSGICache
__all__ = [
"BaseCache",
"NullCache",
"SimpleCache",
"FileSystemCache",
"MemcachedCache",
"RedisCache",
"UWSGICache",
"DynamoDbCache",
"MongoDbCache",
]
__version__ = "0.12.0"

@ -0,0 +1,185 @@
import typing as _t
class BaseCache:
"""Baseclass for the cache systems. All the cache systems implement this
API or a superset of it.
:param default_timeout: the default timeout (in seconds) that is used if
no timeout is specified on :meth:`set`. A timeout
of 0 indicates that the cache never expires.
"""
def __init__(self, default_timeout: int = 300):
self.default_timeout = default_timeout
def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
if timeout is None:
timeout = self.default_timeout
return timeout
def get(self, key: str) -> _t.Any:
"""Look up key in the cache and return the value for it.
:param key: the key to be looked up.
:returns: The value if it exists and is readable, else ``None``.
"""
return None
def delete(self, key: str) -> bool:
"""Delete `key` from the cache.
:param key: the key to delete.
:returns: Whether the key existed and has been deleted.
:rtype: boolean
"""
return True
def get_many(self, *keys: str) -> _t.List[_t.Any]:
"""Returns a list of values for the given keys.
For each key an item in the list is created::
foo, bar = cache.get_many("foo", "bar")
Has the same error handling as :meth:`get`.
:param keys: The function accepts multiple keys as positional
arguments.
"""
return [self.get(k) for k in keys]
def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
"""Like :meth:`get_many` but return a dict::
d = cache.get_dict("foo", "bar")
foo = d["foo"]
bar = d["bar"]
:param keys: The function accepts multiple keys as positional
arguments.
"""
return dict(zip(keys, self.get_many(*keys))) # noqa: B905
def set(
self, key: str, value: _t.Any, timeout: _t.Optional[int] = None
) -> _t.Optional[bool]:
"""Add a new key/value to the cache (overwrites value, if key already
exists in the cache).
:param key: the key to set
:param value: the value for the key
:param timeout: the cache timeout for the key in seconds (if not
specified, it uses the default timeout). A timeout of
0 indicates that the cache never expires.
:returns: ``True`` if key has been updated, ``False`` for backend
errors. Pickling errors, however, will raise a subclass of
``pickle.PickleError``.
:rtype: boolean
"""
return True
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool:
"""Works like :meth:`set` but does not overwrite the values of already
existing keys.
:param key: the key to set
:param value: the value for the key
:param timeout: the cache timeout for the key in seconds (if not
specified, it uses the default timeout). A timeout of
0 indicates that the cache never expires.
:returns: Same as :meth:`set`, but also ``False`` for already
existing keys.
:rtype: boolean
"""
return True
def set_many(
self, mapping: _t.Dict[str, _t.Any], timeout: _t.Optional[int] = None
) -> _t.List[_t.Any]:
"""Sets multiple keys and values from a mapping.
:param mapping: a mapping with the keys/values to set.
:param timeout: the cache timeout for the key in seconds (if not
specified, it uses the default timeout). A timeout of
0 indicates that the cache never expires.
:returns: A list containing all keys successfully set
:rtype: boolean
"""
set_keys = []
for key, value in mapping.items():
if self.set(key, value, timeout):
set_keys.append(key)
return set_keys
def delete_many(self, *keys: str) -> _t.List[_t.Any]:
"""Deletes multiple keys at once.
:param keys: The function accepts multiple keys as positional
arguments.
:returns: A list containing all successfully deleted keys
:rtype: boolean
"""
deleted_keys = []
for key in keys:
if self.delete(key):
deleted_keys.append(key)
return deleted_keys
def has(self, key: str) -> bool:
"""Checks if a key exists in the cache without returning it. This is a
cheap operation that bypasses loading the actual data on the backend.
:param key: the key to check
"""
raise NotImplementedError(
"%s doesn't have an efficient implementation of `has`. That "
"means it is impossible to check whether a key exists without "
"fully loading the key's data. Consider using `self.get` "
"explicitly if you don't care about performance."
)
def clear(self) -> bool:
"""Clears the cache. Keep in mind that not all caches support
completely clearing the cache.
:returns: Whether the cache has been cleared.
:rtype: boolean
"""
return True
def inc(self, key: str, delta: int = 1) -> _t.Optional[int]:
"""Increments the value of a key by `delta`. If the key does
not yet exist it is initialized with `delta`.
For supporting caches this is an atomic operation.
:param key: the key to increment.
:param delta: the delta to add.
:returns: The new value or ``None`` for backend errors.
"""
value = (self.get(key) or 0) + delta
return value if self.set(key, value) else None
def dec(self, key: str, delta: int = 1) -> _t.Optional[int]:
"""Decrements the value of a key by `delta`. If the key does
not yet exist it is initialized with `-delta`.
For supporting caches this is an atomic operation.
:param key: the key to increment.
:param delta: the delta to subtract.
:returns: The new value or `None` for backend errors.
"""
value = (self.get(key) or 0) - delta
return value if self.set(key, value) else None
class NullCache(BaseCache):
"""A cache that doesn't cache. This can be useful for unit testing.
:param default_timeout: a dummy parameter that is ignored but exists
for API compatibility with other caches.
"""
def has(self, key: str) -> bool:
return False

@ -0,0 +1,226 @@
import datetime
import typing as _t
from cachelib.base import BaseCache
from cachelib.serializers import DynamoDbSerializer
CREATED_AT_FIELD = "created_at"
RESPONSE_FIELD = "response"
class DynamoDbCache(BaseCache):
"""
Implementation of cachelib.BaseCache that uses an AWS DynamoDb table
as the backend.
Your server process will require dynamodb:GetItem and dynamodb:PutItem
IAM permissions on the cache table.
Limitations: DynamoDB table items are limited to 400 KB in size. Since
this class stores cached items in a table, the max size of a cache entry
will be slightly less than 400 KB, since the cache key and expiration
time fields are also part of the item.
:param table_name: The name of the DynamoDB table to use
:param default_timeout: Set the timeout in seconds after which cache entries
expire
:param key_field: The name of the hash_key attribute in the DynamoDb
table. This must be a string attribute.
:param expiration_time_field: The name of the table attribute to store the
expiration time in. This will be an int
attribute. The timestamp will be stored as
seconds past the epoch. If you configure
this as the TTL field, then DynamoDB will
automatically delete expired entries.
:param key_prefix: A prefix that should be added to all keys.
"""
serializer = DynamoDbSerializer()
def __init__(
self,
table_name: _t.Optional[str] = "python-cache",
default_timeout: int = 300,
key_field: _t.Optional[str] = "cache_key",
expiration_time_field: _t.Optional[str] = "expiration_time",
key_prefix: _t.Optional[str] = None,
**kwargs: _t.Any
):
super().__init__(default_timeout)
try:
import boto3 # type: ignore
except ImportError as err:
raise RuntimeError("no boto3 module found") from err
self._table_name = table_name
self._key_field = key_field
self._expiration_time_field = expiration_time_field
self.key_prefix = key_prefix or ""
self._dynamo = boto3.resource("dynamodb", **kwargs)
self._attr = boto3.dynamodb.conditions.Attr
try:
self._table = self._dynamo.Table(table_name)
self._table.load()
# catch this exception (triggered if the table doesn't exist)
except Exception:
table = self._dynamo.create_table(
AttributeDefinitions=[
{"AttributeName": key_field, "AttributeType": "S"}
],
TableName=table_name,
KeySchema=[
{"AttributeName": key_field, "KeyType": "HASH"},
],
BillingMode="PAY_PER_REQUEST",
)
table.wait_until_exists()
dynamo = boto3.client("dynamodb", **kwargs)
dynamo.update_time_to_live(
TableName=table_name,
TimeToLiveSpecification={
"Enabled": True,
"AttributeName": expiration_time_field,
},
)
self._table = self._dynamo.Table(table_name)
self._table.load()
def _utcnow(self) -> _t.Any:
"""Return a tz-aware UTC datetime representing the current time"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
def _get_item(self, key: str, attributes: _t.Optional[list] = None) -> _t.Any:
"""
Get an item from the cache table, optionally limiting the returned
attributes.
:param key: The cache key of the item to fetch
:param attributes: An optional list of attributes to fetch. If not
given, all attributes are fetched. The
expiration_time field will always be added to the
list of fetched attributes.
:return: The table item for key if it exists and is not expired, else
None
"""
kwargs = {}
if attributes:
if self._expiration_time_field not in attributes:
attributes = list(attributes) + [self._expiration_time_field]
kwargs = dict(ProjectionExpression=",".join(attributes))
response = self._table.get_item(Key={self._key_field: key}, **kwargs)
cache_item = response.get("Item")
if cache_item:
now = int(self._utcnow().timestamp())
if cache_item.get(self._expiration_time_field, now + 100) > now:
return cache_item
return None
def get(self, key: str) -> _t.Any:
"""
Get a cache item
:param key: The cache key of the item to fetch
:return: cache value if not expired, else None
"""
cache_item = self._get_item(self.key_prefix + key)
if cache_item:
response = cache_item[RESPONSE_FIELD]
value = self.serializer.loads(response)
return value
return None
def delete(self, key: str) -> bool:
"""
Deletes an item from the cache. This is a no-op if the item doesn't
exist
:param key: Key of the item to delete.
:return: True if the key existed and was deleted
"""
try:
self._table.delete_item(
Key={self._key_field: self.key_prefix + key},
ConditionExpression=self._attr(self._key_field).exists(),
)
return True
except self._dynamo.meta.client.exceptions.ConditionalCheckFailedException:
return False
def _set(
self,
key: str,
value: _t.Any,
timeout: _t.Optional[int] = None,
overwrite: _t.Optional[bool] = True,
) -> _t.Any:
"""
Store a cache item, with the option to not overwrite existing items
:param key: Cache key to use
:param value: a serializable object
:param timeout: The timeout in seconds for the cached item, to override
the default
:param overwrite: If true, overwrite any existing cache item with key.
If false, the new value will only be stored if no
non-expired cache item exists with key.
:return: True if the new item was stored.
"""
timeout = self._normalize_timeout(timeout)
now = self._utcnow()
kwargs = {}
if not overwrite:
# Cause the put to fail if a non-expired item with this key
# already exists
cond = self._attr(self._key_field).not_exists() | self._attr(
self._expiration_time_field
).lte(int(now.timestamp()))
kwargs = dict(ConditionExpression=cond)
try:
dump = self.serializer.dumps(value)
item = {
self._key_field: key,
CREATED_AT_FIELD: now.isoformat(),
RESPONSE_FIELD: dump,
}
if timeout > 0:
expiration_time = now + datetime.timedelta(seconds=timeout)
item[self._expiration_time_field] = int(expiration_time.timestamp())
self._table.put_item(Item=item, **kwargs)
return True
except Exception:
return False
def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
return self._set(self.key_prefix + key, value, timeout=timeout, overwrite=True)
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
return self._set(self.key_prefix + key, value, timeout=timeout, overwrite=False)
def has(self, key: str) -> bool:
return (
self._get_item(self.key_prefix + key, [self._expiration_time_field])
is not None
)
def clear(self) -> bool:
paginator = self._dynamo.meta.client.get_paginator("scan")
pages = paginator.paginate(
TableName=self._table_name, ProjectionExpression=self._key_field
)
with self._table.batch_writer() as batch:
for page in pages:
for item in page["Items"]:
batch.delete_item(Key=item)
return True

@ -0,0 +1,333 @@
import errno
import logging
import os
import platform
import stat
import struct
import tempfile
import typing as _t
from contextlib import contextmanager
from hashlib import md5
from pathlib import Path
from time import sleep
from time import time
from cachelib.base import BaseCache
from cachelib.serializers import FileSystemSerializer
class FileSystemCache(BaseCache):
"""A cache that stores the items on the file system. This cache depends
on being the only user of the `cache_dir`. Make absolutely sure that
nobody but this cache stores files there or otherwise the cache will
randomly delete files therein.
:param cache_dir: the directory where cache files are stored.
:param threshold: the maximum number of items the cache stores before
it starts deleting some. A threshold value of 0
indicates no threshold.
:param default_timeout: the default timeout that is used if no timeout is
specified on :meth:`~BaseCache.set`. A timeout of
0 indicates that the cache never expires.
:param mode: the file mode wanted for the cache files, default 0600
:param hash_method: Default hashlib.md5. The hash method used to
generate the filename for cached results.
"""
#: used for temporary files by the FileSystemCache
_fs_transaction_suffix = ".__wz_cache"
#: keep amount of files in a cache element
_fs_count_file = "__wz_cache_count"
serializer = FileSystemSerializer()
def __init__(
self,
cache_dir: str,
threshold: int = 500,
default_timeout: int = 300,
mode: _t.Optional[int] = None,
hash_method: _t.Any = md5,
):
BaseCache.__init__(self, default_timeout)
self._path = cache_dir
self._threshold = threshold
self._hash_method = hash_method
# Mode set by user takes precedence. If no mode has
# been given, we need to set the correct default based
# on user platform.
self._mode = mode
if self._mode is None:
self._mode = self._get_compatible_platform_mode()
try:
os.makedirs(self._path)
except OSError as ex:
if ex.errno != errno.EEXIST:
raise
# If there are many files and a zero threshold,
# the list_dir can slow initialisation massively
if self._threshold != 0:
self._update_count(value=len(list(self._list_dir())))
def _get_compatible_platform_mode(self) -> int:
mode = 0o600 # nix systems
if platform.system() == "Windows":
mode = stat.S_IWRITE
return mode
@property
def _file_count(self) -> int:
return self.get(self._fs_count_file) or 0
def _update_count(
self, delta: _t.Optional[int] = None, value: _t.Optional[int] = None
) -> None:
# If we have no threshold, don't count files
if self._threshold == 0:
return
if delta:
new_count = self._file_count + delta
else:
new_count = value or 0
self.set(self._fs_count_file, new_count, mgmt_element=True)
def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
timeout = BaseCache._normalize_timeout(self, timeout)
if timeout != 0:
timeout = int(time()) + timeout
return int(timeout)
def _is_mgmt(self, name: str) -> bool:
fshash = self._get_filename(self._fs_count_file).split(os.sep)[-1]
return name == fshash or name.endswith(self._fs_transaction_suffix)
def _list_dir(self) -> _t.Generator[str, None, None]:
"""return a list of (fully qualified) cache filenames"""
return (
os.path.join(self._path, fn)
for fn in os.listdir(self._path)
if not self._is_mgmt(fn)
)
def _over_threshold(self) -> bool:
return self._threshold != 0 and self._file_count > self._threshold
def _remove_expired(self, now: float) -> None:
for fname in self._list_dir():
try:
with self._safe_stream_open(fname, "rb") as f:
expires = struct.unpack("I", f.read(4))[0]
if expires != 0 and expires < now:
os.remove(fname)
self._update_count(delta=-1)
except FileNotFoundError:
pass
except (OSError, EOFError, struct.error):
logging.warning(
"Exception raised while handling cache file '%s'",
fname,
exc_info=True,
)
def _remove_older(self) -> bool:
exp_fname_tuples = []
for fname in self._list_dir():
try:
with self._safe_stream_open(fname, "rb") as f:
timestamp = struct.unpack("I", f.read(4))[0]
exp_fname_tuples.append((timestamp, fname))
except FileNotFoundError:
pass
except (OSError, EOFError, struct.error):
logging.warning(
"Exception raised while handling cache file '%s'",
fname,
exc_info=True,
)
fname_sorted = (
fname for _, fname in sorted(exp_fname_tuples, key=lambda item: item[0])
)
for fname in fname_sorted:
try:
os.remove(fname)
self._update_count(delta=-1)
except FileNotFoundError:
pass
except OSError:
logging.warning(
"Exception raised while handling cache file '%s'",
fname,
exc_info=True,
)
return False
if not self._over_threshold():
break
return True
def _prune(self) -> None:
if self._over_threshold():
now = time()
self._remove_expired(now)
# if still over threshold
if self._over_threshold():
self._remove_older()
def clear(self) -> bool:
for i, fname in enumerate(self._list_dir()):
try:
os.remove(fname)
except FileNotFoundError:
pass
except OSError:
logging.warning(
"Exception raised while handling cache file '%s'",
fname,
exc_info=True,
)
self._update_count(delta=-i)
return False
self._update_count(value=0)
return True
def _get_filename(self, key: str) -> str:
if isinstance(key, str):
bkey = key.encode("utf-8") # XXX unicode review
bkey_hash = self._hash_method(bkey).hexdigest()
else:
raise TypeError(f"Key must be a string, received type {type(key)}")
return os.path.join(self._path, bkey_hash)
def get(self, key: str) -> _t.Any:
filename = self._get_filename(key)
try:
with self._safe_stream_open(filename, "rb") as f:
pickle_time = struct.unpack("I", f.read(4))[0]
if pickle_time == 0 or pickle_time >= time():
return self.serializer.load(f)
except FileNotFoundError:
pass
except (OSError, EOFError, struct.error):
logging.warning(
"Exception raised while handling cache file '%s'",
filename,
exc_info=True,
)
return None
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool:
filename = self._get_filename(key)
if not os.path.exists(filename):
return self.set(key, value, timeout)
return False
def set(
self,
key: str,
value: _t.Any,
timeout: _t.Optional[int] = None,
mgmt_element: bool = False,
) -> bool:
# Management elements have no timeout
if mgmt_element:
timeout = 0
# Don't prune on management element update, to avoid loop
else:
self._prune()
timeout = self._normalize_timeout(timeout)
filename = self._get_filename(key)
overwrite = os.path.isfile(filename)
try:
fd, tmp = tempfile.mkstemp(
suffix=self._fs_transaction_suffix, dir=self._path
)
with os.fdopen(fd, "wb") as f:
f.write(struct.pack("I", timeout))
self.serializer.dump(value, f)
self._run_safely(os.replace, tmp, filename)
self._run_safely(os.chmod, filename, self._mode)
fsize = Path(filename).stat().st_size
except OSError:
logging.warning(
"Exception raised while handling cache file '%s'",
filename,
exc_info=True,
)
return False
else:
# Management elements should not count towards threshold
if not overwrite and not mgmt_element:
self._update_count(delta=1)
return fsize > 0 # function should fail if file is empty
def delete(self, key: str, mgmt_element: bool = False) -> bool:
try:
os.remove(self._get_filename(key))
except FileNotFoundError: # if file doesn't exist we consider it deleted
return True
except OSError:
logging.warning("Exception raised while handling cache file", exc_info=True)
return False
else:
# Management elements should not count towards threshold
if not mgmt_element:
self._update_count(delta=-1)
return True
def has(self, key: str) -> bool:
filename = self._get_filename(key)
try:
with self._safe_stream_open(filename, "rb") as f:
pickle_time = struct.unpack("I", f.read(4))[0]
if pickle_time == 0 or pickle_time >= time():
return True
else:
return False
except FileNotFoundError: # if there is no file there is no key
return False
except (OSError, EOFError, struct.error):
logging.warning(
"Exception raised while handling cache file '%s'",
filename,
exc_info=True,
)
return False
def _run_safely(self, fn: _t.Callable, *args: _t.Any, **kwargs: _t.Any) -> _t.Any:
"""On Windows os.replace, os.chmod and open can yield
permission errors if executed by two different processes."""
if platform.system() == "Windows":
output = None
wait_step = 0.001
max_sleep_time = 10.0
total_sleep_time = 0.0
while total_sleep_time < max_sleep_time:
try:
output = fn(*args, **kwargs)
except PermissionError:
sleep(wait_step)
total_sleep_time += wait_step
wait_step *= 2
else:
break
else:
output = fn(*args, **kwargs)
return output
@contextmanager
def _safe_stream_open(self, path: str, mode: str) -> _t.Generator:
fs = self._run_safely(open, path, mode)
if fs is None:
raise OSError
try:
yield fs
finally:
fs.close()

@ -0,0 +1,196 @@
import re
import typing as _t
from time import time
from cachelib.base import BaseCache
_test_memcached_key = re.compile(r"[^\x00-\x21\xff]{1,250}$").match
class MemcachedCache(BaseCache):
"""A cache that uses memcached as backend.
The first argument can either be an object that resembles the API of a
:class:`memcache.Client` or a tuple/list of server addresses. In the
event that a tuple/list is passed, Werkzeug tries to import the best
available memcache library.
This cache looks into the following packages/modules to find bindings for
memcached:
- ``pylibmc``
- ``google.appengine.api.memcached``
- ``memcached``
- ``libmc``
Implementation notes: This cache backend works around some limitations in
memcached to simplify the interface. For example unicode keys are encoded
to utf-8 on the fly. Methods such as :meth:`~BaseCache.get_dict` return
the keys in the same format as passed. Furthermore all get methods
silently ignore key errors to not cause problems when untrusted user data
is passed to the get methods which is often the case in web applications.
:param servers: a list or tuple of server addresses or alternatively
a :class:`memcache.Client` or a compatible client.
:param default_timeout: the default timeout that is used if no timeout is
specified on :meth:`~BaseCache.set`. A timeout of
0 indicates that the cache never expires.
:param key_prefix: a prefix that is added before all keys. This makes it
possible to use the same memcached server for different
applications. Keep in mind that
:meth:`~BaseCache.clear` will also clear keys with a
different prefix.
"""
def __init__(
self,
servers: _t.Any = None,
default_timeout: int = 300,
key_prefix: _t.Optional[str] = None,
):
BaseCache.__init__(self, default_timeout)
if servers is None or isinstance(servers, (list, tuple)):
if servers is None:
servers = ["127.0.0.1:11211"]
self._client = self.import_preferred_memcache_lib(servers)
if self._client is None:
raise RuntimeError("no memcache module found")
else:
# NOTE: servers is actually an already initialized memcache
# client.
self._client = servers
self.key_prefix = key_prefix
def _normalize_key(self, key: str) -> str:
if self.key_prefix:
key = self.key_prefix + key
return key
def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
timeout = BaseCache._normalize_timeout(self, timeout)
if timeout > 0:
timeout = int(time()) + timeout
return timeout
def get(self, key: str) -> _t.Any:
key = self._normalize_key(key)
# memcached doesn't support keys longer than that. Because often
# checks for so long keys can occur because it's tested from user
# submitted data etc we fail silently for getting.
if _test_memcached_key(key):
return self._client.get(key)
def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
key_mapping = {}
for key in keys:
encoded_key = self._normalize_key(key)
if _test_memcached_key(key):
key_mapping[encoded_key] = key
_keys = list(key_mapping)
d = rv = self._client.get_multi(_keys) # type: _t.Dict[str, _t.Any]
if self.key_prefix:
rv = {}
for key, value in d.items():
rv[key_mapping[key]] = value
if len(rv) < len(keys):
for key in keys:
if key not in rv:
rv[key] = None
return rv
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool:
key = self._normalize_key(key)
timeout = self._normalize_timeout(timeout)
return bool(self._client.add(key, value, timeout))
def set(
self, key: str, value: _t.Any, timeout: _t.Optional[int] = None
) -> _t.Optional[bool]:
key = self._normalize_key(key)
timeout = self._normalize_timeout(timeout)
return bool(self._client.set(key, value, timeout))
def get_many(self, *keys: str) -> _t.List[_t.Any]:
d = self.get_dict(*keys)
return [d[key] for key in keys]
def set_many(
self, mapping: _t.Dict[str, _t.Any], timeout: _t.Optional[int] = None
) -> _t.List[_t.Any]:
new_mapping = {}
for key, value in mapping.items():
key = self._normalize_key(key)
new_mapping[key] = value
timeout = self._normalize_timeout(timeout)
failed_keys = self._client.set_multi(
new_mapping, timeout
) # type: _t.List[_t.Any]
k_normkey = zip(mapping.keys(), new_mapping.keys()) # noqa: B905
return [k for k, nkey in k_normkey if nkey not in failed_keys]
def delete(self, key: str) -> bool:
key = self._normalize_key(key)
if _test_memcached_key(key):
return bool(self._client.delete(key))
return False
def delete_many(self, *keys: str) -> _t.List[_t.Any]:
new_keys = []
for key in keys:
key = self._normalize_key(key)
if _test_memcached_key(key):
new_keys.append(key)
self._client.delete_multi(new_keys)
return [k for k in new_keys if not self.has(k)]
def has(self, key: str) -> bool:
key = self._normalize_key(key)
if _test_memcached_key(key):
return bool(self._client.append(key, ""))
return False
def clear(self) -> bool:
return bool(self._client.flush_all())
def inc(self, key: str, delta: int = 1) -> _t.Optional[int]:
key = self._normalize_key(key)
value = (self._client.get(key) or 0) + delta
return value if self.set(key, value) else None
def dec(self, key: str, delta: int = 1) -> _t.Optional[int]:
key = self._normalize_key(key)
value = (self._client.get(key) or 0) - delta
return value if self.set(key, value) else None
def import_preferred_memcache_lib(self, servers: _t.Any) -> _t.Any:
"""Returns an initialized memcache client. Used by the constructor."""
try:
import pylibmc # type: ignore
except ImportError:
pass
else:
return pylibmc.Client(servers)
try:
from google.appengine.api import memcache # type: ignore
except ImportError:
pass
else:
return memcache.Client()
try:
import memcache # type: ignore
except ImportError:
pass
else:
return memcache.Client(servers)
try:
import libmc # type: ignore
except ImportError:
pass
else:
return libmc.Client(servers)

@ -0,0 +1,202 @@
import datetime
import logging
import typing as _t
from cachelib.base import BaseCache
from cachelib.serializers import BaseSerializer
class MongoDbCache(BaseCache):
"""
Implementation of cachelib.BaseCache that uses mongodb collection
as the backend.
Limitations: maximum MongoDB document size is 16mb
:param client: mongodb client or connection string
:param db: mongodb database name
:param collection: mongodb collection name
:param default_timeout: Set the timeout in seconds after which cache entries
expire
:param key_prefix: A prefix that should be added to all keys.
"""
serializer = BaseSerializer()
def __init__(
self,
client: _t.Any = None,
db: _t.Optional[str] = "cache-db",
collection: _t.Optional[str] = "cache-collection",
default_timeout: int = 300,
key_prefix: _t.Optional[str] = None,
**kwargs: _t.Any
):
super().__init__(default_timeout)
try:
import pymongo # type: ignore
except ImportError:
logging.warning("no pymongo module found")
if client is None or isinstance(client, str):
client = pymongo.MongoClient(host=client)
self.client = client[db][collection]
index_info = self.client.index_information()
all_keys = {
subkey[0] for value in index_info.values() for subkey in value["key"]
}
if "id" not in all_keys:
self.client.create_index("id", unique=True)
self.key_prefix = key_prefix or ""
self.collection = collection
def _utcnow(self) -> _t.Any:
"""Return a tz-aware UTC datetime representing the current time"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
def _expire_records(self) -> _t.Any:
res = self.client.delete_many({"expiration": {"$lte": self._utcnow()}})
return res
def get(self, key: str) -> _t.Any:
"""
Get a cache item
:param key: The cache key of the item to fetch
:return: cache value if not expired, else None
"""
self._expire_records()
record = self.client.find_one({"id": self.key_prefix + key})
value = None
if record:
value = self.serializer.loads(record["val"])
return value
def delete(self, key: str) -> bool:
"""
Deletes an item from the cache. This is a no-op if the item doesn't
exist
:param key: Key of the item to delete.
:return: True if the key existed and was deleted
"""
res = self.client.delete_one({"id": self.key_prefix + key})
deleted = bool(res.deleted_count > 0)
return deleted
def _set(
self,
key: str,
value: _t.Any,
timeout: _t.Optional[int] = None,
overwrite: _t.Optional[bool] = True,
) -> _t.Any:
"""
Store a cache item, with the option to not overwrite existing items
:param key: Cache key to use
:param value: a serializable object
:param timeout: The timeout in seconds for the cached item, to override
the default
:param overwrite: If true, overwrite any existing cache item with key.
If false, the new value will only be stored if no
non-expired cache item exists with key.
:return: True if the new item was stored.
"""
timeout = self._normalize_timeout(timeout)
now = self._utcnow()
if not overwrite:
# fail if a non-expired item with this key
# already exists
if self.has(key):
return False
dump = self.serializer.dumps(value)
record = {"id": self.key_prefix + key, "val": dump}
if timeout > 0:
record["expiration"] = now + datetime.timedelta(seconds=timeout)
self.client.update_one({"id": self.key_prefix + key}, {"$set": record}, True)
return True
def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
self._expire_records()
return self._set(key, value, timeout=timeout, overwrite=True)
def set_many(
self, mapping: _t.Dict[str, _t.Any], timeout: _t.Optional[int] = None
) -> _t.List[_t.Any]:
self._expire_records()
from pymongo import UpdateOne
operations = []
now = self._utcnow()
timeout = self._normalize_timeout(timeout)
for key, val in mapping.items():
dump = self.serializer.dumps(val)
record = {"id": self.key_prefix + key, "val": dump}
if timeout > 0:
record["expiration"] = now + datetime.timedelta(seconds=timeout)
operations.append(
UpdateOne({"id": self.key_prefix + key}, {"$set": record}, upsert=True),
)
result = self.client.bulk_write(operations)
keys = list(mapping.keys())
if result.bulk_api_result["nUpserted"] != len(keys):
query = self.client.find(
{"id": {"$in": [self.key_prefix + key for key in keys]}}
)
keys = []
for item in query:
keys.append(item["id"])
return keys
def get_many(self, *keys: str) -> _t.List[_t.Any]:
results = self.get_dict(*keys)
values = []
for key in keys:
values.append(results.get(key, None))
return values
def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
self._expire_records()
query = self.client.find(
{"id": {"$in": [self.key_prefix + key for key in keys]}}
)
results = dict.fromkeys(keys, None)
for item in query:
value = self.serializer.loads(item["val"])
results[item["id"][len(self.key_prefix) :]] = value
return results
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
self._expire_records()
return self._set(key, value, timeout=timeout, overwrite=False)
def has(self, key: str) -> bool:
self._expire_records()
record = self.get(key)
return record is not None
def delete_many(self, *keys: str) -> _t.List[_t.Any]:
self._expire_records()
res = list(keys)
filter = {"id": {"$in": [self.key_prefix + key for key in keys]}}
result = self.client.delete_many(filter)
if result.deleted_count != len(keys):
existing_keys = [
item["id"][len(self.key_prefix) :] for item in self.client.find(filter)
]
res = [item for item in keys if item not in existing_keys]
return res
def clear(self) -> bool:
self.client.drop()
return True

@ -0,0 +1,159 @@
import typing as _t
from cachelib.base import BaseCache
from cachelib.serializers import RedisSerializer
class RedisCache(BaseCache):
"""Uses the Redis key-value store as a cache backend.
The first argument can be either a string denoting address of the Redis
server or an object resembling an instance of a redis.Redis class.
Note: Python Redis API already takes care of encoding unicode strings on
the fly.
:param host: address of the Redis server or an object which API is
compatible with the official Python Redis client (redis-py).
:param port: port number on which Redis server listens for connections.
:param password: password authentication for the Redis server.
:param db: db (zero-based numeric index) on Redis Server to connect.
:param default_timeout: the default timeout that is used if no timeout is
specified on :meth:`~BaseCache.set`. A timeout of
0 indicates that the cache never expires.
:param key_prefix: A prefix that should be added to all keys.
Any additional keyword arguments will be passed to ``redis.Redis``.
"""
_read_client: _t.Any = None
_write_client: _t.Any = None
serializer = RedisSerializer()
def __init__(
self,
host: _t.Any = "localhost",
port: int = 6379,
password: _t.Optional[str] = None,
db: int = 0,
default_timeout: int = 300,
key_prefix: _t.Optional[_t.Union[str, _t.Callable[[], str]]] = None,
**kwargs: _t.Any,
):
BaseCache.__init__(self, default_timeout)
if host is None:
raise ValueError("RedisCache host parameter may not be None")
if isinstance(host, str):
try:
import redis
except ImportError as err:
raise RuntimeError("no redis module found") from err
if kwargs.get("decode_responses", None):
raise ValueError("decode_responses is not supported by RedisCache.")
self._write_client = self._read_client = redis.Redis(
host=host, port=port, password=password, db=db, **kwargs
)
else:
self._read_client = self._write_client = host
self.key_prefix = key_prefix or ""
def _get_prefix(self) -> str:
return (
self.key_prefix if isinstance(self.key_prefix, str) else self.key_prefix()
)
def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
"""Normalize timeout by setting it to default of 300 if
not defined (None) or -1 if explicitly set to zero.
:param timeout: timeout to normalize.
"""
timeout = BaseCache._normalize_timeout(self, timeout)
if timeout == 0:
timeout = -1
return timeout
def get(self, key: str) -> _t.Any:
return self.serializer.loads(
self._read_client.get(f"{self._get_prefix()}{key}")
)
def get_many(self, *keys: str) -> _t.List[_t.Any]:
if self.key_prefix:
prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys]
else:
prefixed_keys = list(keys)
return [self.serializer.loads(x) for x in self._read_client.mget(prefixed_keys)]
def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
timeout = self._normalize_timeout(timeout)
dump = self.serializer.dumps(value)
if timeout == -1:
result = self._write_client.set(
name=f"{self._get_prefix()}{key}", value=dump
)
else:
result = self._write_client.setex(
name=f"{self._get_prefix()}{key}", value=dump, time=timeout
)
return result
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
timeout = self._normalize_timeout(timeout)
dump = self.serializer.dumps(value)
created = self._write_client.setnx(
name=f"{self._get_prefix()}{key}", value=dump
)
# handle case where timeout is explicitly set to zero
if created and timeout != -1:
self._write_client.expire(name=f"{self._get_prefix()}{key}", time=timeout)
return created
def set_many(
self, mapping: _t.Dict[str, _t.Any], timeout: _t.Optional[int] = None
) -> _t.List[_t.Any]:
timeout = self._normalize_timeout(timeout)
# Use transaction=False to batch without calling redis MULTI
# which is not supported by twemproxy
pipe = self._write_client.pipeline(transaction=False)
for key, value in mapping.items():
dump = self.serializer.dumps(value)
if timeout == -1:
pipe.set(name=f"{self._get_prefix()}{key}", value=dump)
else:
pipe.setex(name=f"{self._get_prefix()}{key}", value=dump, time=timeout)
results = pipe.execute()
return [k for k, was_set in zip(mapping.keys(), results) if was_set]
def delete(self, key: str) -> bool:
return bool(self._write_client.delete(f"{self._get_prefix()}{key}"))
def delete_many(self, *keys: str) -> _t.List[_t.Any]:
if not keys:
return []
if self.key_prefix:
prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys]
else:
prefixed_keys = [k for k in keys]
self._write_client.delete(*prefixed_keys)
return [k for k in prefixed_keys if not self.has(k)]
def has(self, key: str) -> bool:
return bool(self._read_client.exists(f"{self._get_prefix()}{key}"))
def clear(self) -> bool:
status = 0
if self.key_prefix:
keys = self._read_client.keys(self._get_prefix() + "*")
if keys:
status = self._write_client.delete(*keys)
else:
status = self._write_client.flushdb()
return bool(status)
def inc(self, key: str, delta: int = 1) -> _t.Any:
return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=delta)
def dec(self, key: str, delta: int = 1) -> _t.Any:
return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=-delta)

@ -0,0 +1,112 @@
import logging
import pickle
import typing as _t
class BaseSerializer:
"""This is the base interface for all default serializers.
BaseSerializer.load and BaseSerializer.dump will
default to pickle.load and pickle.dump. This is currently
used only by FileSystemCache which dumps/loads to/from a file stream.
"""
def _warn(self, e: pickle.PickleError) -> None:
logging.warning(
f"An exception has been raised during a pickling operation: {e}"
)
def dump(
self, value: int, f: _t.IO, protocol: int = pickle.HIGHEST_PROTOCOL
) -> None:
try:
pickle.dump(value, f, protocol)
except (pickle.PickleError, pickle.PicklingError) as e:
self._warn(e)
def load(self, f: _t.BinaryIO) -> _t.Any:
try:
data = pickle.load(f)
except pickle.PickleError as e:
self._warn(e)
return None
else:
return data
"""BaseSerializer.loads and BaseSerializer.dumps
work on top of pickle.loads and pickle.dumps. Dumping/loading
strings and byte strings is the default for most cache types.
"""
def dumps(self, value: _t.Any, protocol: int = pickle.HIGHEST_PROTOCOL) -> bytes:
try:
serialized = pickle.dumps(value, protocol)
except (pickle.PickleError, pickle.PicklingError) as e:
self._warn(e)
return serialized
def loads(self, bvalue: bytes) -> _t.Any:
try:
data = pickle.loads(bvalue)
except pickle.PickleError as e:
self._warn(e)
return None
else:
return data
"""Default serializers for each cache type.
The following classes can be used to further customize
serialiation behaviour. Alternatively, any serializer can be
overriden in order to use a custom serializer with a different
strategy altogether.
"""
class UWSGISerializer(BaseSerializer):
"""Default serializer for UWSGICache."""
class SimpleSerializer(BaseSerializer):
"""Default serializer for SimpleCache."""
class FileSystemSerializer(BaseSerializer):
"""Default serializer for FileSystemCache."""
class RedisSerializer(BaseSerializer):
"""Default serializer for RedisCache."""
def dumps(self, value: _t.Any, protocol: int = pickle.HIGHEST_PROTOCOL) -> bytes:
"""Dumps an object into a string for redis, using pickle by default."""
return b"!" + pickle.dumps(value, protocol)
def loads(self, value: _t.Optional[bytes]) -> _t.Any:
"""The reversal of :meth:`dump_object`. This might be called with
None.
"""
if value is None:
return None
if value.startswith(b"!"):
try:
return pickle.loads(value[1:])
except pickle.PickleError:
return None
try:
return int(value)
except ValueError:
# before 0.8 we did not have serialization. Still support that.
return value
class DynamoDbSerializer(RedisSerializer):
"""Default serializer for DynamoDbCache."""
def loads(self, value: _t.Any) -> _t.Any:
"""The reversal of :meth:`dump_object`. This might be called with
None.
"""
value = value.value
return super().loads(value)

@ -0,0 +1,100 @@
import typing as _t
from time import time
from cachelib.base import BaseCache
from cachelib.serializers import SimpleSerializer
class SimpleCache(BaseCache):
"""Simple memory cache for single process environments. This class exists
mainly for the development server and is not 100% thread safe. It tries
to use as many atomic operations as possible and no locks for simplicity
but it could happen under heavy load that keys are added multiple times.
:param threshold: the maximum number of items the cache stores before
it starts deleting some.
:param default_timeout: the default timeout that is used if no timeout is
specified on :meth:`~BaseCache.set`. A timeout of
0 indicates that the cache never expires.
"""
serializer = SimpleSerializer()
def __init__(
self,
threshold: int = 500,
default_timeout: int = 300,
):
BaseCache.__init__(self, default_timeout)
self._cache: _t.Dict[str, _t.Any] = {}
self._threshold = threshold or 500 # threshold = 0
def _over_threshold(self) -> bool:
return len(self._cache) > self._threshold
def _remove_expired(self, now: float) -> None:
toremove = [k for k, (expires, _) in self._cache.items() if expires < now]
for k in toremove:
self._cache.pop(k, None)
def _remove_older(self) -> None:
k_ordered = (
k for k, v in sorted(self._cache.items(), key=lambda item: item[1][0])
)
for k in k_ordered:
self._cache.pop(k, None)
if not self._over_threshold():
break
def _prune(self) -> None:
if self._over_threshold():
now = time()
self._remove_expired(now)
# remove older items if still over threshold
if self._over_threshold():
self._remove_older()
def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
timeout = BaseCache._normalize_timeout(self, timeout)
if timeout > 0:
timeout = int(time()) + timeout
return timeout
def get(self, key: str) -> _t.Any:
try:
expires, value = self._cache[key]
if expires == 0 or expires > time():
return self.serializer.loads(value)
except KeyError:
return None
def set(
self, key: str, value: _t.Any, timeout: _t.Optional[int] = None
) -> _t.Optional[bool]:
expires = self._normalize_timeout(timeout)
self._prune()
self._cache[key] = (expires, self.serializer.dumps(value))
return True
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool:
expires = self._normalize_timeout(timeout)
self._prune()
item = (expires, self.serializer.dumps(value))
if key in self._cache:
return False
self._cache.setdefault(key, item)
return True
def delete(self, key: str) -> bool:
return self._cache.pop(key, None) is not None
def has(self, key: str) -> bool:
try:
expires, value = self._cache[key]
return bool(expires == 0 or expires > time())
except KeyError:
return False
def clear(self) -> bool:
self._cache.clear()
return not bool(self._cache)

@ -0,0 +1,83 @@
import platform
import typing as _t
from cachelib.base import BaseCache
from cachelib.serializers import UWSGISerializer
class UWSGICache(BaseCache):
"""Implements the cache using uWSGI's caching framework.
.. note::
This class cannot be used when running under PyPy, because the uWSGI
API implementation for PyPy is lacking the needed functionality.
:param default_timeout: The default timeout in seconds.
:param cache: The name of the caching instance to connect to, for
example: mycache@localhost:3031, defaults to an empty string, which
means uWSGI will cache in the local instance. If the cache is in the
same instance as the werkzeug app, you only have to provide the name of
the cache.
"""
serializer = UWSGISerializer()
def __init__(
self,
default_timeout: int = 300,
cache: str = "",
):
BaseCache.__init__(self, default_timeout)
if platform.python_implementation() == "PyPy":
raise RuntimeError(
"uWSGI caching does not work under PyPy, see "
"the docs for more details."
)
try:
import uwsgi # type: ignore
self._uwsgi = uwsgi
except ImportError as err:
raise RuntimeError(
"uWSGI could not be imported, are you running under uWSGI?"
) from err
self.cache = cache
def get(self, key: str) -> _t.Any:
rv = self._uwsgi.cache_get(key, self.cache)
if rv is None:
return
return self.serializer.loads(rv)
def delete(self, key: str) -> bool:
return bool(self._uwsgi.cache_del(key, self.cache))
def set(
self, key: str, value: _t.Any, timeout: _t.Optional[int] = None
) -> _t.Optional[bool]:
result = self._uwsgi.cache_update(
key,
self.serializer.dumps(value),
self._normalize_timeout(timeout),
self.cache,
) # type: bool
return result
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool:
return bool(
self._uwsgi.cache_set(
key,
self.serializer.dumps(value),
self._normalize_timeout(timeout),
self.cache,
)
)
def clear(self) -> bool:
return bool(self._uwsgi.cache_clear(self.cache))
def has(self, key: str) -> bool:
return self._uwsgi.cache_exists(key, self.cache) is not None

@ -0,0 +1,28 @@
Copyright 2014 Pallets Community Ecosystem
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -0,0 +1,61 @@
Metadata-Version: 2.1
Name: Flask-Session
Version: 0.6.0
Summary: Server-side session support for Flask
Author-email: Shipeng Feng <fsp261@gmail.com>
Maintainer-email: Pallets Community Ecosystem <contact@palletsprojects.com>
Requires-Python: >=3.7
Description-Content-Type: text/x-rst
Classifier: Development Status :: 4 - Beta
Classifier: Environment :: Web Environment
Classifier: Framework :: Flask
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Topic :: Internet :: WWW/HTTP :: Session
Classifier: Topic :: Internet :: WWW/HTTP :: WSGI
Classifier: Topic :: Internet :: WWW/HTTP :: WSGI :: Application
Classifier: Topic :: Software Development :: Libraries :: Application Frameworks
Requires-Dist: flask>=2.2
Requires-Dist: cachelib
Project-URL: Changes, https://flask-session.readthedocs.io/changes.html
Project-URL: Chat, https://discord.gg/pallets
Project-URL: Documentation, https://flask-session.readthedocs.io
Project-URL: Issue Tracker, https://github.com/pallets-eco/flask-session/issues/
Project-URL: Source Code, https://github.com/pallets-eco/flask-session/
Flask-Session
=============
Flask-Session is an extension for Flask that adds support for server-side sessions to
your application.
.. image:: https://github.com/pallets-eco/flask-session/actions/workflows/test.yaml/badge.svg?branch=development
:target: https://github.com/pallets-eco/flask-session/actions/workflows/test.yaml?query=workflow%3ACI+branch%3Adeveloment
:alt: Tests
.. image:: https://readthedocs.org/projects/flask-session/badge/?version=stable&style=flat
:target: https://flask-session.readthedocs.io
:alt: docs
.. image:: https://img.shields.io/github/license/pallets-eco/flask-session
:target: ./LICENSE
:alt: BSD-3 Clause License
.. image:: https://img.shields.io/pypi/v/flask-session.svg?
:target: https://pypi.org/project/flask-session
:alt: PyPI
.. image:: https://img.shields.io/badge/dynamic/json?query=info.requires_python&label=python&url=https%3A%2F%2Fpypi.org%2Fpypi%2Fflask-session%2Fjson
:target: https://pypi.org/project/Flask-Session/
:alt: PyPI - Python Version
.. image:: https://img.shields.io/github/v/release/pallets-eco/flask-session?include_prereleases&label=latest-prerelease
:target: https://github.com/pallets-eco/flask-session/releases
:alt: pre-release
.. image:: https://codecov.io/gh/pallets-eco/flask-session/branch/master/graph/badge.svg?token=yenl5fzxxr
:target: https://codecov.io/gh/pallets-eco/flask-session
:alt: codecov

@ -0,0 +1,10 @@
flask_session-0.6.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
flask_session-0.6.0.dist-info/LICENSE.rst,sha256=avK7glmtsxOGN0YcECHYPdjG2EMHdV_HnTtZP0uD4RE,1495
flask_session-0.6.0.dist-info/METADATA,sha256=tF1yWEoeJTuiyERanugn5n5Ad8gJopzRLN8v5pdx8zg,2665
flask_session-0.6.0.dist-info/RECORD,,
flask_session-0.6.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
flask_session-0.6.0.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
flask_session/__init__.py,sha256=TlU8TXAiVYnP1jDI5NUD5Jy4FHZdH3E1kvhOFCxxla8,4553
flask_session/__pycache__/__init__.cpython-310.pyc,,
flask_session/__pycache__/sessions.cpython-310.pyc,,
flask_session/sessions.py,sha256=EdobpyuF0pK0d-iNKF6j9LaZVSYO64Jxqbx7uTC7Gww,25596

@ -0,0 +1,4 @@
Wheel-Version: 1.0
Generator: flit 3.9.0
Root-Is-Purelib: true
Tag: py3-none-any

@ -0,0 +1,134 @@
import os
from .sessions import (
FileSystemSessionInterface,
MemcachedSessionInterface,
MongoDBSessionInterface,
NullSessionInterface,
RedisSessionInterface,
SqlAlchemySessionInterface,
)
__version__ = "0.6.0"
class Session:
"""This class is used to add Server-side Session to one or more Flask
applications.
There are two usage modes. One is initialize the instance with a very
specific Flask application::
app = Flask(__name__)
Session(app)
The second possibility is to create the object once and configure the
application later::
sess = Session()
def create_app():
app = Flask(__name__)
sess.init_app(app)
return app
By default Flask-Session will use :class:`NullSessionInterface`, you
really should configurate your app to use a different SessionInterface.
.. note::
You can not use ``Session`` instance directly, what ``Session`` does
is just change the :attr:`~flask.Flask.session_interface` attribute on
your Flask applications.
"""
def __init__(self, app=None):
self.app = app
if app is not None:
self.init_app(app)
def init_app(self, app):
"""This is used to set up session for your app object.
:param app: the Flask app object with proper configuration.
"""
app.session_interface = self._get_interface(app)
def _get_interface(self, app):
config = app.config.copy()
# Flask-session specific settings
config.setdefault("SESSION_TYPE", "null")
config.setdefault("SESSION_PERMANENT", True)
config.setdefault("SESSION_USE_SIGNER", False)
config.setdefault("SESSION_KEY_PREFIX", "session:")
config.setdefault("SESSION_ID_LENGTH", 32)
# Redis settings
config.setdefault("SESSION_REDIS", None)
# Memcached settings
config.setdefault("SESSION_MEMCACHED", None)
# Filesystem settings
config.setdefault(
"SESSION_FILE_DIR", os.path.join(os.getcwd(), "flask_session")
)
config.setdefault("SESSION_FILE_THRESHOLD", 500)
config.setdefault("SESSION_FILE_MODE", 384)
# MongoDB settings
config.setdefault("SESSION_MONGODB", None)
config.setdefault("SESSION_MONGODB_DB", "flask_session")
config.setdefault("SESSION_MONGODB_COLLECT", "sessions")
# SQLAlchemy settings
config.setdefault("SESSION_SQLALCHEMY", None)
config.setdefault("SESSION_SQLALCHEMY_TABLE", "sessions")
config.setdefault("SESSION_SQLALCHEMY_SEQUENCE", None)
config.setdefault("SESSION_SQLALCHEMY_SCHEMA", None)
config.setdefault("SESSION_SQLALCHEMY_BIND_KEY", None)
common_params = {
"key_prefix": config["SESSION_KEY_PREFIX"],
"use_signer": config["SESSION_USE_SIGNER"],
"permanent": config["SESSION_PERMANENT"],
"sid_length": config["SESSION_ID_LENGTH"],
}
if config["SESSION_TYPE"] == "redis":
session_interface = RedisSessionInterface(
config["SESSION_REDIS"], **common_params
)
elif config["SESSION_TYPE"] == "memcached":
session_interface = MemcachedSessionInterface(
config["SESSION_MEMCACHED"], **common_params
)
elif config["SESSION_TYPE"] == "filesystem":
session_interface = FileSystemSessionInterface(
config["SESSION_FILE_DIR"],
config["SESSION_FILE_THRESHOLD"],
config["SESSION_FILE_MODE"],
**common_params,
)
elif config["SESSION_TYPE"] == "mongodb":
session_interface = MongoDBSessionInterface(
config["SESSION_MONGODB"],
config["SESSION_MONGODB_DB"],
config["SESSION_MONGODB_COLLECT"],
**common_params,
)
elif config["SESSION_TYPE"] == "sqlalchemy":
session_interface = SqlAlchemySessionInterface(
app,
config["SESSION_SQLALCHEMY"],
config["SESSION_SQLALCHEMY_TABLE"],
config["SESSION_SQLALCHEMY_SEQUENCE"],
config["SESSION_SQLALCHEMY_SCHEMA"],
config["SESSION_SQLALCHEMY_BIND_KEY"],
**common_params,
)
else:
session_interface = NullSessionInterface()
return session_interface

@ -0,0 +1,697 @@
import secrets
import time
from abc import ABC
try:
import cPickle as pickle
except ImportError:
import pickle
from datetime import datetime, timezone
from flask.sessions import SessionInterface as FlaskSessionInterface
from flask.sessions import SessionMixin
from itsdangerous import BadSignature, Signer, want_bytes
from werkzeug.datastructures import CallbackDict
def total_seconds(td):
return td.days * 60 * 60 * 24 + td.seconds
class ServerSideSession(CallbackDict, SessionMixin):
"""Baseclass for server-side based sessions."""
def __bool__(self) -> bool:
return bool(dict(self)) and self.keys() != {"_permanent"}
def __init__(self, initial=None, sid=None, permanent=None):
def on_update(self):
self.modified = True
CallbackDict.__init__(self, initial, on_update)
self.sid = sid
if permanent:
self.permanent = permanent
self.modified = False
class RedisSession(ServerSideSession):
pass
class MemcachedSession(ServerSideSession):
pass
class FileSystemSession(ServerSideSession):
pass
class MongoDBSession(ServerSideSession):
pass
class SqlAlchemySession(ServerSideSession):
pass
class SessionInterface(FlaskSessionInterface):
def _generate_sid(self, session_id_length):
return secrets.token_urlsafe(session_id_length)
def __get_signer(self, app):
if not hasattr(app, "secret_key") or not app.secret_key:
raise KeyError("SECRET_KEY must be set when SESSION_USE_SIGNER=True")
return Signer(app.secret_key, salt="flask-session", key_derivation="hmac")
def _unsign(self, app, sid):
signer = self.__get_signer(app)
sid_as_bytes = signer.unsign(sid)
sid = sid_as_bytes.decode()
return sid
def _sign(self, app, sid):
signer = self.__get_signer(app)
sid_as_bytes = want_bytes(sid)
return signer.sign(sid_as_bytes).decode("utf-8")
class NullSessionInterface(SessionInterface):
"""Used to open a :class:`flask.sessions.NullSession` instance.
If you do not configure a different ``SESSION_TYPE``, this will be used to
generate nicer error messages. Will allow read-only access to the empty
session but fail on setting.
"""
def open_session(self, app, request):
return None
class ServerSideSessionInterface(SessionInterface, ABC):
"""Used to open a :class:`flask.sessions.ServerSideSessionInterface` instance."""
def __init__(self, db, key_prefix, use_signer=False, permanent=True, sid_length=32):
self.db = db
self.key_prefix = key_prefix
self.use_signer = use_signer
self.permanent = permanent
self.sid_length = sid_length
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")
def set_cookie_to_response(self, app, session, response, expires):
session_id = self._sign(app, session.sid) if self.use_signer else session.sid
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
httponly = self.get_cookie_httponly(app)
secure = self.get_cookie_secure(app)
samesite = None
if self.has_same_site_capability:
samesite = self.get_cookie_samesite(app)
response.set_cookie(
app.config["SESSION_COOKIE_NAME"],
session_id,
expires=expires,
httponly=httponly,
domain=domain,
path=path,
secure=secure,
samesite=samesite,
)
def open_session(self, app, request):
sid = request.cookies.get(app.config["SESSION_COOKIE_NAME"])
if not sid:
sid = self._generate_sid(self.sid_length)
return self.session_class(sid=sid, permanent=self.permanent)
if self.use_signer:
try:
sid = self._unsign(app, sid)
except BadSignature:
sid = self._generate_sid(self.sid_length)
return self.session_class(sid=sid, permanent=self.permanent)
return self.fetch_session(sid)
def fetch_session(self, sid):
raise NotImplementedError()
class RedisSessionInterface(ServerSideSessionInterface):
"""Uses the Redis key-value store as a session backend. (`redis-py` required)
:param redis: A ``redis.Redis`` instance.
:param key_prefix: A prefix that is added to all Redis store keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
.. versionadded:: 0.6
The `sid_length` parameter was added.
.. versionadded:: 0.2
The `use_signer` parameter was added.
"""
serializer = pickle
session_class = RedisSession
def __init__(self, redis, key_prefix, use_signer, permanent, sid_length):
if redis is None:
from redis import Redis
redis = Redis()
self.redis = redis
super().__init__(redis, key_prefix, use_signer, permanent, sid_length)
def fetch_session(self, sid):
# Get the saved session (value) from the database
prefixed_session_id = self.key_prefix + sid
value = self.redis.get(prefixed_session_id)
# If the saved session still exists and hasn't auto-expired, load the session data from the document
if value is not None:
try:
session_data = self.serializer.loads(value)
return self.session_class(session_data, sid=sid)
except pickle.UnpicklingError:
return self.session_class(sid=sid, permanent=self.permanent)
# If the saved session does not exist, create a new session
return self.session_class(sid=sid, permanent=self.permanent)
def save_session(self, app, session, response):
if not self.should_set_cookie(app, session):
return
# Get the domain and path for the cookie from the app config
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
# If the session is empty, do not save it to the database or set a cookie
if not session:
# If the session was deleted (empty and modified), delete the saved session from the database and tell the client to delete the cookie
if session.modified:
self.redis.delete(self.key_prefix + session.sid)
response.delete_cookie(
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
)
return
# Get the new expiration time for the session
expiration_datetime = self.get_expiration_time(app, session)
# Serialize the session data
serialized_session_data = self.serializer.dumps(dict(session))
# Update existing or create new session in the database
self.redis.set(
name=self.key_prefix + session.sid,
value=serialized_session_data,
ex=total_seconds(app.permanent_session_lifetime),
)
# Set the browser cookie
self.set_cookie_to_response(app, session, response, expiration_datetime)
class MemcachedSessionInterface(ServerSideSessionInterface):
"""A Session interface that uses memcached as backend. (`pylibmc` or `python-memcached` or `pymemcache` required)
:param client: A ``memcache.Client`` instance.
:param key_prefix: A prefix that is added to all Memcached store keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
.. versionadded:: 0.6
The `sid_length` parameter was added.
.. versionadded:: 0.2
The `use_signer` parameter was added.
"""
serializer = pickle
session_class = MemcachedSession
def __init__(self, client, key_prefix, use_signer, permanent, sid_length):
if client is None:
client = self._get_preferred_memcache_client()
self.client = client
super().__init__(client, key_prefix, use_signer, permanent, sid_length)
def _get_preferred_memcache_client(self):
clients = [
("pylibmc", ["127.0.0.1:11211"]),
("memcache", ["127.0.0.1:11211"]),
("pymemcache.client.base", "127.0.0.1:11211"),
]
for module_name, server in clients:
try:
module = __import__(module_name)
ClientClass = module.Client
return ClientClass(server)
except ImportError:
continue
raise ImportError("No memcache module found")
def _get_memcache_timeout(self, timeout):
"""
Memcached deals with long (> 30 days) timeouts in a special
way. Call this function to obtain a safe value for your timeout.
"""
if timeout > 2592000: # 60*60*24*30, 30 days
# Switch to absolute timestamps.
timeout += int(time.time())
return timeout
def fetch_session(self, sid):
# Get the saved session (item) from the database
prefixed_session_id = self.key_prefix + sid
item = self.client.get(prefixed_session_id)
# If the saved session still exists and hasn't auto-expired, load the session data from the document
if item is not None:
try:
session_data = self.serializer.loads(want_bytes(item))
return self.session_class(session_data, sid=sid)
except pickle.UnpicklingError:
return self.session_class(sid=sid, permanent=self.permanent)
# If the saved session does not exist, create a new session
return self.session_class(sid=sid, permanent=self.permanent)
def save_session(self, app, session, response):
if not self.should_set_cookie(app, session):
return
# Get the domain and path for the cookie from the app config
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
# Generate a prefixed session id from the session id as a storage key
prefixed_session_id = self.key_prefix + session.sid
# If the session is empty, do not save it to the database or set a cookie
if not session:
# If the session was deleted (empty and modified), delete the saved session from the database and tell the client to delete the cookie
if session.modified:
self.client.delete(prefixed_session_id)
response.delete_cookie(
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
)
return
# Get the new expiration time for the session
expiration_datetime = self.get_expiration_time(app, session)
# Serialize the session data
serialized_session_data = self.serializer.dumps(dict(session))
# Update existing or create new session in the database
self.client.set(
prefixed_session_id,
serialized_session_data,
self._get_memcache_timeout(total_seconds(app.permanent_session_lifetime)),
)
# Set the browser cookie
self.set_cookie_to_response(app, session, response, expiration_datetime)
class FileSystemSessionInterface(ServerSideSessionInterface):
"""Uses the :class:`cachelib.file.FileSystemCache` as a session backend.
:param cache_dir: the directory where session files are stored.
:param threshold: the maximum number of items the session stores before it
starts deleting some.
:param mode: the file mode wanted for the session files, default 0600
:param key_prefix: A prefix that is added to FileSystemCache store keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
.. versionadded:: 0.6
The `sid_length` parameter was added.
.. versionadded:: 0.2
The `use_signer` parameter was added.
"""
session_class = FileSystemSession
def __init__(
self,
cache_dir,
threshold,
mode,
key_prefix,
use_signer,
permanent,
sid_length,
):
from cachelib.file import FileSystemCache
self.cache = FileSystemCache(cache_dir, threshold=threshold, mode=mode)
super().__init__(self.cache, key_prefix, use_signer, permanent, sid_length)
def fetch_session(self, sid):
# Get the saved session (item) from the database
prefixed_session_id = self.key_prefix + sid
item = self.cache.get(prefixed_session_id)
# If the saved session exists and has not auto-expired, load the session data from the item
if item is not None:
return self.session_class(item, sid=sid)
# If the saved session does not exist, create a new session
return self.session_class(sid=sid, permanent=self.permanent)
def save_session(self, app, session, response):
if not self.should_set_cookie(app, session):
return
# Get the domain and path for the cookie from the app config
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
# Generate a prefixed session id from the session id as a storage key
prefixed_session_id = self.key_prefix + session.sid
# If the session is empty, do not save it to the database or set a cookie
if not session:
# If the session was deleted (empty and modified), delete the saved session from the database and tell the client to delete the cookie
if session.modified:
self.cache.delete(prefixed_session_id)
response.delete_cookie(
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
)
return
# Get the new expiration time for the session
expiration_datetime = self.get_expiration_time(app, session)
# Serialize the session data (or just cast into dictionary in this case)
session_data = dict(session)
# Update existing or create new session in the database
self.cache.set(
prefixed_session_id,
session_data,
total_seconds(app.permanent_session_lifetime),
)
# Set the browser cookie
self.set_cookie_to_response(app, session, response, expiration_datetime)
class MongoDBSessionInterface(ServerSideSessionInterface):
"""A Session interface that uses mongodb as backend. (`pymongo` required)
:param client: A ``pymongo.MongoClient`` instance.
:param db: The database you want to use.
:param collection: The collection you want to use.
:param key_prefix: A prefix that is added to all MongoDB store keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
.. versionadded:: 0.6
The `sid_length` parameter was added.
.. versionadded:: 0.2
The `use_signer` parameter was added.
"""
serializer = pickle
session_class = MongoDBSession
def __init__(
self,
client,
db,
collection,
key_prefix,
use_signer,
permanent,
sid_length,
):
import pymongo
if client is None:
client = pymongo.MongoClient()
self.client = client
self.store = client[db][collection]
self.use_deprecated_method = int(pymongo.version.split(".")[0]) < 4
super().__init__(self.store, key_prefix, use_signer, permanent, sid_length)
def fetch_session(self, sid):
# Get the saved session (document) from the database
prefixed_session_id = self.key_prefix + sid
document = self.store.find_one({"id": prefixed_session_id})
# If the expiration time is less than or equal to the current time (expired), delete the document
if document is not None:
expiration_datetime = document.get("expiration")
# tz_aware mongodb fix
expiration_datetime_tz_aware = expiration_datetime.replace(
tzinfo=timezone.utc
)
now_datetime_tz_aware = datetime.utcnow().replace(tzinfo=timezone.utc)
if expiration_datetime is None or (
expiration_datetime_tz_aware <= now_datetime_tz_aware
):
if self.use_deprecated_method:
self.store.remove({"id": prefixed_session_id})
else:
self.store.delete_one({"id": prefixed_session_id})
document = None
# If the saved session still exists after checking for expiration, load the session data from the document
if document is not None:
try:
session_data = self.serializer.loads(want_bytes(document["val"]))
return self.session_class(session_data, sid=sid)
except pickle.UnpicklingError:
return self.session_class(sid=sid, permanent=self.permanent)
# If the saved session does not exist, create a new session
return self.session_class(sid=sid, permanent=self.permanent)
def save_session(self, app, session, response):
if not self.should_set_cookie(app, session):
return
# Get the domain and path for the cookie from the app config
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
# Generate a prefixed session id from the session id as a storage key
prefixed_session_id = self.key_prefix + session.sid
# If the session is empty, do not save it to the database or set a cookie
if not session:
# If the session was deleted (empty and modified), delete the saved session from the database and tell the client to delete the cookie
if session.modified:
if self.use_deprecated_method:
self.store.remove({"id": prefixed_session_id})
else:
self.store.delete_one({"id": prefixed_session_id})
response.delete_cookie(
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
)
return
# Get the new expiration time for the session
expiration_datetime = self.get_expiration_time(app, session)
# Serialize the session data
serialized_session_data = self.serializer.dumps(dict(session))
# Update existing or create new session in the database
if self.use_deprecated_method:
self.store.update(
{"id": prefixed_session_id},
{
"id": prefixed_session_id,
"val": serialized_session_data,
"expiration": expiration_datetime,
},
True,
)
else:
self.store.update_one(
{"id": prefixed_session_id},
{
"$set": {
"id": prefixed_session_id,
"val": serialized_session_data,
"expiration": expiration_datetime,
}
},
True,
)
# Set the browser cookie
self.set_cookie_to_response(app, session, response, expiration_datetime)
class SqlAlchemySessionInterface(ServerSideSessionInterface):
"""Uses the Flask-SQLAlchemy from a flask app as a session backend.
:param app: A Flask app instance.
:param db: A Flask-SQLAlchemy instance.
:param table: The table name you want to use.
:param key_prefix: A prefix that is added to all store keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
:param sequence: The sequence to use for the primary key if needed.
:param schema: The db schema to use
:param bind_key: The db bind key to use
.. versionadded:: 0.6
The `sid_length`, `sequence`, `schema` and `bind_key` parameters were added.
.. versionadded:: 0.2
The `use_signer` parameter was added.
"""
serializer = pickle
session_class = SqlAlchemySession
def __init__(
self,
app,
db,
table,
sequence,
schema,
bind_key,
key_prefix,
use_signer,
permanent,
sid_length,
):
if db is None:
from flask_sqlalchemy import SQLAlchemy
db = SQLAlchemy(app)
self.db = db
self.sequence = sequence
self.schema = schema
self.bind_key = bind_key
super().__init__(self.db, key_prefix, use_signer, permanent, sid_length)
# Create the Session database model
class Session(self.db.Model):
__tablename__ = table
if self.schema is not None:
__table_args__ = {"schema": self.schema, "keep_existing": True}
else:
__table_args__ = {"keep_existing": True}
if self.bind_key is not None:
__bind_key__ = self.bind_key
# Set the database columns, support for id sequences
if sequence:
id = self.db.Column(
self.db.Integer, self.db.Sequence(sequence), primary_key=True
)
else:
id = self.db.Column(self.db.Integer, primary_key=True)
session_id = self.db.Column(self.db.String(255), unique=True)
data = self.db.Column(self.db.LargeBinary)
expiry = self.db.Column(self.db.DateTime)
def __init__(self, session_id, data, expiry):
self.session_id = session_id
self.data = data
self.expiry = expiry
def __repr__(self):
return "<Session data %s>" % self.data
with app.app_context():
self.db.create_all()
self.sql_session_model = Session
def fetch_session(self, sid):
# Get the saved session (record) from the database
store_id = self.key_prefix + sid
record = self.sql_session_model.query.filter_by(session_id=store_id).first()
# If the expiration time is less than or equal to the current time (expired), delete the document
if record is not None:
expiration_datetime = record.expiry
if expiration_datetime is None or expiration_datetime <= datetime.utcnow():
self.db.session.delete(record)
self.db.session.commit()
record = None
# If the saved session still exists after checking for expiration, load the session data from the document
if record:
try:
session_data = self.serializer.loads(want_bytes(record.data))
return self.session_class(session_data, sid=sid)
except pickle.UnpicklingError:
return self.session_class(sid=sid, permanent=self.permanent)
return self.session_class(sid=sid, permanent=self.permanent)
def save_session(self, app, session, response):
if not self.should_set_cookie(app, session):
return
# Get the domain and path for the cookie from the app
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
# Generate a prefixed session id
prefixed_session_id = self.key_prefix + session.sid
# If the session is empty, do not save it to the database or set a cookie
if not session:
# If the session was deleted (empty and modified), delete the saved session from the database and tell the client to delete the cookie
if session.modified:
self.sql_session_model.query.filter_by(
session_id=prefixed_session_id
).delete()
self.db.session.commit()
response.delete_cookie(
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
)
return
# Serialize session data
serialized_session_data = self.serializer.dumps(dict(session))
# Get the new expiration time for the session
expiration_datetime = self.get_expiration_time(app, session)
# Update existing or create new session in the database
record = self.sql_session_model.query.filter_by(
session_id=prefixed_session_id
).first()
if record:
record.data = serialized_session_data
record.expiry = expiration_datetime
else:
record = self.sql_session_model(
session_id=prefixed_session_id,
data=serialized_session_data,
expiry=expiration_datetime,
)
self.db.session.add(record)
self.db.session.commit()
# Set the browser cookie
self.set_cookie_to_response(app, session, response, expiration_datetime)