245 lines
11 KiB
Python
245 lines
11 KiB
Python
#!/usr/bin/python3
|
|
|
|
import collections
|
|
import email.message
|
|
import getpass
|
|
import multiprocessing
|
|
import os
|
|
import pathlib
|
|
import shutil
|
|
import smtplib
|
|
import socket
|
|
import subprocess
|
|
import syslog
|
|
|
|
import click
|
|
import flask
|
|
import flask.cli
|
|
import ldap3
|
|
|
|
from . import db
|
|
from . import ipsets
|
|
|
|
def init_app(app):
|
|
app.cli.add_command(generate)
|
|
app.cli.add_command(push)
|
|
|
|
def mail(rcpt, subject, body):
|
|
try:
|
|
msg = email.message.EmailMessage()
|
|
msg['Subject'] = f'friwall: {subject}'
|
|
msg['From'] = f'{getpass.getuser()}@{socket.getfqdn()}'
|
|
msg['To'] = rcpt
|
|
msg.set_content(body)
|
|
with smtplib.SMTP('localhost') as server:
|
|
server.send_message(msg)
|
|
except Exception as e:
|
|
syslog.syslog(f'error sending mail: {e}')
|
|
|
|
def run(fun, args=()):
|
|
def task():
|
|
if os.fork() == 0:
|
|
os.setsid()
|
|
fun(*args)
|
|
multiprocessing.Process(target=task).start()
|
|
|
|
# Generate configuration files and create a config tarball.
|
|
def save_config():
|
|
output = None
|
|
try:
|
|
# Just load required settings here but keep the database unlocked
|
|
# while we load group memberships from LDAP.
|
|
with db.locked():
|
|
settings = db.read('settings')
|
|
|
|
# Build LDAP query for users and groups.
|
|
filters = [
|
|
'(objectClass=user)', # only users
|
|
'(objectCategory=person)', # that are people
|
|
'(!(userAccountControl:1.2.840.113556.1.4.803:=2))', # with enabled accounts
|
|
]
|
|
if group := settings.get('user_group'):
|
|
filters += [f'(memberOf:1.2.840.113556.1.4.1941:={group})'] # in given group, recursively
|
|
|
|
# Run query and store group membership data.
|
|
server = ldap3.Server(settings['ldap_host'], use_ssl=True)
|
|
ldap = ldap3.Connection(server, settings['ldap_user'], settings['ldap_pass'], auto_bind=True)
|
|
ldap.search(settings.get('ldap_base_dn', ''),
|
|
f'(&{"".join(filters)})', # conjuction (&(…)(…)(…)) of queries
|
|
attributes=['userPrincipalName', 'memberOf'])
|
|
user_groups = { e.userPrincipalName.value: set(e.memberOf) for e in ldap.entries }
|
|
|
|
# Now read the settings again while keeping the database locked until
|
|
# config files are generated, and increment version before unlocking.
|
|
with db.locked():
|
|
sets = ipsets.read()
|
|
wireguard = db.read('wireguard')
|
|
settings = db.read('settings')
|
|
version = settings['version'] = int(settings.get('version') or '0') + 1
|
|
|
|
# Find networks accessible to VPN users for each AD group.
|
|
vpn_groups = {e['vpn'] for e in sets.values() if e.get('vpn')}
|
|
group_networks = {
|
|
group: [name for name, data in sets.items() if data.get('vpn') == group] for group in vpn_groups
|
|
}
|
|
|
|
# Add VPN addresses to IP sets.
|
|
for ip, key in wireguard.items():
|
|
# Find all networks this IP should belong to:
|
|
# - manually specified networks for custom keys,
|
|
# - networks accessible to any of the user’s groups.
|
|
key_networks = set(key.get('networks', ()))
|
|
for group in user_groups.get(key.get('user', ''), ()):
|
|
key_networks |= set(group_networks.get(group, ()))
|
|
for network in key_networks:
|
|
sets[network]['ip'].append(f'{ip}/32')
|
|
if ip6 := key.get('ip6'):
|
|
sets[network]['ip6'].append(ip6)
|
|
|
|
# Create config files.
|
|
output = pathlib.Path.home() / 'config' / f'{version}'
|
|
shutil.rmtree(output, ignore_errors=True)
|
|
os.makedirs(output / 'etc/nftables.d', exist_ok=True)
|
|
os.makedirs(output / 'etc/wireguard', exist_ok=True)
|
|
|
|
# Print version.
|
|
with open(output / 'version', 'w', encoding='utf-8') as f:
|
|
f.write(f'{version}')
|
|
|
|
# Print nftables sets.
|
|
with open(output / 'etc/nftables.d/sets.nft', 'w', encoding='utf-8') as f:
|
|
nft_set = 'set {name} {{\n type ipv4_addr; flags interval; {ips}\n}}\n'
|
|
nft_set6 = 'set {name}/6 {{\n type ipv6_addr; flags interval; {ips}\n}}\n'
|
|
def make_set(ips):
|
|
# return "elements = { ip1, ip2, … }", prefixed with "# " if no ips
|
|
return f'{"" if ips else "# "}elements = {{ {", ".join(ips)} }}'
|
|
for name, data in sets.items():
|
|
f.write(nft_set.format(name=name, ips=make_set(data.get('ip', ()))))
|
|
f.write(nft_set6.format(name=name, ips=make_set(data.get('ip6', ()))))
|
|
f.write('\n')
|
|
|
|
# Print static NAT (1:1) rules.
|
|
with open(output / 'etc/nftables.d/netmap.nft', 'w', encoding='utf-8') as f:
|
|
nft_map = 'map {name} {{\n type ipv4_addr : interval ipv4_addr; flags interval; elements = {{\n{ips}\n }}\n}}\n'
|
|
def make_map(ips, reverse=False):
|
|
# return "{ from1: to1, from2: to2, … }" with possibly reversed from and to
|
|
return ',\n'.join(f"{b if reverse else a}: {a if reverse else b}" for a, b in ips)
|
|
if netmap := db.read('netmap'): # { private range: public range… }
|
|
f.write(nft_map.format(name='netmap-out', ips=make_map(netmap.items())))
|
|
f.write('\n')
|
|
f.write(nft_map.format(name='netmap-in', ips=make_map(netmap.items(), reverse=True)))
|
|
|
|
# Print dynamic NAT rules.
|
|
with open(output / 'etc/nftables.d/nat.nft', 'w', encoding='utf-8') as f:
|
|
no_nat_set = settings.get('no_nat_set')
|
|
nft_nat = 'iif @inside oif @outside ip saddr @{name}'
|
|
if no_nat_set:
|
|
# don’t NAT for these destination addresses
|
|
nft_nat += ' ip daddr != @{no_nat_set}'
|
|
nft_nat += ' snat to {nat}\n'
|
|
for name, data in sets.items():
|
|
if nat := data.get('nat'):
|
|
f.write(nft_nat.format(name=name, nat=nat, no_nat_set=no_nat_set))
|
|
|
|
# Print forwarding rules.
|
|
with open(output / 'etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f:
|
|
# Forwarding rules for VPN users.
|
|
if vpn_networks := sorted(name for name, data in sets.items() if data.get('vpn')):
|
|
nft_forward = 'iif @inside oif @inside ip saddr @{name} ip daddr @{name} accept\n'
|
|
nft_forward6 = 'iif @inside oif @inside ip6 saddr @{name}/6 ip6 daddr @{name}/6 accept\n'
|
|
f.write('# forward from the VPN interface to physical networks and back\n')
|
|
for name in vpn_networks:
|
|
f.write(nft_forward.format(name=name))
|
|
for name in vpn_networks:
|
|
f.write(nft_forward6.format(name=name))
|
|
f.write('\n')
|
|
|
|
# Custom forwarding rules.
|
|
nft_rule = '# {index}. {name}\n{text}\n\n'
|
|
for index, rule in enumerate(db.read('rules')):
|
|
if rule.get('enabled') and rule.get('text'):
|
|
f.write(nft_rule.format(index=index, name=rule.get('name', ''), text=rule['text']))
|
|
|
|
# Print wireguard config.
|
|
with open(output / 'etc/wireguard/wg.conf', 'w', encoding='utf-8') as f:
|
|
# Server configuration.
|
|
wg_intf = '[Interface]\nListenPort = {port}\nPrivateKey = {key}\n\n'
|
|
f.write(wg_intf.format(port=settings.get('wg_port') or 51820, key=settings.get('wg_key')))
|
|
|
|
# Client configuration.
|
|
wg_peer = '# {user}\n[Peer]\nPublicKey = {key}\nAllowedIPs = {ips}\n\n'
|
|
for ip, data in wireguard.items():
|
|
f.write(wg_peer.format(
|
|
user=data.get('user'),
|
|
key=data.get('key'),
|
|
ips=', '.join(filter(None, [ip, data.get('ip6')]))))
|
|
|
|
# Make a temporary config archive and move it to the final location,
|
|
# so we avoid sending incomplete tars.
|
|
tar_file = shutil.make_archive(f'{output}-tmp', 'gztar', root_dir=output, owner='root', group='root')
|
|
os.rename(tar_file, f'{output}.tar.gz')
|
|
|
|
# If we get here, write settings with the new version.
|
|
db.write('settings', settings)
|
|
return True
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
e.add_note(f'exception while generating config: {e}')
|
|
msg = traceback.format_exc()
|
|
if rcpt := settings.get('admin_mail'):
|
|
mail(rcpt, 'error generating config', msg)
|
|
# TODO this doesn’t seem to work
|
|
#syslog.syslog(msg)
|
|
return False
|
|
|
|
finally:
|
|
# Remove temporary directory.
|
|
if output:
|
|
shutil.rmtree(output, ignore_errors=True)
|
|
|
|
@click.command('generate')
|
|
@flask.cli.with_appcontext
|
|
def generate():
|
|
save_config()
|
|
|
|
@click.command('push')
|
|
@click.option('--version', '-v', type=click.INT, default=None, help="Config version to push")
|
|
@flask.cli.with_appcontext
|
|
def push(version=None):
|
|
try:
|
|
with db.locked('nodes'):
|
|
if version is None:
|
|
version = db.load('settings').get('version', 0)
|
|
|
|
nodes = db.read('nodes')
|
|
tar_file = pathlib.Path.home() / 'config' / f'{version}.tar.gz'
|
|
|
|
errors = []
|
|
for node, node_version in nodes.items():
|
|
if node_version != version:
|
|
try:
|
|
# Push config tarfile to node. There sshd runs a forced command that
|
|
# reads in a tarball, copies files to /etc and reloads services.
|
|
syslog.syslog(f'updating config for {node} from v{node_version} to v{version}')
|
|
result = subprocess.run(['/usr/bin/ssh', '-T', '-o', 'ConnectTimeout=10', f'root@{node}'],
|
|
stdin=open(tar_file), capture_output=True, text=True)
|
|
if result.returncode == 0:
|
|
nodes[node] = version
|
|
db.write('nodes', nodes)
|
|
syslog.syslog(f'successfully updated config for {node} to v{version}')
|
|
else:
|
|
raise RuntimeError(f'error updating config to v{version}: {result.stderr}')
|
|
except (FileNotFoundError, RuntimeError) as e:
|
|
e.add_note(f'error while updating node {node}')
|
|
errors.append(e)
|
|
if errors:
|
|
raise ExceptionGroup('errors while updating nodes', errors)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
msg = traceback.format_exc()
|
|
if rcpt := db.load('settings').get('admin_mail'):
|
|
mail(rcpt, 'error updating nodes', msg)
|
|
syslog.syslog(msg)
|