Clean up save_config
This commit is contained in:
parent
ff2246df8c
commit
bb68978b22
|
@ -10,7 +10,6 @@ import shutil
|
||||||
import smtplib
|
import smtplib
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
|
||||||
import syslog
|
import syslog
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
@ -57,6 +56,13 @@ def ipset_add(ipsets, name, ip=None, ip6=None):
|
||||||
ipsets[f'{name}/6'].update(ip6 or ())
|
ipsets[f'{name}/6'].update(ip6 or ())
|
||||||
|
|
||||||
def save_config():
|
def save_config():
|
||||||
|
# Format strings for creating firewall config files.
|
||||||
|
nft_set = 'set {name} {{\n type ipv{family}_addr; flags interval; {elements}\n}}\n\n'
|
||||||
|
nft_map = 'map {name} {{\n type ipv4_addr : interval ipv4_addr; flags interval; {elements}\n}}\n\n'
|
||||||
|
nft_forward = '# {index}. {name}\n{text}\n\n'
|
||||||
|
wg_intf = '[Interface]\nListenPort = {port}\nPrivateKey = {key}\n\n'
|
||||||
|
wg_peer = '# {user}\n[Peer]\nPublicKey = {key}\nAllowedIPs = {ips}\n\n'
|
||||||
|
|
||||||
output = None
|
output = None
|
||||||
try:
|
try:
|
||||||
# Just load the settings here but keep the database unlocked
|
# Just load the settings here but keep the database unlocked
|
||||||
|
@ -111,74 +117,57 @@ def save_config():
|
||||||
# Create config files.
|
# Create config files.
|
||||||
output = pathlib.Path.home() / 'config' / f'{version}'
|
output = pathlib.Path.home() / 'config' / f'{version}'
|
||||||
shutil.rmtree(output, ignore_errors=True)
|
shutil.rmtree(output, ignore_errors=True)
|
||||||
os.makedirs(f'{output}/etc/nftables.d', exist_ok=True)
|
os.makedirs(output / 'etc/nftables.d', exist_ok=True)
|
||||||
os.makedirs(f'{output}/etc/wireguard', exist_ok=True)
|
os.makedirs(output / 'etc/wireguard', exist_ok=True)
|
||||||
|
|
||||||
# Print version.
|
# Print version.
|
||||||
with open(f'{output}/version', 'w', encoding='utf-8') as f:
|
with open(output / 'version', 'w', encoding='utf-8') as f:
|
||||||
print(version, file=f)
|
f.write(f'{version}')
|
||||||
|
|
||||||
# Print nftables sets.
|
# Print nftables sets.
|
||||||
with open(f'{output}/etc/nftables.d/sets.nft', 'w', encoding='utf-8') as f:
|
with open(output / 'etc/nftables.d/sets.nft', 'w', encoding='utf-8') as f:
|
||||||
def format_set(name, ips):
|
|
||||||
return f'''\
|
|
||||||
set {name} {{
|
|
||||||
type {"ipv6_addr" if name.endswith('/6') else "ipv4_addr"}; flags interval
|
|
||||||
{"" if ips else "# "}elements = {{ {", ".join(ips)} }}
|
|
||||||
}}'''
|
|
||||||
for name, ips in ipsets.items():
|
for name, ips in ipsets.items():
|
||||||
print(format_set(name, ips), file=f)
|
f.write(nft_set.format(
|
||||||
|
name=name,
|
||||||
|
family='6' if name.endswith('/6') else '4',
|
||||||
|
elements=f'{"" if ips else "# "}elements = {{ {", ".join(ips)} }}'))
|
||||||
|
|
||||||
# Print static NAT (1:1) rules.
|
# Print static NAT (1:1) rules.
|
||||||
with open(f'{output}/etc/nftables.d/netmap.nft', 'w', encoding='utf-8') as f:
|
with open(output / 'etc/nftables.d/netmap.nft', 'w', encoding='utf-8') as f:
|
||||||
def format_map(name, elements):
|
|
||||||
lines = ',\n'.join(f'{a}: {b}' for a, b in elements)
|
|
||||||
return f'''\
|
|
||||||
map {name} {{
|
|
||||||
type ipv4_addr : interval ipv4_addr; flags interval
|
|
||||||
elements = {{ {lines} }}
|
|
||||||
}}
|
|
||||||
'''
|
|
||||||
netmap = db.read('netmap') # { private range: public range… }
|
netmap = db.read('netmap') # { private range: public range… }
|
||||||
if netmap:
|
if netmap:
|
||||||
print(format_map('netmap-out', ((private, public) for private, public in netmap.items())), file=f)
|
f.write(nft_map.format(
|
||||||
print(format_map('netmap-in', ((public, private) for private, public in netmap.items())), file=f)
|
name='netmap-out',
|
||||||
|
elements='elements = {' + ',\n'.join(f'{a}: {b}' for a, b in netmap.items()) + '}'))
|
||||||
|
f.write(nft_map.format(
|
||||||
|
name='netmap-in',
|
||||||
|
elements='elements = {' + ',\n'.join(f'{b}: {a}' for a, b in netmap.items()) + '}'))
|
||||||
|
|
||||||
# Print dynamic NAT rules.
|
# Print dynamic NAT rules.
|
||||||
with open(f'{output}/etc/nftables.d/nat.nft', 'w', encoding='utf-8') as f:
|
with open(output / 'etc/nftables.d/nat.nft', 'w', encoding='utf-8') as f:
|
||||||
nat = db.read('nat') # { network name: public range… }
|
nat = db.read('nat') # { network name: public range… }
|
||||||
for network, address in nat.items():
|
for network, address in nat.items():
|
||||||
if address:
|
if address:
|
||||||
print(f'iif @inside oif @outside ip saddr @{network} snat to {address}', file=f)
|
f.write(f'iif @inside oif @outside ip saddr @{network} snat to {address}\n')
|
||||||
|
|
||||||
# Print forwarding rules.
|
# Print forwarding rules.
|
||||||
with open(f'{output}/etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f:
|
with open(output / 'etc/nftables.d/forward.nft', 'w', encoding='utf-8') as f:
|
||||||
for index, rule in enumerate(db.read('rules')):
|
for index, rule in enumerate(db.read('rules')):
|
||||||
if rule.get('enabled') and rule.get('text'):
|
if rule.get('enabled') and rule.get('text'):
|
||||||
if 'name' in rule:
|
f.write(nft_forward.format(index=index, name=rule.get('name', ''), text=rule['text']))
|
||||||
print(f'# {index}. {rule["name"]}', file=f)
|
|
||||||
print(rule['text'], file=f)
|
|
||||||
print(file=f)
|
|
||||||
|
|
||||||
# Print wireguard config.
|
# Print wireguard config.
|
||||||
with open(f'{output}/etc/wireguard/wg.conf', 'w', encoding='utf-8') as f:
|
with open(output / 'etc/wireguard/wg.conf', 'w', encoding='utf-8') as f:
|
||||||
print(f'''\
|
f.write(wg_intf.format(
|
||||||
[Interface]
|
port=settings.get('wg_port', 51820),
|
||||||
ListenPort = {settings.get('wg_port', 51820)}
|
key=settings.get('wg_key')))
|
||||||
PrivateKey = {settings.get('wg_key')}
|
|
||||||
''', file=f)
|
|
||||||
for ip, data in wireguard.items():
|
for ip, data in wireguard.items():
|
||||||
print(f'''\
|
f.write(wg_peer.format(
|
||||||
# {data.get('user')}
|
user=data.get('user'),
|
||||||
[Peer]
|
key=data.get('key'),
|
||||||
PublicKey = {data.get('key')}
|
ips=', '.join(filter(None, [ip, data.get('ip6')]))))
|
||||||
AllowedIPs = {ip}
|
|
||||||
''', file=f)
|
|
||||||
if 'ip6' in data:
|
|
||||||
print(f'AllowedIPs = {data["ip6"]}', file=f)
|
|
||||||
|
|
||||||
# Make a config archive in a temporary place, so we don’t send
|
# Make a config archive in a temporary place, so we don’t send incomplete tars.
|
||||||
# incomplete tars.
|
|
||||||
tar_file = shutil.make_archive(f'{output}-tmp', 'gztar', root_dir=output, owner='root', group='root')
|
tar_file = shutil.make_archive(f'{output}-tmp', 'gztar', root_dir=output, owner='root', group='root')
|
||||||
|
|
||||||
# Move config archive to the final destination.
|
# Move config archive to the final destination.
|
||||||
|
@ -247,4 +236,3 @@ def push(version=None):
|
||||||
if rcpt := db.load('settings').get('admin_mail'):
|
if rcpt := db.load('settings').get('admin_mail'):
|
||||||
mail(rcpt, 'error updating nodes', msg)
|
mail(rcpt, 'error updating nodes', msg)
|
||||||
syslog.syslog(msg)
|
syslog.syslog(msg)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue