Gremlin-Python/.venv/lib/python3.10/site-packages/cs50/sql.py

656 lines
23 KiB
Python

import sys
import threading
# Thread-local data
_data = threading.local()
def _enable_logging(f):
"""Enable logging of SQL statements when Flask is in use."""
import logging
import functools
import os
@functools.wraps(f)
def decorator(*args, **kwargs):
# Infer whether Flask is installed
try:
import flask
except ModuleNotFoundError:
return f(*args, **kwargs)
# Enable logging in development mode
disabled = logging.getLogger("cs50").disabled
if flask.current_app and os.getenv("FLASK_ENV") == "development":
logging.getLogger("cs50").disabled = False
try:
return f(*args, **kwargs)
finally:
logging.getLogger("cs50").disabled = disabled
return decorator
class SQL(object):
"""Wrap SQLAlchemy to provide a simple SQL API."""
def __init__(self, url, **kwargs):
"""
Create instance of sqlalchemy.engine.Engine.
URL should be a string that indicates database dialect and connection arguments.
http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine
http://docs.sqlalchemy.org/en/latest/dialects/index.html
"""
# Lazily import
import logging
import os
import re
import sqlalchemy
import sqlalchemy.orm
import threading
# Temporary fix for missing sqlite3 module on the buildpack stack
try:
import sqlite3
except:
pass
# Require that file already exist for SQLite
matches = re.search(r"^sqlite:///(.+)$", url)
if matches:
if not os.path.exists(matches.group(1)):
raise RuntimeError("does not exist: {}".format(matches.group(1)))
if not os.path.isfile(matches.group(1)):
raise RuntimeError("not a file: {}".format(matches.group(1)))
# Create engine, disabling SQLAlchemy's own autocommit mode raising exception if back end's module not installed;
# without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and
# "there is no transaction in progress" for our own COMMIT
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(
autocommit=False, isolation_level="AUTOCOMMIT", no_parameters=True
)
# Avoid doubly escaping percent signs, since no_parameters=True anyway
# https://github.com/cs50/python-cs50/issues/171
self._engine.dialect.identifier_preparer._double_percents = False
# Get logger
self._logger = logging.getLogger("cs50")
# Listener for connections
def connect(dbapi_connection, connection_record):
# Enable foreign key constraints
try:
if isinstance(
dbapi_connection, sqlite3.Connection
): # If back end is sqlite
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
except:
# Temporary fix for missing sqlite3 module on the buildpack stack
pass
# Register listener
sqlalchemy.event.listen(self._engine, "connect", connect)
# Autocommit by default
self._autocommit = True
# Test database
disabled = self._logger.disabled
self._logger.disabled = True
try:
connection = self._engine.connect()
connection.execute(sqlalchemy.text("SELECT 1"))
connection.close()
except sqlalchemy.exc.OperationalError as e:
e = RuntimeError(_parse_exception(e))
e.__cause__ = None
raise e
finally:
self._logger.disabled = disabled
def __del__(self):
"""Disconnect from database."""
self._disconnect()
def _disconnect(self):
"""Close database connection."""
if hasattr(_data, self._name()):
getattr(_data, self._name()).close()
delattr(_data, self._name())
def _name(self):
"""Return object's hash as a str."""
return str(hash(self))
@_enable_logging
def execute(self, sql, *args, **kwargs):
"""Execute a SQL statement."""
# Lazily import
import decimal
import re
import sqlalchemy
import sqlparse
import termcolor
import warnings
# Parse statement, stripping comments and then leading/trailing whitespace
statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip())
# Allow only one statement at a time, since SQLite doesn't support multiple
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
if len(statements) > 1:
raise RuntimeError("too many statements at once")
elif len(statements) == 0:
raise RuntimeError("missing statement")
# Ensure named and positional parameters are mutually exclusive
if len(args) > 0 and len(kwargs) > 0:
raise RuntimeError("cannot pass both positional and named parameters")
# Infer command from flattened statement to a single string separated by spaces
full_statement = " ".join(
str(token)
for token in statements[0].tokens
if token.ttype
in [
sqlparse.tokens.Keyword,
sqlparse.tokens.Keyword.DDL,
sqlparse.tokens.Keyword.DML,
]
)
full_statement = full_statement.upper()
# Set of possible commands
commands = {
"BEGIN",
"CREATE VIEW",
"DELETE",
"INSERT",
"SELECT",
"START",
"UPDATE",
}
# Check if the full_statement starts with any command
command = next(
(cmd for cmd in commands if full_statement.startswith(cmd)), None
)
# Flatten statement
tokens = list(statements[0].flatten())
# Validate paramstyle
placeholders = {}
paramstyle = None
for index, token in enumerate(tokens):
# If token is a placeholder
if token.ttype == sqlparse.tokens.Name.Placeholder:
# Determine paramstyle, name
_paramstyle, name = _parse_placeholder(token)
# Remember paramstyle
if not paramstyle:
paramstyle = _paramstyle
# Ensure paramstyle is consistent
elif _paramstyle != paramstyle:
raise RuntimeError("inconsistent paramstyle")
# Remember placeholder's index, name
placeholders[index] = name
# If no placeholders
if not paramstyle:
# Error-check like qmark if args
if args:
paramstyle = "qmark"
# Error-check like named if kwargs
elif kwargs:
paramstyle = "named"
# In case of errors
_placeholders = ", ".join([str(tokens[index]) for index in placeholders])
_args = ", ".join([str(self._escape(arg)) for arg in args])
# qmark
if paramstyle == "qmark":
# Validate number of placeholders
if len(placeholders) != len(args):
if len(placeholders) < len(args):
raise RuntimeError(
"fewer placeholders ({}) than values ({})".format(
_placeholders, _args
)
)
else:
raise RuntimeError(
"more placeholders ({}) than values ({})".format(
_placeholders, _args
)
)
# Escape values
for i, index in enumerate(placeholders.keys()):
tokens[index] = self._escape(args[i])
# numeric
elif paramstyle == "numeric":
# Escape values
for index, i in placeholders.items():
if i >= len(args):
raise RuntimeError(
"missing value for placeholder (:{})".format(i + 1, len(args))
)
tokens[index] = self._escape(args[i])
# Check if any values unused
indices = set(range(len(args))) - set(placeholders.values())
if indices:
raise RuntimeError(
"unused {} ({})".format(
"value" if len(indices) == 1 else "values",
", ".join(
[str(self._escape(args[index])) for index in indices]
),
)
)
# named
elif paramstyle == "named":
# Escape values
for index, name in placeholders.items():
if name not in kwargs:
raise RuntimeError(
"missing value for placeholder (:{})".format(name)
)
tokens[index] = self._escape(kwargs[name])
# Check if any keys unused
keys = kwargs.keys() - placeholders.values()
if keys:
raise RuntimeError("unused values ({})".format(", ".join(keys)))
# format
elif paramstyle == "format":
# Validate number of placeholders
if len(placeholders) != len(args):
if len(placeholders) < len(args):
raise RuntimeError(
"fewer placeholders ({}) than values ({})".format(
_placeholders, _args
)
)
else:
raise RuntimeError(
"more placeholders ({}) than values ({})".format(
_placeholders, _args
)
)
# Escape values
for i, index in enumerate(placeholders.keys()):
tokens[index] = self._escape(args[i])
# pyformat
elif paramstyle == "pyformat":
# Escape values
for index, name in placeholders.items():
if name not in kwargs:
raise RuntimeError(
"missing value for placeholder (%{}s)".format(name)
)
tokens[index] = self._escape(kwargs[name])
# Check if any keys unused
keys = kwargs.keys() - placeholders.values()
if keys:
raise RuntimeError(
"unused {} ({})".format(
"value" if len(keys) == 1 else "values", ", ".join(keys)
)
)
# For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape
# https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text
for index, token in enumerate(tokens):
# In string literal
# https://www.sqlite.org/lang_keywords.html
if token.ttype in [
sqlparse.tokens.Literal.String,
sqlparse.tokens.Literal.String.Single,
]:
token.value = re.sub("(^'|\s+):", r"\1\:", token.value)
# In identifier
# https://www.sqlite.org/lang_keywords.html
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
token.value = re.sub('(^"|\s+):', r"\1\:", token.value)
# Join tokens into statement
statement = "".join([str(token) for token in tokens])
# If no connection yet
if not hasattr(_data, self._name()):
# Connect to database
setattr(_data, self._name(), self._engine.connect())
# Use this connection
connection = getattr(_data, self._name())
# Disconnect if/when a Flask app is torn down
try:
import flask
assert flask.current_app
def teardown_appcontext(exception):
self._disconnect()
if teardown_appcontext not in flask.current_app.teardown_appcontext_funcs:
flask.current_app.teardown_appcontext(teardown_appcontext)
except (ModuleNotFoundError, AssertionError):
pass
# Catch SQLAlchemy warnings
with warnings.catch_warnings():
# Raise exceptions for warnings
warnings.simplefilter("error")
# Prepare, execute statement
try:
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
_statement = "".join(
[
str(bytes)
if token.ttype == sqlparse.tokens.Other
else str(token)
for token in tokens
]
)
# Check for start of transaction
if command in ["BEGIN", "START"]:
self._autocommit = False
# Execute statement
if self._autocommit:
connection.execute(sqlalchemy.text("BEGIN"))
result = connection.execute(sqlalchemy.text(statement))
if self._autocommit:
connection.execute(sqlalchemy.text("COMMIT"))
# Check for end of transaction
if command in ["COMMIT", "ROLLBACK"]:
self._autocommit = True
# Return value
ret = True
# If SELECT, return result set as list of dict objects
if command == "SELECT":
# Coerce types
rows = [dict(row) for row in result.mappings().all()]
for row in rows:
for column in row:
# Coerce decimal.Decimal objects to float objects
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
if isinstance(row[column], decimal.Decimal):
row[column] = float(row[column])
# Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
elif isinstance(row[column], memoryview):
row[column] = bytes(row[column])
# Rows to be returned
ret = rows
# If INSERT, return primary key value for a newly inserted row (or None if none)
elif command == "INSERT":
# If PostgreSQL
if self._engine.url.get_backend_name() == "postgresql":
# Return LASTVAL() or NULL, avoiding
# "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session",
# a la https://stackoverflow.com/a/24186770/5156190;
# cf. https://www.psycopg.org/docs/errors.html re 55000
result = connection.execute(
sqlalchemy.text(
"""
CREATE OR REPLACE FUNCTION _LASTVAL()
RETURNS integer LANGUAGE plpgsql
AS $$
BEGIN
BEGIN
RETURN (SELECT LASTVAL());
EXCEPTION
WHEN SQLSTATE '55000' THEN RETURN NULL;
END;
END $$;
SELECT _LASTVAL();
"""
)
)
ret = result.first()[0]
# If not PostgreSQL
else:
ret = result.lastrowid if result.rowcount == 1 else None
# If DELETE or UPDATE, return number of rows matched
elif command in ["DELETE", "UPDATE"]:
ret = result.rowcount
# If CREATE VIEW, return True
elif command == "CREATE VIEW":
ret = True
# If constraint violated
except sqlalchemy.exc.IntegrityError as e:
if self._autocommit:
connection.execute(sqlalchemy.text("ROLLBACK"))
self._logger.error(termcolor.colored(_statement, "red"))
e = ValueError(e.orig)
e.__cause__ = None
raise e
# If user error
except (
sqlalchemy.exc.OperationalError,
sqlalchemy.exc.ProgrammingError,
) as e:
self._disconnect()
self._logger.error(termcolor.colored(_statement, "red"))
e = RuntimeError(e.orig)
e.__cause__ = None
raise e
# Return value
else:
self._logger.info(termcolor.colored(_statement, "green"))
if self._autocommit: # Don't stay connected unnecessarily
self._disconnect()
return ret
def _escape(self, value):
"""
Escapes value using engine's conversion function.
https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
"""
# Lazily import
import sqlparse
def __escape(value):
# Lazily import
import datetime
import sqlalchemy
# bool
if isinstance(value, bool):
return sqlparse.sql.Token(
sqlparse.tokens.Number,
sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(
value
),
)
# bytes
elif isinstance(value, bytes):
if self._engine.url.get_backend_name() in ["mysql", "sqlite"]:
return sqlparse.sql.Token(
sqlparse.tokens.Other, f"x'{value.hex()}'"
) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
elif self._engine.url.get_backend_name() == "postgresql":
return sqlparse.sql.Token(
sqlparse.tokens.Other, f"'\\x{value.hex()}'"
) # https://dba.stackexchange.com/a/203359
else:
raise RuntimeError("unsupported value: {}".format(value))
# datetime.datetime
elif isinstance(value, datetime.datetime):
return sqlparse.sql.Token(
sqlparse.tokens.String,
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
value.strftime("%Y-%m-%d %H:%M:%S")
),
)
# datetime.date
elif isinstance(value, datetime.date):
return sqlparse.sql.Token(
sqlparse.tokens.String,
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
value.strftime("%Y-%m-%d")
),
)
# datetime.time
elif isinstance(value, datetime.time):
return sqlparse.sql.Token(
sqlparse.tokens.String,
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
value.strftime("%H:%M:%S")
),
)
# float
elif isinstance(value, float):
return sqlparse.sql.Token(
sqlparse.tokens.Number,
sqlalchemy.types.Float().literal_processor(self._engine.dialect)(
value
),
)
# int
elif isinstance(value, int):
return sqlparse.sql.Token(
sqlparse.tokens.Number,
sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(
value
),
)
# str
elif isinstance(value, str):
return sqlparse.sql.Token(
sqlparse.tokens.String,
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
value
),
)
# None
elif value is None:
return sqlparse.sql.Token(sqlparse.tokens.Keyword, sqlalchemy.null())
# Unsupported value
else:
raise RuntimeError("unsupported value: {}".format(value))
# Escape value(s), separating with commas as needed
if isinstance(value, (list, tuple)):
return sqlparse.sql.TokenList(
sqlparse.parse(", ".join([str(__escape(v)) for v in value]))
)
else:
return __escape(value)
def _parse_exception(e):
"""Parses an exception, returns its message."""
# Lazily import
import re
# MySQL
matches = re.search(
r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)
)
if matches:
return matches.group(1)
# PostgreSQL
matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e))
if matches:
return matches.group(1)
# SQLite
matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e))
if matches:
return matches.group(1)
# Default
return str(e)
def _parse_placeholder(token):
"""Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""
# Lazily load
import re
import sqlparse
# Validate token
if (
not isinstance(token, sqlparse.sql.Token)
or token.ttype != sqlparse.tokens.Name.Placeholder
):
raise TypeError()
# qmark
if token.value == "?":
return "qmark", None
# numeric
matches = re.search(r"^:([1-9]\d*)$", token.value)
if matches:
return "numeric", int(matches.group(1)) - 1
# named
matches = re.search(r"^:([a-zA-Z]\w*)$", token.value)
if matches:
return "named", matches.group(1)
# format
if token.value == "%s":
return "format", None
# pyformat
matches = re.search(r"%\((\w+)\)s$", token.value)
if matches:
return "pyformat", matches.group(1)
# Invalid
raise RuntimeError("{}: invalid placeholder".format(token.value))