Timotej Lazar
3c25cbe88a
Custom keys are created by admin and specify networks directly, bypassing AD permissions. They are intended to join managed devices into networks where users are not allowed to create keys themselves. Also comprehend a set directly.
240 lines
11 KiB
Python
240 lines
11 KiB
Python
#!/usr/bin/python3
|
|
|
|
import collections
|
|
import email.message
|
|
import getpass
|
|
import multiprocessing
|
|
import os
|
|
import pathlib
|
|
import shutil
|
|
import smtplib
|
|
import socket
|
|
import subprocess
|
|
import syslog
|
|
|
|
import click
|
|
import flask
|
|
import flask.cli
|
|
import ldap3
|
|
|
|
from . import db
|
|
|
|
def init_app(app):
|
|
app.cli.add_command(generate)
|
|
app.cli.add_command(push)
|
|
|
|
def mail(rcpt, subject, body):
|
|
try:
|
|
msg = email.message.EmailMessage()
|
|
msg['Subject'] = f'friwall: {subject}'
|
|
msg['From'] = f'{getpass.getuser()}@{socket.getfqdn()}'
|
|
msg['To'] = rcpt
|
|
msg.set_content(body)
|
|
with smtplib.SMTP('localhost') as server:
|
|
server.send_message(msg)
|
|
except Exception as e:
|
|
syslog.syslog(f'error sending mail: {e}')
|
|
|
|
def run(fun, args=()):
|
|
def task():
|
|
if os.fork() == 0:
|
|
os.setsid()
|
|
fun(*args)
|
|
multiprocessing.Process(target=task).start()
|
|
|
|
# Generate configuration files and create a config tarball.
|
|
def save_config():
|
|
output = None
|
|
try:
|
|
# Just load required settings here but keep the database unlocked
|
|
# while we load group memberships from LDAP.
|
|
with db.locked():
|
|
ipsets = db.read('ipsets')
|
|
settings = db.read('settings')
|
|
|
|
# Build LDAP query for users and groups.
|
|
filters = [
|
|
'(objectClass=user)', # only users
|
|
'(objectCategory=person)', # that are people
|
|
'(!(userAccountControl:1.2.840.113556.1.4.803:=2))', # with enabled accounts
|
|
]
|
|
if group := settings.get('user_group'):
|
|
filters += [f'(memberOf:1.2.840.113556.1.4.1941:={group})'] # in given group, recursively
|
|
|
|
# Run query and store group membership data.
|
|
server = ldap3.Server(settings['ldap_host'], use_ssl=True)
|
|
ldap = ldap3.Connection(server, settings['ldap_user'], settings['ldap_pass'], auto_bind=True)
|
|
ldap.search(settings.get('ldap_base_dn', ''),
|
|
f'(&{"".join(filters)})', # conjuction (&(…)(…)(…)) of queries
|
|
attributes=['userPrincipalName', 'memberOf'])
|
|
user_groups = { e.userPrincipalName.value: set(e.memberOf) for e in ldap.entries }
|
|
|
|
# Now read the settings again while keeping the database locked until
|
|
# config files are generated, and increment version before unlocking.
|
|
with db.locked():
|
|
ipsets = db.read('ipsets')
|
|
wireguard = db.read('wireguard')
|
|
settings = db.read('settings')
|
|
version = settings['version'] = int(settings.get('version') or '0') + 1
|
|
|
|
# Find networks accessible to VPN users for each AD group.
|
|
vpn_groups = {e['vpn'] for e in ipsets.values() if e.get('vpn')}
|
|
group_networks = {
|
|
group: [name for name, data in ipsets.items() if data['vpn'] == group] for group in vpn_groups
|
|
}
|
|
|
|
# Add VPN addresses to IP sets.
|
|
for ip, key in wireguard.items():
|
|
# Find all networks this IP should belong to:
|
|
# - manually specified networks for custom keys,
|
|
# - networks accessible to any of the user’s groups.
|
|
key_networks = set(key.get('networks', ()))
|
|
for group in user_groups.get(key.get('user', ''), ()):
|
|
key_networks |= set(group_networks.get(group, ()))
|
|
for network in key_networks:
|
|
ipsets[network]['ip'].append(f'{ip}/32')
|
|
if ip6 := key.get('ip6'):
|
|
ipsets[network]['ip6'].append(ip6)
|
|
|
|
# Create config files.
|
|
output = pathlib.Path.home() / 'config' / f'{version}'
|
|
shutil.rmtree(output, ignore_errors=True)
|
|
os.makedirs(output / 'etc/nftables.d', exist_ok=True)
|
|
os.makedirs(output / 'etc/wireguard', exist_ok=True)
|
|
|
|
# Print version.
|
|
with open(output / 'version', 'w', encoding='utf-8') as f:
|
|
f.write(f'{version}')
|
|
|
|
# Print nftables sets.
|
|
with open(output / 'etc/nftables.d/sets.nft', 'w', encoding='utf-8') as f:
|
|
nft_set = 'set {name} {{\n type ipv4_addr; flags interval; {ips}\n}}\n'
|
|
nft_set6 = 'set {name}/6 {{\n type ipv6_addr; flags interval; {ips}\n}}\n'
|
|
def make_set(ips):
|
|
# return "elements = { ip1, ip2, … }", prefixed with "# " if no ips
|
|
return f'{"" if ips else "# "}elements = {{ {", ".join(ips)} }}'
|
|
for name, data in ipsets.items():
|
|
f.write(nft_set.format(name=name, ips=make_set(data.get('ip', ()))))
|
|
f.write(nft_set6.format(name=name, ips=make_set(data.get('ip6', ()))))
|
|
f.write('\n')
|
|
|
|
# Print static NAT (1:1) rules.
|
|
with open(output / 'etc/nftables.d/netmap.nft', 'w', encoding='utf-8') as f:
|
|
nft_map = 'map {name} {{\n type ipv4_addr : interval ipv4_addr; flags interval; elements = {{\n{ips}\n }}\n}}\n'
|
|
def make_map(ips, reverse=False):
|
|
# return "{ from1: to1, from2: to2, … }" with possibly reversed from and to
|
|
return ',\n'.join(f"{b if reverse else a}: {a if reverse else b}" for a, b in ips)
|
|
if netmap := db.read('netmap'): # { private range: public range… }
|
|
f.write(nft_map.format(name='netmap-out', ips=make_map(netmap.items())))
|
|
f.write('\n')
|
|
f.write(nft_map.format(name='netmap-in', ips=make_map(netmap.items(), reverse=True)))
|
|
|
|
# Print dynamic NAT rules.
|
|
with open(output / 'etc/nftables.d/nat.nft', 'w', encoding='utf-8') as f:
|
|
nft_nat = 'iif @inside oif @outside ip saddr @{name} snat to {nat}\n'
|
|
for name, data in ipsets.items():
|
|
if nat := data.get('nat'):
|
|
f.write(nft_nat.format(name=name, nat=nat))
|
|
|
|
# Print forwarding rules.
|
|
with open(output / 'etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f:
|
|
# Forwarding rules for VPN users.
|
|
if vpn_networks := sorted(name for name, data in ipsets.items() if data.get('vpn')):
|
|
nft_forward = 'iif @inside oif @inside ip saddr @{name} ip daddr @{name} accept\n'
|
|
nft_forward6 = 'iif @inside oif @inside ip6 saddr @{name}/6 ip6 daddr @{name}/6 accept\n'
|
|
f.write('# forward from the VPN interface to physical networks and back\n')
|
|
for name in vpn_networks:
|
|
f.write(nft_forward.format(name=name))
|
|
for name in vpn_networks:
|
|
f.write(nft_forward6.format(name=name))
|
|
f.write('\n')
|
|
|
|
# Custom forwarding rules.
|
|
nft_rule = '# {index}. {name}\n{text}\n\n'
|
|
for index, rule in enumerate(db.read('rules')):
|
|
if rule.get('enabled') and rule.get('text'):
|
|
f.write(nft_rule.format(index=index, name=rule.get('name', ''), text=rule['text']))
|
|
|
|
# Print wireguard config.
|
|
with open(output / 'etc/wireguard/wg.conf', 'w', encoding='utf-8') as f:
|
|
# Server configuration.
|
|
wg_intf = '[Interface]\nListenPort = {port}\nPrivateKey = {key}\n\n'
|
|
f.write(wg_intf.format(port=settings.get('wg_port') or 51820, key=settings.get('wg_key')))
|
|
|
|
# Client configuration.
|
|
wg_peer = '# {user}\n[Peer]\nPublicKey = {key}\nAllowedIPs = {ips}\n\n'
|
|
for ip, data in wireguard.items():
|
|
f.write(wg_peer.format(
|
|
user=data.get('user'),
|
|
key=data.get('key'),
|
|
ips=', '.join(filter(None, [ip, data.get('ip6')]))))
|
|
|
|
# Make a temporary config archive and move it to the final location,
|
|
# so we avoid sending incomplete tars.
|
|
tar_file = shutil.make_archive(f'{output}-tmp', 'gztar', root_dir=output, owner='root', group='root')
|
|
os.rename(tar_file, f'{output}.tar.gz')
|
|
|
|
# If we get here, write settings with the new version.
|
|
db.write('settings', settings)
|
|
return True
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
e.add_note(f'exception while generating config: {e}')
|
|
msg = traceback.format_exc()
|
|
if rcpt := settings.get('admin_mail'):
|
|
mail(rcpt, 'error generating config', msg)
|
|
# TODO this doesn’t seem to work
|
|
#syslog.syslog(msg)
|
|
return False
|
|
|
|
finally:
|
|
# Remove temporary directory.
|
|
if output:
|
|
shutil.rmtree(output, ignore_errors=True)
|
|
|
|
@click.command('generate')
|
|
@flask.cli.with_appcontext
|
|
def generate():
|
|
save_config()
|
|
|
|
@click.command('push')
|
|
@click.option('--version', '-v', type=click.INT, default=None, help="Config version to push")
|
|
@flask.cli.with_appcontext
|
|
def push(version=None):
|
|
try:
|
|
with db.locked('nodes'):
|
|
if version is None:
|
|
version = db.load('settings').get('version', 0)
|
|
|
|
nodes = db.read('nodes')
|
|
tar_file = pathlib.Path.home() / 'config' / f'{version}.tar.gz'
|
|
|
|
errors = []
|
|
for node, node_version in nodes.items():
|
|
if node_version != version:
|
|
try:
|
|
# Push config tarfile to node. There sshd runs a forced command that
|
|
# reads in a tarball, copies files to /etc and reloads services.
|
|
syslog.syslog(f'updating config for {node} from v{node_version} to v{version}')
|
|
result = subprocess.run(['/usr/bin/ssh', '-T', '-o', 'ConnectTimeout=10', f'root@{node}'],
|
|
stdin=open(tar_file), capture_output=True, text=True)
|
|
if result.returncode == 0:
|
|
nodes[node] = version
|
|
db.write('nodes', nodes)
|
|
syslog.syslog(f'successfully updated config for {node} to v{version}')
|
|
else:
|
|
raise RuntimeError(f'error updating config to v{version}: {result.stderr}')
|
|
except (FileNotFoundError, RuntimeError) as e:
|
|
e.add_note(f'error while updating node {node}')
|
|
errors.append(e)
|
|
if errors:
|
|
raise ExceptionGroup('errors while updating nodes', errors)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
msg = traceback.format_exc()
|
|
if rcpt := db.load('settings').get('admin_mail'):
|
|
mail(rcpt, 'error updating nodes', msg)
|
|
syslog.syslog(msg)
|