Support IPv6 sets

Also some unrelated cleanups in system.save_config.
This commit is contained in:
Timotej Lazar 2023-05-29 12:59:57 +02:00
parent 765d4a3ce7
commit 6780f074c7

View file

@ -56,16 +56,9 @@ def save_config():
# Populate IP sets and translation maps for NAT. # Populate IP sets and translation maps for NAT.
ipsets = collections.defaultdict(set) ipsets = collections.defaultdict(set)
networks = db.read('networks') for name, network in db.read('networks').items():
for name, network in networks.items(): ipsets[name].update(network.get('ip', ()))
for ip in network.get('ip', ()): ipsets[f'{name}/6'].update(network.get('ip6', ()))
ipsets[name].add(ip)
for ip6 in network.get('ip6', ()):
ipsets[f'{name}6'].update(ip6)
# Load static and dynamic NAT translations.
nat = db.read('nat') # { network name: public range… }
netmap = db.read('netmap') # { private range: public range… }
# Add registered VPN addresses for each network based on # Add registered VPN addresses for each network based on
# LDAP group membership. # LDAP group membership.
@ -74,7 +67,7 @@ def save_config():
for network in user_networks.get(key.get('user', ''), ()): for network in user_networks.get(key.get('user', ''), ()):
ipsets[network].add(f'{ip}/32') ipsets[network].add(f'{ip}/32')
if 'ip6' in key: if 'ip6' in key:
ipsets[f'{network}6'].add(f'{key["ip6"]}/128') ipsets[f'{network}/6'].add(f'{key["ip6"]}/128')
# Create config files. # Create config files.
output = pathlib.Path.home() / 'config' / f'{version}' output = pathlib.Path.home() / 'config' / f'{version}'
@ -82,16 +75,15 @@ def save_config():
os.makedirs(f'{output}/etc/nftables.d', exist_ok=True) os.makedirs(f'{output}/etc/nftables.d', exist_ok=True)
os.makedirs(f'{output}/etc/wireguard', exist_ok=True) os.makedirs(f'{output}/etc/wireguard', exist_ok=True)
# Print nftables set for wireguard IPs. # Print nftables sets.
with open(f'{output}/etc/nftables.d/sets.nft', 'w', encoding='utf-8') as f: with open(f'{output}/etc/nftables.d/sets.nft', 'w', encoding='utf-8') as f:
def format_set(name, ips): def format_set(name, ips):
return f'''\ return f'''\
set {name} {{ set {name} {{
type ipv4_addr; flags interval type {"ipv6_addr" if name.endswith('/6') else "ipv4_addr"}; flags interval
elements = {{ {', '.join(ips)} }} elements = {{ {', '.join(ips)} }}
}}''' }}'''
for name, ips in ipsets.items(): for name, ips in ipsets.items():
if not name.endswith('6'):
print(format_set(name, ips), file=f) print(format_set(name, ips), file=f)
# Print static NAT (1:1) rules. # Print static NAT (1:1) rules.
@ -101,17 +93,17 @@ set {name} {{
return f'''\ return f'''\
map {name} {{ map {name} {{
type ipv4_addr : interval ipv4_addr; flags interval type ipv4_addr : interval ipv4_addr; flags interval
elements = {{ elements = {{ {lines} }}
{lines}
}}
}} }}
''' '''
netmap = db.read('netmap') # { private range: public range… }
if netmap: if netmap:
print(format_map('netmap-out', ((private, public) for private, public in netmap.items())), file=f) print(format_map('netmap-out', ((private, public) for private, public in netmap.items())), file=f)
print(format_map('netmap-in', ((public, private) for private, public in netmap.items())), file=f) print(format_map('netmap-in', ((public, private) for private, public in netmap.items())), file=f)
# Print dynamic NAT rules. # Print dynamic NAT rules.
with open(f'{output}/etc/nftables.d/nat.nft', 'w', encoding='utf-8') as f: with open(f'{output}/etc/nftables.d/nat.nft', 'w', encoding='utf-8') as f:
nat = db.read('nat') # { network name: public range… }
for network, address in nat.items(): for network, address in nat.items():
print(f'iif @inside oif @outside ip saddr @{network} snat to {address}', file=f) print(f'iif @inside oif @outside ip saddr @{network} snat to {address}', file=f)