203 lines
7.8 KiB
203 lines
7.8 KiB
import collections
import multiprocessing
import os
import pathlib
import shutil
import subprocess
import syslog
import click
import flask
import flask.cli
import ldap3
from . import db
def init_app(app):
def run(fun, args=()):
def task():
if os.fork() == 0:
def save_config():
output = None
# Just load the settings here but keep the database unlocked
# while we load group memberships from LDAP.
with db.locked():
settings = db.read('settings')
groups = db.read('groups')
# For each user build a list of networks they have access to, based on
# group membership in AD. Only query groups associated with at least one
# network, and query each group only once.
user_networks = collections.defaultdict(set)
ldap = ldap3.Connection(ldap3.Server(settings.get('ldap_host'), use_ssl=True),
settings.get('ldap_user'), settings.get('ldap_pass'), auto_bind=True)
for group, network in groups.items():
ldap.search(settings.get('ldap_base_dn', ''),
f'(distinguishedName={group})', attributes='member')
if ldap.entries:
for user in ldap.entries[0]['member']:
# Now read the settings again and lock the database while generating
# config files, then increment version before unlocking.
with db.locked():
settings = db.read('settings')
version = settings['version'] = int(settings.get('version', 0)) + 1
# Populate IP sets and translation maps for NAT.
ipsets = collections.defaultdict(set)
for name, network in db.read('networks').items():
ipsets[name].update(network.get('ip', ()))
ipsets[f'{name}/6'].update(network.get('ip6', ()))
# Add registered VPN addresses for each network based on
# LDAP group membership.
wireguard = db.read('wireguard')
for ip, key in wireguard.items():
for network in user_networks.get(key.get('user', ''), ()):
if 'ip6' in key:
# Create config files.
output = pathlib.Path.home() / 'config' / f'{version}'
shutil.rmtree(output, ignore_errors=True)
os.makedirs(f'{output}/etc/nftables.d', exist_ok=True)
os.makedirs(f'{output}/etc/wireguard', exist_ok=True)
# Print nftables sets.
with open(f'{output}/etc/nftables.d/sets.nft', 'w', encoding='utf-8') as f:
def format_set(name, ips):
return f'''\
set {name} {{
type {"ipv6_addr" if name.endswith('/6') else "ipv4_addr"}; flags interval
elements = {{ {', '.join(ips)} }}
for name, ips in ipsets.items():
print(format_set(name, ips), file=f)
# Print static NAT (1:1) rules.
with open(f'{output}/etc/nftables.d/netmap.nft', 'w', encoding='utf-8') as f:
def format_map(name, elements):
lines = ',\n'.join(f'{a}: {b}' for a, b in elements)
return f'''\
map {name} {{
type ipv4_addr : interval ipv4_addr; flags interval
elements = {{ {lines} }}
netmap = db.read('netmap') # { private range: public range… }
if netmap:
print(format_map('netmap-out', ((private, public) for private, public in netmap.items())), file=f)
print(format_map('netmap-in', ((public, private) for private, public in netmap.items())), file=f)
# Print dynamic NAT rules.
with open(f'{output}/etc/nftables.d/nat.nft', 'w', encoding='utf-8') as f:
nat = db.read('nat') # { network name: public range… }
for network, address in nat.items():
print(f'iif @inside oif @outside ip saddr @{network} snat to {address}', file=f)
# Print forwarding rules.
with open(f'{output}/etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f:
for index, rule in enumerate(db.read('rules')):
if rule.get('enabled') and rule.get('text'):
if 'name' in rule:
print(f'# {index}. {rule["name"]}', file=f)
print(rule['text'], file=f)
# Print wireguard config.
with open(f'{output}/etc/wireguard/wg.conf', 'w', encoding='utf-8') as f:
ListenPort = {settings.get('wg_port', 51820)}
PrivateKey = {settings.get('wg_key')}
''', file=f)
for ip, data in wireguard.items():
# {data.get('user')}
PublicKey = {data.get('key')}
AllowedIPs = {ip}
''', file=f)
# Make a config archive in a temporary place, so we don’t send
# incomplete tars.
tar_file = shutil.make_archive(f'{output}-tmp', 'gztar', root_dir=output, owner='root', group='root')
# Move config archive to the final destination.
os.rename(tar_file, f'{output}.tar.gz')
# If we get here, write settings with the new version.
db.write('settings', settings)
return True
except Exception as e:
syslog.syslog(f'exception while generating config: {e}')
import traceback
with open('/tmp/wtflog', 'a+') as f:
return False
# Remove temporary directory.
if output:
shutil.rmtree(output, ignore_errors=True)
def generate():
@click.option('--version', '-v', type=click.INT, default=None, help="Config version to push")
def push(version=None):
with db.locked('nodes'):
if version is None:
version = db.load('settings').get('version', 0)
# Write wanted version to file for uploading to firewall nodes.
version_file = pathlib.Path.home() / 'config' / 'version'
with open(version_file, 'w') as f:
print(version, file=f)
nodes = db.read('nodes')
tar_file = pathlib.Path.home() / 'config' / f'{version}.tar.gz'
done = True
for node, node_version in nodes.items():
if node_version != version:
if not os.path.exists(tar_file):
syslog.syslog(f'wanted to push version {version} but {version}.tar.gz doesn’t exist')
# Push config tarfile.
syslog.syslog(f'updating {node} from {node_version} to {version}')
result = subprocess.run([f'sftp -o ConnectTimeout=10 root@{node}'],
shell=True, text=True, capture_output=True,
input=f'put {tar_file}\nput {version_file}\n')
if result.returncode == 0:
nodes[node] = version
db.write('nodes', nodes)
syslog.syslog(f'error updating node {node}: {result.stderr}')
done = False
return done
except Exception as e:
import traceback
with open('/tmp/wtflog', 'a+') as f:
return False