Simplify database locking
Use a single lock for everything to ensure we don’t go inconsistent. One exception is the firewall nodes table which is only accessed when pushing updated config.
This commit is contained in:
parent
93458c4782
commit
22cec64bef
|
@ -28,7 +28,7 @@ def create_app(test_config=None):
|
||||||
}
|
}
|
||||||
|
|
||||||
from . import db
|
from . import db
|
||||||
with db.locked('settings'):
|
with db.locked():
|
||||||
settings |= db.read('settings')
|
settings |= db.read('settings')
|
||||||
db.write('settings', settings)
|
db.write('settings', settings)
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ def index():
|
||||||
try:
|
try:
|
||||||
if not flask_login.current_user.is_admin:
|
if not flask_login.current_user.is_admin:
|
||||||
return flask.Response('forbidden', status=403, mimetype='text/plain')
|
return flask.Response('forbidden', status=403, mimetype='text/plain')
|
||||||
with db.locked('settings'):
|
with db.locked():
|
||||||
if flask.request.method == 'POST':
|
if flask.request.method == 'POST':
|
||||||
form = flask.request.form
|
form = flask.request.form
|
||||||
db.write('settings', dict(zip(form.getlist('setting'), form.getlist('value'))))
|
db.write('settings', dict(zip(form.getlist('setting'), form.getlist('value'))))
|
||||||
|
|
|
@ -13,12 +13,12 @@ def lock(name, timeout=5):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
raise TimeoutError(f'could not lock {name}')
|
raise TimeoutError(f'could not lock {name}')
|
||||||
|
|
||||||
def unlock(name):
|
def unlock(name='db'):
|
||||||
lockfile = pathlib.Path.home() / f'{name}.lock'
|
lockfile = pathlib.Path.home() / f'{name}.lock'
|
||||||
lockfile.unlink(missing_ok=True)
|
lockfile.unlink(missing_ok=True)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def locked(name):
|
def locked(name='db'):
|
||||||
lock(name)
|
lock(name)
|
||||||
try:
|
try:
|
||||||
yield name
|
yield name
|
||||||
|
@ -36,9 +36,9 @@ def write(name, data):
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
def load(name):
|
def load(name):
|
||||||
with locked(name):
|
with locked():
|
||||||
return read(name)
|
return read(name)
|
||||||
|
|
||||||
def save(name, data):
|
def save(name, data):
|
||||||
with locked(name):
|
with locked():
|
||||||
write(name, data)
|
write(name, data)
|
||||||
|
|
|
@ -29,10 +29,11 @@ def run(fun, args=()):
|
||||||
def save_config():
|
def save_config():
|
||||||
output = None
|
output = None
|
||||||
try:
|
try:
|
||||||
# Just load the settings here but don’t lock the database while we load
|
# Just load the settings here but keep the database unlocked
|
||||||
# stuff from LDAP.
|
# while we load group memberships from LDAP.
|
||||||
settings = db.load('settings')
|
with db.locked():
|
||||||
groups = db.load('groups')
|
settings = db.read('settings')
|
||||||
|
groups = db.read('groups')
|
||||||
|
|
||||||
# For each user build a list of networks they have access to, based on
|
# For each user build a list of networks they have access to, based on
|
||||||
# group membership in AD. Only query groups associated with at least one
|
# group membership in AD. Only query groups associated with at least one
|
||||||
|
@ -49,13 +50,13 @@ def save_config():
|
||||||
|
|
||||||
# Now read the settings again and lock the database while generating
|
# Now read the settings again and lock the database while generating
|
||||||
# config files, then increment version before unlocking.
|
# config files, then increment version before unlocking.
|
||||||
with db.locked('settings'):
|
with db.locked():
|
||||||
settings = db.read('settings')
|
settings = db.read('settings')
|
||||||
version = settings['version'] = int(settings.get('version', 0)) + 1
|
version = settings['version'] = int(settings.get('version', 0)) + 1
|
||||||
|
|
||||||
# Populate IP sets and translation maps for NAT.
|
# Populate IP sets and translation maps for NAT.
|
||||||
ipsets = collections.defaultdict(set)
|
ipsets = collections.defaultdict(set)
|
||||||
networks = db.load('networks')
|
networks = db.read('networks')
|
||||||
for name, network in networks.items():
|
for name, network in networks.items():
|
||||||
for ip in network.get('ip', ()):
|
for ip in network.get('ip', ()):
|
||||||
ipsets[name].add(ip)
|
ipsets[name].add(ip)
|
||||||
|
@ -63,12 +64,12 @@ def save_config():
|
||||||
ipsets[f'{name}6'].update(ip6)
|
ipsets[f'{name}6'].update(ip6)
|
||||||
|
|
||||||
# Load static and dynamic NAT translations.
|
# Load static and dynamic NAT translations.
|
||||||
nat = db.load('nat') # { network name: public range… }
|
nat = db.read('nat') # { network name: public range… }
|
||||||
netmap = db.load('netmap') # { private range: public range… }
|
netmap = db.read('netmap') # { private range: public range… }
|
||||||
|
|
||||||
# Add registered VPN addresses for each network based on
|
# Add registered VPN addresses for each network based on
|
||||||
# LDAP group membership.
|
# LDAP group membership.
|
||||||
wireguard = db.load('wireguard')
|
wireguard = db.read('wireguard')
|
||||||
for ip, key in wireguard.items():
|
for ip, key in wireguard.items():
|
||||||
for network in user_networks.get(key.get('user', ''), ()):
|
for network in user_networks.get(key.get('user', ''), ()):
|
||||||
ipsets[network].add(f'{ip}/32')
|
ipsets[network].add(f'{ip}/32')
|
||||||
|
@ -116,7 +117,7 @@ map {name} {{
|
||||||
|
|
||||||
# Print forwarding rules.
|
# Print forwarding rules.
|
||||||
with open(f'{output}/etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f:
|
with open(f'{output}/etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f:
|
||||||
for forward in db.load('forwards'):
|
for forward in db.read('forwards'):
|
||||||
print(forward, file=f)
|
print(forward, file=f)
|
||||||
|
|
||||||
# Print wireguard config.
|
# Print wireguard config.
|
||||||
|
|
|
@ -40,7 +40,7 @@ def new():
|
||||||
text=True, capture_output=True, shell=True).stdout.strip()
|
text=True, capture_output=True, shell=True).stdout.strip()
|
||||||
|
|
||||||
host = ipaddress.ip_interface(settings.get('wg_net', '10.0.0.1/24'))
|
host = ipaddress.ip_interface(settings.get('wg_net', '10.0.0.1/24'))
|
||||||
with db.locked('wireguard'):
|
with db.locked():
|
||||||
# Find a free address for the new key.
|
# Find a free address for the new key.
|
||||||
ips = db.read('wireguard')
|
ips = db.read('wireguard')
|
||||||
for ip in host.network.hosts():
|
for ip in host.network.hosts():
|
||||||
|
@ -88,7 +88,7 @@ def delete():
|
||||||
return flask.Response('invalid key', status=400, mimetype='text/plain')
|
return flask.Response('invalid key', status=400, mimetype='text/plain')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with db.locked('wireguard'):
|
with db.locked():
|
||||||
user = flask_login.current_user.get_id()
|
user = flask_login.current_user.get_id()
|
||||||
ips = {k: v for k, v in db.read('wireguard').items() if v.get('user') != user or v.get('key') != pubkey}
|
ips = {k: v for k, v in db.read('wireguard').items() if v.get('user') != user or v.get('key') != pubkey}
|
||||||
db.write('wireguard', ips)
|
db.write('wireguard', ips)
|
||||||
|
|
Loading…
Reference in a new issue