import logging from datetime import datetime from cs50 import SQL from flask import Flask, flash, redirect, render_template, request, session from flask_session import Session from werkzeug.security import check_password_hash, generate_password_hash from helpers import apology, login_required, lookup, usd # Configure application app = Flask(__name__) # Custom filter app.jinja_env.filters["usd"] = usd # Configure session to use filesystem (instead of signed cookies) app.config["SESSION_PERMANENT"] = False app.config["SESSION_TYPE"] = "filesystem" Session(app) # Configure CS50 Library to use SQLite database db = SQL("sqlite:///finance.db") # db = SQL("postgresql://postgres:password@db/finance") @app.after_request def after_request(response): """Ensure responses aren't cached""" response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" response.headers["Expires"] = 0 response.headers["Pragma"] = "no-cache" return response @app.route("/") @login_required def index(): # Get actual cash of the logged-in user user = db.execute("SELECT * FROM users WHERE id = ?", session["user_id"]) transactions = db.execute("SELECT symbol, SUM(shares), SUM(costs) FROM transactions WHERE user_id = ? GROUP BY symbol", session["user_id"]) # Calculations stocks_total = 0.0 portfolio = [] for holding in transactions: if holding['SUM(shares)'] <= 0: continue symbol = holding['symbol'] stock = lookup(symbol) symbol_price = float(stock['price']) holding['actual_price'] = symbol_price * holding['SUM(shares)'] stocks_total += int(holding['actual_price']) holding['difference'] = holding['actual_price'] + holding['SUM(costs)'] holding['percentage'] = holding['difference'] / holding['actual_price'] * 100 portfolio.append(holding) print(portfolio) total = float(user[0]["cash"]) + stocks_total return render_template("index.html", user=user[0], portfolio=portfolio, stocks_total=stocks_total, total=total) @app.route("/buy", methods=["GET", "POST"]) @login_required def buy(): if request.method == "POST": # Ensure symbol was given if not request.form.get("symbol"): return apology("must provide symbol", 400) symbol = request.form.get("symbol") # Ensure shares was given if not request.form.get("shares"): return apology("must provide shares", 400) shares = request.form.get("shares") # Check for digit shares if not shares.isdigit(): return apology("must provide a digit number of shares", 400) shares = float(shares) # Check for digit shares if not shares.is_integer(): return apology("must provide a integer number of shares", 400) shares = int(shares) # Check if shares are greater than 0 if shares <= 0: return apology("must provide a positive number of shares", 400) # Lookup the actual price stock_prices = lookup(symbol) # Check if symbol is valid if not stock_prices: return apology("symbol not found", 400) # Get the cash of the user rows = db.execute("SELECT * FROM users WHERE id = ?", session["user_id"]) cash = float(rows[0]["cash"]) logging.debug("cash = " + str(cash)) # Calculate the costs costs = float(stock_prices["price"]) * shares logging.debug("costs = " + str(stock_prices["price"]) + " * " + str(shares) + " = " + str(costs)) # Check if user has enough money if costs > cash: return apology("not enough money left", 400) # Buy and save transaction cash -= costs db.execute("UPDATE users SET cash = ? WHERE id = ? ", cash, session["user_id"]) db.execute("INSERT INTO transactions (user_id, buy_sell, datetime, symbol, shares, costs) VALUES(?, ?, ?, ?, ?, ?)",session["user_id"], 1, datetime.now(), symbol, shares, -costs) # Show notification about bought flash(f"Bought {shares} shares of {symbol} for {usd(costs)}!") # Redirect user to home page return redirect("/") # User reached route via GET (as by clicking a link or via redirect) else: return render_template("buy.html") @app.route("/history") @login_required def history(): # Show history of transactions transactions = db.execute("SELECT * FROM transactions WHERE user_id = ?", session["user_id"]) if not transactions: transactions = [] return render_template("history.html", transactions=transactions) @app.route("/login", methods=["GET", "POST"]) def login(): """Log user in""" # Forget any user_id session.clear() # User reached route via POST (as by submitting a form via POST) if request.method == "POST": # Ensure username was submitted if not request.form.get("username"): return apology("must provide username", 400) username = request.form.get("username") # Ensure password was submitted if not request.form.get("password"): return apology("must provide password", 400) password = request.form.get("password") # Query database for username rows = db.execute("SELECT * FROM users WHERE username = ?", username) # Ensure username exists and password is correct if len(rows) != 1 or not check_password_hash(rows[0]["hash"], password): return apology("invalid username and/or password", 400) # Remember which user has logged in session["user_id"] = rows[0]["id"] # Redirect user to home page return redirect("/") # User reached route via GET (as by clicking a link or via redirect) else: return render_template("login.html") @app.route("/logout") def logout(): """Log user out""" # Forget any user_id session.clear() # Redirect user to login form return redirect("/") @app.route("/quote", methods=["GET", "POST"]) @login_required def quote(): """Get stock quote.""" if request.method == "POST": symbol = request.form.get("symbol") logging.debug("symbol = " + symbol) if not symbol: return apology("must provide symbol", 400) stock_prices = lookup(symbol) logging.debug("stock_prices = " + str(stock_prices)) if not stock_prices: return apology("symbol not found", 400) return render_template("quoted.html", stock_prices=stock_prices) else: return render_template("quote.html") @app.route("/register", methods=["GET", "POST"]) def register(): """Register user""" # User reached route via POST (as by submitting a form via POST) if request.method == "POST": username = request.form.get("username") # Ensure username was submitted if not username: return apology("must provide username", 400) # Check if user already existing if db.execute("SELECT * FROM users WHERE username = ?", username): return apology("user already exists", 400) password = request.form.get("password") # Ensure password was submitted if not password: return apology("must provide password", 400) confirmation = request.form.get("confirmation") # Ensure password was submitted if not confirmation: # Query database for username return apology("must provide confirmation", 400) # Ensure password is identical with confirmation if password != confirmation: return apology("password and confirmation does not match", 400) # Add user to db password_hash = generate_password_hash(password) db.execute("INSERT INTO users (username, hash) VALUES(?, ?)", username, password_hash) # Redirect user to home page return redirect("/") # User reached route via GET (as by clicking a link or via redirect) else: return render_template("register.html") @app.route("/sell", methods=["GET", "POST"]) @login_required def sell(): # Get portfolio transactions = db.execute( "SELECT symbol, SUM(shares), SUM(costs) FROM transactions WHERE user_id = ? GROUP BY symbol", session["user_id"]) if request.method == "POST": # Ensure symbol was given if not request.form.get("symbol"): return apology("must provide symbol", 400) symbol = request.form.get("symbol") # Ensure shares was given if not request.form.get("shares"): return apology("must provide shares", 400) shares = request.form.get("shares") # Check for digit shares if not shares.isdigit(): return apology("must provide a digit number of shares", 400) shares = float(shares) # Check for digit shares if not shares.is_integer(): return apology("must provide a integer number of shares", 400) shares = int(shares) # Check if shares are greater than 0 if shares <= 0: return apology("must provide a positive number of shares", 400) # Calculate the costs stock_prices = lookup(symbol) # Check if symbol is valid if not stock_prices: return apology("symbol not found", 400) # Check if enough shares available old_shares = 0 for transaction in transactions: if transaction['symbol'] == symbol: old_shares = transaction['SUM(shares)'] if shares > old_shares: return apology("You have only " + str(old_shares) + " shares of " + symbol + " available!", 400) # Get the cash of the user user = db.execute("SELECT * FROM users WHERE id = ?", session["user_id"]) cash = float(user[0]["cash"]) profit = float(stock_prices["price"]) * shares # Buy and save transaction cash += profit db.execute("UPDATE users SET cash = ? WHERE id = ? ", cash, session["user_id"]) db.execute("INSERT INTO transactions (user_id, buy_sell, datetime, symbol, shares, costs) VALUES(?, ?, ?, ?, ?, ?)", session["user_id"], 0, datetime.now(), symbol, -shares, profit) # Show notification about bought flash(f"Sold {shares} shares of {symbol} for {usd(profit)}!") # Redirect user to home page return redirect("/") # User reached route via GET (as by clicking a link or via redirect) else: portfolio = [] for transaction in transactions: if transaction['SUM(shares)'] <= 0: continue portfolio.append(transaction) return render_template("sell.html", portfolio=portfolio)