From 22cec64befd69b88daa97cf03a337feff8545d4a Mon Sep 17 00:00:00 2001 From: Timotej Lazar Date: Fri, 19 May 2023 09:30:28 +0200 Subject: [PATCH] Simplify database locking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use a single lock for everything to ensure we don’t go inconsistent. One exception is the firewall nodes table which is only accessed when pushing updated config. --- web/__init__.py | 2 +- web/config.py | 2 +- web/db.py | 8 ++++---- web/system.py | 21 +++++++++++---------- web/vpn.py | 4 ++-- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/web/__init__.py b/web/__init__.py index 0bb99c2..f1286b7 100644 --- a/web/__init__.py +++ b/web/__init__.py @@ -28,7 +28,7 @@ def create_app(test_config=None): } from . import db - with db.locked('settings'): + with db.locked(): settings |= db.read('settings') db.write('settings', settings) diff --git a/web/config.py b/web/config.py index c7c352e..85cc9a6 100644 --- a/web/config.py +++ b/web/config.py @@ -14,7 +14,7 @@ def index(): try: if not flask_login.current_user.is_admin: return flask.Response('forbidden', status=403, mimetype='text/plain') - with db.locked('settings'): + with db.locked(): if flask.request.method == 'POST': form = flask.request.form db.write('settings', dict(zip(form.getlist('setting'), form.getlist('value')))) diff --git a/web/db.py b/web/db.py index 1cdee76..4ed5182 100644 --- a/web/db.py +++ b/web/db.py @@ -13,12 +13,12 @@ def lock(name, timeout=5): time.sleep(1) raise TimeoutError(f'could not lock {name}') -def unlock(name): +def unlock(name='db'): lockfile = pathlib.Path.home() / f'{name}.lock' lockfile.unlink(missing_ok=True) @contextlib.contextmanager -def locked(name): +def locked(name='db'): lock(name) try: yield name @@ -36,9 +36,9 @@ def write(name, data): f.close() def load(name): - with locked(name): + with locked(): return read(name) def save(name, data): - with locked(name): + with locked(): write(name, data) diff --git a/web/system.py b/web/system.py index 0dc6f32..ad445df 100644 --- a/web/system.py +++ b/web/system.py @@ -29,10 +29,11 @@ def run(fun, args=()): def save_config(): output = None try: - # Just load the settings here but don’t lock the database while we load - # stuff from LDAP. - settings = db.load('settings') - groups = db.load('groups') + # Just load the settings here but keep the database unlocked + # while we load group memberships from LDAP. + with db.locked(): + settings = db.read('settings') + groups = db.read('groups') # For each user build a list of networks they have access to, based on # group membership in AD. Only query groups associated with at least one @@ -49,13 +50,13 @@ def save_config(): # Now read the settings again and lock the database while generating # config files, then increment version before unlocking. - with db.locked('settings'): + with db.locked(): settings = db.read('settings') version = settings['version'] = int(settings.get('version', 0)) + 1 # Populate IP sets and translation maps for NAT. ipsets = collections.defaultdict(set) - networks = db.load('networks') + networks = db.read('networks') for name, network in networks.items(): for ip in network.get('ip', ()): ipsets[name].add(ip) @@ -63,12 +64,12 @@ def save_config(): ipsets[f'{name}6'].update(ip6) # Load static and dynamic NAT translations. - nat = db.load('nat') # { network name: public range… } - netmap = db.load('netmap') # { private range: public range… } + nat = db.read('nat') # { network name: public range… } + netmap = db.read('netmap') # { private range: public range… } # Add registered VPN addresses for each network based on # LDAP group membership. - wireguard = db.load('wireguard') + wireguard = db.read('wireguard') for ip, key in wireguard.items(): for network in user_networks.get(key.get('user', ''), ()): ipsets[network].add(f'{ip}/32') @@ -116,7 +117,7 @@ map {name} {{ # Print forwarding rules. with open(f'{output}/etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f: - for forward in db.load('forwards'): + for forward in db.read('forwards'): print(forward, file=f) # Print wireguard config. diff --git a/web/vpn.py b/web/vpn.py index 7212de7..e76a691 100644 --- a/web/vpn.py +++ b/web/vpn.py @@ -40,7 +40,7 @@ def new(): text=True, capture_output=True, shell=True).stdout.strip() host = ipaddress.ip_interface(settings.get('wg_net', '10.0.0.1/24')) - with db.locked('wireguard'): + with db.locked(): # Find a free address for the new key. ips = db.read('wireguard') for ip in host.network.hosts(): @@ -88,7 +88,7 @@ def delete(): return flask.Response('invalid key', status=400, mimetype='text/plain') try: - with db.locked('wireguard'): + with db.locked(): user = flask_login.current_user.get_id() ips = {k: v for k, v in db.read('wireguard').items() if v.get('user') != user or v.get('key') != pubkey} db.write('wireguard', ips)