forked from GoogleCloudPlatform/python-docs-samples
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdatabase.py
More file actions
205 lines (177 loc) · 7.35 KB
/
database.py
File metadata and controls
205 lines (177 loc) · 7.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
from typing import Dict
import sqlalchemy
from sqlalchemy.orm import close_all_sessions
from sqlalchemy.pool import NullPool
import credentials
from middleware import logger
# This global variable is declared with a value of `None`, instead of calling
# `init_connection_engine()` immediately, to simplify testing. In general, it
# is safe to initialize your database connection pool when your script starts
# -- there is no need to wait for the first request.
db = None
def init_connection_engine() -> Dict[str, int]:
if os.getenv("TRAMPOLINE_CI", None):
logger.info("Using NullPool for testing")
db_config = {"poolclass": NullPool}
else:
db_config = {
# Pool size is the maximum number of permanent connections to keep.
"pool_size": 5,
# Temporarily exceeds the set pool_size if no connections are available.
"max_overflow": 2,
# The total number of concurrent connections for your application will be
# a total of pool_size and max_overflow.
# SQLAlchemy automatically uses delays between failed connection attempts,
# but provides no arguments for configuration.
# 'pool_timeout' is the maximum number of seconds to wait when retrieving a
# new connection from the pool. After the specified amount of time, an
# exception will be thrown.
"pool_timeout": 30, # 30 seconds
# 'pool_recycle' is the maximum number of seconds a connection can persist.
# Connections that live longer than the specified amount of time will be
# reestablished
"pool_recycle": 1800, # 30 minutes
}
if os.environ.get("DB_HOST"):
return init_tcp_connection_engine(db_config)
else:
return init_unix_connection_engine(db_config)
def init_tcp_connection_engine(
db_config: Dict[str, str]
) -> sqlalchemy.engine.base.Engine:
creds = credentials.get_cred_config()
db_user = creds["DB_USER"]
db_pass = creds["DB_PASSWORD"]
db_name = creds["DB_NAME"]
db_host = creds["DB_HOST"]
# Extract host and port from db_host
host_args = db_host.split(":")
db_hostname, db_port = host_args[0], int(host_args[1])
pool = sqlalchemy.create_engine(
# Equivalent URL:
# postgres+pg8000://<db_user>:<db_pass>@<db_host>:<db_port>/<db_name>
sqlalchemy.engine.url.URL.create(
drivername="postgresql+pg8000",
username=db_user, # e.g. "my-database-user"
password=db_pass, # e.g. "my-database-password"
host=db_hostname, # e.g. "127.0.0.1"
port=db_port, # e.g. 5432
database=db_name, # e.g. "my-database-name"
),
**db_config,
)
pool.dialect.description_encoding = None
logger.info("Database engine initialised from tcp connection")
return pool
# [START cloudrun_user_auth_sql_connect]
def init_unix_connection_engine(
db_config: Dict[str, str]
) -> sqlalchemy.engine.base.Engine:
creds = credentials.get_cred_config()
db_user = creds["DB_USER"]
db_pass = creds["DB_PASSWORD"]
db_name = creds["DB_NAME"]
db_socket_dir = creds.get("DB_SOCKET_DIR", "/cloudsql")
cloud_sql_connection_name = creds["CLOUD_SQL_CONNECTION_NAME"]
pool = sqlalchemy.create_engine(
# Equivalent URL:
# postgres+pg8000://<db_user>:<db_pass>@/<db_name>
# ?unix_sock=<socket_path>/<cloud_sql_instance_name>/.s.PGSQL.5432
sqlalchemy.engine.url.URL.create(
drivername="postgresql+pg8000",
username=db_user, # e.g. "my-database-user"
password=db_pass, # e.g. "my-database-password"
database=db_name, # e.g. "my-database-name"
query={
"unix_sock": "{}/{}/.s.PGSQL.5432".format(
db_socket_dir, cloud_sql_connection_name # e.g. "/cloudsql"
) # i.e "<PROJECT-NAME>:<INSTANCE-REGION>:<INSTANCE-NAME>"
},
),
**db_config,
)
pool.dialect.description_encoding = None
logger.info("Database engine initialised from unix conection")
return pool
# [END cloudrun_user_auth_sql_connect]
def create_tables() -> None:
# This is called before any request on the main app, ensuring the database has been setup
logger.info("Creating tables")
global db
db = init_connection_engine()
# Create pet_votes table if it doesn't already exist
with db.connect() as conn:
conn.execute(
"CREATE TABLE IF NOT EXISTS pet_votes"
"( vote_id SERIAL NOT NULL, "
"time_cast timestamp NOT NULL, "
"candidate VARCHAR(6) NOT NULL, "
"uid VARCHAR(128) NOT NULL, "
"PRIMARY KEY (vote_id)"
");"
)
def get_index_context() -> Dict:
votes = []
with db.connect() as conn:
# Execute the query and fetch all results
recent_votes = conn.execute(
"SELECT candidate, time_cast FROM pet_votes "
"ORDER BY time_cast DESC LIMIT 5"
).fetchall()
# Convert the results into a list of dicts representing votes
for row in recent_votes:
votes.append(
{
"candidate": row[0],
"time_cast": row[1],
}
)
stmt = sqlalchemy.text(
"SELECT COUNT(vote_id) FROM pet_votes WHERE candidate=:candidate"
)
# Count number of votes for cats
cats_result = conn.execute(stmt, candidate="CATS").fetchone()
cats_count = cats_result[0]
# Count number of votes for dogs
dogs_result = conn.execute(stmt, candidate="DOGS").fetchone()
dogs_count = dogs_result[0]
return {
"dogs_count": dogs_count,
"recent_votes": votes,
"cats_count": cats_count,
}
def save_vote(team: str, uid: str, time_cast: datetime.datetime) -> None:
# Preparing a statement before hand can help protect against injections.
stmt = sqlalchemy.text(
"INSERT INTO pet_votes (time_cast, candidate, uid)"
" VALUES (:time_cast, :candidate, :uid)"
)
# Using a with statement ensures that the connection is always released
# back into the pool at the end of statement (even if an error occurs)
with db.connect() as conn:
conn.execute(stmt, time_cast=time_cast, candidate=team, uid=uid)
logger.info("Vote for %s saved.", team)
def shutdown() -> None:
# Find all Sessions in memory and close them.
close_all_sessions()
logger.info("All sessions closed.")
# Each connection was released on execution, so just formally
# dispose of the db connection if it's been instantiated
if db:
db.dispose()
logger.info("Database connection disposed.")