friwall/web/system.py

206 lines
7.7 KiB
Python
Raw Normal View History

2022-01-03 10:33:02 +00:00
#!/usr/bin/python3
import collections
import multiprocessing
import os
import pathlib
2022-01-03 10:33:02 +00:00
import shutil
import subprocess
import syslog
import click
import flask
import flask.cli
import ldap3
from . import db
def init_app(app):
app.cli.add_command(generate)
app.cli.add_command(push)
def run(fun, args=()):
def task():
if os.fork() == 0:
os.setsid()
fun(*args)
multiprocessing.Process(target=task).start()
def save_config():
output = None
try:
# Just load the settings here but don’t lock the database while we load
# stuff from LDAP.
settings = db.load('settings')
groups = db.load('groups')
# For each user build a list of networks they have access to, based on
# group membership in AD. Only query groups associated with at least one
# network, and query each group only once.
user_networks = collections.defaultdict(set)
2022-01-03 10:33:02 +00:00
ldap = ldap3.Connection(ldap3.Server(settings.get('ldap_host'), use_ssl=True),
settings.get('ldap_user'), settings.get('ldap_pass'), auto_bind=True)
for group, network in groups.items():
2022-01-03 10:33:02 +00:00
ldap.search(settings.get('ldap_base_dn', ''),
f'(distinguishedName={group})', attributes='member')
if ldap.entries:
for user in ldap.entries[0]['member']:
user_networks[user].add(network)
2022-01-03 10:33:02 +00:00
# Now read the settings again and lock the database while generating
# config files, then increment version before unlocking.
with db.locked('settings'):
settings = db.read('settings')
version = settings['version'] = int(settings.get('version', 0)) + 1
# Populate IP sets and translation maps for NAT.
2022-01-03 10:33:02 +00:00
ipsets = collections.defaultdict(set)
networks = db.load('networks')
for name, network in networks.items():
for ip in network.get('ip', ()):
ipsets[name].add(ip)
for ip6 in network.get('ip6', ()):
ipsets[f'{name}6'].update(ip6)
2023-05-18 13:44:10 +00:00
# Load static and dynamic NAT translations.
nat = db.load('nat') # { network name: public range… }
netmap = db.load('netmap') # { private range: public range… }
# Add registered VPN addresses for each network based on
# LDAP group membership.
wireguard = db.load('wireguard')
2022-01-03 10:33:02 +00:00
for ip, key in wireguard.items():
for network in user_networks.get(key.get('user', ''), ()):
ipsets[network].add(f'{ip}/32')
if 'ip6' in key:
ipsets[f'{network}6'].add(f'{key["ip6"]}/128')
2022-01-03 10:33:02 +00:00
# Create config files.
output = pathlib.Path.home() / 'config' / f'{version}'
2022-01-03 10:33:02 +00:00
shutil.rmtree(output, ignore_errors=True)
os.makedirs(f'{output}/etc/nftables.d', exist_ok=True)
os.makedirs(f'{output}/etc/wireguard', exist_ok=True)
2023-05-18 13:44:10 +00:00
# Print nftables set for wireguard IPs.
with open(f'{output}/etc/nftables.d/sets.nft', 'w', encoding='utf-8') as f:
2022-01-03 10:33:02 +00:00
def format_set(name, ips):
return f'''\
set {name} {{
type ipv4_addr; flags interval
2022-01-03 10:33:02 +00:00
elements = {{ {', '.join(ips)} }}
}}'''
2023-04-07 20:51:38 +00:00
for name, ips in ipsets.items():
if not name.endswith('6'):
print(format_set(name, ips), file=f)
2023-05-18 13:44:10 +00:00
# Print static NAT (1:1) rules.
with open(f'{output}/etc/nftables.d/netmap.nft', 'w', encoding='utf-8') as f:
def format_map(name, elements):
lines = ',\n'.join(f'{a}: {b}' for a, b in elements)
return f'''\
map {name} {{
type ipv4_addr : interval ipv4_addr; flags interval
elements = {{
{lines}
}}
}}
'''
if netmap:
print(format_map('netmap-out', ((private, public) for private, public in netmap.items())), file=f)
print(format_map('netmap-in', ((public, private) for private, public in netmap.items())), file=f)
2022-01-03 10:33:02 +00:00
2023-05-18 13:44:10 +00:00
# Print dynamic NAT rules.
with open(f'{output}/etc/nftables.d/nat.nft', 'w', encoding='utf-8') as f:
for network, address in nat.items():
print(f'iifname @inside oifname @outside ip saddr @{network} snat to {address}', file=f)
2022-01-03 10:33:02 +00:00
# Print forwarding rules.
with open(f'{output}/etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f:
for forward in db.load('forwards'):
print(forward, file=f)
2022-01-03 10:33:02 +00:00
# Print wireguard config.
with open(f'{output}/etc/wireguard/wg.conf', 'w', encoding='utf-8') as f:
print(f'''\
[Interface]
ListenPort = {settings.get('wg_port', 51820)}
PrivateKey = {settings.get('wg_key')}
''', file=f)
for ip, data in wireguard.items():
print(f'''\
# {data.get('user')}
[Peer]
PublicKey = {data.get('key')}
AllowedIPs = {ip}
''', file=f)
2022-01-03 10:33:02 +00:00
# Make a config archive in a temporary place, so we don’t send
# incomplete tars.
tar_file = shutil.make_archive(f'{output}-tmp', 'gztar', root_dir=output, owner='root', group='root')
2022-01-03 10:33:02 +00:00
# Move config archive to the final destination.
os.rename(tar_file, f'{output}.tar.gz')
2022-01-03 10:33:02 +00:00
# If we get here, write settings with the new version.
db.write('settings', settings)
return True
except Exception as e:
syslog.syslog(f'exception while generating config: {e}')
import traceback
with open('/tmp/wtflog', 'a+') as f:
traceback.print_exc(file=f)
return False
finally:
# Remove temporary directory.
if output:
shutil.rmtree(output, ignore_errors=True)
@click.command('generate')
@flask.cli.with_appcontext
def generate():
save_config()
@click.command('push')
@click.option('--version', '-v', type=click.INT, default=None, help="Config version to push")
@flask.cli.with_appcontext
def push(version=None):
try:
with db.locked('nodes'):
2023-04-06 08:04:30 +00:00
if version is None:
version = db.load('settings').get('version', 0)
# Write wanted version to file for uploading to firewall nodes.
version_file = pathlib.Path.home() / 'config' / 'version'
with open(version_file, 'w') as f:
2023-04-06 08:04:30 +00:00
print(version, file=f)
2022-01-03 10:33:02 +00:00
nodes = db.read('nodes')
tar_file = pathlib.Path.home() / 'config' / f'{version}.tar.gz'
2022-01-03 10:33:02 +00:00
done = True
for node, node_version in nodes.items():
if node_version != version:
if not os.path.exists(tar_file):
2022-01-03 10:33:02 +00:00
syslog.syslog(f'wanted to push version {version} but {version}.tar.gz doesn’t exist')
return
# Push config tarfile.
syslog.syslog(f'updating {node} from {node_version} to {version}')
result = subprocess.run([f'sftp -o ConnectTimeout=10 root@{node}'],
shell=True, text=True, capture_output=True,
input=f'put {tar_file}\nput {version_file}\n')
2022-01-03 10:33:02 +00:00
if result.returncode == 0:
nodes[node] = version
db.write('nodes', nodes)
else:
syslog.syslog(f'error updating node {node}: {result.stderr}')
done = False
return done
except Exception as e:
import traceback
with open('/tmp/wtflog', 'a+') as f:
traceback.print_exc(file=f)
return False