diff --git a/web/system.py b/web/system.py index 7c2298a..cf3fe93 100644 --- a/web/system.py +++ b/web/system.py @@ -34,17 +34,18 @@ def save_config(): settings = db.load('settings') groups = db.load('groups') - # Get users’ group membership from LDAP server. Only query the groups used - # by at least one network, and query each group just once. - user_groups = collections.defaultdict(set) + # For each user build a list of networks they have access to, based on + # group membership in AD. Only query groups associated with at least one + # network, and query each group only once. + user_networks = collections.defaultdict(set) ldap = ldap3.Connection(ldap3.Server(settings.get('ldap_host'), use_ssl=True), settings.get('ldap_user'), settings.get('ldap_pass'), auto_bind=True) - for group in groups: + for group, network in groups.items(): ldap.search(settings.get('ldap_base_dn', ''), f'(distinguishedName={group})', attributes='member') if ldap.entries: for user in ldap.entries[0]['member']: - user_groups[user].add(group) + user_networks[user].add(network) # Now read the settings again and lock the database while generating # config files, then increment version before unlocking. @@ -52,13 +53,26 @@ def save_config(): settings = db.read('settings') version = settings['version'] = int(settings.get('version', 0)) + 1 - # Populate IP sets. - wireguard = db.load('wireguard') + # Populate IP sets and translation maps for NAT. ipsets = collections.defaultdict(set) + networks = db.load('networks') + nat = {} + netmap = {} + for name, network in networks.items(): + for ip in network.get('ip', ()): + ipsets[name].add(ip) + if 'nat' in network: + nat[ip] = network['nat'] + for ip6 in network.get('ip6', ()): + ipsets[f'{name}6'].update(ip6) + netmap.update(network.get('netmap', {})) + + wireguard = db.load('wireguard') for ip, key in wireguard.items(): - for group in user_groups.get(key.get('user', ''), ()): - for name in groups[group]: - ipsets[name].add(f'{ip}/32') + for network in user_networks.get(key.get('user', ''), ()): + ipsets[network].add(f'{ip}/32') + if 'ip6' in key: + ipsets[f'{network}6'].add(f'{key["ip6"]}/128') # Create config files. output = pathlib.Path.home() / 'config' / f'{version}' @@ -68,44 +82,54 @@ def save_config(): # Add registered VPN addresses for each network based on # LDAP group membership. - with open(f'{output}/etc/nftables.d/sets-vpn.nft', 'w', encoding='utf-8') as f: + with open(f'{output}/etc/nftables.d/sets.nft', 'w', encoding='utf-8') as f: def format_set(name, ips): return f'''\ set {name} {{ - typeof ip daddr; flags interval + type ipv4_addr; flags interval elements = {{ {', '.join(ips)} }} }}''' for name, ips in ipsets.items(): - print(format_set(name, ips), file=f) + if not name.endswith('6'): + print(format_set(name, ips), file=f) + + # Print NAT (dynamic and 1:1) rules. + with open(f'{output}/etc/nftables.d/nat.nft', 'w', encoding='utf-8') as f: + def format_map(name, elements): + lines = ',\n'.join(f'{a}: {b}' for a, b in elements) + return f'''\ +map {name} {{ + type ipv4_addr : interval ipv4_addr; flags interval + elements = {{ +{lines} + }} +}} +''' + if nat: + print(format_map('nat', ((private, public) for private, public in nat.items())), file=f) + if netmap: + print(format_map('netmap-out', ((private, public) for private, public in netmap.items())), file=f) + print(format_map('netmap-in', ((public, private) for private, public in netmap.items())), file=f) # Print forwarding rules. with open(f'{output}/etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f: - def format_forward(src, dst): - rule = 'iifname @ifaces_inside oifname @ifaces_inside' - if src: - rule += f' ip saddr @{src}' - if dst: - rule += f' ip daddr @{dst}' - return rule + ' accept' - for src, dst in db.load('forwards'): - print(format_forward(src, dst), file=f) + for forward in db.load('forwards'): + print(forward, file=f) # Print wireguard config. with open(f'{output}/etc/wireguard/wg.conf', 'w', encoding='utf-8') as f: - def format_wg_peer(ip, data): - return f'''\ -# {data.get('user')} -[Peer] -PublicKey = {data.get('key')} -AllowedIPs = {ip} -''' print(f'''\ [Interface] ListenPort = {settings.get('wg_port', 51820)} PrivateKey = {settings.get('wg_key')} ''', file=f) - for ip, key in wireguard.items(): - print(format_wg_peer(ip, key), file=f) + for ip, data in wireguard.items(): + print(f'''\ +# {data.get('user')} +[Peer] +PublicKey = {data.get('key')} +AllowedIPs = {ip} +''', file=f) # Make a config archive in a temporary place, so we don’t send # incomplete tars. diff --git a/web/templates/index.html b/web/templates/index.html index e9e2b2e..981f7e7 100644 --- a/web/templates/index.html +++ b/web/templates/index.html @@ -4,6 +4,7 @@