from ipaddress import ip_network
from re import sub

from bundlewrap.exceptions import NoSuchNode
from bundlewrap.metadata import atomic

defaults = {
    'apt': {
        'packages': {
            'wireguard': {},
        },
        'repos': {
            'backports': {
                'install_gpg_key': False, # default debian signing key
                'items': {
                    'deb http://deb.debian.org/debian {os_release}-backports main',
                },
            },
        },
    },
    'wireguard': {
        'privatekey': repo.libs.keys.gen_privkey(repo, f'{node.name} wireguard privatekey'),
    },
}

if node.has_bundle('telegraf'):
    defaults['telegraf'] = {
        'input_plugins': {
            'builtin': {
                'wireguard': [{}],
            },
        },
        'additional_capabilities': {
            'CAP_NET_ADMIN',
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peers_auto_full_mesh(metadata):
    peers = {}

    if node.name not in repo.libs.s2s.WG_AUTOGEN_NODES:
        return {}

    for rnode in repo.libs.s2s.WG_AUTOGEN_NODES:
        if rnode is None or rnode == node.name:
            continue

        try:
            rnode = repo.get_node(rnode)
        except NoSuchNode:
            continue

        if rnode.dummy:
            continue

        peers[rnode.name] = {}

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_psks(metadata):
    peers = {}

    for peer_name in metadata.get('wireguard/peers', {}):
        peers[peer_name] = {
            'iface': sub('[^a-z0-9-_]+', '_', peer_name)[:12],
        }

        if node.name < peer_name:
            peers[peer_name]['psk'] = repo.vault.random_bytes_as_base64_for(f'{node.name} wireguard {peer_name}')
        else:
            peers[peer_name]['psk'] = repo.vault.random_bytes_as_base64_for(f'{peer_name} wireguard {node.name}')

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_pubkeys(metadata):
    peers = {}

    for peer_name in metadata.get('wireguard/peers', {}):
        try:
            rnode = repo.get_node(peer_name)
        except NoSuchNode:
            continue

        peers[peer_name] = {
            'pubkey': repo.libs.keys.get_pubkey_from_privkey(
                repo,
                f'{rnode.name} wireguard pubkey',
                rnode.metadata.get('wireguard/privatekey'),
            ),
        }

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_ips_and_ports(metadata):
    if node.name not in repo.libs.s2s.WG_AUTOGEN_NODES:
        raise DoNotRunAgain

    peers = {}
    base_port = 51820

    for number, peer_name in enumerate(sorted(metadata.get('wireguard/peers', {}).keys())):
        try:
            rnode = repo.get_node(peer_name)
        except NoSuchNode:
            continue

        if rnode.name not in repo.libs.s2s.WG_AUTOGEN_NODES:
            continue

        ip_a, ip_b = repo.libs.s2s.get_subnet_for_connection(repo, *sorted({node.name, peer_name}))

        if peer_name < node.name:
            my_ip = ip_a
            their_ip = ip_b
        else:
            my_ip = ip_b
            their_ip = ip_a

        peers[rnode.name] = {
            'my_ip': str(my_ip),
            'my_port': base_port + repo.libs.s2s.WG_AUTOGEN_NODES.index(rnode.name),
            'their_ip': str(their_ip)
        }

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_endpoints(metadata):
    peers = {}

    for name, config in metadata.get('wireguard/peers', {}).items():
        try:
            rnode = repo.get_node(name)
        except NoSuchNode:
            continue


        peers[rnode.name] = {
            'endpoint': '{}:{}'.format(
                rnode.metadata.get('wireguard/external_hostname', rnode.hostname),
                rnode.metadata.get(f'wireguard/peers/{node.name}/my_port', 51820),
            ),
        }

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'icinga2_api/wireguard/services',
)
def icinga2(metadata):
    services = {}

    for peer, config in sorted(metadata.get('wireguard/peers', {}).items()):
        if config.get('exclude_from_monitoring', False):
            continue

        services[f'WIREGUARD CONNECTION {peer}'] = {
            'command_on_monitored_host': config['pubkey'].format_into(f'sudo /usr/local/share/icinga/plugins/check_wireguard_connected wg_{config["iface"]} {{}}'),
        }

    return {
        'icinga2_api': {
            'wireguard': {
                'services': services,
            },
        },
    }


@metadata_reactor.provides(
    'firewall/port_rules',
)
def firewall(metadata):
    ports = {}
    for name, config in metadata.get('wireguard/peers').items():
        try:
            rnode = repo.get_node(name)
        except NoSuchNode:  # roadwarrior
            ports['{}/udp'.format(config['my_port'])] = atomic(set(metadata.get('wireguard/restrict-to', set())))
        else:
            ports['{}/udp'.format(config['my_port'])] = atomic({name})

    return {
        'firewall': {
            'port_rules': ports,
        },
    }


@metadata_reactor.provides(
    'interfaces',
)
def interface_ips(metadata):
    interfaces = {}
    snat_ip = metadata.get('wireguard/snat_ip', None)

    for peer, config in sorted(metadata.get('wireguard/peers', {}).items()):
        routes = {}

        if '/' in config['my_ip']:
            my_ip = config['my_ip']
        else:
            my_ip = '{}/31'.format(config['my_ip'])

        ips = {my_ip}
        if snat_ip:
            ips.add(snat_ip)

        their_ip = config['their_ip']
        if '/' in their_ip:
            their_ip = their_ip.split('/')[0]

        for route in config.get('routes', set()):
            routes[route] = {'via': their_ip}

        interfaces[f'wg_{config["iface"]}'] = {
            'activation_policy': 'up' if config.get('auto_connection', True) else 'manual',
            'ips': ips,
            'routes': routes,
        }
    return {
        'interfaces': interfaces,
    }


@metadata_reactor.provides(
    'nftables/forward/10-wireguard',
    'nftables/postrouting/10-wireguard',
)
def snat(metadata):
    if not node.has_bundle('nftables') or node.os == 'arch':
        raise DoNotRunAgain

    snat_ip = metadata.get('wireguard/snat_ip', None)

    forward = set()
    postrouting = set()
    for peer, config in sorted(metadata.get('wireguard/peers', {}).items()):
        forward.add(f'iifname wg_{config["iface"]} accept')
        forward.add(f'oifname wg_{config["iface"]} accept')

        if snat_ip:
            postrouting.add('ip saddr {} ip daddr != {} snat to {}'.format(
                config['my_ip'],
                config['their_ip'],
                snat_ip,
            ))

    return {
        'nftables': {
            'forward': {
                '10-wireguard': sorted(forward),
            },
            'postrouting': {
                '10-wireguard': sorted(postrouting),
            },
        },
    }


@metadata_reactor.provides(
    'wireguard/health_checks',
    'systemd-timers/timers/wg-health-check',
)
def health_checks(metadata):
    checks = {}

    for peer, config in metadata.get('wireguard/peers', {}).items():
        if (
            config.get('exclude_from_monitoring', False)
            or not config.get('auto_connection', True)
            or 'endpoint' not in config
        ):
            continue

        checks[peer] = config['their_ip']

    if checks:
        timer = {
            'wg-health-check': {
                'command': '/usr/local/bin/wg_health_check',
                'when': 'minutely',
            },
        }
    else:
        timer = {}

    return {
        'systemd-timers': {
            'timers': timer,
        },
        'wireguard': {
            'health_checks': checks,
        },
    }