Simplify database locking

Use a single lock for everything to ensure we don’t go inconsistent.
One exception is the firewall nodes table which is only accessed when
pushing updated config.
This commit is contained in:
Timotej Lazar 2023-05-19 09:30:28 +02:00
parent 93458c4782
commit 22cec64bef
5 changed files with 19 additions and 18 deletions

View file

@ -28,7 +28,7 @@ def create_app(test_config=None):
} }
from . import db from . import db
with db.locked('settings'): with db.locked():
settings |= db.read('settings') settings |= db.read('settings')
db.write('settings', settings) db.write('settings', settings)

View file

@ -14,7 +14,7 @@ def index():
try: try:
if not flask_login.current_user.is_admin: if not flask_login.current_user.is_admin:
return flask.Response('forbidden', status=403, mimetype='text/plain') return flask.Response('forbidden', status=403, mimetype='text/plain')
with db.locked('settings'): with db.locked():
if flask.request.method == 'POST': if flask.request.method == 'POST':
form = flask.request.form form = flask.request.form
db.write('settings', dict(zip(form.getlist('setting'), form.getlist('value')))) db.write('settings', dict(zip(form.getlist('setting'), form.getlist('value'))))

View file

@ -13,12 +13,12 @@ def lock(name, timeout=5):
time.sleep(1) time.sleep(1)
raise TimeoutError(f'could not lock {name}') raise TimeoutError(f'could not lock {name}')
def unlock(name): def unlock(name='db'):
lockfile = pathlib.Path.home() / f'{name}.lock' lockfile = pathlib.Path.home() / f'{name}.lock'
lockfile.unlink(missing_ok=True) lockfile.unlink(missing_ok=True)
@contextlib.contextmanager @contextlib.contextmanager
def locked(name): def locked(name='db'):
lock(name) lock(name)
try: try:
yield name yield name
@ -36,9 +36,9 @@ def write(name, data):
f.close() f.close()
def load(name): def load(name):
with locked(name): with locked():
return read(name) return read(name)
def save(name, data): def save(name, data):
with locked(name): with locked():
write(name, data) write(name, data)

View file

@ -29,10 +29,11 @@ def run(fun, args=()):
def save_config(): def save_config():
output = None output = None
try: try:
# Just load the settings here but don’t lock the database while we load # Just load the settings here but keep the database unlocked
# stuff from LDAP. # while we load group memberships from LDAP.
settings = db.load('settings') with db.locked():
groups = db.load('groups') settings = db.read('settings')
groups = db.read('groups')
# For each user build a list of networks they have access to, based on # For each user build a list of networks they have access to, based on
# group membership in AD. Only query groups associated with at least one # group membership in AD. Only query groups associated with at least one
@ -49,13 +50,13 @@ def save_config():
# Now read the settings again and lock the database while generating # Now read the settings again and lock the database while generating
# config files, then increment version before unlocking. # config files, then increment version before unlocking.
with db.locked('settings'): with db.locked():
settings = db.read('settings') settings = db.read('settings')
version = settings['version'] = int(settings.get('version', 0)) + 1 version = settings['version'] = int(settings.get('version', 0)) + 1
# Populate IP sets and translation maps for NAT. # Populate IP sets and translation maps for NAT.
ipsets = collections.defaultdict(set) ipsets = collections.defaultdict(set)
networks = db.load('networks') networks = db.read('networks')
for name, network in networks.items(): for name, network in networks.items():
for ip in network.get('ip', ()): for ip in network.get('ip', ()):
ipsets[name].add(ip) ipsets[name].add(ip)
@ -63,12 +64,12 @@ def save_config():
ipsets[f'{name}6'].update(ip6) ipsets[f'{name}6'].update(ip6)
# Load static and dynamic NAT translations. # Load static and dynamic NAT translations.
nat = db.load('nat') # { network name: public range… } nat = db.read('nat') # { network name: public range… }
netmap = db.load('netmap') # { private range: public range… } netmap = db.read('netmap') # { private range: public range… }
# Add registered VPN addresses for each network based on # Add registered VPN addresses for each network based on
# LDAP group membership. # LDAP group membership.
wireguard = db.load('wireguard') wireguard = db.read('wireguard')
for ip, key in wireguard.items(): for ip, key in wireguard.items():
for network in user_networks.get(key.get('user', ''), ()): for network in user_networks.get(key.get('user', ''), ()):
ipsets[network].add(f'{ip}/32') ipsets[network].add(f'{ip}/32')
@ -116,7 +117,7 @@ map {name} {{
# Print forwarding rules. # Print forwarding rules.
with open(f'{output}/etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f: with open(f'{output}/etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f:
for forward in db.load('forwards'): for forward in db.read('forwards'):
print(forward, file=f) print(forward, file=f)
# Print wireguard config. # Print wireguard config.

View file

@ -40,7 +40,7 @@ def new():
text=True, capture_output=True, shell=True).stdout.strip() text=True, capture_output=True, shell=True).stdout.strip()
host = ipaddress.ip_interface(settings.get('wg_net', '10.0.0.1/24')) host = ipaddress.ip_interface(settings.get('wg_net', '10.0.0.1/24'))
with db.locked('wireguard'): with db.locked():
# Find a free address for the new key. # Find a free address for the new key.
ips = db.read('wireguard') ips = db.read('wireguard')
for ip in host.network.hosts(): for ip in host.network.hosts():
@ -88,7 +88,7 @@ def delete():
return flask.Response('invalid key', status=400, mimetype='text/plain') return flask.Response('invalid key', status=400, mimetype='text/plain')
try: try:
with db.locked('wireguard'): with db.locked():
user = flask_login.current_user.get_id() user = flask_login.current_user.get_id()
ips = {k: v for k, v in db.read('wireguard').items() if v.get('user') != user or v.get('key') != pubkey} ips = {k: v for k, v in db.read('wireguard').items() if v.get('user') != user or v.get('key') != pubkey}
db.write('wireguard', ips) db.write('wireguard', ips)