from ipaddress import ip_network

from bundlewrap.exceptions import NoSuchNode
from bundlewrap.metadata import atomic


defaults = {
    'apt': {
        'packages': {
            'wireguard': {},
        },
        'repos': {
            'backports': {
                'install_gpg_key': False, # default debian signing key
                'items': {
                    'deb http://deb.debian.org/debian {os_release}-backports main',
                },
            },
        },
    },
    'wireguard': {
        'privatekey': repo.libs.keys.gen_privkey(repo, f'{node.name} wireguard privatekey'),
    },
}

if node.has_bundle('telegraf'):
    defaults['telegraf'] = {
        'input_plugins': {
            'builtin': {
                'wireguard': [{}],
            },
        },
        'additional_capabilities': {
            'CAP_NET_ADMIN',
        },
    }



@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_psks(metadata):
    peers = {}

    for peer_name in metadata.get('wireguard/peers', {}):
        peers[peer_name] = {}

        if node.name < peer_name:
            peers[peer_name] = {
                'psk': repo.vault.random_bytes_as_base64_for(f'{node.name} wireguard {peer_name}'),
            }
        else:
            peers[peer_name] = {
                'psk': repo.vault.random_bytes_as_base64_for(f'{peer_name} wireguard {node.name}'),
            }

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_pubkeys(metadata):
    peers = {}

    for peer_name in metadata.get('wireguard/peers', {}):
        try:
            rnode = repo.get_node(peer_name)
        except NoSuchNode:
            continue

        peers[peer_name] = {
            'pubkey': repo.libs.keys.get_pubkey_from_privkey(
                repo,
                f'{rnode.name} wireguard pubkey',
                rnode.metadata.get('wireguard/privatekey'),
            ),
        }

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_ips_and_ports(metadata):
    peers = {}
    base_port = 51820

    for number, peer_name in enumerate(sorted(metadata.get('wireguard/peers', {}).keys())):
        try:
            rnode = repo.get_node(peer_name)
        except NoSuchNode:
            continue

        ip_a, ip_b = repo.libs.s2s.get_subnet_for_connection(repo, *sorted({node.name, peer_name}))

        if peer_name < node.name:
            my_ip = ip_a
            their_ip = ip_b
        else:
            my_ip = ip_b
            their_ip = ip_a

        peers[rnode.name] = {
            'my_ip': str(my_ip),
            'my_port': base_port + number,
            'their_ip': str(their_ip)
        }

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'wireguard/peers',
)
def peer_endpoints(metadata):
    peers = {}

    for name, config in metadata.get('wireguard/peers', {}).items():
        try:
            rnode = repo.get_node(name)
        except NoSuchNode:
            continue


        peers[rnode.name] = {
            'endpoint': '{}:{}'.format(
                rnode.metadata.get('wireguard/external_hostname', rnode.hostname),
                rnode.metadata.get(f'wireguard/peers/{node.name}/my_port', 51820),
            ),
        }

    return {
        'wireguard': {
            'peers': peers,
        },
    }


@metadata_reactor.provides(
    'icinga2_api/wireguard/services',
)
def icinga2(metadata):
    services = {}

    for number, (peer, config) in enumerate(sorted(metadata.get('wireguard/peers', {}).items())):
        if config.get('exclude_from_monitoring', False):
            continue

        services[f'WIREGUARD CONNECTION {peer}'] = {
            'command_on_monitored_host': config['pubkey'].format_into(f'sudo /usr/local/share/icinga/plugins/check_wireguard_connected wg{number} {{}}'),
        }

    return {
        'icinga2_api': {
            'wireguard': {
                'services': services,
            },
        },
    }


@metadata_reactor.provides(
    'firewall/port_rules',
)
def firewall(metadata):
    ports = {}
    for name, config in metadata.get('wireguard/peers').items():
        try:
            rnode = repo.get_node(name)
        except NoSuchNode:  # roadwarrior
            ports['{}/udp'.format(config['my_port'])] = atomic(set(metadata.get('wireguard/restrict-to', set())))
        else:
            ports['{}/udp'.format(config['my_port'])] = atomic({name})

    return {
        'firewall': {
            'port_rules': ports,
        },
    }


@metadata_reactor.provides(
    'interfaces',
)
def interface_ips(metadata):
    interfaces = {}
    for number, (peer, config) in enumerate(sorted(metadata.get('wireguard/peers', {}).items())):
        if '/' in config['my_ip']:
            my_ip = config['my_ip']
        else:
            my_ip = '{}/31'.format(config['my_ip'])
        interfaces[f'wg{number}'] = {
            'ips': {
                my_ip,
            },
        }
    return {
        'interfaces': interfaces,
    }


@metadata_reactor.provides(
    'nftables/rules/10-wireguard',
)
def snat(metadata):
    if not node.has_bundle('nftables') or node.os == 'arch':
        raise DoNotRunAgain

    rules = set()
    for number, (peer, config) in enumerate(sorted(metadata.get('wireguard/peers', {}).items())):
        rules.add(f'inet filter forward iif wg{number} accept')
        rules.add(f'inet filter forward oif wg{number} accept')

        if 'snat_to' in config:
            rules.add('nat postrouting ip saddr {} ip daddr != {} snat to {}'.format(
                config['my_ip'],
                config['their_ip'],
                config['snat_to'],
            ))

    return {
        'nftables': {
            'rules': {
                '10-wireguard': sorted(rules),
            },
        },
    }