#!/usr/bin/python3
# SPDX-FileCopyrightText: 2019-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

from __future__ import annotations

import os
import subprocess
import sys
import time
from typing import TYPE_CHECKING

import click
from dns.resolver import NXDOMAIN, Resolver
from ldap.dn import escape_dn_chars
from ldap.filter import filter_format

from univention.config_registry import handler_set, handler_unset, ucr
from univention.udm import UDM, NoObject


if TYPE_CHECKING:
    from univention.udm.base import BaseModule, BaseObject


_forward_zones = None


@click.command(context_settings={"help_option_names": ['-h', '--help']})
@click.argument("fqdn")
@click.argument("port")
@click.option("--aliases", multiple=True, help="Additional FQDNs for this vhost entry.")
@click.option("--conffile", multiple=True, help="Path to a file with additional configuration of vhost entry.")
@click.option('--binddn', help="DN of account to use to write to LDAP (e.g. uid=Administrator,cn=users,..).")
@click.option('--bindpwdfile', help="File containing password of user provided by --binddn.")
@click.option('--ssl', is_flag=True, help="Whether this vhost should have SSL enabled (cert, private-key, ca will have default values; 443 is ssl by default).")
@click.option('--cert', help="Path to SSL certificate.")
@click.option('--private-key', help="Path to SSL private key.")
@click.option('--ca', help="Path to SSL CA.")
@click.option("--remove", is_flag=True, help="Remove previously created aliases and UCR variables.")
@click.option("--dont-reload-services", is_flag=True, help="DO NOT reload local services although necessary, i.e., Apache and Bind.")
def main(fqdn: str, port: int, aliases: list[str], conffile: list[str], binddn: str, bindpwdfile: str, ssl: bool, cert: str, private_key: str, ca: str, remove: bool, dont_reload_services: bool) -> None:
    """
    Create an Apache vhost entry (and DNS alias) with hostname
    FQDN on port PORT.

    FQDN: Fully qualified domain name the vhost should be created for.

    PORT: Port, usually 80 or 443.
    """
    if os.geteuid() != 0:
        click.echo(click.style("This script must be executed as root.", fg="red"), err=True)
        sys.exit(1)
    if ucr["server/role"] in ("domaincontroller_master", "domaincontroller_backup"):
        binddn = password = None
    else:
        if not binddn:
            username = click.prompt("Username to use for LDAP connection").strip()
            mod = UDM.machine().version(1).get("users/user")
            for obj in mod.search(filter_format("uid=%s", (username,))):
                binddn = obj.dn
                break
            else:
                click.echo(
                    click.style(f"Cannot find DN for username '{username}'.", fg="red"),
                    err=True,
                )
                sys.exit(1)
        if bindpwdfile:
            with click.open_file(bindpwdfile, "r") as fp:
                password = fp.read().strip()
        else:
            password = click.prompt(f"Password for '{binddn}'", hide_input=True).strip()

    needs_apache_reload = False
    needs_dns_reload = False
    dns_should_resolve = None

    if not binddn:
        udm = UDM.admin().version(1)
    else:
        server = ucr["ldap/master"]
        server_port = ucr["ldap/master/port"]
        udm = UDM.credentials(binddn, password, server=server, port=server_port).version(1)

    if port == "443":
        ssl = True
    if remove:
        aliases = ucr.get(f"apache2/vhosts/{fqdn}/{port}/aliases")
        aliases = aliases.split(",") if aliases else []
        needs_apache_reload |= unset_ucr_vars(fqdn, port)
    else:
        needs_apache_reload |= set_ucr_vars(udm, fqdn, port, ssl, aliases, conffile, cert, private_key, ca)

    ucr.load()

    for host in [fqdn, *list(aliases)]:
        if remove:
            ns = f"apache2/vhosts/{fqdn}"
            if not any(key.startswith(ns) for key in ucr.keys()):
                needs_dns_reload |= remove_dns_entry(udm, host)
                dns_should_resolve = False
        else:
            needs_dns_reload |= create_dns_entry(udm, host)
            dns_should_resolve = True

    if dont_reload_services:
        if needs_apache_reload or needs_dns_reload:
            click.echo(click.style("Please now reload the DNS and the web servers:", bold=True))
        if needs_apache_reload:
            click.echo(click.style("$ service apache2 reload", bold=True))
        if needs_dns_reload:
            click.echo(click.style("$ nscd -i hosts", bold=True))
            click.echo(click.style("$ service named reload", bold=True))
    else:
        if needs_apache_reload:
            click.echo('Reloading Apache ...')
            if subprocess.call(['service', 'apache2', 'reload']) != 0:
                click.echo(click.style(
                    "Reload failed",
                    fg="red"),
                )
        if needs_dns_reload:
            click.echo('Reloading Bind ...')
            subprocess.call(['nscd', '-i', 'hosts'])
            subprocess.call(['service', 'named', 'reload'])
            click.echo(f'Checking on {fqdn}...')
            # normally, we would need to check for aliases as well
            # but checking the most important one should be good enough
            timeout = 60
            pause = 5
            while timeout > 0:
                if resolve_host(fqdn, pause) == dns_should_resolve:
                    break
                time.sleep(pause)
                timeout -= pause
            else:
                click.echo(click.style(
                    "DNS reconfiguration not complete.",
                    fg="red"),
                )
                click.echo(f"`host {fqdn}` did not respond as expected after {timeout} seconds.")
                click.echo(click.style("You may want to try", bold=True))
                click.echo(click.style("$ nscd -i hosts", bold=True))
                click.echo(click.style("$ service named reload", bold=True))


def resolve_host(hostname, timeout):
    resolver = Resolver()
    resolver.lifetime = timeout
    try:
        resolver.query(hostname)
    except NXDOMAIN:
        return False
    except Exception as exc:
        click.echo(click.style(
            f"DNS query resulted in {exc}.",
            fg="yellow"),
        )
        # should not happen, retry anyway
        return None
    else:
        return True


def get_wildcard_certificate(udm: UDM) -> None:
    service = 'Wildcard Certificate'
    hostobj = udm.obj_by_dn(ucr['ldap/hostdn'])
    services = udm.get('settings/service')
    try:
        services.get_by_id(service)
    except NoObject:
        position = 'cn=services,cn=univention,{}'.format(ucr['ldap/base'])
        service_obj = services.new(superordinate=position)
        service_obj.props.name = service
        service_obj.position = position
        service_obj.save()
    if service not in hostobj.props.service:
        hostobj.props.service.append(service)
        hostobj.save()
    subprocess.check_call(['univention-fetch-certificate', '*.' + ucr['hostname'], ucr['ldap/master']])
    print('')


def forward_zones(udm: UDM) -> list[BaseObject]:
    global _forward_zones
    if not _forward_zones:
        dns_forward_zone_mod: BaseModule = udm.get("dns/forward_zone")
        _forward_zones = sorted(
            dns_forward_zone_mod.search(),
            key=lambda x: len(x.props.zone),
            reverse=True,
        )
    return _forward_zones


def superordinate_of_fqdn(udm: UDM, fqdn: str) -> BaseObject | None:
    known_zones = forward_zones(udm)
    for zone in known_zones:
        if fqdn.endswith(zone.props.zone):
            return zone


def host_obj(udm: UDM, hostname: str, superordinate: BaseObject) -> BaseObject | None:
    obj_dn = f"relativeDomainName={escape_dn_chars(hostname)},{superordinate.dn}"
    try:
        return udm.obj_by_dn(obj_dn)
    except NoObject:
        pass


def create_dns_entry(udm: UDM, fqdn: str, alias_target: str | None = None) -> BaseObject | None:
    if not alias_target:
        alias_target = "{hostname}.{domainname}".format(**ucr)
    if fqdn in [alias_target, '*']:
        return False
    click.echo(f"Creating DNS alias for '{fqdn}'...")
    dns_alias_mod: BaseModule = udm.get("dns/alias")
    superordinate = superordinate_of_fqdn(udm, fqdn)
    if not superordinate:
        click.echo(click.style(
            f"'{fqdn}' is not part of any of the hosted DNS zones. Not creating an alias.",
            fg="yellow"))
        return False
    alias_name = fqdn.replace(superordinate.props.zone, "").rstrip(".")
    # check for existing dns/alias or dns/host
    if host_obj(udm, alias_name, superordinate):
        click.echo(click.style(f"Alias/Host '{fqdn}' exists.", fg="green"))
        return False
    alias_obj: BaseObject = dns_alias_mod.new(superordinate=superordinate)
    alias_obj.props.name = alias_name
    alias_obj.props.cname = "{}.".format(alias_target.rstrip("."))
    alias_obj.save()
    click.echo(click.style(f"Created DNS alias '{fqdn}' -> '{alias_target}'.", fg="green"))
    return True


def remove_dns_entry(udm: UDM, fqdn: str) -> None:
    alias_target = "{hostname}.{domainname}".format(**ucr)
    if fqdn in [alias_target, '*']:
        return False
    click.echo(f"Deleting DNS alias for '{fqdn}'...")
    superordinate = superordinate_of_fqdn(udm, fqdn)
    if not superordinate:
        click.echo(click.style(f"'{fqdn}' is not part of any of the hosted DNS zones.", fg="yellow"))
        return False
    alias_name = fqdn.replace(superordinate.props.zone, "").rstrip(".")
    obj = host_obj(udm, alias_name, superordinate)
    if not obj:
        click.echo(click.style(f"Alias '{fqdn}' does not exit (anymore).", fg="yellow"))
        return False
    udm_module = obj._udm_module.name
    if udm_module == "dns/alias":
        obj.delete()
        click.echo(click.style(f"Deleted DNS alias '{fqdn}'.", fg="green"))
        return True
    else:
        click.echo(click.style(
            f"Not deleting '{fqdn}': it is not an alias, but of type '{udm_module}'!",
            fg="red"),
        )
        return False


def set_ucr_vars(udm: str, fqdn: str, port: int, ssl: bool, aliases: list[str] | None = None, path: str | None = None, cert: str | None = None, private_key: str | None = None, ca: str | None = None) -> None:
    click.echo("Setting UCR variables for Apache vhost configuration...")
    ucrvs = []
    ns = f"apache2/vhosts/{fqdn}/{port}"
    for key in ucr.keys():
        if key.startswith(ns):
            click.echo(click.style(f"UCR values for '{ns}' exist. Not setting UCR values.", fg="yellow"))
            return False
    ucrvs.extend([
        f"{ns}/enabled=true",
        "{}/aliases={}".format(ns, ",".join(aliases or [])),
    ])
    if ssl:
        alias_target = "{hostname}.{domainname}".format(**ucr)
        if fqdn in [alias_target, '*']:
            if cert is None:
                cert = ucr.get("apache2/ssl/certificate", "/etc/univention/ssl/{}.{}/cert.pem".format(ucr["hostname"], ucr["domainname"]))
            if private_key is None:
                private_key = ucr.get("apache2/ssl/key", "/etc/univention/ssl/{}.{}/private.key".format(ucr["hostname"], ucr["domainname"]))
            if ca is None:
                ca = ucr.get("apache2/ssl/ca", "/etc/univention/ssl/ucsCA/CAcert.pem")
        elif fqdn.endswith(alias_target):
            if not os.path.exists("/etc/univention/ssl/*.{}.{}/cert.pem".format(ucr["hostname"], ucr["domainname"])):
                get_wildcard_certificate(udm)
            if cert is None:
                cert = "/etc/univention/ssl/*.{}.{}/cert.pem".format(ucr["hostname"], ucr["domainname"])
            if private_key is None:
                private_key = "/etc/univention/ssl/*.{}.{}/private.key".format(ucr["hostname"], ucr["domainname"])
            if ca is None:
                ca = ucr.get("apache2/ssl/ca", "/etc/univention/ssl/ucsCA/CAcert.pem")
    if cert and private_key and ca:
        ucrvs.extend([
            f"{ns}/ssl/certificate={cert}",
            f"{ns}/ssl/key={private_key}",
            f"{ns}/ssl/ca={ca}",
        ])
    if path:
        ucrvs.append("{}/files={}".format(ns, ','.join(path)))
    handler_set(ucrvs)
    click.echo(click.style("Done setting UCR variables.", fg="green"))
    return True


def unset_ucr_vars(fqdn, port):
    click.echo("Unsetting UCR variables...")
    ns = f"apache2/vhosts/{fqdn}/{port}"
    ucrvs = [key for key in ucr.keys() if key.startswith(ns)]
    if not ucrvs:
        return False
    handler_unset(ucrvs)
    click.echo(click.style("Done unsetting UCR variables.", fg="green"))
    return True


if __name__ == '__main__':
    main()
