Skip to content

Commit ddd645c

Browse files
author
KirillMysnik
committed
Made database sessions more safe
1 parent 41ca610 commit ddd645c

File tree

4 files changed

+202
-216
lines changed

4 files changed

+202
-216
lines changed

srcds/addons/source-python/plugins/admin/core/orm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,21 @@
1919
))
2020
Base = declarative_base()
2121
Session = sessionmaker(bind=engine)
22+
23+
24+
# =============================================================================
25+
# >> CLASSES
26+
# =============================================================================
27+
class SessionContext:
28+
def __init__(self):
29+
self.session = None
30+
31+
def __enter__(self):
32+
self.session = Session()
33+
return self.session
34+
35+
def __exit__(self, exc_type, exc_val, exc_tb):
36+
if exc_type is not None:
37+
self.session.rollback()
38+
self.session.close()
39+
self.session = None

srcds/addons/source-python/plugins/admin/plugins/included/admin_comm_management/blocks/base.py

Lines changed: 57 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from admin.core.features import BaseFeature
2020
from admin.core.frontends.menus import AdminCommand, PlayerBasedAdminCommand
2121
from admin.core.helpers import format_player_name
22-
from admin.core.orm import Session
22+
from admin.core.orm import SessionContext
2323
from admin.core.paths import ADMIN_CFG_PATH, get_server_file
2424
from admin.core.strings import strings_common
2525

@@ -89,24 +89,21 @@ def _on_change(self):
8989
def refresh(self):
9090
self.clear()
9191

92-
session = Session()
92+
with SessionContext() as session:
93+
blocked_users = session.query(self.model).all()
9394

94-
blocked_users = session.query(self.model).all()
95+
current_time = time()
96+
for blocked_user in blocked_users:
97+
if blocked_user.is_unblocked:
98+
continue
9599

96-
current_time = time()
97-
for blocked_user in blocked_users:
98-
if blocked_user.is_unblocked:
99-
continue
100+
if 0 <= blocked_user.expires_at < current_time:
101+
continue
100102

101-
if 0 <= blocked_user.expires_at < current_time:
102-
continue
103-
104-
self[blocked_user.steamid64] = _BlockedCommUserInfo(
105-
blocked_user.steamid64, blocked_user.id, blocked_user.name,
106-
blocked_user.blocked_by, blocked_user.expires_at
107-
)
108-
109-
session.close()
103+
self[blocked_user.steamid64] = _BlockedCommUserInfo(
104+
blocked_user.steamid64, blocked_user.id, blocked_user.name,
105+
blocked_user.blocked_by, blocked_user.expires_at
106+
)
110107

111108
self._on_change()
112109

@@ -129,60 +126,55 @@ def save_block_to_database(self, blocked_by, steamid, name, duration):
129126
steamid = self._convert_steamid_to_db_format(steamid)
130127
blocked_by = self._convert_steamid_to_db_format(blocked_by)
131128

132-
session = Session()
133-
134-
blocked_user = self.model(steamid, name, blocked_by, duration)
129+
with SessionContext() as session:
130+
blocked_user = self.model(steamid, name, blocked_by, duration)
135131

136-
session.add(blocked_user)
137-
session.commit()
132+
session.add(blocked_user)
133+
session.commit()
138134

139-
self[steamid] = _BlockedCommUserInfo(
140-
steamid, blocked_user.id, name, blocked_by,
141-
blocked_user.expires_at)
142-
143-
session.close()
135+
self[steamid] = _BlockedCommUserInfo(
136+
steamid, blocked_user.id, name, blocked_by,
137+
blocked_user.expires_at)
144138

145139
self._on_change()
146140

147141
def get_all_blocks(
148142
self, steamid=None, blocked_by=None, expired=None, unblocked=None):
149143

150144
result = []
151-
session = Session()
152-
153-
query = session.query(self.model)
154145

155-
if steamid is not None:
156-
steamid = self._convert_steamid_to_db_format(steamid)
157-
query = query.filter_by(steamid64=steamid)
158-
159-
if blocked_by is not None:
160-
blocked_by = self._convert_steamid_to_db_format(blocked_by)
161-
query = query.filter_by(blocked_by=blocked_by)
162-
163-
if expired is not None:
164-
current_time = int(time())
165-
if expired:
166-
query = query.filter(and_(
167-
self.model.expires_at < current_time,
168-
self.model.expires_at >= 0
146+
with SessionContext() as session:
147+
query = session.query(self.model)
148+
149+
if steamid is not None:
150+
steamid = self._convert_steamid_to_db_format(steamid)
151+
query = query.filter_by(steamid64=steamid)
152+
153+
if blocked_by is not None:
154+
blocked_by = self._convert_steamid_to_db_format(blocked_by)
155+
query = query.filter_by(blocked_by=blocked_by)
156+
157+
if expired is not None:
158+
current_time = int(time())
159+
if expired:
160+
query = query.filter(and_(
161+
self.model.expires_at < current_time,
162+
self.model.expires_at >= 0
163+
))
164+
else:
165+
query = query.filter(or_(
166+
self.model.expires_at >= current_time,
167+
self.model.expires_at < 0
168+
))
169+
170+
if unblocked is not None:
171+
query = query.filter_by(is_unblocked=unblocked)
172+
173+
for blocked_user in query.all():
174+
result.append(_BlockedCommUserInfo(
175+
blocked_user.steamid64, blocked_user.id, blocked_user.name,
176+
blocked_user.blocked_by, blocked_user.expires_at
169177
))
170-
else:
171-
query = query.filter(or_(
172-
self.model.expires_at >= current_time,
173-
self.model.expires_at < 0
174-
))
175-
176-
if unblocked is not None:
177-
query = query.filter_by(is_unblocked=unblocked)
178-
179-
for blocked_user in query.all():
180-
result.append(_BlockedCommUserInfo(
181-
blocked_user.steamid64, blocked_user.id, blocked_user.name,
182-
blocked_user.blocked_by, blocked_user.expires_at
183-
))
184-
185-
session.close()
186178

187179
return result
188180

@@ -210,18 +202,15 @@ def get_active_blocks(self, blocked_by=None):
210202
def lift_block(self, id_, unblocked_by):
211203
unblocked_by = self._convert_steamid_to_db_format(unblocked_by)
212204

213-
session = Session()
205+
with SessionContext() as session:
206+
blocked_user = session.query(self.model).filter_by(id=id_).first()
214207

215-
blocked_user = session.query(self.model).filter_by(id=id_).first()
216-
217-
if blocked_user is None:
218-
session.close()
219-
return
208+
if blocked_user is None:
209+
return
220210

221-
blocked_user.lift_block(unblocked_by)
211+
blocked_user.lift_block(unblocked_by)
222212

223-
session.commit()
224-
session.close()
213+
session.commit()
225214

226215
for steamid64, blocked_comm_user_info in self.items():
227216
if blocked_comm_user_info.id != id_:

0 commit comments

Comments
 (0)