from ipaddress import ip_network
from re import sub

from bundlewrap.exceptions import NoSuchNode
from bundlewrap.metadata import atomic

defaults = {
    'apt': {
        'packages': {
            'wireguard': {},
        },
    },
    'wireguard': {
        'privatekey': repo.libs.keys.gen_privkey(repo, f'{node.name} wireguard privatekey'),
    },
}

if node.os_version <= (11,):
    defaults['apt']['repos'] = {
        'backports': {
            'install_gpg_key': False, # default debian signing key
            'items': {
                'deb http://deb.debian.org/debian {os_release}-backports main',
            },
        },
    }
else:
    defaults['apt']['packages']['wireguard-dkms'] = {
        'installed': False,
    }

if node.has_bundle('telegraf'):
    defaults['telegraf'] = {
        'input_plugins': {
            'builtin': {
                'wireguard': [{}],
            },
        },
        'additional_capabilities': {
            'CAP_NET_ADMIN',
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peers_auto_full_mesh(metadata):
    peers = {}

    if node.name not in repo.libs.s2s.WG_AUTOGEN_NODES:
        return {}

    for rnode in repo.libs.s2s.WG_AUTOGEN_NODES:
        if rnode is None or rnode == node.name:
            continue

        try:
            rnode = repo.get_node(rnode)
        except NoSuchNode:
            continue

        if rnode.dummy:
            continue

        peers[rnode.name] = {}

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_psks(metadata):
    peers = {}

    for peer_name in metadata.get('wireguard/peers', {}):
        peers[peer_name] = {
            'iface': sub('[^a-z0-9-_]+', '_', peer_name)[:12],
        }

        try:
            repo.get_node(peer_name)

            if node.name < peer_name:
                peers[peer_name]['psk'] = repo.vault.random_bytes_as_base64_for(f'{node.name} wireguard {peer_name}')
            else:
                peers[peer_name]['psk'] = repo.vault.random_bytes_as_base64_for(f'{peer_name} wireguard {node.name}')
        except NoSuchNode:
            pass

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_pubkeys(metadata):
    peers = {}

    for peer_name in metadata.get('wireguard/peers', {}):
        try:
            rnode = repo.get_node(peer_name)
        except NoSuchNode:
            continue

        peers[peer_name] = {
            'pubkey': repo.libs.keys.get_pubkey_from_privkey(
                repo,
                f'{rnode.name} wireguard pubkey',
                rnode.metadata.get('wireguard/privatekey'),
            ),
        }

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_ips_and_ports(metadata):
    if node.name not in repo.libs.s2s.WG_AUTOGEN_NODES:
        raise DoNotRunAgain

    peers = {}
    base_port = 51820

    for number, peer_name in enumerate(sorted(metadata.get('wireguard/peers', {}).keys())):
        try:
            rnode = repo.get_node(peer_name)
        except NoSuchNode:
            continue

        if rnode.name not in repo.libs.s2s.WG_AUTOGEN_NODES:
            continue

        ip_a, ip_b = repo.libs.s2s.get_subnet_for_connection(repo, *sorted({node.name, peer_name}))

        if peer_name < node.name:
            my_ip = ip_a
            their_ip = ip_b
        else:
            my_ip = ip_b
            their_ip = ip_a

        peers[rnode.name] = {
            'my_ip': str(my_ip),
            'my_port': base_port + repo.libs.s2s.WG_AUTOGEN_NODES.index(rnode.name),
            'their_ip': str(their_ip)
        }

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_endpoints(metadata):
    peers = {}

    for name, config in metadata.get('wireguard/peers', {}).items():
        try:
            rnode = repo.get_node(name)
        except NoSuchNode:
            continue

        if repo.libs.s2s.WG_AUTOGEN_SETTINGS.get(name, {}).get('no_autoconnect'):
            continue

        peers[rnode.name] = {
            'endpoint': '{}:{}'.format(
                rnode.hostname,
                rnode.metadata.get(f'wireguard/peers/{node.name}/my_port'),
            ),
        }

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'icinga2_api/wireguard/services',
)
def icinga2(metadata):
    services = {}

    for peer, config in sorted(metadata.get('wireguard/peers', {}).items()):
        if config.get('exclude_from_monitoring', False):
            continue

        services[f'WIREGUARD CONNECTION {peer}'] = {
            'command_on_monitored_host': config['pubkey'].format_into(f'sudo /usr/local/share/icinga/plugins/check_wireguard_connected wg_{config["iface"]} {{}}'),
        }

    return {
        'icinga2_api': {
            'wireguard': {
                'services': services,
            },
        },
    }


@metadata_reactor.provides(
    'firewall/port_rules',
)
def firewall(metadata):
    ports = {}
    for name, config in metadata.get('wireguard/peers').items():
        try:
            rnode = repo.get_node(name)
        except NoSuchNode:  # roadwarrior
            ports['{}/udp'.format(config['my_port'])] = atomic(set(metadata.get('wireguard/restrict-to', set())))
        else:
            ports['{}/udp'.format(config['my_port'])] = atomic(
                set(repo.libs.s2s.WG_AUTOGEN_SETTINGS.get(name, {}).get('firewall', set())) | {name}
            )

    return {
        'firewall': {
            'port_rules': ports,
        },
    }


@metadata_reactor.provides(
    'interfaces',
)
def interface_ips(metadata):
    interfaces = {}
    snat_ip = metadata.get('wireguard/snat_ip', None)

    for peer, config in sorted(metadata.get('wireguard/peers', {}).items()):
        routes = {}

        if '/' in config['my_ip']:
            my_ip = config['my_ip']
        else:
            my_ip = '{}/31'.format(config['my_ip'])

        ips = {my_ip}
        if snat_ip and peer in repo.libs.s2s.WG_AUTOGEN_NODES:
            ips.add(snat_ip)

        their_ip = config['their_ip']
        if '/' in their_ip:
            their_ip = their_ip.split('/')[0]

        for route in config.get('routes', set()):
            routes[route] = {'via': their_ip}

        interfaces[f'wg_{config["iface"]}'] = {
            'activation_policy': 'up' if config.get('auto_connection', True) else 'manual',
            'ips': ips,
            'routes': routes,
        }
    return {
        'interfaces': interfaces,
    }


@metadata_reactor.provides(
    'nftables/forward/10-wireguard',
    'nftables/postrouting/10-wireguard',
)
def snat(metadata):
    if not node.has_bundle('nftables'):
        raise DoNotRunAgain

    snat_ip = metadata.get('wireguard/snat_ip', None)

    forward = set()
    postrouting = set()
    for peer, config in sorted(metadata.get('wireguard/peers', {}).items()):
        forward.add(f'iifname wg_{config["iface"]} accept')
        forward.add(f'oifname wg_{config["iface"]} accept')

        if snat_ip and peer in repo.libs.s2s.WG_AUTOGEN_NODES:
            postrouting.add('ip saddr {} ip daddr != {} snat to {}'.format(
                config['my_ip'],
                config['their_ip'],
                snat_ip,
            ))
        elif config.get('masquerade', False):
            postrouting.add(f'oifname wg_{peer} masquerade')

    return {
        'nftables': {
            'forward': {
                '10-wireguard': sorted(forward),
            },
            'postrouting': {
                '10-wireguard': sorted(postrouting),
            },
        },
    }