#!/usr/bin/env python3

from argparse import ArgumentParser
from json import dump
from os import environ, makedirs, remove, scandir
from os.path import abspath, dirname, join
from sys import exit

import bwpass
from requests import post

from bundlewrap.utils.text import bold, red, validate_name
from bundlewrap.utils.ui import io

TOKEN = environ.get("NETBOX_AUTH_TOKEN")
if not TOKEN:
    try:
        TOKEN = bwpass.attr("netbox.franzi.business/kunsi", "token")
    except Exception:
        print("NETBOX_AUTH_TOKEN missing")
        exit(1)

TARGET_PATH = join(dirname(dirname(abspath(__file__))), "configs", "netbox")

QUERY_SITES = """{
    site_list {
        name
        id
        vlans {
            name
            vid
        }
    }
}"""

QUERY_DEVICES = """{
    device_list(filters: {tag: "bundlewrap", site_id: "SITE_ID"}) {
        name
        id
    }
}"""

QUERY_DEVICE_DETAILS = """{
    device(id: DEVICE_ID) {
        name
        interfaces {
            id
            name
            enabled
            description
            mode
            type
            ip_addresses {
                address
            }
            untagged_vlan {
                name
            }
            tagged_vlans {
                name
            }
            link_peers {
                ... on InterfaceType {
                    name
                    device {
                        name
                    }
                }
                ... on FrontPortType {
                    name
                    device {
                        name
                    }
                }
            }
            connected_endpoints {
                ... on InterfaceType {
                    name
                    device {
                        name
                    }
                }
            }
        }
    }
}"""


def graphql(query):
    r = post(
        "https://netbox.franzi.business/graphql/",
        headers={
            "Accept": "application/json",
            "Authorization": f"Token {TOKEN}",
        },
        json={
            "query": query,
        },
    )
    r.raise_for_status()
    return r.json()["data"]


def filter_results(results, filter_by):
    if filter_by is None:
        return results

    out = []
    for result in results:
        if str(result["id"]) in filter_by or result["name"] in filter_by:
            out.append(result)
    return out


parser = ArgumentParser()
parser.add_argument("--only-site", nargs="+", type=str)
parser.add_argument("--only-device", nargs="+", type=str)
args = parser.parse_args()

try:
    io.activate()
    filenames_used = set()

    with io.job("getting sites"):
        sites = filter_results(
            graphql(QUERY_SITES).get("site_list", []), args.only_site
        )

    io.stdout(f"Processing {len(sites)} sites in total")

    for site in sites:
        with io.job(f"{bold(site['name'])}  getting devices"):
            devices = filter_results(
                graphql(QUERY_DEVICES.replace("SITE_ID", site["id"])).get(
                    "device_list", []
                ),
                args.only_device,
            )
        io.stdout(f"Site {bold(site['name'])} has {len(devices)} devices to process")

        for device in devices:
            if not device["name"] or not validate_name(device["name"]):
                # invalid node name, ignore
                continue

            with io.job(
                f"{bold(site['name'])}  {bold(device['name'])}  getting interfaces"
            ):
                details = graphql(
                    QUERY_DEVICE_DETAILS.replace("DEVICE_ID", device["id"])
                )["device"]

                result = {
                    "interfaces": {},
                    "vlans": site["vlans"],
                }

                for interface in details["interfaces"]:
                    peers = None

                    if interface["connected_endpoints"]:
                        peers = interface["connected_endpoints"]
                    elif interface["link_peers"]:
                        peers = interface["link_peers"]

                    if interface["description"]:
                        description = interface["description"]
                    elif peers:
                        peer_list = set()

                        for i in peers:
                            peer_list.add(
                                "{} ({})".format(
                                    i["device"]["name"],
                                    i["name"],
                                )
                            )

                        description = "; ".join(sorted(peer_list))
                    else:
                        description = ""

                    assert description.isascii()

                    result["interfaces"][interface["name"]] = {
                        "description": description,
                        "enabled": interface["enabled"],
                        "mode": interface["mode"],
                        "type": interface["type"],
                        "ips": sorted(
                            {i["address"] for i in interface["ip_addresses"]}
                        ),
                        "untagged_vlan": (
                            interface["untagged_vlan"]["name"]
                            if interface["untagged_vlan"]
                            else None
                        ),
                        "tagged_vlans": sorted(
                            {v["name"] for v in interface["tagged_vlans"]}
                        ),
                    }

            if result["interfaces"]:
                filename = f"{device['name']}.json"
                filenames_used.add(filename)
                file_with_path = join(TARGET_PATH, filename)

                with io.job(
                    f"{bold(site['name'])}  {bold(device['name'])}  writing to {file_with_path}"
                ):
                    with open(
                        file_with_path,
                        "w+",
                    ) as f:
                        dump(
                            result,
                            f,
                            indent=4,
                            sort_keys=True,
                        )
            else:
                io.stdout(
                    f"device {bold(device['name'])} has no interfaces, {red('not')} dumping!"
                )

    if not args.only_site and not args.only_device and filenames_used:
        with io.job(f"cleaning leftover files from {TARGET_PATH}"):
            for direntry in scandir(TARGET_PATH):
                filename = direntry.name
                if filename.startswith("."):
                    continue
                if not direntry.is_file():
                    io.stderr(
                        f"found non-file {filename} in {TARGET_PATH}, please check what's going on!"
                    )
                    continue
                if filename not in filenames_used:
                    remove(join(TARGET_PATH, filename))
finally:
    io.deactivate()