656 lines
23 KiB
Python
656 lines
23 KiB
Python
import sys
|
|
import threading
|
|
|
|
# Thread-local data
|
|
_data = threading.local()
|
|
|
|
|
|
def _enable_logging(f):
|
|
"""Enable logging of SQL statements when Flask is in use."""
|
|
|
|
import logging
|
|
import functools
|
|
import os
|
|
|
|
@functools.wraps(f)
|
|
def decorator(*args, **kwargs):
|
|
# Infer whether Flask is installed
|
|
try:
|
|
import flask
|
|
except ModuleNotFoundError:
|
|
return f(*args, **kwargs)
|
|
|
|
# Enable logging in development mode
|
|
disabled = logging.getLogger("cs50").disabled
|
|
if flask.current_app and os.getenv("FLASK_ENV") == "development":
|
|
logging.getLogger("cs50").disabled = False
|
|
try:
|
|
return f(*args, **kwargs)
|
|
finally:
|
|
logging.getLogger("cs50").disabled = disabled
|
|
|
|
return decorator
|
|
|
|
|
|
class SQL(object):
|
|
"""Wrap SQLAlchemy to provide a simple SQL API."""
|
|
|
|
def __init__(self, url, **kwargs):
|
|
"""
|
|
Create instance of sqlalchemy.engine.Engine.
|
|
|
|
URL should be a string that indicates database dialect and connection arguments.
|
|
|
|
http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine
|
|
http://docs.sqlalchemy.org/en/latest/dialects/index.html
|
|
"""
|
|
|
|
# Lazily import
|
|
import logging
|
|
import os
|
|
import re
|
|
import sqlalchemy
|
|
import sqlalchemy.orm
|
|
import threading
|
|
|
|
# Temporary fix for missing sqlite3 module on the buildpack stack
|
|
try:
|
|
import sqlite3
|
|
except:
|
|
pass
|
|
|
|
# Require that file already exist for SQLite
|
|
matches = re.search(r"^sqlite:///(.+)$", url)
|
|
if matches:
|
|
if not os.path.exists(matches.group(1)):
|
|
raise RuntimeError("does not exist: {}".format(matches.group(1)))
|
|
if not os.path.isfile(matches.group(1)):
|
|
raise RuntimeError("not a file: {}".format(matches.group(1)))
|
|
|
|
# Create engine, disabling SQLAlchemy's own autocommit mode raising exception if back end's module not installed;
|
|
# without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and
|
|
# "there is no transaction in progress" for our own COMMIT
|
|
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(
|
|
autocommit=False, isolation_level="AUTOCOMMIT", no_parameters=True
|
|
)
|
|
|
|
# Avoid doubly escaping percent signs, since no_parameters=True anyway
|
|
# https://github.com/cs50/python-cs50/issues/171
|
|
self._engine.dialect.identifier_preparer._double_percents = False
|
|
|
|
# Get logger
|
|
self._logger = logging.getLogger("cs50")
|
|
|
|
# Listener for connections
|
|
def connect(dbapi_connection, connection_record):
|
|
# Enable foreign key constraints
|
|
try:
|
|
if isinstance(
|
|
dbapi_connection, sqlite3.Connection
|
|
): # If back end is sqlite
|
|
cursor = dbapi_connection.cursor()
|
|
cursor.execute("PRAGMA foreign_keys=ON")
|
|
cursor.close()
|
|
except:
|
|
# Temporary fix for missing sqlite3 module on the buildpack stack
|
|
pass
|
|
|
|
# Register listener
|
|
sqlalchemy.event.listen(self._engine, "connect", connect)
|
|
|
|
# Autocommit by default
|
|
self._autocommit = True
|
|
|
|
# Test database
|
|
disabled = self._logger.disabled
|
|
self._logger.disabled = True
|
|
try:
|
|
connection = self._engine.connect()
|
|
connection.execute(sqlalchemy.text("SELECT 1"))
|
|
connection.close()
|
|
except sqlalchemy.exc.OperationalError as e:
|
|
e = RuntimeError(_parse_exception(e))
|
|
e.__cause__ = None
|
|
raise e
|
|
finally:
|
|
self._logger.disabled = disabled
|
|
|
|
def __del__(self):
|
|
"""Disconnect from database."""
|
|
self._disconnect()
|
|
|
|
def _disconnect(self):
|
|
"""Close database connection."""
|
|
if hasattr(_data, self._name()):
|
|
getattr(_data, self._name()).close()
|
|
delattr(_data, self._name())
|
|
|
|
def _name(self):
|
|
"""Return object's hash as a str."""
|
|
return str(hash(self))
|
|
|
|
@_enable_logging
|
|
def execute(self, sql, *args, **kwargs):
|
|
"""Execute a SQL statement."""
|
|
|
|
# Lazily import
|
|
import decimal
|
|
import re
|
|
import sqlalchemy
|
|
import sqlparse
|
|
import termcolor
|
|
import warnings
|
|
|
|
# Parse statement, stripping comments and then leading/trailing whitespace
|
|
statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip())
|
|
|
|
# Allow only one statement at a time, since SQLite doesn't support multiple
|
|
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
|
|
if len(statements) > 1:
|
|
raise RuntimeError("too many statements at once")
|
|
elif len(statements) == 0:
|
|
raise RuntimeError("missing statement")
|
|
|
|
# Ensure named and positional parameters are mutually exclusive
|
|
if len(args) > 0 and len(kwargs) > 0:
|
|
raise RuntimeError("cannot pass both positional and named parameters")
|
|
|
|
# Infer command from flattened statement to a single string separated by spaces
|
|
full_statement = " ".join(
|
|
str(token)
|
|
for token in statements[0].tokens
|
|
if token.ttype
|
|
in [
|
|
sqlparse.tokens.Keyword,
|
|
sqlparse.tokens.Keyword.DDL,
|
|
sqlparse.tokens.Keyword.DML,
|
|
]
|
|
)
|
|
full_statement = full_statement.upper()
|
|
|
|
# Set of possible commands
|
|
commands = {
|
|
"BEGIN",
|
|
"CREATE VIEW",
|
|
"DELETE",
|
|
"INSERT",
|
|
"SELECT",
|
|
"START",
|
|
"UPDATE",
|
|
}
|
|
|
|
# Check if the full_statement starts with any command
|
|
command = next(
|
|
(cmd for cmd in commands if full_statement.startswith(cmd)), None
|
|
)
|
|
|
|
# Flatten statement
|
|
tokens = list(statements[0].flatten())
|
|
|
|
# Validate paramstyle
|
|
placeholders = {}
|
|
paramstyle = None
|
|
for index, token in enumerate(tokens):
|
|
# If token is a placeholder
|
|
if token.ttype == sqlparse.tokens.Name.Placeholder:
|
|
# Determine paramstyle, name
|
|
_paramstyle, name = _parse_placeholder(token)
|
|
|
|
# Remember paramstyle
|
|
if not paramstyle:
|
|
paramstyle = _paramstyle
|
|
|
|
# Ensure paramstyle is consistent
|
|
elif _paramstyle != paramstyle:
|
|
raise RuntimeError("inconsistent paramstyle")
|
|
|
|
# Remember placeholder's index, name
|
|
placeholders[index] = name
|
|
|
|
# If no placeholders
|
|
if not paramstyle:
|
|
# Error-check like qmark if args
|
|
if args:
|
|
paramstyle = "qmark"
|
|
|
|
# Error-check like named if kwargs
|
|
elif kwargs:
|
|
paramstyle = "named"
|
|
|
|
# In case of errors
|
|
_placeholders = ", ".join([str(tokens[index]) for index in placeholders])
|
|
_args = ", ".join([str(self._escape(arg)) for arg in args])
|
|
|
|
# qmark
|
|
if paramstyle == "qmark":
|
|
# Validate number of placeholders
|
|
if len(placeholders) != len(args):
|
|
if len(placeholders) < len(args):
|
|
raise RuntimeError(
|
|
"fewer placeholders ({}) than values ({})".format(
|
|
_placeholders, _args
|
|
)
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
"more placeholders ({}) than values ({})".format(
|
|
_placeholders, _args
|
|
)
|
|
)
|
|
|
|
# Escape values
|
|
for i, index in enumerate(placeholders.keys()):
|
|
tokens[index] = self._escape(args[i])
|
|
|
|
# numeric
|
|
elif paramstyle == "numeric":
|
|
# Escape values
|
|
for index, i in placeholders.items():
|
|
if i >= len(args):
|
|
raise RuntimeError(
|
|
"missing value for placeholder (:{})".format(i + 1, len(args))
|
|
)
|
|
tokens[index] = self._escape(args[i])
|
|
|
|
# Check if any values unused
|
|
indices = set(range(len(args))) - set(placeholders.values())
|
|
if indices:
|
|
raise RuntimeError(
|
|
"unused {} ({})".format(
|
|
"value" if len(indices) == 1 else "values",
|
|
", ".join(
|
|
[str(self._escape(args[index])) for index in indices]
|
|
),
|
|
)
|
|
)
|
|
|
|
# named
|
|
elif paramstyle == "named":
|
|
# Escape values
|
|
for index, name in placeholders.items():
|
|
if name not in kwargs:
|
|
raise RuntimeError(
|
|
"missing value for placeholder (:{})".format(name)
|
|
)
|
|
tokens[index] = self._escape(kwargs[name])
|
|
|
|
# Check if any keys unused
|
|
keys = kwargs.keys() - placeholders.values()
|
|
if keys:
|
|
raise RuntimeError("unused values ({})".format(", ".join(keys)))
|
|
|
|
# format
|
|
elif paramstyle == "format":
|
|
# Validate number of placeholders
|
|
if len(placeholders) != len(args):
|
|
if len(placeholders) < len(args):
|
|
raise RuntimeError(
|
|
"fewer placeholders ({}) than values ({})".format(
|
|
_placeholders, _args
|
|
)
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
"more placeholders ({}) than values ({})".format(
|
|
_placeholders, _args
|
|
)
|
|
)
|
|
|
|
# Escape values
|
|
for i, index in enumerate(placeholders.keys()):
|
|
tokens[index] = self._escape(args[i])
|
|
|
|
# pyformat
|
|
elif paramstyle == "pyformat":
|
|
# Escape values
|
|
for index, name in placeholders.items():
|
|
if name not in kwargs:
|
|
raise RuntimeError(
|
|
"missing value for placeholder (%{}s)".format(name)
|
|
)
|
|
tokens[index] = self._escape(kwargs[name])
|
|
|
|
# Check if any keys unused
|
|
keys = kwargs.keys() - placeholders.values()
|
|
if keys:
|
|
raise RuntimeError(
|
|
"unused {} ({})".format(
|
|
"value" if len(keys) == 1 else "values", ", ".join(keys)
|
|
)
|
|
)
|
|
|
|
# For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape
|
|
# https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text
|
|
for index, token in enumerate(tokens):
|
|
# In string literal
|
|
# https://www.sqlite.org/lang_keywords.html
|
|
if token.ttype in [
|
|
sqlparse.tokens.Literal.String,
|
|
sqlparse.tokens.Literal.String.Single,
|
|
]:
|
|
token.value = re.sub("(^'|\s+):", r"\1\:", token.value)
|
|
|
|
# In identifier
|
|
# https://www.sqlite.org/lang_keywords.html
|
|
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
|
token.value = re.sub('(^"|\s+):', r"\1\:", token.value)
|
|
|
|
# Join tokens into statement
|
|
statement = "".join([str(token) for token in tokens])
|
|
|
|
# If no connection yet
|
|
if not hasattr(_data, self._name()):
|
|
# Connect to database
|
|
setattr(_data, self._name(), self._engine.connect())
|
|
|
|
# Use this connection
|
|
connection = getattr(_data, self._name())
|
|
|
|
# Disconnect if/when a Flask app is torn down
|
|
try:
|
|
import flask
|
|
|
|
assert flask.current_app
|
|
|
|
def teardown_appcontext(exception):
|
|
self._disconnect()
|
|
|
|
if teardown_appcontext not in flask.current_app.teardown_appcontext_funcs:
|
|
flask.current_app.teardown_appcontext(teardown_appcontext)
|
|
except (ModuleNotFoundError, AssertionError):
|
|
pass
|
|
|
|
# Catch SQLAlchemy warnings
|
|
with warnings.catch_warnings():
|
|
# Raise exceptions for warnings
|
|
warnings.simplefilter("error")
|
|
|
|
# Prepare, execute statement
|
|
try:
|
|
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
|
|
_statement = "".join(
|
|
[
|
|
str(bytes)
|
|
if token.ttype == sqlparse.tokens.Other
|
|
else str(token)
|
|
for token in tokens
|
|
]
|
|
)
|
|
|
|
# Check for start of transaction
|
|
if command in ["BEGIN", "START"]:
|
|
self._autocommit = False
|
|
|
|
# Execute statement
|
|
if self._autocommit:
|
|
connection.execute(sqlalchemy.text("BEGIN"))
|
|
result = connection.execute(sqlalchemy.text(statement))
|
|
if self._autocommit:
|
|
connection.execute(sqlalchemy.text("COMMIT"))
|
|
|
|
# Check for end of transaction
|
|
if command in ["COMMIT", "ROLLBACK"]:
|
|
self._autocommit = True
|
|
|
|
# Return value
|
|
ret = True
|
|
|
|
# If SELECT, return result set as list of dict objects
|
|
if command == "SELECT":
|
|
# Coerce types
|
|
rows = [dict(row) for row in result.mappings().all()]
|
|
for row in rows:
|
|
for column in row:
|
|
# Coerce decimal.Decimal objects to float objects
|
|
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
|
|
if isinstance(row[column], decimal.Decimal):
|
|
row[column] = float(row[column])
|
|
|
|
# Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
|
|
elif isinstance(row[column], memoryview):
|
|
row[column] = bytes(row[column])
|
|
|
|
# Rows to be returned
|
|
ret = rows
|
|
|
|
# If INSERT, return primary key value for a newly inserted row (or None if none)
|
|
elif command == "INSERT":
|
|
# If PostgreSQL
|
|
if self._engine.url.get_backend_name() == "postgresql":
|
|
# Return LASTVAL() or NULL, avoiding
|
|
# "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session",
|
|
# a la https://stackoverflow.com/a/24186770/5156190;
|
|
# cf. https://www.psycopg.org/docs/errors.html re 55000
|
|
result = connection.execute(
|
|
sqlalchemy.text(
|
|
"""
|
|
CREATE OR REPLACE FUNCTION _LASTVAL()
|
|
RETURNS integer LANGUAGE plpgsql
|
|
AS $$
|
|
BEGIN
|
|
BEGIN
|
|
RETURN (SELECT LASTVAL());
|
|
EXCEPTION
|
|
WHEN SQLSTATE '55000' THEN RETURN NULL;
|
|
END;
|
|
END $$;
|
|
SELECT _LASTVAL();
|
|
"""
|
|
)
|
|
)
|
|
ret = result.first()[0]
|
|
|
|
# If not PostgreSQL
|
|
else:
|
|
ret = result.lastrowid if result.rowcount == 1 else None
|
|
|
|
# If DELETE or UPDATE, return number of rows matched
|
|
elif command in ["DELETE", "UPDATE"]:
|
|
ret = result.rowcount
|
|
|
|
# If CREATE VIEW, return True
|
|
elif command == "CREATE VIEW":
|
|
ret = True
|
|
|
|
# If constraint violated
|
|
except sqlalchemy.exc.IntegrityError as e:
|
|
if self._autocommit:
|
|
connection.execute(sqlalchemy.text("ROLLBACK"))
|
|
self._logger.error(termcolor.colored(_statement, "red"))
|
|
e = ValueError(e.orig)
|
|
e.__cause__ = None
|
|
raise e
|
|
|
|
# If user error
|
|
except (
|
|
sqlalchemy.exc.OperationalError,
|
|
sqlalchemy.exc.ProgrammingError,
|
|
) as e:
|
|
self._disconnect()
|
|
self._logger.error(termcolor.colored(_statement, "red"))
|
|
e = RuntimeError(e.orig)
|
|
e.__cause__ = None
|
|
raise e
|
|
|
|
# Return value
|
|
else:
|
|
self._logger.info(termcolor.colored(_statement, "green"))
|
|
if self._autocommit: # Don't stay connected unnecessarily
|
|
self._disconnect()
|
|
return ret
|
|
|
|
def _escape(self, value):
|
|
"""
|
|
Escapes value using engine's conversion function.
|
|
|
|
https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
|
|
"""
|
|
|
|
# Lazily import
|
|
import sqlparse
|
|
|
|
def __escape(value):
|
|
# Lazily import
|
|
import datetime
|
|
import sqlalchemy
|
|
|
|
# bool
|
|
if isinstance(value, bool):
|
|
return sqlparse.sql.Token(
|
|
sqlparse.tokens.Number,
|
|
sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(
|
|
value
|
|
),
|
|
)
|
|
|
|
# bytes
|
|
elif isinstance(value, bytes):
|
|
if self._engine.url.get_backend_name() in ["mysql", "sqlite"]:
|
|
return sqlparse.sql.Token(
|
|
sqlparse.tokens.Other, f"x'{value.hex()}'"
|
|
) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
|
|
elif self._engine.url.get_backend_name() == "postgresql":
|
|
return sqlparse.sql.Token(
|
|
sqlparse.tokens.Other, f"'\\x{value.hex()}'"
|
|
) # https://dba.stackexchange.com/a/203359
|
|
else:
|
|
raise RuntimeError("unsupported value: {}".format(value))
|
|
|
|
# datetime.datetime
|
|
elif isinstance(value, datetime.datetime):
|
|
return sqlparse.sql.Token(
|
|
sqlparse.tokens.String,
|
|
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
|
|
value.strftime("%Y-%m-%d %H:%M:%S")
|
|
),
|
|
)
|
|
|
|
# datetime.date
|
|
elif isinstance(value, datetime.date):
|
|
return sqlparse.sql.Token(
|
|
sqlparse.tokens.String,
|
|
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
|
|
value.strftime("%Y-%m-%d")
|
|
),
|
|
)
|
|
|
|
# datetime.time
|
|
elif isinstance(value, datetime.time):
|
|
return sqlparse.sql.Token(
|
|
sqlparse.tokens.String,
|
|
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
|
|
value.strftime("%H:%M:%S")
|
|
),
|
|
)
|
|
|
|
# float
|
|
elif isinstance(value, float):
|
|
return sqlparse.sql.Token(
|
|
sqlparse.tokens.Number,
|
|
sqlalchemy.types.Float().literal_processor(self._engine.dialect)(
|
|
value
|
|
),
|
|
)
|
|
|
|
# int
|
|
elif isinstance(value, int):
|
|
return sqlparse.sql.Token(
|
|
sqlparse.tokens.Number,
|
|
sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(
|
|
value
|
|
),
|
|
)
|
|
|
|
# str
|
|
elif isinstance(value, str):
|
|
return sqlparse.sql.Token(
|
|
sqlparse.tokens.String,
|
|
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
|
|
value
|
|
),
|
|
)
|
|
|
|
# None
|
|
elif value is None:
|
|
return sqlparse.sql.Token(sqlparse.tokens.Keyword, sqlalchemy.null())
|
|
|
|
# Unsupported value
|
|
else:
|
|
raise RuntimeError("unsupported value: {}".format(value))
|
|
|
|
# Escape value(s), separating with commas as needed
|
|
if isinstance(value, (list, tuple)):
|
|
return sqlparse.sql.TokenList(
|
|
sqlparse.parse(", ".join([str(__escape(v)) for v in value]))
|
|
)
|
|
else:
|
|
return __escape(value)
|
|
|
|
|
|
def _parse_exception(e):
|
|
"""Parses an exception, returns its message."""
|
|
|
|
# Lazily import
|
|
import re
|
|
|
|
# MySQL
|
|
matches = re.search(
|
|
r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)
|
|
)
|
|
if matches:
|
|
return matches.group(1)
|
|
|
|
# PostgreSQL
|
|
matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e))
|
|
if matches:
|
|
return matches.group(1)
|
|
|
|
# SQLite
|
|
matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e))
|
|
if matches:
|
|
return matches.group(1)
|
|
|
|
# Default
|
|
return str(e)
|
|
|
|
|
|
def _parse_placeholder(token):
|
|
"""Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""
|
|
|
|
# Lazily load
|
|
import re
|
|
import sqlparse
|
|
|
|
# Validate token
|
|
if (
|
|
not isinstance(token, sqlparse.sql.Token)
|
|
or token.ttype != sqlparse.tokens.Name.Placeholder
|
|
):
|
|
raise TypeError()
|
|
|
|
# qmark
|
|
if token.value == "?":
|
|
return "qmark", None
|
|
|
|
# numeric
|
|
matches = re.search(r"^:([1-9]\d*)$", token.value)
|
|
if matches:
|
|
return "numeric", int(matches.group(1)) - 1
|
|
|
|
# named
|
|
matches = re.search(r"^:([a-zA-Z]\w*)$", token.value)
|
|
if matches:
|
|
return "named", matches.group(1)
|
|
|
|
# format
|
|
if token.value == "%s":
|
|
return "format", None
|
|
|
|
# pyformat
|
|
matches = re.search(r"%\((\w+)\)s$", token.value)
|
|
if matches:
|
|
return "pyformat", matches.group(1)
|
|
|
|
# Invalid
|
|
raise RuntimeError("{}: invalid placeholder".format(token.value))
|