#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-3.0-or-later
#
# fwctl - foomuuri firewall configuration manager
# Copyright (C) 2026  snix
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
fwctl - foomuuri (nftables firewall) configuration manager

Reads fw.yaml and generates /etc/foomuuri/fw.conf.
"""

import argparse
import difflib
import hashlib
import ipaddress
import json
import os
import re
import subprocess
import sys

try:
    import yaml
except ImportError:
    print("Error: PyYAML is required. Install with: pip install pyyaml", file=sys.stderr)
    sys.exit(1)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

VERSION = "3.0_r13"

FOOMUURI_DIR        = "/etc/foomuuri"
FOOMUURI_SHARE_DIR  = "/usr/share/foomuuri"
DEFAULT_CONFIG_PATHS  = ["/etc/fw.yaml", "fw.yaml"]

FWCTL_DEFAULTS = {
    "output":           "{}/fw.conf".format(FOOMUURI_DIR),
    "ulogd":            "/etc/ulogd.conf",
    "cmd":              "",
    "auto_ips":         "yes",
    "auto_forward":     "yes",
    "log_file":         "/var/log/fw.json",
    "exclude_services": "ping, ping6, dns, ntp, dhcp, igmp, ssdp, mdns, llmnr, netbios-ns",
}

FOOMUURI_SETTINGS_DEFAULTS = {
    "localhost_zone":    "localhost",
    "dbus_zone":         "public",
    "log_input":         "no",
    "log_output":        "no",
    "log_forward":       "no",
    "log_rpfilter":      "no",
    "log_invalid":       "no",
    "log_smurfs":        "no",
    "log_rate":          "1/second burst 3",
    "log_prefix":        "$(szone)-$(dzone) $(statement)",
    "log_level":         "level info flags skuid",
    "rpfilter":          "yes",
    "counter":           "no",
    "set_size":          "65535",
    "recursion_limit":   "10000",
    "priority_offset":   "5",
    "dbus_firewalld":    "no",
    "nft_bin":           "nft",
    "try-reload_timeout": "15",
}

PROTO_KEYWORDS = {"tcp", "udp", "icmp", "icmpv6", "protocol", "multicast", "broadcast"}


# ---------------------------------------------------------------------------
# Shared helpers (extracted from repeated patterns)
# ---------------------------------------------------------------------------

def _as_list(v):
    """Normalize a value to a list: list passes through, scalar wraps, falsy returns []."""
    return v if isinstance(v, list) else ([v] if v else [])


def _iter_pods(pods):
    """Iterate over pods dict, skipping the default_zones key."""
    for name, pod in pods.items():
        if name != "default_zones":
            yield name, pod


def _svc_matches(svc, name):
    """True if service token matches macro name (e.g. 'ssh' matches 'ssh' and 'ssh ipv4')."""
    return svc == name or svc.startswith(name + " ")


def _group_proto_tokens(tokens):
    """Group protocol tokens by keyword boundaries and join with '; '."""
    groups, current = [], []
    for token in tokens:
        if token in PROTO_KEYWORDS and current:
            groups.append(" ".join(current))
            current = [token]
        else:
            current.append(token)
    if current:
        groups.append(" ".join(current))
    return "; ".join(groups)


def _is_bidir_key(key):
    return " <-> " in key or " \u2194 " in key


def _forward_key_candidates(src, dst, bidir):
    """Return list of possible forward key strings (unicode first, ascii second)."""
    if bidir:
        return ["{} \u2194 {}".format(src, dst), "{} <-> {}".format(src, dst)]
    return ["{} \u2192 {}".format(src, dst), "{} -> {}".format(src, dst)]


# ---------------------------------------------------------------------------
# Parsing foomuuri built-in conf files
# ---------------------------------------------------------------------------

def _parse_conf_file(path, services, settings):
    in_macro    = False
    in_foomuuri = False
    with open(path) as f:
        for line in f:
            stripped = line.strip()
            if not stripped or stripped.startswith("#"):
                continue
            stripped = stripped.split("#")[0].strip()
            if not stripped:
                continue
            if stripped.startswith("macro") and "{" in stripped:
                in_macro = True
                continue
            if stripped.startswith("foomuuri") and "{" in stripped:
                in_foomuuri = True
                continue
            if stripped == "}":
                in_macro = False
                in_foomuuri = False
                continue
            if in_macro or in_foomuuri:
                parts = stripped.split(None, 1)
                key = parts[0]
                val = parts[1] if len(parts) > 1 else ""
                if in_macro:
                    services[key] = val
                else:
                    settings[key] = val


def load_builtin_conf(share_dir=FOOMUURI_SHARE_DIR):
    import glob as _glob
    services = {}
    settings = {}
    for path in sorted(_glob.glob(os.path.join(share_dir, "*.conf"))):
        try:
            _parse_conf_file(path, services, settings)
        except OSError:
            pass
    merged = dict(FOOMUURI_SETTINGS_DEFAULTS)
    merged.update(settings)
    return services, merged


# ---------------------------------------------------------------------------
# YAML loading and normalization
# ---------------------------------------------------------------------------

def load_config(path):
    with open(path) as f:
        return yaml.safe_load(f)


def normalize_config(cfg):
    macros = cfg.get("macros", {})

    net_raw = macros.get("net", {})
    net_norm = {}
    for zone, val in net_raw.items():
        if isinstance(val, str):
            val = [val]
        val = list(dict.fromkeys(val))
        entry = {}
        for cidr in val:
            if ":" in cidr:
                entry["v6"] = cidr
            else:
                entry["v4"] = cidr
        net_norm[zone] = entry
    if "net" in macros:
        macros["net"] = net_norm

    pods_raw = cfg.get("pods", {})
    dz = pods_raw.get("default_zones")
    if isinstance(dz, str):
        pods_raw["default_zones"] = [dz]
    pod_macros = macros.get("pod", {})
    for pname, pod in _iter_pods(pods_raw):
        ip_val = pod.get("ip")
        if isinstance(ip_val, str):
            pod["ip"] = [ip_val]
        i_val = pod.get("i")
        if isinstance(i_val, str) and i_val != "accept":
            pod["i"] = [i_val]
        if "i" not in pod:
            pod["i"] = []
        if isinstance(pod.get("zones"), str):
            pod["zones"] = [pod["zones"]]
        o_val = pod.get("o")
        if isinstance(o_val, str) and o_val != "accept":
            pod["o"] = [o_val]
        # Auto-detect zones from pod IP if not explicitly set
        if not pod.get("zones") and not dz:
            detected = set()
            for ip_ref in pod.get("ip", []):
                ip_addr = pod_macros.get(ip_ref)
                if not ip_addr:
                    continue
                try:
                    addr = ipaddress.ip_address(ip_addr)
                    for zone, nets in net_norm.items():
                        for fam, cidr in nets.items():
                            if addr in ipaddress.ip_network(cidr, strict=False):
                                detected.add(zone)
                except ValueError:
                    pass
            if detected:
                pod["zones"] = sorted(detected)

    for zone, dirs in cfg.get("policy", {}).items():
        for direction in ("i", "o"):
            val = dirs.get(direction)
            if isinstance(val, str) and val != "accept":
                dirs[direction] = [val]

    fwd = cfg.get("forward")
    if isinstance(fwd, dict):
        for key, val in list(fwd.items()):
            if isinstance(val, str) and val != "accept":
                fwd[key] = [val]

    for rule in cfg.get("dnat", []):
        if isinstance(rule.get("service"), str):
            rule["service"] = [rule["service"]]
        if isinstance(rule.get("iif"), str):
            rule["iif"] = [rule["iif"]]

    for rule in cfg.get("snat", []):
        if isinstance(rule.get("ip"), int):
            rule["ip"] = str(rule["ip"])

    for rule in cfg.get("redirect", []):
        if isinstance(rule.get("service"), str):
            rule["service"] = [rule["service"]]

    return cfg


# ---------------------------------------------------------------------------
# Validation
# ---------------------------------------------------------------------------

def _is_direct_ip(val):
    return "." in str(val) or ":" in str(val)

def _known_services(raw, builtins):
    proto_macros = set((raw.get("macros") or {}).get("proto", {}).keys())
    return set(builtins.keys()) | proto_macros


def _check_services_or_exit(services, raw, builtins, context):
    known = _known_services(raw, builtins)
    for svc in (services or []):
        if svc == "accept":
            continue
        name = svc.strip().split()[0]
        if name not in known:
            print("Error: unknown service '{}' in {} - not in built-in services or macros.proto".format(
                name, context), file=sys.stderr)
            sys.exit(1)


def _check_ip_macro_or_exit(name, raw, context):
    if not name:
        return
    ip_macros = (raw.get("macros") or {}).get("ip", {})
    if name[-1] in ("4", "6"):
        if name not in ip_macros:
            print("Error: IP macro '{}' not found in macros.ip ({})".format(name, context), file=sys.stderr)
            sys.exit(1)
    else:
        if name not in ip_macros and "{}4".format(name) not in ip_macros and "{}6".format(name) not in ip_macros:
            print("Error: '{}' not found in macros.ip "
                  "(no '{}', '{}4' or '{}6') ({})".format(name, name, name, name, context), file=sys.stderr)
            sys.exit(1)


def _check_zone_or_exit(name, raw, context):
    if name not in raw.get("zones", {}):
        defined = ", ".join(raw.get("zones", {}).keys()) or "(none)"
        print("Error: zone '{}' not defined in zones ({}). Defined: {}".format(
            name, context, defined), file=sys.stderr)
        sys.exit(1)


def _check_log_macro_or_exit(name, raw, context):
    if name in (None, "drop", "accept", ""):
        return
    log_macros = (raw.get("macros") or {}).get("log", {})
    if name not in log_macros:
        defined = ", ".join(log_macros.keys()) or "(none)"
        print("Error: log macro '{}' not found in macros.log ({}). Defined: {}".format(
            name, context, defined), file=sys.stderr)
        sys.exit(1)



def validate(cfg, builtins):
    errors = []
    warnings = []

    zones = set(cfg.get("zones", {}).keys())
    macros = cfg.get("macros", {})
    proto_macros = set(macros.get("proto", {}).keys())
    ip_macros = set(macros.get("ip", {}).keys())
    net_macros = set(macros.get("net", {}).keys())
    log_macros = set(macros.get("log", {}).keys())
    known_services = set(builtins.keys()) | proto_macros

    null_zones = [z for z, iface in cfg.get("zones", {}).items() if iface is None]
    if len(null_zones) > 1:
        errors.append("ERROR: multiple zones without interface (localhost zones): {}".format(", ".join(null_zones)))

    for name in proto_macros:
        if name in builtins:
            warnings.append("WARN: proto macro '{}' shadows a foomuuri built-in service".format(name))

    policy = cfg.get("policy", {})
    forward_raw = cfg.get("forward", {})
    for zone in zones:
        if zone == "fw":
            continue
        has_policy = zone in policy
        has_forward = any(zone in _parse_forward_key(k) for k in forward_raw)
        if not has_policy and not has_forward:
            warnings.append("WARN: zone '{}' has no policy or forward rules".format(zone))

    def check_service(token, context):
        svc = token.strip().split()[0]
        if svc not in known_services:
            errors.append("ERROR: unknown service '{}' in {}".format(svc, context))

    for zone, dirs in policy.items():
        if zone not in zones:
            errors.append("ERROR: zone '{}' used in policy but not defined in zones".format(zone))
            continue
        for direction, rules in dirs.items():
            if not isinstance(rules, list):
                continue
            for rule in rules:
                check_service(rule, "policy.{}.{}".format(zone, direction))

    for key, val in forward_raw.items():
        src, dst = _parse_forward_key(key)
        for z in [src, dst]:
            if z not in zones:
                errors.append("ERROR: zone '{}' used in forward '{}' but not defined in zones".format(z, key))
        if val == "accept":
            continue
        if isinstance(val, list):
            for rule in val:
                check_service(rule, "forward '{}'".format(key))

    for rule in cfg.get("dnat", []):
        frm = rule.get("from", "")
        if frm:
            if frm[-1] in ("4", "6"):
                if frm not in ip_macros:
                    errors.append("ERROR: ip macro '{}' used in dnat 'from' not defined in macros.ip".format(frm))
            else:
                if "{}4".format(frm) not in ip_macros and "{}6".format(frm) not in ip_macros:
                    errors.append("ERROR: dnat 'from: {}' - no '{}4' or '{}6' found in macros.ip".format(frm, frm, frm))
        for svc in rule.get("service", []):
            if svc not in known_services:
                errors.append("ERROR: unknown service '{}' in dnat".format(svc))
        for v in rule.get("iif", []):
            if v in ip_macros:
                base = v[:-1] if v[-1] in ("4", "6") else v
                hint = " - did you mean zone '{}'?".format(base) if base in zones else ""
                errors.append("ERROR: dnat iif '{}' is an IP macro, not a zone or interface{}".format(v, hint))
        to = rule.get("to", "")
        if to:
            if to[-1] in ("4", "6"):
                if to not in ip_macros:
                    errors.append("ERROR: ip macro '{}' used in dnat 'to' not defined in macros.ip".format(to))
            else:
                if to not in ip_macros and "{}4".format(to) not in ip_macros and "{}6".format(to) not in ip_macros:
                    errors.append("ERROR: dnat 'to: {}' - not found in macros.ip".format(to))

    for rule in cfg.get("snat", []):
        zone = rule.get("zone")
        if zone and zone not in zones:
            errors.append("ERROR: zone '{}' used in snat but not defined in zones".format(zone))
        if "via" in rule:
            via = rule["via"]
            if via not in zones:
                errors.append("ERROR: zone '{}' used in snat 'via' but not defined in zones".format(via))
            if zone and zone not in net_macros:
                errors.append("ERROR: zone '{}' in snat has no entry in macros.net".format(zone))
            ip_filter = rule.get("ip")
            if ip_filter not in (None, "4", "6"):
                errors.append("ERROR: snat 'ip' must be '4' or '6', got '{}'".format(ip_filter))

    for rule in cfg.get("redirect", []):
        for svc in rule.get("service", []):
            if svc not in known_services:
                errors.append("ERROR: unknown service '{}' in redirect".format(svc))
        to = rule.get("to", "")
        if not to:
            errors.append("ERROR: redirect rule missing 'to'")
        elif not _is_direct_ip(to):
            if to not in ip_macros:
                errors.append("ERROR: ip macro '{}' used in redirect 'to' not defined in macros.ip".format(to))

    pod_macros = set(macros.get("pod", {}).keys())
    pods_section = cfg.get("pods", {})
    for dz_zone in (pods_section.get("default_zones") or []):
        if dz_zone not in zones:
            errors.append("ERROR: zone '{}' in pods.default_zones not defined in zones".format(dz_zone))
    for pname, pod in _iter_pods(pods_section):
        ip_list = pod.get("ip", [])
        if not ip_list:
            errors.append("ERROR: pod '{}' has no ip defined".format(pname))
        for ip_ref in ip_list:
            if ip_ref not in pod_macros:
                errors.append("ERROR: pod macro '{}' used in pods.{}.ip not defined in macros.pod".format(ip_ref, pname))
        i_val = pod.get("i", [])
        if isinstance(i_val, list):
            for svc in i_val:
                if svc.strip().split()[0] not in known_services:
                    errors.append("ERROR: unknown service '{}' in pods.{}.i".format(svc.strip().split()[0], pname))
        o_val = pod.get("o")
        if isinstance(o_val, list):
            for svc in o_val:
                if svc.strip().split()[0] not in known_services:
                    errors.append("ERROR: unknown service '{}' in pods.{}.o".format(svc, pname))
        for ext_svc in pod.get("dnat", {}).keys():
            if ext_svc.strip().split()[0] not in known_services:
                errors.append("ERROR: unknown service '{}' in pods.{}.dnat".format(ext_svc, pname))
        for zone in (pod.get("zones") or []):
            if zone not in zones:
                errors.append("ERROR: zone '{}' in pods.{}.zones not defined in zones".format(zone, pname))

    for zone, log_pair in cfg.get("log", {}).items():
        if zone not in zones:
            errors.append("ERROR: zone '{}' in log section not defined in zones".format(zone))
        for macro_name in log_pair:
            if macro_name in (None, "drop", "accept", ""):
                continue
            if macro_name not in log_macros:
                errors.append("ERROR: log macro '{}' in log.{} not defined in macros.log".format(macro_name, zone))

    return errors, warnings


def _parse_forward_key(key):
    for sep in [" \u2192 ", " -> ", " \u2194 ", " <-> "]:
        if sep in key:
            parts = key.split(sep, 1)
            return parts[0].strip(), parts[1].strip()
    parts = key.split()
    if len(parts) >= 3:
        return parts[0], parts[2]
    return key, key


# ---------------------------------------------------------------------------
# Output generation helpers
# ---------------------------------------------------------------------------

def _resolve_drop_action(val):
    if not val or val == "drop":
        return "drop"
    if val == "accept":
        return None
    return "{} drop".format(val)


def _add_accept(rule):
    ACTION_KEYWORDS = {"accept", "drop", "reject", "continue", "return", "log"}
    parts = rule.strip().split()
    if parts and parts[-1].lower() not in ACTION_KEYWORDS:
        return rule.strip() + " accept"
    return rule.strip()


def _column_align(pairs, indent="  "):
    if not pairs:
        return []
    max_len = max(len(k) for k, _ in pairs)
    return ["{}{:<{}}  {}".format(indent, k, max_len, v) for k, v in pairs]


def _ip_family(addr):
    return "v6" if ":" in addr else "v4"


def _dnat_resolve_families(from_base, to_base, ip_macros_dict):
    def is_explicit(name):
        return name[-1] in ("4", "6") and name in ip_macros_dict

    if is_explicit(from_base):
        suffix = from_base[-1]
        if to_base in ip_macros_dict:
            return [(from_base, to_base)]
        to_macro = "{}{}".format(to_base, suffix)
        if to_macro in ip_macros_dict:
            return [(from_base, to_macro)]
        return []
    else:
        result = []
        for suffix in ("4", "6"):
            from_macro = "{}{}".format(from_base, suffix)
            if from_macro not in ip_macros_dict:
                continue
            if is_explicit(to_base):
                if to_base[-1] != suffix:
                    continue
                result.append((from_macro, to_base))
            else:
                to_macro = "{}{}".format(to_base, suffix)
                if to_macro in ip_macros_dict:
                    result.append((from_macro, to_macro))
                elif to_base in ip_macros_dict:
                    val = ip_macros_dict[to_base]
                    val_suffix = "4" if "." in val else "6" if ":" in val else None
                    if val_suffix == suffix:
                        result.append((from_macro, to_base))
        return result


def _resolve_port_num(port, known_services):
    if isinstance(port, int):
        return port
    defn = known_services.get(str(port), "")
    for token in defn.replace(";", " ").split():
        if token.isdigit():
            return int(token)
    try:
        return int(port)
    except (ValueError, TypeError):
        return None


def _find_service_for_port(port, known_services):
    port_str = str(port)
    for name, defn in known_services.items():
        tokens = str(defn).replace(";", " ").split()
        if port_str in tokens:
            return name
    return None


# ---------------------------------------------------------------------------
# fw.conf generation
# ---------------------------------------------------------------------------

def _zone_has_pods(pods, zone, default_zones=None):
    for pname, pod in _iter_pods(pods):
        pz = pod.get("zones") or default_zones
        if pz is None or zone in pz:
            return True
    return False


def generate_conf(cfg, builtins):
    macros = cfg.get("macros", {})
    zones = cfg.get("zones", {})
    settings = cfg.get("settings", {})
    policy = cfg.get("policy", {})
    log_cfg = cfg.get("log", {})
    forward_raw = cfg.get("forward", {})
    pods = cfg.get("pods", {})

    def _ip(n):  return "ip_{}".format(n)
    def _pod(n): return "pod_{}".format(n)
    def _log(n): return "log_{}".format(n)

    def _drop_action(val):
        action = _resolve_drop_action(val)
        if action and action != "drop":
            macro, rest = action.split(" ", 1)
            return "log_{} {}".format(macro, rest)
        return action

    pods_default_zones = pods.get("default_zones")

    lines = []

    # 1. foomuuri { } settings block
    lines.append("foomuuri {")
    for key, val in settings.items():
        if isinstance(val, bool):
            val = "yes" if val else "no"
        lines.append("  {} {}".format(key, val))
    lines.append("}")
    lines.append("")

    # 2. snat { }
    snat_rules = cfg.get("snat", [])
    ip_macros_dict = macros.get("ip", {})
    net_macros_norm = macros.get("net", {})

    if snat_rules:
        lines.append("snat {")
        for rule in snat_rules:
            zone_name = rule["zone"]
            nets = net_macros_norm.get(zone_name, {})
            via = rule["via"]
            oif = cfg.get("zones", {}).get(via, via)
            ip_filter = rule.get("ip")
            for suffix, family in (("4", "v4"), ("6", "v6")):
                if ip_filter and ip_filter != suffix:
                    continue
                to_name = "{}{}".format(via, suffix)
                ip_val = ip_macros_dict.get(to_name, "")
                if ip_val and family in nets:
                    subnet = nets[family]
                    lines.append('  saddr {} oifname "{}" snat to {}'.format(subnet, oif, _ip(to_name)))
        lines.append("}")
        lines.append("")

    # 3. macro { }
    macro_pairs = []
    for name, val in macros.get("proto", {}).items():
        macro_pairs.append((name, val))
    for name, val in macros.get("ip", {}).items():
        macro_pairs.append((_ip(name), val))
    for name, val in macros.get("log", {}).items():
        macro_pairs.append((_log(name), val))
    for name, val in macros.get("pod", {}).items():
        macro_pairs.append((_pod(name), val))

    lines.append("macro {")
    lines.extend(_column_align(macro_pairs))
    lines.append("}")
    lines.append("")

    # 4. zone { }
    lines.append("zone {")
    for zone, iface in zones.items():
        if iface:
            lines.append("  {}  {}".format(zone, iface))
        else:
            lines.append("  {}".format(zone))
    lines.append("}")
    lines.append("")

    # 5-6. template <zone>_i / <zone>_o
    for zone, dirs in policy.items():
        zfamilies = _zone_families(zone, cfg)
        for direction, tpl_name in (("i", "{}_i".format(zone)), ("o", "{}_o".format(zone))):
            if direction not in dirs:
                continue
            val = dirs[direction]
            if isinstance(val, list) and not val:
                continue
            lines.append("template {} {{".format(tpl_name))
            if val == "accept":
                lines.append("  accept")
            elif isinstance(val, list):
                for rule in val:
                    qualified = _add_family_qualifier(rule, zfamilies)
                    lines.append("  {}".format(_add_accept(qualified)))
            lines.append("}")
            lines.append("")

    # Pod templates
    pod_i_rules = []
    pod_o_rules = []
    for pname, pod in _iter_pods(pods):
        for ref in pod.get("ip", []):
            pref = _pod(ref)
            i_val = pod.get("i")
            if i_val == "accept":
                pod_i_rules.append("  daddr {} accept".format(pref))
            elif isinstance(i_val, list):
                for svc in i_val:
                    pod_i_rules.append("  {} daddr {} accept".format(svc, pref))
            o_val = pod.get("o")
            if o_val == "accept":
                pod_o_rules.append("  daddr {} accept".format(pref))
            elif isinstance(o_val, list):
                for svc in o_val:
                    pod_o_rules.append("  {} daddr {} accept".format(svc, pref))

    if pod_i_rules:
        lines.append("template pod_i {")
        lines.extend(pod_i_rules)
        lines.append("}")
        lines.append("")

    if pod_o_rules:
        lines.append("template pod_o {")
        lines.extend(pod_o_rules)
        lines.append("}")
        lines.append("")

    # 7. zone-fw / fw-zone blocks
    for zone, dirs in policy.items():
        log_pair = log_cfg.get(zone, [])
        drop_i = _drop_action(log_pair[0] if len(log_pair) > 0 else None)
        drop_o = _drop_action(log_pair[1] if len(log_pair) > 1 else None)

        has_pod_i = bool(pod_i_rules) and _zone_has_pods(pods, zone, pods_default_zones)
        has_pod_o = bool(pod_o_rules) and _zone_has_pods(pods, zone, pods_default_zones)

        i_val = dirs.get("i")
        o_val = dirs.get("o")
        has_i_tpl = "i" in dirs and not (isinstance(i_val, list) and not i_val)
        has_o_tpl = "o" in dirs and not (isinstance(o_val, list) and not o_val)

        block_i = []
        if has_i_tpl:
            block_i.append("  template {}_i".format(zone))
        if has_pod_i:
            block_i.append("  template pod_i")
        if drop_i:
            block_i.append("  {}".format(drop_i))
        if block_i:
            lines.append("{}-fw {{".format(zone))
            lines.extend(block_i)
            lines.append("}")
            lines.append("")

        block_o = []
        if has_o_tpl:
            block_o.append("  template {}_o".format(zone))
        if has_pod_o:
            block_o.append("  template pod_o")
        if drop_o:
            block_o.append("  {}".format(drop_o))
        if block_o:
            lines.append("fw-{} {{".format(zone))
            lines.extend(block_o)
            lines.append("}")
            lines.append("")

    # 8. Forward blocks
    def emit_forward(src, dst, val):
        lines.append("{}-{} {{".format(src, dst))
        if val == "accept":
            lines.append("  accept")
        elif isinstance(val, list):
            for rule in val:
                lines.append("  {}".format(_add_accept(rule)))
        elif isinstance(val, str):
            lines.append("  {}".format(_add_accept(val)))
        lines.append("}")
        lines.append("")

    for key, val in forward_raw.items():
        src, dst = _parse_forward_key(key)
        emit_forward(src, dst, val)
        if _is_bidir_key(key):
            emit_forward(dst, src, val)

    # 9. dnat { }
    dnat_rules = cfg.get("dnat", [])
    pod_dnat = [
        (pname, ip_ref, ext_svc, int_port)
        for pname, pod in _iter_pods(pods)
        for ext_svc, int_port in pod.get("dnat", {}).items()
        for ip_ref in pod.get("ip", [])
    ]
    if dnat_rules or pod_dnat:
        lines.append("dnat {")
        for rule in dnat_rules:
            frm      = rule.get("from") or None
            services = rule.get("service", [])
            iif_raw  = rule.get("iif", [])
            to_base  = rule.get("to", "")
            port     = rule.get("port")

            iif_list = []
            for v in iif_raw:
                resolved = zones.get(v, v)
                if resolved not in iif_list:
                    iif_list.append(resolved)

            port_suffix = ""
            if port is not None:
                known_svcs = dict(builtins)
                known_svcs.update(macros.get("proto", {}))
                pnum = _resolve_port_num(port, known_svcs)
                if pnum is None:
                    print("Warning: cannot resolve port '{}' to a number - using as-is".format(port), file=sys.stderr)
                port_suffix = ":{}".format(pnum) if pnum is not None else ":{}".format(port)

            if frm:
                pairs = _dnat_resolve_families(frm, to_base, ip_macros_dict)
                for from_macro, to_macro in pairs:
                    for svc in services:
                        for iif in iif_list:
                            lines.append(
                                '  daddr {} {}  iifname "{}"'
                                '  dnat to {}{}'.format(_ip(from_macro), svc, iif, _ip(to_macro), port_suffix))
            else:
                if to_base in ip_macros_dict:
                    to_macros = [to_base]
                else:
                    to_macros = ["{}{}".format(to_base, s) for s in ("4", "6")
                                 if "{}{}".format(to_base, s) in ip_macros_dict]
                for to_macro in to_macros:
                    for svc in services:
                        for iif in iif_list:
                            lines.append(
                                '  {}  iifname "{}"'
                                '  dnat to {}{}'.format(svc, iif, _ip(to_macro), port_suffix))
        for pname, ip_ref, ext_svc, int_port in pod_dnat:
            lines.append('  daddr {} {}  dnat to {}:{}'.format(_pod(ip_ref), ext_svc, _pod(ip_ref), int_port))
        lines.append("}")
        lines.append("")

    # 10. output nat dstnat { } (redirect)
    redirect_rules = cfg.get("redirect", [])
    if redirect_rules:
        lines.append("output nat dstnat {")
        for rule in redirect_rules:
            services = rule.get("service", [])
            to       = rule.get("to", "")
            port     = rule.get("port")

            if _is_direct_ip(to):
                to_str = to
            else:
                to_str = _ip(to)

            port_suffix = ""
            if port is not None:
                known_svcs = dict(builtins)
                known_svcs.update(macros.get("proto", {}))
                pnum = _resolve_port_num(port, known_svcs)
                if pnum is None:
                    print("Warning: cannot resolve port '{}' to a number - using as-is".format(port), file=sys.stderr)
                port_suffix = ":{}".format(pnum) if pnum is not None else ":{}".format(port)

            for svc in services:
                lines.append("  {} dnat to {}{}".format(svc, to_str, port_suffix))
        lines.append("}")
        lines.append("")

    return "\n".join(lines)


# ---------------------------------------------------------------------------
# Auto IP sync
# ---------------------------------------------------------------------------

def _get_iface_ips(iface):
    try:
        result = subprocess.run(
            ["ip", "-j", "addr", "show", iface],
            capture_output=True, text=True,
        )
        if result.returncode != 0:
            return None, None
        data = json.loads(result.stdout)
        ipv4 = ipv6 = None
        for entry in data:
            for addr in entry.get("addr_info", []):
                local = addr.get("local", "")
                if addr.get("family") == "inet" and ipv4 is None:
                    ipv4 = local
                elif addr.get("family") == "inet6" and ipv6 is None:
                    if not local.startswith("fe80"):
                        ipv6 = local
        return ipv4, ipv6
    except (OSError, json.JSONDecodeError, subprocess.SubprocessError):
        return None, None


def sync_ips(raw, config_path):
    if not raw.get("fwctl", {}).get("auto_ips", True):
        return raw

    zones    = raw.get("zones", {})
    ip_macros = raw.setdefault("macros", {}).setdefault("ip", {})

    changed = False
    for zone, iface in zones.items():
        if not iface:
            continue
        ipv4, ipv6 = _get_iface_ips(iface)
        for suffix, addr in (("4", ipv4), ("6", ipv6)):
            if addr is None:
                continue
            key = "{}{}".format(zone, suffix)
            if ip_macros.get(key) != addr:
                ip_macros[key] = addr
                changed = True

    if changed:
        _save_raw(raw, config_path)

    return raw


def normalize_raw_yaml(raw):
    """Normalize raw YAML dict to canonical forms (string<->list where needed).
    Returns list of human-readable messages describing corrections made."""
    messages = []

    def _to_list(val):
        return [val] if isinstance(val, str) else val

    # policy.<zone>.i / .o  -- string (non "accept") -> list
    for zone, dirs in (raw.get("policy") or {}).items():
        if not isinstance(dirs, dict):
            continue
        for direction in ("i", "o"):
            val = dirs.get(direction)
            if isinstance(val, str) and val != "accept":
                dirs[direction] = [val]
                messages.append("  policy.{}.{}: string -> list".format(zone, direction))

    # forward.<key> -- string (non "accept") -> list
    fwd = raw.get("forward")
    if isinstance(fwd, dict):
        for key, val in list(fwd.items()):
            if isinstance(val, str) and val != "accept":
                fwd[key] = [val]
                messages.append("  forward['{}']: string -> list".format(key))

    # dnat[*].service / .iif  -- string -> list
    dnat = raw.get("dnat")
    if isinstance(dnat, list):
        for i, rule in enumerate(dnat):
            if not isinstance(rule, dict):
                continue
            for field in ("service", "iif"):
                val = rule.get(field)
                if isinstance(val, str):
                    rule[field] = [val]
                    messages.append("  dnat[{}].{}: string -> list".format(i, field))

    # redirect[*].service -- string -> list
    redirect = raw.get("redirect")
    if isinstance(redirect, list):
        for i, rule in enumerate(redirect):
            if not isinstance(rule, dict):
                continue
            val = rule.get("service")
            if isinstance(val, str):
                rule["service"] = [val]
                messages.append("  redirect[{}].service: string -> list".format(i))

    # macros.net.<zone> -- string -> list (normalize_config converts further to dict)
    net = (raw.get("macros") or {}).get("net")
    if isinstance(net, dict):
        for zone, val in list(net.items()):
            if isinstance(val, str):
                net[zone] = [val]
                messages.append("  macros.net.{}: string -> list".format(zone))

    # pods.default_zones -- string -> list
    pods = raw.get("pods")
    if isinstance(pods, dict):
        dz = pods.get("default_zones")
        if isinstance(dz, str):
            pods["default_zones"] = [dz]
            messages.append("  pods.default_zones: string -> list")

        # pods.<name>.ip -- string -> list
        # pods.<name>.i / .o / .zones -- string -> list (except "accept")
        for pname, pod in pods.items():
            if pname == "default_zones" or not isinstance(pod, dict):
                continue
            ip_val = pod.get("ip")
            if isinstance(ip_val, str):
                pod["ip"] = [ip_val]
                messages.append("  pods.{}.ip: string -> list".format(pname))
            for field in ("i", "o"):
                val = pod.get(field)
                if isinstance(val, str) and val != "accept":
                    pod[field] = [val]
                    messages.append("  pods.{}.{}: string -> list".format(pname, field))
            zval = pod.get("zones")
            if isinstance(zval, str):
                pod["zones"] = [zval]
                messages.append("  pods.{}.zones: string -> list".format(pname))

    # fwctl.exclude_services -- string -> list
    fwctl_sec = raw.get("fwctl")
    if isinstance(fwctl_sec, dict):
        val = fwctl_sec.get("exclude_services")
        if isinstance(val, str):
            fwctl_sec["exclude_services"] = [val]
            messages.append("  fwctl.exclude_services: string -> list")

    return messages


def auto_fix_rules(raw, config_path, builtins=None):
    messages = []
    norm_msgs = normalize_raw_yaml(raw)

    for rule in raw.get("dnat", []):
        msgs = _ensure_forward_for_dnat(rule, raw, raw, builtins)
        messages.extend(msgs)

    if norm_msgs or messages:
        if norm_msgs:
            print("Normalized {}:".format(os.path.basename(config_path)))
            for msg in norm_msgs:
                print(msg)
        for msg in messages:
            print(msg)
        _save_raw(raw, config_path)
    return raw


# ---------------------------------------------------------------------------
# YAML helpers
# ---------------------------------------------------------------------------

def _yaml_dumper():
    class FwDumper(yaml.Dumper):
        pass

    FwDumper.add_representer(
        bool,
        lambda self, val: self.represent_scalar("tag:yaml.org,2002:bool", "yes" if val else "no"),
    )

    def _represent_list(self, data):
        if all(isinstance(i, str) for i in data):
            return self.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True)
        if data and all(
            isinstance(i, dict) and all(not isinstance(v, dict) for v in i.values())
            for i in data
        ):
            items = [self.represent_mapping("tag:yaml.org,2002:map", i, flow_style=True) for i in data]
            return yaml.SequenceNode("tag:yaml.org,2002:seq", items, flow_style=False)
        return self.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=False)

    FwDumper.add_representer(list, _represent_list)

    def _represent_dict(self, data):
        if data and all(isinstance(v, int) for v in data.values()):
            return self.represent_mapping("tag:yaml.org,2002:map", data, flow_style=True)
        return self.represent_mapping("tag:yaml.org,2002:map", data, flow_style=False)

    FwDumper.add_representer(dict, _represent_dict)
    return FwDumper


def _yq_available():
    import shutil as _shutil
    return bool(_shutil.which("yq"))


def _yq_print(data):
    yaml_str = yaml.dump(data, default_flow_style=False, allow_unicode=True,
                         sort_keys=False, Dumper=_yaml_dumper())
    if _yq_available():
        subprocess.run(["yq", "."], input=yaml_str, text=True)
    else:
        print(yaml_str, end="")


def _show_yq(cfg, section, sub):
    CFG_SECTIONS = ["settings", "zones", "macros", "policy", "forward", "dnat", "snat", "redirect", "pods", "hairpin"]
    SECTION_MAP = {"zone": "zones", "macro": "macros", "pod": "pods"}
    yaml_section = SECTION_MAP.get(section, section)
    if section is None:
        data = {k: cfg[k] for k in CFG_SECTIONS if k in cfg}
    elif yaml_section == "policy" and sub:
        policy = cfg.get("policy", {})
        if sub not in policy:
            print("zone '{}' not found in policy".format(sub), file=sys.stderr)
            sys.exit(1)
        data = {"policy": {sub: policy[sub]}}
    elif yaml_section == "macros" and sub in {"proto", "ip", "pod", "net", "log"}:
        data = {"macros": {sub: cfg.get("macros", {}).get(sub, {})}}
    else:
        data = {yaml_section: cfg.get(yaml_section)}
    _yq_print(data)


# ---------------------------------------------------------------------------
# CLI commands
# ---------------------------------------------------------------------------

def cmd_check(cfg, builtins, args):
    errors, warnings = validate(cfg, builtins)
    for w in warnings:
        print(w)
    for e in errors:
        print(e)
    if errors:
        sys.exit(1)
    print("OK: config is valid" + (" (with warnings)" if warnings else ""))


def _nflog_group(cfg):
    log_level = cfg.get("settings", {}).get("log_level", "group 0")
    log_level = log_level.strip('"').strip("'")
    for token in log_level.split():
        if token.isdigit():
            return int(token)
    return 0


def generate_ulogd_conf(cfg):
    group = _nflog_group(cfg)
    return """\
[global]
logfile="/var/log/ulogd.log"
plugin="/usr/lib/ulogd/ulogd_inppkt_NFLOG.so"
plugin="/usr/lib/ulogd/ulogd_raw2packet_BASE.so"
plugin="/usr/lib/ulogd/ulogd_filter_IFINDEX.so"
plugin="/usr/lib/ulogd/ulogd_filter_IP2STR.so"
plugin="/usr/lib/ulogd/ulogd_filter_PRINTPKT.so"
plugin="/usr/lib/ulogd/ulogd_output_LOGEMU.so"
plugin="/usr/lib/ulogd/ulogd_output_JSON.so"
stack=log1:NFLOG,base1:BASE,ifi1:IFINDEX,ip2str1:IP2STR,json1:JSON
stack=log1:NFLOG,base1:BASE,ifi1:IFINDEX,ip2str1:IP2STR,print1:PRINTPKT,emu1:LOGEMU

[log1]
group={group}
bufsize=150000

[base1]

[ifi1]

[ip2str1]

[print1]

[json1]
sync=1
file="/var/log/fw.json"
timestamp=1

[emu1]
sync=1
file="/var/log/fw"
""".format(group=group)


def _macro_to_ip(name, ip_macros):
    for key in (name, "{}4".format(name), "{}6".format(name)):
        val = ip_macros.get(key, "")
        if val:
            return val
    return ""


def _find_zone_for_ip(ip_str, cfg):
    if not ip_str:
        return None
    try:
        ip_obj = ipaddress.ip_address(ip_str)
    except ValueError:
        return None
    net_macros = (cfg.get("macros") or {}).get("net", {})
    for zone, entry in net_macros.items():
        if isinstance(entry, dict):
            cidrs = [entry.get("v4", ""), entry.get("v6", "")]
        elif isinstance(entry, list):
            cidrs = entry
        elif isinstance(entry, str):
            cidrs = [entry]
        else:
            continue
        for cidr in cidrs:
            if not cidr:
                continue
            try:
                if ip_obj in ipaddress.ip_network(cidr, strict=False):
                    return zone
            except ValueError:
                continue
    return None


def _ensure_forward_for_dnat(rule, raw, cfg, builtins=None):
    ip_macros = (cfg.get("macros") or {}).get("ip", {})
    zones     = cfg.get("zones", {})
    messages  = []

    iif_list = rule.get("iif", [])
    if isinstance(iif_list, str):
        iif_list = [iif_list]
    to_base  = rule.get("to", "")
    services = rule.get("service", [])
    if isinstance(services, str):
        services = [services]

    port = rule.get("port")
    if port is not None:
        known = dict(builtins or {})
        known.update((cfg.get("macros") or {}).get("proto", {}))
        int_port = _resolve_port_num(port, known)
        int_svc  = _find_service_for_port(int_port, known) if int_port else None
        if int_svc:
            services = [int_svc]

    to_families = _resolve_ip_family(to_base, ip_macros)
    if to_families == {'4'}:
        services = ["{} ipv4".format(s) if "ipv4" not in s and "ipv6" not in s else s for s in services]
    elif to_families == {'6'}:
        services = ["{} ipv6".format(s) if "ipv4" not in s and "ipv6" not in s else s for s in services]

    iif_zones = []
    for v in iif_list:
        if v in zones:
            iif_zones.append(v)
        else:
            for zname, iface in zones.items():
                if iface == v:
                    iif_zones.append(zname)
                    break

    ip_str  = _macro_to_ip(to_base, ip_macros)
    to_zone = _find_zone_for_ip(ip_str, cfg)
    if not to_zone:
        return messages

    forward = raw.setdefault("forward", {})

    snat_rules = raw.get("snat")
    if not isinstance(snat_rules, list):
        snat_rules = []
        raw["snat"] = snat_rules

    for iif_zone in iif_zones:
        if iif_zone == to_zone:
            continue

        key_variants = [
            "{} \u2192 {}".format(iif_zone, to_zone), "{} -> {}".format(iif_zone, to_zone),
            "{} \u2192 {}".format(to_zone, iif_zone), "{} -> {}".format(to_zone, iif_zone),
            "{} \u2194 {}".format(iif_zone, to_zone), "{} <-> {}".format(iif_zone, to_zone),
            "{} \u2194 {}".format(to_zone, iif_zone), "{} <-> {}".format(to_zone, iif_zone),
        ]
        existing_key = next((k for k in key_variants if k in forward), None)

        bidir_key = existing_key and _is_bidir_key(existing_key)
        reverse_key = existing_key and existing_key.startswith(to_zone) and not bidir_key
        if reverse_key or existing_key is None:
            existing_key = None

        if existing_key:
            existing_val = forward[existing_key]
            if existing_val == "accept" or bidir_key:
                pass
            elif isinstance(existing_val, list):
                existing_bases = {s.split()[0] for s in existing_val}
                missing = [s for s in services if s.split()[0] not in existing_bases]
                if missing:
                    forward[existing_key] = existing_val + missing
                    messages.append("Updated forward '{}': added {}".format(existing_key, missing))
        else:
            key = "{} \u2192 {}".format(iif_zone, to_zone)
            forward[key] = list(services)
            messages.append("Added forward rule: '{}: {}'".format(key, forward[key]))

        hairpin = cfg.get("hairpin") or []
        if isinstance(hairpin, str):
            hairpin = [hairpin]
        wan_zone = (cfg.get("settings") or {}).get("dbus_zone")
        if to_zone in hairpin and iif_zone != to_zone and iif_zone != wan_zone:
            if not any(r.get("zone") == iif_zone and r.get("via") == to_zone
                       for r in snat_rules):
                snat_rules.append({"zone": iif_zone, "via": to_zone})
                messages.append("Added SNAT rule: {} -> {}".format(iif_zone, to_zone))

    return messages


def _resolve_ip_family(name, ip_macros):
    if '.' in name: return {'4'}
    if ':' in name: return {'6'}
    families = set()
    for key in (name, "{}4".format(name), "{}6".format(name)):
        val = ip_macros.get(key, "")
        if '.' in val:   families.add('4')
        elif ':' in val: families.add('6')
    return families


def _zone_families(zone, cfg):
    ip_macros = (cfg.get("macros") or {}).get("ip", {})
    has4 = "{}4".format(zone) in ip_macros
    has6 = "{}6".format(zone) in ip_macros
    if has4 and not has6:
        return {'4'}
    if has6 and not has4:
        return {'6'}
    return {'4', '6'}


def _add_family_qualifier(rule, families):
    tokens = rule.split()
    if 'ipv4' in tokens or 'ipv6' in tokens:
        return rule
    if families == {'4'}:
        return rule + ' ipv4'
    if families == {'6'}:
        return rule + ' ipv6'
    return rule


def _needed_forward_families(cfg):
    families = set()
    ip_macros = (cfg.get("macros") or {}).get("ip", {})

    def _svc_family(svc):
        tokens = svc.split()
        if 'ipv4' in tokens: return '4'
        if 'ipv6' in tokens: return '6'
        return None

    for services in cfg.get("forward", {}).values():
        if services == "accept" or not isinstance(services, list):
            families |= {'4', '6'}
        else:
            fams = [_svc_family(s) for s in services]
            if fams and all(f == '4' for f in fams):
                families.add('4')
            elif fams and all(f == '6' for f in fams):
                families.add('6')
            else:
                families |= {'4', '6'}

    for rule in cfg.get("snat", []):
        ip_filter = str(rule.get("ip", ""))
        if ip_filter == "4":   families.add('4')
        elif ip_filter == "6": families.add('6')
        else:                  families |= {'4', '6'}

    for rule in cfg.get("dnat", []):
        frm = rule.get("from") or None
        to  = rule.get("to", "")
        if frm:
            families |= _resolve_ip_family(frm, ip_macros)
        else:
            families |= _resolve_ip_family(to, ip_macros)

    return families


def _enable_forwarding(families):
    paths = []
    if '4' in families:
        paths.append("/proc/sys/net/ipv4/ip_forward")
    if '6' in families:
        paths.append("/proc/sys/net/ipv6/conf/all/forwarding")
    for path in paths:
        try:
            with open(path) as f:
                if f.read().strip() == "1":
                    continue
            with open(path, "w") as f:
                f.write("1\n")
            print("Enabled forwarding: {}".format(path))
        except OSError as e:
            print("Warning: could not enable forwarding ({}): {}".format(path, e), file=sys.stderr)


def _validate_or_exit(cfg, builtins):
    errors, warnings = validate(cfg, builtins)
    for w in warnings:
        print(w, file=sys.stderr)
    for e in errors:
        print(e, file=sys.stderr)
    if errors:
        sys.exit(1)


def _write_conf(conf, output_path):
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "w") as f:
        f.write(conf)
    print("Written: {}".format(output_path))


def cmd_gen(cfg, builtins, args):
    _validate_or_exit(cfg, builtins)
    conf = generate_conf(cfg, builtins)

    if args.print:
        sys.stdout.write(conf)
        return

    output_path = cfg.get("fwctl", {}).get("output", "{}/fw.conf".format(FOOMUURI_DIR))
    _write_conf(conf, output_path)

    if args.ulogd:
        ulogd_path = cfg.get("fwctl", {}).get("ulogd", "/etc/ulogd.conf")
        _write_conf(generate_ulogd_conf(cfg), ulogd_path)


def cmd_apply(cfg, builtins, args):
    _validate_or_exit(cfg, builtins)
    output_path = cfg.get("fwctl", {}).get("output", "{}/fw.conf".format(FOOMUURI_DIR))
    conf = generate_conf(cfg, builtins)

    if args.dry_run:
        try:
            with open(output_path) as f:
                current = f.readlines()
        except FileNotFoundError:
            current = []
        new_lines = conf.splitlines(keepends=True)
        diff = list(difflib.unified_diff(
            current, new_lines,
            fromfile=output_path,
            tofile=output_path + ".new",
        ))
        if diff:
            sys.stdout.writelines(diff)
        else:
            print("No changes.")
        return

    _write_conf(conf, output_path)

    if cfg.get("fwctl", {}).get("auto_forward", True):
        families = _needed_forward_families(cfg)
        if families:
            _enable_forwarding(families)

    cmd = cfg.get("fwctl", {}).get("cmd", "")
    if cmd:
        print("Running: {}".format(cmd))
        result = subprocess.run(cmd, shell=True)
        sys.exit(result.returncode)
    else:
        print("foomuuri needs restart")


def cmd_add_macro(cfg, builtins, args):
    mtype  = args.mtype
    name   = args.name
    values = args.values

    raw, config_path = _load_raw(args)

    macros = raw.setdefault("macros", {})

    section = macros.get(mtype, {})
    if name in section:
        print("Error: macro {} '{}' already exists. Use 'fwctl edit' to modify.".format(mtype, name), file=sys.stderr)
        sys.exit(1)

    if mtype == "proto":
        if not values:
            print("Error: proto requires at least one value (e.g. tcp 443)", file=sys.stderr)
            sys.exit(1)
        if name in builtins:
            if not getattr(args, "force", False):
                print("Error: '{}' is already a foomuuri built-in service - use it directly without defining it".format(name),
                      file=sys.stderr)
                print("  use -f to override anyway", file=sys.stderr)
                sys.exit(1)
            print("Warning: '{}' shadows a foomuuri built-in service".format(name))
        if len(values) == 1:
            val = values[0]
        else:
            val = _group_proto_tokens(values)
        # Check for existing macros with same definition
        norm_val = val.replace(";", " ").split()
        dupes = []
        for bname, bdef in builtins.items():
            if bname != name and str(bdef).replace(";", " ").split() == norm_val:
                dupes.append("{} (builtin)".format(bname))
        for cname, cdef in macros.get("proto", {}).items():
            if cname != name and str(cdef).replace(";", " ").split() == norm_val:
                dupes.append(cname)
        if dupes and not getattr(args, "force", False):
            print("Error: same definition as: {}".format(", ".join(dupes)), file=sys.stderr)
            print("  use -f to add anyway", file=sys.stderr)
            sys.exit(1)
        elif dupes:
            print("Warning: same definition as: {}".format(", ".join(dupes)))
        macros.setdefault("proto", {})[name] = val

    elif mtype == "ip":
        if len(values) != 1:
            print("Error: ip requires exactly one value (e.g. 103.109.234.7)", file=sys.stderr)
            sys.exit(1)
        macros.setdefault("ip", {})[name] = values[0]

    elif mtype == "pod":
        if len(values) != 1:
            print("Error: pod requires exactly one IP address (e.g. 10.0.0.41)", file=sys.stderr)
            sys.exit(1)
        macros.setdefault("pod", {})[name] = values[0]

    elif mtype == "net":
        if not values:
            print("Error: net requires at least one CIDR (e.g. 10.0.0.0/24)", file=sys.stderr)
            sys.exit(1)
        values = list(dict.fromkeys(values))
        macros.setdefault("net", {})[name] = values if len(values) > 1 else values[0]

    elif mtype == "log":
        label = values[0] if values else name
        val = 'log "{}"'.format(label)
        macros.setdefault("log", {})[name] = val

    _save_raw(raw, config_path)
    print("Macro '{}' added to macros.{}".format(name, mtype))


def cmd_add_forward(cfg, builtins, args):
    src   = args.src
    dst   = args.dst
    rules = args.rules or []

    bidir = False
    if rules and rules[0] == "-":
        bidir = True
        rules = rules[1:]

    if not rules:
        print("Error: specify services or 'accept'", file=sys.stderr)
        sys.exit(1)

    val = "accept" if rules == ["accept"] else rules
    key = "{} <-> {}".format(src, dst) if bidir else "{} -> {}".format(src, dst)

    raw, config_path = _load_raw(args)

    _check_zone_or_exit(src, raw, "forward src")
    _check_zone_or_exit(dst, raw, "forward dst")
    if isinstance(val, list):
        _check_services_or_exit(val, raw, builtins, "forward '{}'".format(key))

    forward = raw.setdefault("forward", {})
    if key in forward:
        print("Error: forward rule '{}' already exists.".format(key), file=sys.stderr)
        sys.exit(1)

    forward[key] = val

    _save_raw(raw, config_path)
    print("Forward '{}' added".format(key))


def cmd_add_dnat(cfg, builtins, args):
    services = args.s if len(args.s) > 1 else args.s[0]

    raw, config_path = _load_raw(args)

    ip_macros_dict = (raw.get("macros") or {}).get("ip", {})
    zones_dict     = raw.get("zones", {})
    resolved  = []
    for v in args.i:
        if v in ip_macros_dict:
            base = v[:-1] if v[-1] in ("4", "6") else v
            if base in zones_dict:
                iface = zones_dict[base]
                print("Note: iif '{}' is an IP macro - using zone '{}' (interface '{}')".format(v, base, iface))
                resolved.append(base)
            else:
                print("Error: iif '{}' is an IP macro but zone '{}' not found - use a zone name or interface name".format(v, base), file=sys.stderr)
                sys.exit(1)
        else:
            resolved.append(v)
    iif = resolved if len(resolved) > 1 else resolved[0]

    rule = {"service": services, "iif": iif, "to": args.to}
    if args.frm:
        rule["from"] = args.frm
    if args.p is not None:
        rule["port"] = args.p

    _check_services_or_exit(_as_list(services), raw, builtins, "dnat service")
    if args.frm:
        _check_ip_macro_or_exit(args.frm, raw, "dnat from")
    _check_ip_macro_or_exit(args.to,  raw, "dnat to")
    for iif_val in _as_list(iif):
        _check_zone_or_exit(iif_val, raw, "dnat iif")
    for existing in raw.get("dnat", []):
        if (existing.get("from") == args.frm and
                sorted(_as_list(existing.get("service", []))) == sorted(_as_list(services)) and
                sorted(_as_list(existing.get("iif",     []))) == sorted(_as_list(iif))      and
                existing.get("to") == args.to):
            print("Error: DNAT rule already exists. Use 'fwctl edit' to modify.", file=sys.stderr)
            sys.exit(1)

    if not isinstance(raw.get("dnat"), list):
        raw["dnat"] = []
    raw["dnat"].append(rule)
    _save_raw(raw, config_path)
    svc_str = ", ".join(_as_list(services))
    iif_str = ", ".join(_as_list(iif))
    port_str = ":{}".format(args.p) if args.p is not None else ""
    frm_str = args.frm if args.frm else "(any)"
    print("DNAT rule added: {}  [{}]  iif=[{}]  -> {}{}".format(frm_str, svc_str, iif_str, args.to, port_str))

    fwd_msgs = _ensure_forward_for_dnat(rule, raw, raw, builtins)
    if fwd_msgs:
        for msg in fwd_msgs:
            print(msg)
        _save_raw(raw, config_path)

    if args.p is not None:
        known_dict = dict(builtins)
        known_dict.update(raw.get("macros", {}).get("proto", {}))
        int_port = _resolve_port_num(args.p, known_dict)
        int_svc  = _find_service_for_port(int_port, known_dict) if int_port else None

        to_ip      = _macro_to_ip(args.to, raw.get("macros", {}).get("ip", {}))
        dest_zone  = _find_zone_for_ip(to_ip, raw)
        zones      = raw.get("zones", {})
        iif_zones  = []
        for v in _as_list(iif):
            if v in zones:
                iif_zones.append(v)
            else:
                for zname, iface in zones.items():
                    if iface == v:
                        iif_zones.append(zname)
                        break

        forward = raw.get("forward", {}) or {}
        if int_port and dest_zone and iif_zones:
            for iif_zone in iif_zones:
                if iif_zone == dest_zone:
                    continue
                key_variants = [
                    "{} → {}".format(iif_zone, dest_zone), "{} -> {}".format(iif_zone, dest_zone),
                    "{} ↔ {}".format(iif_zone, dest_zone), "{} <-> {}".format(iif_zone, dest_zone),
                    "{} ↔ {}".format(dest_zone, iif_zone), "{} <-> {}".format(dest_zone, iif_zone),
                ]
                fwd_key = next((k for k in key_variants if k in forward), None)
                fwd_val = forward.get(fwd_key) if fwd_key else None
                fwd_bases = ([s.split()[0] for s in fwd_val]
                             if isinstance(fwd_val, list) else [])
                has_accept = (fwd_val == "accept" or
                              (int_svc and int_svc in fwd_bases))
                if not has_accept:
                    print("Warning: no accept found for internal port {} in forward '{} -> {}'".format(
                        int_port, iif_zone, dest_zone))
                    if int_svc:
                        existing = fwd_val if isinstance(fwd_val, list) else []
                        print("  -> fwctl edit forward {} {} {}".format(
                            iif_zone, dest_zone, " ".join(existing + [int_svc])))
                    else:
                        print("  -> fwctl show macros {}  (find service name)".format(int_port))
                        print("  -> fwctl edit forward {} {} ... <svc>".format(iif_zone, dest_zone))


def cmd_add_snat(cfg, builtins, args):
    raw, config_path = _load_raw(args)

    _check_zone_or_exit(args.zone, raw, "snat zone")
    _check_zone_or_exit(args.via,  raw, "snat via")
    for existing in raw.get("snat", []):
        if existing.get("zone") == args.zone and existing.get("via") == args.via:
            print("Error: SNAT rule '{} -> {}' already exists. Use 'fwctl edit' to modify.".format(args.zone, args.via), file=sys.stderr)
            sys.exit(1)

    rule = {"zone": args.zone, "via": args.via}
    if args.ip:
        rule["ip"] = args.ip

    if not isinstance(raw.get("snat"), list):
        raw["snat"] = []
    raw["snat"].append(rule)

    forward = raw.setdefault("forward", {}) or {}
    key_variants = [
        "{} → {}".format(args.zone, args.via), "{} -> {}".format(args.zone, args.via),
        "{} ↔ {}".format(args.zone, args.via), "{} <-> {}".format(args.zone, args.via),
        "{} ↔ {}".format(args.via, args.zone), "{} <-> {}".format(args.via, args.zone),
    ]
    fwd_key = next((k for k in key_variants if k in forward), None)
    fwd_added = False
    if fwd_key is None:
        new_key = "{} → {}".format(args.zone, args.via)
        forward[new_key] = "accept"
        raw["forward"] = forward
        fwd_added = True

    _save_raw(raw, config_path)
    msg = "SNAT rule added: {} -> {}".format(args.zone, args.via)
    if args.ip:
        msg += " (IPv{} only)".format(args.ip)
    print(msg)
    if fwd_added:
        print("Added forward rule: '{} → {}: accept'".format(args.zone, args.via))


def cmd_add_hairpin(cfg, builtins, args):
    raw, config_path = _load_raw(args)
    _check_zone_or_exit(args.zone, raw, "hairpin zone")

    hairpin = raw.get("hairpin") or []
    if isinstance(hairpin, str):
        hairpin = [hairpin]
    if args.zone in hairpin:
        print("Error: zone '{}' already in hairpin.".format(args.zone), file=sys.stderr)
        sys.exit(1)
    hairpin.append(args.zone)
    raw["hairpin"] = hairpin
    _save_raw(raw, config_path)
    print("Hairpin zone added: {}".format(args.zone))


def cmd_del_hairpin(cfg, builtins, args):
    raw, config_path = _load_raw(args)
    hairpin = raw.get("hairpin") or []
    if isinstance(hairpin, str):
        hairpin = [hairpin]
    if args.zone not in hairpin:
        print("Error: zone '{}' not in hairpin.".format(args.zone), file=sys.stderr)
        sys.exit(1)
    hairpin = [z for z in hairpin if z != args.zone]
    if hairpin:
        raw["hairpin"] = hairpin
    else:
        raw.pop("hairpin", None)
    _save_raw(raw, config_path)
    print("Hairpin zone removed: {}".format(args.zone))


def cmd_add_zone(cfg, builtins, args):
    name  = args.name
    iface = args.iface or name
    nets  = args.net or []

    raw, config_path = _load_raw(args)

    if name in raw.get("zones", {}):
        print("Error: zone '{}' already exists.".format(name), file=sys.stderr)
        sys.exit(1)

    raw.setdefault("zones", {})[name] = iface

    raw.setdefault("policy", {})[name] = {"i": [], "o": []}

    log_macros = (raw.get("macros") or {}).get("log", {})
    for lm in ("login", "logout"):
        if lm not in log_macros:
            print("Warning: log macro '{}' not defined in macros.log - add it or edit log.{} afterwards".format(lm, name))
    raw.setdefault("log", {})[name] = ["login", "logout"]

    if nets:
        nets = list(dict.fromkeys(nets))
        raw.setdefault("macros", {}).setdefault("net", {})[name] = nets if len(nets) > 1 else nets[0]

    ipv4, ipv6 = _get_iface_ips(iface)
    ip_macros = raw.setdefault("macros", {}).setdefault("ip", {})
    if ipv4:
        ip_macros["{}4".format(name)] = ipv4
    if ipv6:
        ip_macros["{}6".format(name)] = ipv6

    if not nets:
        detected_nets = _iface_networks(iface)
        if detected_nets:
            net_val = detected_nets if len(detected_nets) > 1 else detected_nets[0]
            raw.setdefault("macros", {}).setdefault("net", {})[name] = net_val
            nets = detected_nets

    _save_raw(raw, config_path)

    print("Zone '{}' added (iface={})".format(name, iface))
    if ipv4:
        print("  {}4 = {}".format(name, ipv4))
    if ipv6:
        print("  {}6 = {}".format(name, ipv6))
    if nets:
        print("  net: {}".format(nets))
    print("  Edit policy.{}.i and policy.{}.o to define rules.".format(name, name))


def cmd_add_policy(cfg, builtins, args):
    zone = args.zone

    raw, config_path = _load_raw(args)

    if zone not in raw.get("zones", {}):
        print("Error: zone '{}' not defined in zones.".format(zone), file=sys.stderr)
        sys.exit(1)

    policy = raw.setdefault("policy", {})
    if zone in policy:
        print("Error: policy for zone '{}' already exists. Use 'fwctl edit' to modify.".format(zone), file=sys.stderr)
        sys.exit(1)

    _check_services_or_exit(args.i, raw, builtins, "policy.{}.i".format(zone))
    _check_services_or_exit(args.o, raw, builtins, "policy.{}.o".format(zone))

    entry = {}
    if args.i is not None:
        entry["i"] = "accept" if args.i == ["accept"] else args.i
    if args.o is not None:
        entry["o"] = "accept" if args.o == ["accept"] else args.o
    policy[zone] = entry

    _save_raw(raw, config_path)
    print("Policy for zone '{}' added.".format(zone))


def cmd_add_pod(cfg, builtins, args):
    name = args.name
    ip   = args.ip

    raw, config_path = _load_raw(args)

    if name == "default_zones":
        print("Error: 'default_zones' is a reserved key in the pods section.", file=sys.stderr)
        sys.exit(1)

    if name in raw.get("pods", {}):
        print("Error: pod '{}' already exists.".format(name), file=sys.stderr)
        sys.exit(1)

    # Create macro in macros.pod
    macros = raw.setdefault("macros", {})
    pod_macros = macros.setdefault("pod", {})
    if name in pod_macros:
        if pod_macros[name] != ip:
            print("Error: pod macro '{}' already exists with different IP ({})".format(name, pod_macros[name]), file=sys.stderr)
            sys.exit(1)
    else:
        pod_macros[name] = ip

    pod = {"ip": [name]}

    in_svcs = list(args.in_svcs or [])

    if args.dnat:
        dnat_dict = {}
        for entry in args.dnat:
            if ":" not in entry:
                print("Error: --dnat format is service:port (e.g. http:8080)", file=sys.stderr)
                sys.exit(1)
            svc, port = entry.split(":", 1)
            known_dict = dict(builtins)
            known_dict.update(raw.get("macros", {}).get("proto", {}))
            int_port = _resolve_port_num(port.strip(), known_dict)
            if int_port is None:
                print("Error: cannot resolve port '{}' in --dnat '{}'".format(port.strip(), entry), file=sys.stderr)
                sys.exit(1)
            dnat_dict[svc.strip()] = int_port
        known_dict = dict(builtins)
        known_dict.update(raw.get("macros", {}).get("proto", {}))
        for ext_svc, int_port in dnat_dict.items():
            # Check if a service covering this port is already in in_svcs
            port_covered = False
            for existing in in_svcs:
                if existing in known_dict:
                    existing_tokens = str(known_dict[existing]).replace(";", " ").split()
                    if str(int_port) in existing_tokens:
                        port_covered = True
                        break
            if port_covered:
                continue
            int_svc = _find_service_for_port(int_port, known_dict)
            if int_svc:
                if int_svc not in in_svcs:
                    in_svcs.append(int_svc)
                    print("Note: added '{}' to inbound (port {} for dnat {}:{})".format(int_svc, int_port, ext_svc, int_port))
            else:
                print("Warning: no service found for internal port {} "
                      "(dnat {}:{}) - add a macro and include it in -i manually".format(int_port, ext_svc, int_port))
        pod["dnat"] = dnat_dict

    if in_svcs:
        pod["i"] = in_svcs

    if args.out is not None:
        pod["o"] = "accept" if len(args.out) == 0 else args.out

    if args.zones:
        pod["zones"] = args.zones

    raw.setdefault("pods", {})[name] = pod

    _save_raw(raw, config_path)
    print("Pod '{}' added".format(name))


def cmd_del_pod(cfg, builtins, args):
    name = args.name

    raw, config_path = _load_raw(args)

    if name not in raw.get("pods", {}):
        print("Error: pod '{}' not found in pods".format(name), file=sys.stderr)
        sys.exit(1)

    del raw["pods"][name]

    # Remove pod macro if not referenced by other pods
    pod_macros = (raw.get("macros") or {}).get("pod", {})
    if name in pod_macros:
        still_used = any(
            name in pod.get("ip", [])
            for pname, pod in _iter_pods(raw.get("pods", {}))
        )
        if not still_used:
            del pod_macros[name]
            print("Macro pod '{}' removed".format(name))
        else:
            print("Note: pod macro '{}' kept (still referenced by other pods)".format(name))

    _save_raw(raw, config_path)
    print("Pod '{}' removed".format(name))


def _find_macro_usages(mtype, name, raw):
    usages = []

    if mtype == "proto":
        for zone, dirs in raw.get("policy", {}).items():
            for d in ("i", "o"):
                val = dirs.get(d, [])
                if isinstance(val, list) and any(_svc_matches(s, name) for s in val):
                    usages.append("policy.{}.{}".format(zone, d))

        for key, val in raw.get("forward", {}).items():
            if isinstance(val, list) and any(_svc_matches(s, name) for s in val):
                usages.append("forward.{!r}".format(key))

        for pname, pod in _iter_pods(raw.get("pods", {})):
            for field in ("i", "o"):
                val = pod.get(field, [])
                if isinstance(val, list) and any(_svc_matches(s, name) for s in val):
                    usages.append("pods.{}.{}".format(pname, field))
            if name in pod.get("dnat", {}):
                usages.append("pods.{}.dnat".format(pname))

        for i, rule in enumerate(raw.get("dnat", [])):
            svcs = rule.get("service", [])
            if isinstance(svcs, str):
                svcs = [svcs]
            if any(_svc_matches(s, name) for s in svcs):
                usages.append("dnat[{}]".format(i))

    elif mtype == "ip":
        base = name[:-1] if name[-1] in "46" else name
        for i, rule in enumerate(raw.get("dnat", [])):
            frm = rule.get("from", "")
            to  = rule.get("to", "")
            if frm in (name, base):
                usages.append("dnat[{}].from".format(i))
            if to in (name, base):
                usages.append("dnat[{}].to".format(i))

    elif mtype == "pod":
        for pname, pod in _iter_pods(raw.get("pods", {})):
            if name in pod.get("ip", []):
                usages.append("pods.{}.ip".format(pname))

    elif mtype == "log":
        for zone, pair in raw.get("log", {}).items():
            if isinstance(pair, list):
                if len(pair) > 0 and pair[0] == name:
                    usages.append("log.{}[inbound]".format(zone))
                if len(pair) > 1 and pair[1] == name:
                    usages.append("log.{}[outbound]".format(zone))

    elif mtype == "net":
        for i, rule in enumerate(raw.get("snat", [])):
            if rule.get("zone") == name:
                usages.append("snat[{}]".format(i))

    return usages


def _remove_macro_from_rules(mtype, name, raw):
    if mtype == "proto":
        def clean_list(lst):
            return [s for s in lst if not _svc_matches(s, name)]

        for zone, dirs in raw.get("policy", {}).items():
            for d in ("i", "o"):
                val = dirs.get(d)
                if isinstance(val, list):
                    new = clean_list(val)
                    if new:
                        dirs[d] = new
                    else:
                        del dirs[d]
            if not dirs:
                del raw["policy"][zone]

        for key in list(raw.get("forward", {}).keys()):
            val = raw["forward"][key]
            if isinstance(val, list):
                new = clean_list(val)
                if new:
                    raw["forward"][key] = new
                else:
                    del raw["forward"][key]

        for pname, pod in _iter_pods(raw.get("pods", {})):
            for field in ("i", "o"):
                val = pod.get(field)
                if isinstance(val, list):
                    new = clean_list(val)
                    if new:
                        pod[field] = new
                    else:
                        del pod[field]
            if name in pod.get("dnat", {}):
                del pod["dnat"][name]

        new_dnat = []
        for rule in raw.get("dnat", []):
            svcs = rule.get("service", [])
            if isinstance(svcs, str):
                svcs = [svcs]
            new_svcs = [s for s in svcs if not _svc_matches(s, name)]
            if not new_svcs:
                continue
            rule["service"] = new_svcs if len(new_svcs) > 1 else new_svcs[0]
            new_dnat.append(rule)
        raw["dnat"] = new_dnat

    elif mtype == "log":
        for zone, pair in raw.get("log", {}).items():
            if isinstance(pair, list):
                raw["log"][zone] = ["drop" if v == name else v for v in pair]

    elif mtype == "ip":
        base = name[:-1] if name[-1] in "46" else name
        raw["dnat"] = [r for r in raw.get("dnat", [])
                       if r.get("from") not in (name, base) and r.get("to") not in (name, base)]

    elif mtype == "pod":
        for pname, pod in _iter_pods(raw.get("pods", {})):
            ip_list = pod.get("ip", [])
            if name in ip_list:
                ip_list.remove(name)

    elif mtype == "net":
        raw["snat"] = [r for r in raw.get("snat", []) if r.get("zone") != name]


def cmd_del_macro(cfg, builtins, args):
    mtype = args.mtype
    name  = args.name
    force = getattr(args, "force", False)

    raw, config_path = _load_raw(args)

    section = raw.get("macros", {}).get(mtype, {})
    if name not in section:
        print("Error: macro '{}' not found in macros.{}".format(name, mtype), file=sys.stderr)
        sys.exit(1)

    usages = _find_macro_usages(mtype, name, raw)
    if usages and not force:
        print("Error: macro '{}' is used in:".format(name), file=sys.stderr)
        for u in usages:
            print("  {}".format(u), file=sys.stderr)
        print("Use -f to remove it from all rules and delete it.", file=sys.stderr)
        sys.exit(1)

    if usages and force:
        _remove_macro_from_rules(mtype, name, raw)
        print("Removed '{}' from {} location(s).".format(name, len(usages)))

    del raw.get("macros", {}).get(mtype, {})[name]

    _save_raw(raw, config_path)
    print("Macro '{}' removed from macros.{}".format(name, mtype))


def cmd_del_zone(cfg, builtins, args):
    name = args.name

    raw, config_path = _load_raw(args)

    if name not in raw.get("zones", {}):
        print("Error: zone '{}' not found in zones".format(name), file=sys.stderr)
        sys.exit(1)

    del raw["zones"][name]
    raw.get("policy",  {}).pop(name, None)
    raw.get("log",     {}).pop(name, None)
    raw.get("macros",  {}).get("net", {}).pop(name, None)

    for key in raw.get("forward", {}):
        src, dst = _parse_forward_key(key)
        if src == name or dst == name:
            print("  WARN: forward rule '{}' still references zone '{}'".format(key, name))
    for rule in raw.get("dnat", []):
        iif = rule.get("iif", [])
        if isinstance(iif, str):
            iif = [iif]
        if name in iif:
            print("  WARN: dnat rule (from={} service={}) still references zone '{}' in iif".format(
                rule.get('from'), rule.get('service'), name))
    for rule in raw.get("snat", []):
        if rule.get("zone") == name:
            print("  WARN: snat rule still references zone '{}'".format(name))

    _save_raw(raw, config_path)
    print("Zone '{}' removed".format(name))


def cmd_del_forward(cfg, builtins, args):
    src   = args.src
    dst   = args.dst
    bidir = bool(args.bidir)

    candidates = _forward_key_candidates(src, dst, bidir)

    raw, config_path = _load_raw(args)

    forward = raw.get("forward", {})
    key_found = next((c for c in candidates if c in forward), None)
    if key_found is None:
        direction = "<->" if bidir else "->"
        print("Error: forward rule '{} {} {}' not found".format(src, direction, dst), file=sys.stderr)
        sys.exit(1)

    del forward[key_found]

    _save_raw(raw, config_path)
    print("Forward rule '{}' removed".format(key_found))


def cmd_del_dnat(cfg, builtins, args):
    raw, config_path = _load_raw(args)

    dnat = raw.get("dnat", [])
    match_idx = None

    if args.hash:
        for i, rule in enumerate(dnat):
            if _dnat_hash(rule) == args.hash:
                match_idx = i
                break
    else:
        if not args.to or not args.s or not args.i:
            print("Error: provide --hash, or all of: to -s -i [-f from]", file=sys.stderr)
            sys.exit(1)
        for i, rule in enumerate(dnat):
            if (rule.get("from", None) == args.frm and
                    rule.get("to")   == args.to  and
                    sorted(_as_list(rule.get("service", []))) == sorted(_as_list(args.s)) and
                    sorted(_as_list(rule.get("iif",     []))) == sorted(_as_list(args.i)) and
                    str(rule.get("port", "")) == str(args.p or "")):
                match_idx = i
                break

    if match_idx is None:
        print("Error: DNAT rule not found", file=sys.stderr)
        sys.exit(1)

    dnat.pop(match_idx)
    _save_raw(raw, config_path)
    print("DNAT rule removed")


def cmd_del_snat(cfg, builtins, args):
    raw, config_path = _load_raw(args)

    snat = raw.get("snat", [])
    via = getattr(args, "via", None)
    if via:
        matches = [i for i, r in enumerate(snat) if r.get("zone") == args.zone and r.get("via") == via]
        label = "'{} -> {}'".format(args.zone, via)
    else:
        matches = [i for i, r in enumerate(snat) if r.get("zone") == args.zone]
        label = "for zone '{}'".format(args.zone)
    if not matches:
        print("Error: SNAT rule {} not found".format(label), file=sys.stderr)
        sys.exit(1)

    for i in reversed(matches):
        snat.pop(i)

    _save_raw(raw, config_path)
    print("SNAT rule(s) {} removed".format(label))


def _redirect_hash(rule):
    svc = rule.get("service", "")
    if isinstance(svc, list):
        svc = ",".join(svc)
    key = "{}:{}:{}".format(svc, rule.get('to'), rule.get('port', ''))
    return hashlib.sha1(key.encode()).hexdigest()[:8]


def cmd_add_redirect(cfg, builtins, args):
    services = args.s if len(args.s) > 1 else args.s[0]

    raw, config_path = _load_raw(args)

    _check_services_or_exit(_as_list(services), raw, builtins, "redirect service")
    to = args.to
    if not _is_direct_ip(to):
        _check_ip_macro_or_exit(to, raw, "redirect to")

    for existing in raw.get("redirect", []):
        if (sorted(_as_list(existing.get("service", []))) == sorted(_as_list(services)) and
                existing.get("to") == to and
                str(existing.get("port", "")) == str(args.p or "")):
            print("Error: redirect rule already exists. Use 'fwctl edit' to modify.", file=sys.stderr)
            sys.exit(1)

    rule = {"service": services, "to": to}
    if args.p is not None:
        rule["port"] = args.p

    if not isinstance(raw.get("redirect"), list):
        raw["redirect"] = []
    raw["redirect"].append(rule)
    _save_raw(raw, config_path)
    svc_str = ", ".join(_as_list(services))
    port_str = ":{}".format(args.p) if args.p is not None else ""
    print("Redirect rule added: [{}] -> {}{}".format(svc_str, to, port_str))


def cmd_del_redirect(cfg, builtins, args):
    raw, config_path = _load_raw(args)

    redirect = raw.get("redirect", [])
    match_idx = None

    if args.hash:
        for i, rule in enumerate(redirect):
            if _redirect_hash(rule) == args.hash:
                match_idx = i
                break
    else:
        if not args.to or not args.s:
            print("Error: provide --hash, or all of: to -s [-p port]", file=sys.stderr)
            sys.exit(1)
        for i, rule in enumerate(redirect):
            if (rule.get("to") == args.to and
                    sorted(_as_list(rule.get("service", []))) == sorted(_as_list(args.s)) and
                    str(rule.get("port", "")) == str(args.p or "")):
                match_idx = i
                break

    if match_idx is None:
        print("Error: redirect rule not found", file=sys.stderr)
        sys.exit(1)

    redirect.pop(match_idx)
    _save_raw(raw, config_path)
    print("Redirect rule removed")


def cmd_edit_redirect(args):
    raw, config_path = _load_raw(args)
    redirect = raw.get("redirect", [])

    def _rule_str(rule):
        svc = rule.get("service", [])
        svc_s = ", ".join(svc) if isinstance(svc, list) else svc
        port = rule.get("port")
        port_s = ":{}".format(port) if port is not None else ""
        return "[{}] -> {}{}".format(svc_s, rule.get('to'), port_s)

    if not getattr(args, "hash", None):
        if not redirect:
            print("No redirect rules defined."); return
        print("=== Redirect rules ===")
        for i, rule in enumerate(redirect):
            print("  [{}] {}  {}".format(i+1, _redirect_hash(rule), _rule_str(rule)))
        try:
            choice = input("\nEnter number to edit (Ctrl+C to abort): ").strip()
        except KeyboardInterrupt:
            print("\nAborted."); return
        try:
            idx = int(choice) - 1
            if not (0 <= idx < len(redirect)):
                raise ValueError
        except ValueError:
            print("Invalid selection.", file=sys.stderr); sys.exit(1)
        rule = redirect[idx]
        print("Editing: {}".format(_rule_str(rule)))

        def _prompt_list(field, prompt):
            cur = rule.get(field, [])
            cur_s = ", ".join(cur) if isinstance(cur, list) else str(cur)
            val = input("  {} [{}]: ".format(prompt, cur_s)).strip()
            if not val:
                return
            parts = [p.strip() for p in val.split(",") if p.strip()]
            rule[field] = parts if len(parts) > 1 else parts[0]

        for field, prompt in [("service", "service (comma-sep)"), ("to", "to"), ("port", "port")]:
            if field == "service":
                _prompt_list(field, prompt)
            else:
                val = input("  {} [{}]: ".format(prompt, rule.get(field, ''))).strip()
                if val:
                    rule[field] = val
        _save_raw(raw, config_path)
        print("Redirect rule updated.")
        return

    h = args.hash
    rule = next((r for r in redirect if _redirect_hash(r) == h), None)
    if rule is None:
        print("Error: redirect rule with hash '{}' not found.".format(h), file=sys.stderr); sys.exit(1)
    if args.s:
        builtins, _ = load_builtin_conf()
        _check_services_or_exit(args.s, raw, builtins, "redirect service")
        rule["service"] = args.s if len(args.s) > 1 else args.s[0]
    if args.to:
        if not _is_direct_ip(args.to):
            _check_ip_macro_or_exit(args.to, raw, "redirect to")
        rule["to"] = args.to
    if args.p is not None:
        if args.p == "":
            rule.pop("port", None)
        else:
            rule["port"] = args.p
    _save_raw(raw, config_path)
    print("Redirect rule {} updated.".format(h))


def _dnat_hash(rule):
    svc = rule.get("service", "")
    if isinstance(svc, list):
        svc = ",".join(svc)
    iif = rule.get("iif", "")
    if isinstance(iif, list):
        iif = ",".join(sorted(iif))
    key = "{}:{}:{}:{}:{}".format(rule.get('from'), svc, iif, rule.get('to'), rule.get('port', ''))
    return hashlib.sha1(key.encode()).hexdigest()[:8]


def _load_raw(args):
    config_path = find_config(args.config)
    with open(config_path) as f:
        return yaml.safe_load(f), config_path


def _save_raw(raw, config_path):
    for key in ("dnat", "snat", "redirect"):
        if key in raw and isinstance(raw[key], dict):
            raw[key] = []
    with open(config_path, "w") as f:
        yaml.dump(raw, f, default_flow_style=False, allow_unicode=True, sort_keys=False,
                  Dumper=_yaml_dumper())


def cmd_edit_zone(args):
    raw, config_path = _load_raw(args)
    if args.name not in raw.get("zones", {}):
        print("Error: zone '{}' not found.".format(args.name), file=sys.stderr); sys.exit(1)
    raw["zones"][args.name] = args.iface
    _save_raw(raw, config_path)
    print("Zone '{}' updated: iface={}".format(args.name, args.iface))


def cmd_edit_policy(args):
    raw, config_path = _load_raw(args)
    if args.zone not in raw.get("policy", {}):
        print("Error: policy for zone '{}' not found.".format(args.zone), file=sys.stderr); sys.exit(1)
    builtins, _ = load_builtin_conf()
    _check_services_or_exit(args.i, raw, builtins, "policy.{}.i".format(args.zone))
    _check_services_or_exit(args.o, raw, builtins, "policy.{}.o".format(args.zone))
    entry = raw["policy"][args.zone]
    if args.i is not None:
        entry["i"] = "accept" if args.i == ["accept"] else args.i
    if args.o is not None:
        entry["o"] = "accept" if args.o == ["accept"] else args.o
    _save_raw(raw, config_path)
    print("Policy for zone '{}' updated.".format(args.zone))


def cmd_edit_pod(args):
    raw, config_path = _load_raw(args)
    pods = raw.setdefault("pods", {})
    if args.name not in pods:
        print("Error: pod '{}' not found.".format(args.name), file=sys.stderr); sys.exit(1)
    pod = pods[args.name]
    if args.ip:
        pod["ip"] = args.ip
    if args.in_svcs is not None:
        pod["i"] = "accept" if args.in_svcs == ["accept"] else args.in_svcs
    if args.out is not None:
        pod["o"] = "accept" if args.out == ["accept"] else args.out
    if args.dnat is not None:
        builtins_edit, _ = load_builtin_conf()
        known_dict = dict(builtins_edit)
        known_dict.update(raw.get("macros", {}).get("proto", {}))
        dnat_dict = {}
        for entry in args.dnat:
            if ":" not in entry:
                print("Error: --dnat format is service:port (e.g. http:8080)", file=sys.stderr)
                sys.exit(1)
            svc, port = entry.split(":", 1)
            int_port = _resolve_port_num(port.strip(), known_dict)
            if int_port is None:
                print("Error: cannot resolve port '{}' in --dnat '{}'".format(port.strip(), entry), file=sys.stderr)
                sys.exit(1)
            dnat_dict[svc.strip()] = int_port
        in_svcs = list(pod.get("i") if isinstance(pod.get("i"), list) else ([pod["i"]] if pod.get("i") else []))
        for ext_svc, int_port in dnat_dict.items():
            port_covered = False
            for existing in in_svcs:
                if existing in known_dict:
                    existing_tokens = str(known_dict[existing]).replace(";", " ").split()
                    if str(int_port) in existing_tokens:
                        port_covered = True
                        break
            if port_covered:
                continue
            int_svc = _find_service_for_port(int_port, known_dict)
            if int_svc:
                if int_svc not in in_svcs:
                    in_svcs.append(int_svc)
                    print("Note: added '{}' to inbound (port {} for dnat {}:{})".format(int_svc, int_port, ext_svc, int_port))
            else:
                print("Warning: no service found for internal port {} "
                      "(dnat {}:{}) - add a macro and include it in -i manually".format(int_port, ext_svc, int_port))
        pod["dnat"] = dnat_dict
        if in_svcs:
            pod["i"] = in_svcs
    if args.zones is not None:
        pod["zones"] = args.zones
    _save_raw(raw, config_path)
    print("Pod '{}' updated.".format(args.name))


def cmd_edit_forward(args):
    raw, config_path = _load_raw(args)
    forward = raw.get("forward", {})
    bidir = args.rules and args.rules[0] == "-"
    rules = args.rules[1:] if bidir else (args.rules or [])
    if not rules:
        print("Error: specify services or 'accept'", file=sys.stderr); sys.exit(1)
    candidates = _forward_key_candidates(args.src, args.dst, bidir)
    key_found = next((c for c in candidates if c in forward), None)
    if key_found is None:
        print("Error: forward rule '{} {} {}' not found.".format(
            args.src, '<->' if bidir else '->', args.dst), file=sys.stderr); sys.exit(1)
    val = "accept" if rules == ["accept"] else rules
    if isinstance(val, list):
        builtins, _ = load_builtin_conf()
        _check_services_or_exit(val, raw, builtins, "forward '{}'".format(key_found))
    forward[key_found] = val
    _save_raw(raw, config_path)
    print("Forward '{}' updated.".format(key_found))


def cmd_edit_log(args):
    raw, config_path = _load_raw(args)
    if args.zone not in raw.get("log", {}):
        print("Error: log entry for zone '{}' not found.".format(args.zone), file=sys.stderr); sys.exit(1)
    _check_log_macro_or_exit(args.i, raw, "log.{} inbound".format(args.zone))
    _check_log_macro_or_exit(args.o, raw, "log.{} outbound".format(args.zone))
    entry = list(raw["log"][args.zone])
    if args.i is not None:
        entry[0] = args.i
    if args.o is not None:
        if len(entry) < 2:
            entry.append(args.o)
        else:
            entry[1] = args.o
    raw["log"][args.zone] = entry
    _save_raw(raw, config_path)
    print("Log for zone '{}' updated.".format(args.zone))


def cmd_edit_macro(args):
    raw, config_path = _load_raw(args)
    section = raw.get("macros", {}).get(args.mtype, {})
    if args.name not in section:
        print("Error: macro {} '{}' not found.".format(args.mtype, args.name), file=sys.stderr); sys.exit(1)
    if args.mtype == "proto":
        if len(args.values) == 1:
            section[args.name] = args.values[0]
        else:
            section[args.name] = _group_proto_tokens(args.values)
    elif args.mtype == "ip":
        section[args.name] = args.values[0]
    elif args.mtype == "pod":
        section[args.name] = args.values[0]
    elif args.mtype == "net":
        section[args.name] = args.values if len(args.values) > 1 else args.values[0]
    elif args.mtype == "log":
        label = args.values[0] if args.values else args.name
        section[args.name] = 'log "{}"'.format(label)
    _save_raw(raw, config_path)
    print("Macro {} '{}' updated.".format(args.mtype, args.name))


def cmd_edit_dnat(args):
    raw, config_path = _load_raw(args)
    dnat = raw.get("dnat", [])

    def _rule_str(rule):
        svc = rule.get("service", [])
        iif = rule.get("iif", [])
        svc_s = ", ".join(svc) if isinstance(svc, list) else svc
        iif_s = ", ".join(iif) if isinstance(iif, list) else iif
        port  = rule.get("port")
        port_s = ":{}".format(port) if port is not None else ""
        return "{}  [{}]  iif=[{}]  -> {}{}".format(rule.get('from'), svc_s, iif_s, rule.get('to'), port_s)

    if not getattr(args, "hash", None):
        if not dnat:
            print("No DNAT rules defined."); return
        print("=== DNAT rules ===")
        for i, rule in enumerate(dnat):
            print("  [{}] {}  {}".format(i+1, _dnat_hash(rule), _rule_str(rule)))
        try:
            choice = input("\nEnter number to edit (Ctrl+C to abort): ").strip()
        except KeyboardInterrupt:
            print("\nAborted."); return
        try:
            idx = int(choice) - 1
            if not (0 <= idx < len(dnat)):
                raise ValueError
        except ValueError:
            print("Invalid selection.", file=sys.stderr); sys.exit(1)
        rule = dnat[idx]
        print("Editing: {}".format(_rule_str(rule)))

        def _prompt_list(field, prompt):
            cur = rule.get(field, [])
            cur_s = ", ".join(cur) if isinstance(cur, list) else str(cur)
            val = input("  {} [{}]: ".format(prompt, cur_s)).strip()
            if not val:
                return
            parts = [p.strip() for p in val.split(",") if p.strip()]
            rule[field] = parts if len(parts) > 1 else parts[0]

        for field, prompt in [("from", "from"), ("service", "service (comma-sep)"),
                               ("iif", "iif (comma-sep)"), ("to", "to"), ("port", "port")]:
            if field in ("service", "iif"):
                _prompt_list(field, prompt)
            else:
                val = input("  {} [{}]: ".format(prompt, rule.get(field, ''))).strip()
                if val:
                    rule[field] = val
        _save_raw(raw, config_path)
        print("DNAT rule updated.")
        return

    h = args.hash
    rule = next((r for r in dnat if _dnat_hash(r) == h), None)
    if rule is None:
        print("Error: DNAT rule with hash '{}' not found.".format(h), file=sys.stderr); sys.exit(1)
    if args.frm:
        _check_ip_macro_or_exit(args.frm, raw, "dnat from")
        rule["from"] = args.frm
    if args.s:
        builtins, _ = load_builtin_conf()
        _check_services_or_exit(args.s, raw, builtins, "dnat service")
        rule["service"] = args.s if len(args.s) > 1 else args.s[0]
    if args.i:
        for iif_val in args.i:
            _check_zone_or_exit(iif_val, raw, "dnat iif")
        rule["iif"] = args.i if len(args.i) > 1 else args.i[0]
    if args.to:
        _check_ip_macro_or_exit(args.to, raw, "dnat to")
        rule["to"] = args.to
    if args.p is not None:
        if args.p == "":
            rule.pop("port", None)
        else:
            rule["port"] = args.p
    _save_raw(raw, config_path)
    print("DNAT rule {} updated.".format(h))


def cmd_edit_snat(args):
    raw, config_path = _load_raw(args)
    snat = raw.get("snat", [])
    rule = next((r for r in snat if r.get("zone") == args.zone), None)
    if rule is None:
        print("Error: SNAT rule for zone '{}' not found.".format(args.zone), file=sys.stderr); sys.exit(1)
    if args.via:
        _check_zone_or_exit(args.via, raw, "snat via")
        rule["via"] = args.via
    if args.ip is not None:
        if args.ip == "":
            rule.pop("ip", None)
        else:
            rule["ip"] = args.ip
    _save_raw(raw, config_path)
    print("SNAT rule for zone '{}' updated.".format(args.zone))


def cmd_edit(cfg_unused, builtins, args):
    import shutil, tempfile
    config_path = find_config(args.config)
    editor = os.environ.get("EDITOR") or os.environ.get("VISUAL")
    if not editor:
        for candidate in ("nano", "vim", "vi"):
            if subprocess.run(["which", candidate], capture_output=True).returncode == 0:
                editor = candidate
                break
        else:
            print("Error: no editor found. Set $EDITOR.", file=sys.stderr)
            sys.exit(1)

    with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as tmp:
        tmp_path = tmp.name
    shutil.copy2(config_path, tmp_path)

    try:
        while True:
            subprocess.run([editor, tmp_path])

            try:
                with open(tmp_path) as f:
                    raw = yaml.safe_load(f)
            except yaml.YAMLError as e:
                print("\nYAML error: {}".format(e), file=sys.stderr)
                try:
                    input("\nPress Enter to re-edit, Ctrl+C to abort...")
                except KeyboardInterrupt:
                    print("\nAborted. No changes saved.")
                    return
                continue

            builtins, _ = load_builtin_conf()
            norm_msgs = normalize_raw_yaml(raw)
            if norm_msgs:
                print("Normalized {}:".format(os.path.basename(config_path)))
                for msg in norm_msgs:
                    print(msg)
                with open(tmp_path, "w") as f:
                    yaml.dump(raw, f, default_flow_style=False, allow_unicode=True,
                              sort_keys=False, Dumper=_yaml_dumper())
            cfg = normalize_config(raw)
            errors, warnings = validate(cfg, builtins)
            if warnings:
                for w in warnings:
                    print(w)
            if errors:
                for e in errors:
                    print(e, file=sys.stderr)
                try:
                    input("\nPress Enter to re-edit, Ctrl+C to abort...")
                except KeyboardInterrupt:
                    print("\nAborted. No changes saved.")
                    return
                continue

            with open(config_path) as f:
                original = f.read()
            with open(tmp_path) as f:
                modified = f.read()
            if original == modified:
                print("No changes.")
            else:
                shutil.copy2(tmp_path, config_path)
                print("Saved: {}".format(config_path))
            return
    except KeyboardInterrupt:
        print("\nAborted. No changes saved.")
    finally:
        if os.path.exists(tmp_path):
            os.unlink(tmp_path)


def _detect_wan_iface():
    try:
        result = subprocess.run(
            ["ip", "route", "get", "1.1.1.1"],
            capture_output=True, text=True,
        )
        tokens = result.stdout.split()
        if "dev" in tokens:
            return tokens[tokens.index("dev") + 1]
    except (OSError, subprocess.SubprocessError, ValueError, IndexError):
        pass
    return None


def _iface_networks(iface):
    try:
        result = subprocess.run(
            ["ip", "-j", "addr", "show", iface],
            capture_output=True, text=True,
        )
        if result.returncode != 0:
            return []
        data = json.loads(result.stdout)
        nets = []
        for entry in data:
            for addr in entry.get("addr_info", []):
                local  = addr.get("local", "")
                prefix = addr.get("prefixlen")
                if not local or prefix is None:
                    continue
                try:
                    net = ipaddress.ip_interface("{}/{}".format(local, prefix)).network
                    if net.is_link_local:
                        continue
                    nets.append(str(net))
                except ValueError:
                    pass
        return list(dict.fromkeys(nets))
    except (OSError, json.JSONDecodeError, subprocess.SubprocessError):
        return []


def cmd_init(cfg_unused, builtins, args):
    out = args.output or "/etc/fw.yaml"

    if os.path.exists(out) and not args.force:
        print("Error: '{}' already exists. Use --force to overwrite.".format(out), file=sys.stderr)
        sys.exit(1)

    iface = args.iface or _detect_wan_iface() or "eth0"

    ipv4, ipv6 = _get_iface_ips(iface)
    networks   = _iface_networks(iface)

    ip_lines = ""
    if ipv4:
        ip_lines += "\n    wan4: {}".format(ipv4)
    if ipv6:
        ip_lines += "\n    wan6: {}".format(ipv6)
    ip_block = ip_lines if ip_lines else " {}"

    if networks:
        if len(networks) == 1:
            net_block = "\n    wan: {}".format(networks[0])
        else:
            nets_yaml = ", ".join('"{}"'.format(n) for n in networks)
            net_block = "\n    wan: [{}]".format(nets_yaml)
    else:
        net_block = " {}"

    init_zones = {"fw": None, "wan": iface}
    localhost_zone = next((z for z, i in init_zones.items() if i is None), "fw")

    if ipv4 or ipv6 or networks:
        detected = ", ".join(filter(None, [ipv4, ipv6] + networks))
        print("Detected: {} ({})".format(iface, detected))

    content = """\
fwctl:
  output: {foomuuri_dir}/fw.conf
  cmd: ""
  auto_ips: yes
  exclude_services: [ping, ping6, dns, ntp, dhcp, igmp, ssdp, mdns, llmnr, netbios-ns]

settings:
  localhost_zone: {localhost_zone}
  dbus_zone: wan
  log_level: '"group 0"'
  log_input: yes
  log_output: yes
  log_forward: yes

zones:
  fw:  ~
  wan: {iface}

macros:
  proto:
    dns: "tcp 53; udp 53"
  ip:{ip_block}
  net:{net_block}
  log:
    login:  'log "login"'
    logout: 'log "logout"'

log:
  wan: [login, logout]

policy:
  wan:
    i: []
    o: [http, http2, dns, ntp]

forward: {{}}
dnat: []
snat: []
pods: {{}}
""".format(foomuuri_dir=FOOMUURI_DIR, localhost_zone=localhost_zone, iface=iface,
           ip_block=ip_block, net_block=net_block)
    os.makedirs(os.path.dirname(os.path.abspath(out)), exist_ok=True)
    with open(out, "w") as f:
        f.write(content)
    print("Written: {}".format(out))


def _use_color():
    return hasattr(sys.stdout, "isatty") and sys.stdout.isatty()

def _c(code, text):
    if _use_color():
        return "\033[{}m{}\033[0m".format(code, text)
    return str(text)

def _c_title(text):   return _c("1;36", text)   # bold cyan
def _c_key(text):     return _c("32", text)      # green
def _c_val(text):     return _c("33", text)      # yellow
def _c_dim(text):     return _c("2", text)       # dim/gray
def _c_hi(text):      return _c("1;33", text)    # bold yellow (highlight)


def _fmt_yaml_value(val):
    if isinstance(val, bool):
        return "yes" if val else "no"
    if isinstance(val, list):
        return " ".join(str(x) for x in val) if val else "[]"
    return str(val)


def _kv(indent, key, val, width=0):
    pad = " " * (width - len(key)) if width > len(key) else ""
    return "{}{}{}  {}".format(indent, _c_key(key), pad, _c_val(val))


def _show_settings(cfg, foomuuri_settings=None, show_default=False):
    sd = foomuuri_settings if foomuuri_settings is not None else FOOMUURI_SETTINGS_DEFAULTS
    settings = cfg.get("settings", {})

    if show_default:
        print(_c_title("=== Settings (foomuuri defaults) ==="))
        w = max(len(k) for k in sd)
        for key, val in sd.items():
            print(_kv("  ", key, _fmt_yaml_value(val), w))
        return

    all_known = list(sd.keys())
    custom = {k: v for k, v in settings.items()}
    defaults = {k: sd[k] for k in all_known if k not in settings}

    print(_c_title("=== Settings ==="))
    if defaults:
        print("  {}:".format(_c_dim("Default")))
        w = max(len(k) for k in defaults)
        for key, val in defaults.items():
            print(_kv("    ", key, _fmt_yaml_value(val), w))
    if custom:
        print("  {}:".format(_c_hi("Custom")))
        w = max(len(k) for k in custom)
        for key, val in custom.items():
            print(_kv("    ", key, _fmt_yaml_value(val), w))


def _show_zones(cfg):
    zones = cfg.get("zones", {})
    print(_c_title("=== Zones ==="))
    w = max(len(z) for z in zones) if zones else 0
    for zone, iface in zones.items():
        if iface:
            print(_kv("  ", zone, iface, w))
        else:
            print("  {}{}  {}".format(_c_key(zone), " " * (w - len(zone)), _c_dim("(no interface)")))


def _macro_matches(name, val, query):
    if name == query:
        return True
    tokens = val.replace(";", " ").split()
    return query in tokens


def _show_macros(cfg, builtins, subtype=None, query=None, show_default=False, include_builtin=True):
    macros = cfg.get("macros", {})
    custom_sections = [
        ("proto", macros.get("proto", {})),
        ("ip",    macros.get("ip", {})),
        ("pod",   macros.get("pod", {})),
        ("log",   macros.get("log", {})),
    ]
    custom_map = {label: items for label, items in custom_sections}
    net = macros.get("net", {})

    def filtered(items):
        if not query:
            return items
        return {k: v for k, v in items.items() if _macro_matches(k, v, query)}

    def print_items(items, indent="    "):
        w = max(len(k) for k in items)
        for name, val in items.items():
            print(_kv(indent, name, val, w))

    def print_labeled(label, items, indent="  "):
        items = filtered(items)
        if not items:
            return
        print("{}{}:".format(indent, _c_dim(label)))
        print_items(items, indent + "  ")

    def print_net(indent="  "):
        if not query and net:
            print("{}{}:".format(indent, _c_dim("net")))
            for zone_name, families in net.items():
                for fam, cidr in families.items():
                    print("{}  {}  {}".format(indent, _c_key("{}/{}".format(zone_name, fam)), _c_val(cidr)))

    def print_type(t, indent="  "):
        if t == "builtin":
            items = filtered({k: builtins[k] for k in sorted(builtins)})
            if items:
                print("{}{}:".format(indent, _c_dim("Built-in (foomuuri)")))
                print_items(items, indent + "  ")
        elif t == "net":
            print_net(indent)
        elif t in custom_map:
            print_labeled(t, custom_map[t], indent)

    if show_default:
        print(_c_title("=== Macros (foomuuri built-in) ==="))
        items = filtered({k: builtins[k] for k in sorted(builtins)})
        if items:
            print_items(items)
        return

    if query:
        print(_c_title("=== Macros matching '{}' ===".format(query)))
    else:
        print(_c_title("=== Macros ==="))

    if subtype:
        print_type(subtype)
        return

    if include_builtin:
        items = filtered({k: builtins[k] for k in sorted(builtins)})
        if items:
            print("  {}:".format(_c_dim("Default (foomuuri built-in)")))
            print_items(items)
    has_custom = any([
        filtered(macros.get("proto", {})),
        filtered(macros.get("ip", {})),
        filtered(macros.get("pod", {})),
        filtered(macros.get("log", {})),
        net and not query,
    ])
    if has_custom:
        print("  {}:".format(_c_hi("Custom")))
    for label, section_items in custom_sections:
        print_labeled(label, section_items, "    ")
        if label == "ip":
            print_net("    ")


def _show_policy(cfg, zone_filter=None, dir_filter=None, compact=False):
    policy = cfg.get("policy", {})
    log_cfg = cfg.get("log", {})
    zones = [zone_filter] if zone_filter else list(policy.keys())
    if not compact:
        print(_c_title("=== Policy ==="))
    for zone in zones:
        if zone not in policy:
            if not compact:
                print("  zone '{}' not found in policy".format(zone))
            continue
        dirs = policy[zone]
        log_pair = log_cfg.get(zone, [])
        drop_i = _resolve_drop_action(log_pair[0] if len(log_pair) > 0 else None)
        drop_o = _resolve_drop_action(log_pair[1] if len(log_pair) > 1 else None)
        show_i = dir_filter in (None, "i") and "i" in dirs
        show_o = dir_filter in (None, "o") and "o" in dirs
        if compact:
            if show_i:
                svcs = dirs["i"] if isinstance(dirs["i"], list) else [dirs["i"]]
                prefix = "" if dir_filter == "i" else "i: "
                print(prefix + " ".join(svcs))
            if show_o:
                svcs = dirs["o"] if isinstance(dirs["o"], list) else [dirs["o"]]
                prefix = "" if dir_filter == "o" else "o: "
                print(prefix + " ".join(svcs))
        else:
            if show_i:
                drop_str = _c_dim("[{}]".format(drop_i or "accept"))
                print("  {} -> {}:  {}  {}".format(_c_key(zone), _c_key("fw"), _c_val(_fmt_yaml_value(dirs['i'])), drop_str))
            if show_o:
                drop_str = _c_dim("[{}]".format(drop_o or "accept"))
                print("  {} -> {}:  {}  {}".format(_c_key("fw"), _c_key(zone), _c_val(_fmt_yaml_value(dirs['o'])), drop_str))


def _show_forward(cfg):
    forward_raw = cfg.get("forward", {})
    print(_c_title("=== Forward ==="))
    for key, val in forward_raw.items():
        print("  {}  {}".format(_c_key(key), _c_val(_fmt_yaml_value(val))))


def _show_dnat(cfg, show_hash=False):
    dnat_rules = cfg.get("dnat", [])
    print(_c_title("=== DNAT ==="))
    for rule in dnat_rules:
        frm  = rule.get("from") or None
        svc  = rule.get("service", [])
        iif  = rule.get("iif", [])
        to   = rule.get("to", "")
        port = rule.get("port")
        port_str = ":{}".format(port) if port is not None else ""
        svc_str  = ", ".join(svc) if isinstance(svc, list) else svc
        iif_str  = ", ".join(iif) if isinstance(iif, list) else iif
        frm_str  = frm if frm else _c_dim("(any)")
        line = "  {}  [{}]  iif=[{}]  -> {}{}".format(frm_str, _c_val(svc_str), _c_dim(iif_str), _c_key(to), _c_val(port_str))
        if show_hash:
            print("  {}  {}".format(_c_dim(_dnat_hash(rule)), line.strip()))
        else:
            print(line)


def _show_snat(cfg):
    snat_rules = cfg.get("snat", [])
    print(_c_title("=== SNAT ==="))
    for rule in snat_rules:
        if "via" in rule:
            ip_filter = rule.get("ip")
            suffix = " {}".format(_c_dim("ip={}".format(ip_filter))) if ip_filter else ""
            print("  {} -> {}{}".format(_c_key(rule['zone']), _c_val(rule['via']), suffix))


def _show_hairpin(cfg):
    hairpin = cfg.get("hairpin") or []
    if isinstance(hairpin, str):
        hairpin = [hairpin]
    print(_c_title("=== Hairpin ==="))
    if not hairpin:
        print("  (none)")
        return
    for zone in hairpin:
        print("  {}".format(_c_key(zone)))


def _show_redirect(cfg, show_hash=False):
    redirect_rules = cfg.get("redirect", [])
    print(_c_title("=== Redirect ==="))
    for rule in redirect_rules:
        svc  = rule.get("service", [])
        to   = rule.get("to", "")
        port = rule.get("port")
        port_str = ":{}".format(port) if port is not None else ""
        svc_str  = ", ".join(svc) if isinstance(svc, list) else svc
        line = "  [{}] -> {}{}".format(_c_val(svc_str), _c_key(to), _c_val(port_str))
        if show_hash:
            print("  {}  {}".format(_c_dim(_redirect_hash(rule)), line.strip()))
        else:
            print(line)


def _show_pods(cfg):
    pods = cfg.get("pods", {})
    print(_c_title("=== Pods ==="))
    dz = pods.get("default_zones")
    if dz:
        print("  {}: {}".format(_c_dim("default_zones"), _c_val(dz)))
    for pname, pod in _iter_pods(pods):
        ip    = pod.get("ip", "")
        i_val = pod.get("i", [])
        o_val = pod.get("o", False)
        dnat  = pod.get("dnat", {})
        zones = pod.get("zones")
        i_str = ", ".join(i_val) if isinstance(i_val, list) else i_val
        o_str = ", ".join(o_val) if isinstance(o_val, list) else ("accept" if o_val else "-")
        ip_str = ", ".join(ip) if isinstance(ip, list) else ip
        print("  {}  {}".format(_c_hi(pname), _c_dim("ip=[{}]".format(ip_str))))
        if i_str:
            print("    {}    {}".format(_c_dim("i:"), _c_val(i_str)))
        if o_val is not False:
            print("    {}    {}".format(_c_dim("o:"), _c_val(o_str)))
        if dnat:
            dnat_str = "  ".join("{}->{}".format(svc, port) for svc, port in dnat.items())
            print("    {} {}".format(_c_dim("dnat:"), _c_val(dnat_str)))
        if zones:
            print("    {} {}".format(_c_dim("zones:"), _c_val(", ".join(zones))))


def _resolve_set_section(key, foomuuri_settings=None):
    if key in FWCTL_DEFAULTS:
        return "fwctl", FWCTL_DEFAULTS
    sd = foomuuri_settings if foomuuri_settings is not None else FOOMUURI_SETTINGS_DEFAULTS
    return "settings", sd


def cmd_set(cfg, builtins, args):
    key   = args.key
    value = " ".join(args.value)

    _, foomuuri_settings = load_builtin_conf()
    raw, config_path = _load_raw(args)

    section, defaults = _resolve_set_section(key, foomuuri_settings)
    block = raw.setdefault(section, {})
    default = defaults.get(key)
    if default is not None and value == str(default):
        block.pop(key, None)
        print("{}.{} removed (matches default: {})".format(section, key, default))
    else:
        block[key] = value
        print("{}.{} = {}".format(section, key, value))

    _save_raw(raw, config_path)


def cmd_unset(cfg, builtins, args):
    key = args.key

    _, foomuuri_settings = load_builtin_conf()
    raw, config_path = _load_raw(args)

    section, defaults = _resolve_set_section(key, foomuuri_settings)
    block = raw.get(section, {})
    if key not in block:
        default = defaults.get(key)
        if default is not None:
            print("{}.{} is already at default ({})".format(section, key, default))
        else:
            print("{}.{} not found".format(section, key), file=sys.stderr)
        return

    del block[key]

    _save_raw(raw, config_path)

    default = defaults.get(key)
    if default is not None:
        print("{}.{} removed (default: {})".format(section, key, default))
    else:
        print("{}.{} removed".format(section, key))


def cmd_show(cfg, builtins, args, foomuuri_settings=None):
    section      = args.section
    sub          = args.sub if hasattr(args, "sub") else None
    compact      = getattr(args, "compact", False)
    show_default = getattr(args, "default", False)
    show_hash    = getattr(args, "show_hash", False)
    macros_query   = section == "macro" and sub and sub not in {"proto", "ip", "pod", "net", "log", "builtin"}
    macros_builtin = section == "macro" and sub == "builtin"
    macros_all     = section == "macro" and sub is None

    if _yq_available() and not compact and not show_default and not macros_query and not macros_builtin and not macros_all and not show_hash:
        _show_yq(cfg, section, sub)
        return

    if section is None:
        _show_settings(cfg, foomuuri_settings);          print()
        _show_zones(cfg);                                print()
        _show_macros(cfg, builtins, include_builtin=False); print()
        _show_policy(cfg);                               print()
        _show_forward(cfg);                              print()
        _show_dnat(cfg);                                 print()
        _show_snat(cfg)
        if cfg.get("redirect"):
            print(); _show_redirect(cfg)
        if cfg.get("pods"):
            print(); _show_pods(cfg)
        if cfg.get("hairpin"):
            print(); _show_hairpin(cfg)
    elif section == "settings":
        _show_settings(cfg, foomuuri_settings, show_default=getattr(args, "default", False))
    elif section == "zone":
        _show_zones(cfg)
    elif section == "macro":
        show_def = getattr(args, "default", False)
        known_subtypes = {"proto", "ip", "pod", "net", "log", "builtin"}
        if sub in known_subtypes or sub is None:
            _show_macros(cfg, builtins, sub, show_default=show_def)
        else:
            _show_macros(cfg, builtins, query=sub, show_default=show_def)
    elif section == "policy":
        dir_filter = "i" if getattr(args, "filter_in", False) else ("o" if getattr(args, "filter_out", False) else None)
        _show_policy(cfg, sub, dir_filter=dir_filter, compact=getattr(args, "compact", False))
    elif section == "forward":
        _show_forward(cfg)
    elif section == "dnat":
        _show_dnat(cfg, show_hash=getattr(args, "show_hash", False))
    elif section == "snat":
        _show_snat(cfg)
    elif section == "redirect":
        _show_redirect(cfg, show_hash=getattr(args, "show_hash", False))
    elif section == "pod":
        _show_pods(cfg)
    elif section == "hairpin":
        _show_hairpin(cfg)
    else:
        print("Unknown section '{}'. Valid: settings, zone, macro, policy, forward, dnat, snat, redirect, pod, hairpin".format(section),
              file=sys.stderr)
        sys.exit(1)


# ---------------------------------------------------------------------------
# Log viewer
# ---------------------------------------------------------------------------

def _extract_log_prefixes(cfg):
    log_macros = cfg.get("macros", {}).get("log", {}) if cfg else {}
    result = {"login": "login", "logout": "logout"}
    for name in ("login", "logout"):
        val = log_macros.get(name, "")
        m = re.search(r'log\s+"([^"]+)"', str(val))
        if m:
            result[name] = m.group(1)
    return result["login"], result["logout"]


class _ServiceMapper:
    def __init__(self, services_file="/etc/services"):
        self.tcp = {}
        self.udp = {}
        try:
            with open(services_file) as f:
                for line in f:
                    line = line.strip()
                    if not line or line.startswith("#"):
                        continue
                    parts = line.split()
                    if len(parts) < 2 or "/" not in parts[1]:
                        continue
                    name = parts[0]
                    port_str, proto = parts[1].split("/", 1)
                    try:
                        port = int(port_str)
                    except ValueError:
                        continue
                    if proto == "tcp":
                        self.tcp[port] = name
                    elif proto == "udp":
                        self.udp[port] = name
        except OSError:
            pass

    def get(self, port, protocol):
        if protocol == 6:
            return self.tcp.get(port, "unknown")
        if protocol == 17:
            return self.udp.get(port, "unknown")
        return "unknown"


class _LogFilter:
    PROTO_MAP    = {1: "icmp", 2: "igmp", 6: "tcp", 17: "udp", 58: "icmpv6"}
    EXCLUDE_SVCS = {"ping", "ping6", "dns", "ntp", "dhcp", "igmp", "ssdp", "mdns", "llmnr", "netbios-ns"}

    def __init__(self, args, login_prefix, logout_prefix, svc_mapper, exclude_svcs=None):
        self.login        = login_prefix
        self.logout       = logout_prefix
        self.mapper       = svc_mapper
        self.exclude_svcs = exclude_svcs if exclude_svcs is not None else self.EXCLUDE_SVCS
        self.show_all     = args.all
        self.filter_in    = args.filter_in
        self.filter_out   = args.filter_out
        self.filter_net   = args.network
        self.filter_svc   = args.service
        self.filter_proto = args.proto
        self.filter_ip    = args.ip
        self.filter_port  = args.port
        self.target_hash  = args.hash

    def _prefix(self, entry):
        return entry.get("oob.prefix", "").strip()

    def _service(self, entry):
        pr  = entry.get("ip.protocol", 0)
        oob = self._prefix(entry)
        sp  = entry.get("src_port", 0)
        dp  = entry.get("dest_port", 0)
        if pr == 1:  return "ping"
        if pr == 58: return "ping6"
        if pr == 2:  return "igmp"
        if pr == 17:
            if sp == 1900 or dp == 1900:          return "ssdp"
            if sp == 53   or dp == 53:            return "dns"
            if sp in (67, 68) or dp in (67, 68): return "dhcp"
        if pr == 6 and (sp == 53 or dp == 53):    return "dns"
        port = dp if oob == self.login else sp
        return self.mapper.get(port, pr)

    def should_include(self, entry):
        pr  = entry.get("ip.protocol", 0)
        oob = self._prefix(entry)
        if self.filter_in or self.filter_out:
            match = (self.filter_in  and oob == self.login) or \
                    (self.filter_out and oob == self.logout)
            if not match:
                return False
        si = not self.show_all and self.filter_in  and not self.filter_out
        so = not self.show_all and self.filter_out and not self.filter_in
        if self.filter_net:
            if   si and entry.get("oob.in",  "") != self.filter_net: return False
            elif so and entry.get("oob.out", "") != self.filter_net: return False
            elif not si and not so and self.filter_net not in (entry.get("oob.in", ""), entry.get("oob.out", "")): return False
        if self.filter_ip:
            if   si and entry.get("dest_ip") != self.filter_ip: return False
            elif so and entry.get("src_ip")  != self.filter_ip: return False
            elif not si and not so and self.filter_ip not in (entry.get("src_ip"), entry.get("dest_ip")): return False
        if self.filter_port is not None:
            if   si and entry.get("dest_port") != self.filter_port: return False
            elif so and entry.get("src_port")  != self.filter_port: return False
            elif not si and not so and self.filter_port not in (entry.get("src_port"), entry.get("dest_port")): return False
        pt = self.PROTO_MAP.get(pr, "proto-{}".format(pr))
        if self.filter_proto and pt.lower() != self.filter_proto:
            return False
        svc = self._service(entry)
        if self.filter_svc and (not svc or svc.lower() != self.filter_svc):
            return False
        if not self.show_all and svc and svc.lower() in self.exclude_svcs:
            return False
        return True

    def format(self, entry):
        pr  = entry.get("ip.protocol", 0)
        oob = self._prefix(entry)
        pt  = self.PROTO_MAP.get(pr, "proto-{}".format(pr))
        svc = self._service(entry)
        ts  = entry.get("timestamp", "")
        h   = hashlib.sha256(ts.encode()).hexdigest() if ts else ""
        if self.show_all:
            entry["hash"]     = h
            entry["protocol"] = pt
            if svc: entry["service"] = svc
            return entry
        out = {"hash": h, "timestamp": ts, "dvc": entry.get("dvc"), "prefix": oob}
        iface_key = "oob.in" if oob == self.login else "oob.out"
        if iface_key in entry:
            out["interface"] = entry[iface_key]
        out["protocol"] = pt.upper()
        if svc: out["service"] = svc.upper()
        if "src_ip"    in entry: out["source"]      = entry["src_ip"]
        if "dest_ip"   in entry: out["destination"] = entry["dest_ip"]
        if "dest_port" in entry: out["port"]         = entry["dest_port"]
        return out


def cmd_log(cfg, builtins, args):
    log_file = "/var/log/fw.json"
    if cfg:
        log_file = cfg.get("fwctl", {}).get("log_file", log_file)

    login_prefix, logout_prefix = _extract_log_prefixes(cfg)
    fwctl_cfg = cfg.get("fwctl", {}) if cfg else {}
    exclude_svcs = set(fwctl_cfg["exclude_services"] or []) if "exclude_services" in fwctl_cfg else None
    log_filter = _LogFilter(args, login_prefix, logout_prefix, _ServiceMapper(), exclude_svcs)

    if not os.path.exists(log_file):
        print("Log file not found: {}".format(log_file), file=sys.stderr)
        sys.exit(1)

    try:
        jq = subprocess.Popen(
            ["jq", "-C", "."],
            stdin=subprocess.PIPE, stdout=sys.stdout, stderr=sys.stderr, text=True,
        )
    except FileNotFoundError:
        print("Error: jq not found. Install with: apt install jq", file=sys.stderr)
        sys.exit(1)

    try:
        with open(log_file) as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    entry = json.loads(line)
                except json.JSONDecodeError:
                    continue
                try:
                    if log_filter.target_hash:
                        ts = entry.get("timestamp", "")
                        if hashlib.sha256(ts.encode()).hexdigest() == log_filter.target_hash:
                            jq.stdin.write(json.dumps(entry) + "\n")
                            jq.stdin.flush()
                    elif log_filter.should_include(entry):
                        jq.stdin.write(json.dumps(log_filter.format(entry)) + "\n")
                        jq.stdin.flush()
                except BrokenPipeError:
                    break
    except OSError as e:
        print("Error reading log file: {}".format(e), file=sys.stderr)
        sys.exit(1)
    finally:
        try:
            if jq.stdin and not jq.stdin.closed:
                jq.stdin.close()
            jq.wait()
        except Exception:
            pass


# ---------------------------------------------------------------------------
# Shell completion
# ---------------------------------------------------------------------------

_COMPLETION_BASH = r"""# fwctl bash completion
_fwctl() {
    local cur prev skip_next i cmd="" subcmd="" subsubcmd=""
    COMPREPLY=()
    cur="${COMP_WORDS[COMP_CWORD]}"
    prev="${COMP_WORDS[COMP_CWORD-1]}"

    # Collect positional words (skip flags and their values)
    skip_next=0
    for ((i=1; i<COMP_CWORD; i++)); do
        local w="${COMP_WORDS[i]}"
        if ((skip_next)); then skip_next=0; continue; fi
        if [[ "$w" == "--config" ]]; then skip_next=1; continue; fi
        if [[ "$w" == -* ]]; then continue; fi
        if [[ -z "$cmd" ]]; then cmd="$w"
        elif [[ -z "$subcmd" ]]; then subcmd="$w"
        elif [[ -z "$subsubcmd" ]]; then subsubcmd="$w"
        fi
    done

    [[ "$prev" == "--config" ]] && { COMPREPLY=($(compgen -f -- "$cur")); return; }

    case "$cmd" in
        "")
            COMPREPLY=($(compgen -W "init check gen apply log add del edit show set unset completion --config --version -V" -- "$cur")) ;;
        init)
            COMPREPLY=($(compgen -W "--output --iface --force" -- "$cur")) ;;
        gen)
            COMPREPLY=($(compgen -W "--print --ulogd" -- "$cur")) ;;
        apply)
            COMPREPLY=($(compgen -W "--dry-run" -- "$cur")) ;;
        log)
            COMPREPLY=($(compgen -W "-a -i -o -n -s -p --ip --port --hash" -- "$cur")) ;;
        completion)
            COMPREPLY=($(compgen -W "bash zsh" -- "$cur")) ;;
        set|unset)
            COMPREPLY=($(compgen -W "output ulogd cmd auto_ips log_file exclude_services localhost_zone dbus_zone log_level log_input log_output log_forward" -- "$cur")) ;;
        add)
            case "$subcmd" in
                "")      COMPREPLY=($(compgen -W "macro zone policy pod forward dnat snat redirect" -- "$cur")) ;;
                macro)   [[ -z "$subsubcmd" ]] && COMPREPLY=($(compgen -W "proto ip pod net log" -- "$cur")) ;;
                zone)    COMPREPLY=($(compgen -W "--net" -- "$cur")) ;;
                policy)  COMPREPLY=($(compgen -W "--in --out" -- "$cur")) ;;
                pod)     COMPREPLY=($(compgen -W "--in --dnat --out --zones" -- "$cur")) ;;
                dnat)    COMPREPLY=($(compgen -W "-f --from -s --service -i --iif -p --port" -- "$cur")) ;;
                snat)     COMPREPLY=($(compgen -W "-i" -- "$cur")) ;;
                redirect) COMPREPLY=($(compgen -W "-s --service -p --port" -- "$cur")) ;;
            esac ;;
        del)
            case "$subcmd" in
                "")       COMPREPLY=($(compgen -W "pod macro zone forward dnat snat redirect" -- "$cur")) ;;
                macro)    [[ -z "$subsubcmd" ]] && COMPREPLY=($(compgen -W "proto ip pod net log" -- "$cur")) ;;
                dnat)     COMPREPLY=($(compgen -W "--hash -f --from -s -i -p" -- "$cur")) ;;
                redirect) COMPREPLY=($(compgen -W "--hash -s -p" -- "$cur")) ;;
            esac ;;
        edit)
            case "$subcmd" in
                "")      COMPREPLY=($(compgen -W "zone policy pod forward log macro dnat snat redirect" -- "$cur")) ;;
                macro)   [[ -z "$subsubcmd" ]] && COMPREPLY=($(compgen -W "proto ip pod net log" -- "$cur")) ;;
                policy)  COMPREPLY=($(compgen -W "--in --out" -- "$cur")) ;;
                pod)     COMPREPLY=($(compgen -W "--ip --in --out --dnat --zones" -- "$cur")) ;;
                log)     COMPREPLY=($(compgen -W "--in --out" -- "$cur")) ;;
                dnat)    COMPREPLY=($(compgen -W "-f --from -s --service -i --iif -t --to -p --port" -- "$cur")) ;;
                snat)     COMPREPLY=($(compgen -W "-v --via -i" -- "$cur")) ;;
                redirect) COMPREPLY=($(compgen -W "-s --service -t --to -p --port" -- "$cur")) ;;
            esac ;;
        show)
            case "$subcmd" in
                "")       COMPREPLY=($(compgen -W "settings zone macro policy forward dnat snat redirect pod" -- "$cur")) ;;
                macro)    [[ -z "$subsubcmd" ]] && COMPREPLY=($(compgen -W "proto ip pod net log builtin" -- "$cur")) ;;
                policy)   COMPREPLY=($(compgen -W "-i -o -c --compact" -- "$cur")) ;;
                dnat)     COMPREPLY=($(compgen -W "--hash" -- "$cur")) ;;
                redirect) COMPREPLY=($(compgen -W "--hash" -- "$cur")) ;;
                settings) COMPREPLY=($(compgen -W "--default" -- "$cur")) ;;
            esac ;;
    esac
}

complete -F _fwctl fwctl
"""

_COMPLETION_ZSH = r"""#compdef fwctl

_fwctl() {
    local context state line
    typeset -A opt_args
    _arguments -C \
        '--config[use specific config file]:file:_files' \
        '(-V --version)'{-V,--version}'[print version]' \
        '1:command:->cmd' \
        '*::args:->args'
    case $state in
        cmd)
            local commands=(
                'init:generate initial fw.yaml'
                'check:validate fw.yaml'
                'gen:generate fw.conf'
                'apply:generate and reload firewall'
                'log:view firewall log'
                'add:add a resource'
                'del:remove a resource'
                'edit:edit a resource'
                'show:show parsed config'
                'set:set a config value'
                'unset:remove a config key'
                'completion:print shell completion script'
            )
            _describe 'command' commands ;;
        args)
            case $line[1] in
                add)        _fwctl_add ;;
                del)        _fwctl_del ;;
                edit)       _fwctl_edit ;;
                show)       _fwctl_show ;;
                init)       _arguments '--output[output file]:file:_files' '--iface[WAN interface]:iface' '--force[overwrite]' ;;
                gen)        _arguments '--print[print to stdout]' '--ulogd[also generate ulogd.conf]' ;;
                apply)      _arguments '--dry-run[show diff]' ;;
                log)        _arguments '-a[all fields]' '-i[inbound]' '-o[outbound]' \
                                '-n[interface]:iface' '-s[service]:svc' '-p[protocol]:proto' \
                                '--ip[IP address]:ip' '--port[port]:port' '--hash[entry hash]:hash' ;;
                set|unset)  _fwctl_set_keys ;;
                completion) compadd bash zsh ;;
            esac ;;
    esac
}

_fwctl_add() {
    local context state line
    _arguments -C '1:subcommand:->sub' '*::args:->args'
    case $state in
        sub)
            local subs=(macro zone policy pod forward dnat snat redirect)
            _describe 'subcommand' subs ;;
        args)
            case $line[1] in
                macro)  (( $#line <= 2 )) && compadd proto ip pod net log ;;
                zone)   _arguments '--net[subnet CIDR]:cidr' ;;
                policy) _arguments '--in[inbound services]:svc' '--out[outbound services]:svc' ;;
                pod)    _arguments '--in[inbound services]:svc' '--dnat[svc:port mapping]:mapping' \
                                   '--out[accept outbound]' '--zones[zones]:zone' ;;
                dnat)   _arguments '(-f --from)'{-f,--from}'[from IP (omit=transparent)]:from' \
                                   '(-s --service)'{-s,--service}'[services]:svc' \
                                   '(-i --iif)'{-i,--iif}'[inbound interface]:iif' \
                                   '(-p --port)'{-p,--port}'[port]:port' ;;
                snat)     _arguments '-i[IP family]:family:(4 6)' ;;
                redirect) _arguments '(-s --service)'{-s,--service}'[services]:svc' \
                                     '(-p --port)'{-p,--port}'[port]:port' ;;
            esac ;;
    esac
}

_fwctl_del() {
    local context state line
    _arguments -C '1:subcommand:->sub' '*::args:->args'
    case $state in
        sub)
            local subs=(pod macro zone forward dnat snat redirect)
            _describe 'subcommand' subs ;;
        args)
            case $line[1] in
                macro) (( $#line <= 2 )) && compadd proto ip pod net log ;;
                dnat)     _arguments '--hash[rule hash]:hash' \
                                     '(-f --from)'{-f,--from}'[from IP (omit=transparent)]:from' \
                                     '-s[services]:svc' '-i[iif]:iif' '-p[port]:port' ;;
                redirect) _arguments '--hash[rule hash]:hash' \
                                     '-s[services]:svc' '-p[port]:port' ;;
            esac ;;
    esac
}

_fwctl_edit() {
    local context state line
    _arguments -C '1:subcommand:->sub' '*::args:->args'
    case $state in
        sub)
            local subs=(zone policy pod forward log macro dnat snat redirect)
            _describe 'subcommand' subs ;;
        args)
            case $line[1] in
                macro)  (( $#line <= 2 )) && compadd proto ip pod net log ;;
                policy) _arguments '--in[inbound services]:svc' '--out[outbound services]:svc' ;;
                pod)    _arguments '--ip[IP address]:ip' '--in[services]:svc' '--out[services]:svc' \
                                   '--dnat[mapping]:mapping' '--zones[zones]:zone' ;;
                log)    _arguments '--in[inbound action]:action' '--out[outbound action]:action' ;;
                dnat)   _arguments '(-f --from)'{-f,--from}'[from]:from' \
                                   '(-s --service)'{-s,--service}'[services]:svc' \
                                   '(-i --iif)'{-i,--iif}'[iif]:iif' \
                                   '(-t --to)'{-t,--to}'[to]:to' \
                                   '(-p --port)'{-p,--port}'[port]:port' ;;
                snat)     _arguments '(-v --via)'{-v,--via}'[via zone]:zone' \
                                     '-i[IP family]:family:(4 6)' ;;
                redirect) _arguments '(-s --service)'{-s,--service}'[services]:svc' \
                                     '(-t --to)'{-t,--to}'[to]:to' \
                                     '(-p --port)'{-p,--port}'[port]:port' ;;
            esac ;;
    esac
}

_fwctl_show() {
    local context state line
    _arguments -C '1:section:->sec' '*::args:->args'
    case $state in
        sec)
            local secs=(settings zone macro policy forward dnat snat redirect pod)
            _describe 'section' secs ;;
        args)
            case $line[1] in
                macro)    compadd proto ip pod net log builtin ;;
                policy)   _arguments '(-i --in)'{-i,--in}'[inbound only]' \
                                     '(-o --out)'{-o,--out}'[outbound only]' \
                                     '(-c --compact)'{-c,--compact}'[compact]' ;;
                dnat)     _arguments '--hash[show with hashes]' ;;
                redirect) _arguments '--hash[show with hashes]' ;;
                settings) _arguments '(-d --default)'{-d,--default}'[show defaults]' ;;
            esac ;;
    esac
}

_fwctl_set_keys() {
    local keys=(
        'output:output file path' 'ulogd:ulogd config path' 'cmd:reload command'
        'auto_ips:auto-update IP macros' 'log_file:log file path'
        'exclude_services:services to exclude from log' 'localhost_zone:localhost zone name'
        'dbus_zone:dbus zone name' 'log_level:nflog group'
        'log_input:log inbound' 'log_output:log outbound' 'log_forward:log forwarded'
    )
    _describe 'key' keys
}

_fwctl
"""

def cmd_completion(args):
    shell = args.shell
    if shell == "bash":
        print(_COMPLETION_BASH, end="")
    elif shell == "zsh":
        print(_COMPLETION_ZSH, end="")
    else:
        print("Error: unknown shell '{}'. Valid: bash, zsh".format(shell), file=sys.stderr)
        sys.exit(1)


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def find_config(explicit=None):
    if explicit:
        if not os.path.exists(explicit):
            print("Error: config file not found: {}".format(explicit), file=sys.stderr)
            sys.exit(1)
        return explicit
    for path in DEFAULT_CONFIG_PATHS:
        if os.path.exists(path):
            return path
    print("Error: no config file found. Tried:", file=sys.stderr)
    for path in DEFAULT_CONFIG_PATHS:
        print("  {}".format(path), file=sys.stderr)
    sys.exit(1)


def main():
    fmt = argparse.RawDescriptionHelpFormatter
    parser = argparse.ArgumentParser(
        description="fwctl - foomuuri firewall config manager",
        formatter_class=fmt,
    )
    parser.add_argument("--version", "-V", action="version", version="fwctl {}".format(VERSION))
    parser.add_argument("-c", "--config", metavar="FILE", help="path to fw.yaml (default: /etc/foomuuri/fw.yaml)")
    sub = parser.add_subparsers(dest="command")

    # --- add ---
    add_p = sub.add_parser("add", help="add resources to fw.yaml",
        description="Add resources to an existing fw.yaml.",
        formatter_class=fmt)
    add_sub = add_p.add_subparsers(dest="add_command")
    add_macro_p = add_sub.add_parser("macro", help="add a macro",
        description="Add a macro to fw.yaml.",
        formatter_class=fmt,
        epilog="""\
types:
  proto   protocol/port macro  - values joined with '; '
  ip      single IP address
  net     subnet CIDR(s)
  log     log label - script wraps in log "..."

examples:
  fwctl add macro proto svc1 tcp 8080
  fwctl add macro proto svc2 tcp 8080 udp 8080
  fwctl add macro ip host1 1.2.3.4
  fwctl add macro net zone1 10.0.0.0/24
  fwctl add macro net zone1 10.0.0.0/24 fd00::/48
  fwctl add macro log event1
  fwctl add macro log event1 "my event"
""")
    add_macro_p.add_argument("mtype", choices=["proto", "ip", "pod", "net", "log"],
                             metavar="type", help="macro type: proto, ip, pod, net, log")
    add_macro_p.add_argument("name",   help="macro name")
    add_macro_p.add_argument("values", nargs="*", help="value(s)")
    add_macro_p.add_argument("-f", "--force", action="store_true",
                             help="allow proto macro even if name shadows a foomuuri built-in")

    add_forward_p = add_sub.add_parser("forward", help="add a forward rule",
        description="Add a forward rule to fw.yaml.",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl add forward zone1 zone2 http https
  fwctl add forward zone1 zone2 accept
  fwctl add forward zone1 zone2 - accept       # bidirectional
  fwctl add forward zone1 zone2 - http https   # bidirectional with services
""")
    add_forward_p.add_argument("src",   help="source zone")
    add_forward_p.add_argument("dst",   help="destination zone")
    add_forward_p.add_argument("rules", nargs="*",
                               help="services or 'accept' (prefix with - for bidirectional)")

    add_dnat_p = add_sub.add_parser("dnat", help="add a DNAT rule",
        description="Add a DNAT rule to fw.yaml.",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl add dnat host1 -s http https -i zone1 zone2 -f zone1
  fwctl add dnat host1 -s svc1 -i zone1 -f zone1 -p 8080
  fwctl add dnat pihole -s dns -i wan          # transparent (no -f = any source)

from/to without family suffix auto-expand to v4+v6.
-i accepts zone names (resolved to interface) or interface names directly.
-p accepts port number or service name.
-f (from) is optional; omit for transparent proxy (no daddr filter).
""")
    add_dnat_p.add_argument("to",                   help="destination IP/zone prefix (e.g. host1, host4)")
    add_dnat_p.add_argument("-f", "--from", dest="frm", default=None, metavar="FROM",
                            help="source IP/zone prefix; omit for transparent proxy")
    add_dnat_p.add_argument("-s", "--service", dest="s", nargs="+", required=True, metavar="SVC",
                            help="service(s)")
    add_dnat_p.add_argument("-i", "--iif",     dest="i", nargs="+", required=True, metavar="IIF",
                            help="input interface(s) or zone name(s)")
    add_dnat_p.add_argument("-p", "--port",    dest="p", default=None, metavar="PORT",
                            help="destination port override (number or service name)")

    add_snat_p = add_sub.add_parser("snat", help="add a SNAT rule",
        description="Add a SNAT rule to fw.yaml.",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl add snat zone1 zone2          # both IPv4 and IPv6
  fwctl add snat zone1 zone2 -i 4     # IPv4 only
  fwctl add snat zone1 zone2 -i 6     # IPv6 only
""")
    add_snat_p.add_argument("zone", help="source zone")
    add_snat_p.add_argument("via",  help="outbound zone name")
    add_snat_p.add_argument("-i", "--ip", choices=["4", "6"], default=None,
                            help="restrict to IPv4 or IPv6 only")

    add_hairpin_p = add_sub.add_parser("hairpin", help="mark a zone as hairpin",
        description="Mark a zone so DNATs targeting it auto-add SNAT.",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl add hairpin awg
""")
    add_hairpin_p.add_argument("zone", help="zone name")

    add_redirect_p = add_sub.add_parser("redirect", help="add a redirect rule",
        description="Add an output DNAT redirect rule to fw.yaml.",
        formatter_class=fmt,
        epilog="""\
Redirect intercepts traffic generated by the firewall itself (output hook).
'to' can be an IP macro (e.g. nat4) or a direct IP (e.g. 127.0.0.1).

examples:
  fwctl add redirect nat4 -s http
  fwctl add redirect nat4 -s http -p 8080
  fwctl add redirect 127.0.0.1 -s dns -p 53
  fwctl add redirect awg4 -s http http2
""")
    add_redirect_p.add_argument("to", help="destination IP macro or direct IP address")
    add_redirect_p.add_argument("-s", "--service", dest="s", nargs="+", required=True, metavar="SVC",
                                help="service(s)")
    add_redirect_p.add_argument("-p", "--port", dest="p", default=None, metavar="PORT",
                                help="destination port override (number or service name)")

    add_pod_p = add_sub.add_parser("pod", help="add a pod",
        description="Add a pod to fw.yaml.",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl add pod pod1 10.0.0.10 -i http -d http:8080
  fwctl add pod pod1 10.0.0.10 -i http https -o           # o: accept
  fwctl add pod pod1 10.0.0.10 -i http -o mqtt http       # o: [mqtt, http]
  fwctl add pod pod1 10.0.0.10 -i svc1 -d http:8080 svc1:9000
  fwctl add pod pod1 10.0.0.10 -i svc1 -z zone1 zone2
""")
    add_pod_p.add_argument("name", help="pod name")
    add_pod_p.add_argument("ip",   help="pod IP address")
    add_pod_p.add_argument("-i", "--in",    nargs="+", dest="in_svcs", metavar="SVC",
                           help="accessible services")
    add_pod_p.add_argument("-d", "--dnat",  nargs="+", metavar="SVC:PORT",
                           help="DNAT mappings (e.g. svc1:8080 svc2:9000)")
    add_pod_p.add_argument("-o", "--out",   nargs="*", dest="out", metavar="SVC",
                           help="outbound services, or no args for accept all")
    add_pod_p.add_argument("-z", "--zones", nargs="+", metavar="ZONE",
                           help="accessible from zones (default: all)")

    add_zone_p = add_sub.add_parser("zone", help="add a zone",
        description="Add a new zone to fw.yaml with empty policy.",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl add zone zone1
  fwctl add zone zone1 eth0 -n 10.0.0.0/24
  fwctl add zone zone1 eth0 -n 10.0.0.0/24 -n fd00::/48
""")
    add_zone_p.add_argument("name",  help="zone name")
    add_zone_p.add_argument("iface", nargs="?", default=None,
                            help="interface name (default: same as zone name)")
    add_zone_p.add_argument("-n", "--net", metavar="CIDR", action="append",
                            help="subnet for this zone (repeat for v4+v6)")

    add_policy_p = add_sub.add_parser("policy", help="add policy rules for a zone",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl add policy zone1 -i svc1 svc2 -o svc1 svc2
  fwctl add policy zone1 -i accept
  fwctl add policy zone1 -o accept
""")
    add_policy_p.add_argument("zone", help="zone name")
    add_policy_p.add_argument("-i", "--in",  nargs="*", dest="i", metavar="SVC",
                              help="inbound services or 'accept'")
    add_policy_p.add_argument("-o", "--out", nargs="*", dest="o", metavar="SVC",
                              help="outbound services or 'accept'")

    # --- del ---
    del_p = sub.add_parser("del", help="remove resources from fw.yaml",
        description="Remove resources from an existing fw.yaml.",
        formatter_class=fmt)
    del_sub = del_p.add_subparsers(dest="del_command")

    del_pod_p = del_sub.add_parser("pod", help="remove a pod", formatter_class=fmt,
        epilog="""\
examples:
  fwctl del pod pod1
""")
    del_pod_p.add_argument("name", help="pod name")

    del_macro_p = del_sub.add_parser("macro", help="remove a macro", formatter_class=fmt,
        epilog="""\
examples:
  fwctl del macro proto svc1
  fwctl del macro proto svc1 -f   # remove also from all rules
  fwctl del macro ip host1
  fwctl del macro net zone1
  fwctl del macro log event1
""")
    del_macro_p.add_argument("mtype", choices=["proto", "ip", "pod", "net", "log"],
                             metavar="type", help="macro type: proto, ip, pod, net, log")
    del_macro_p.add_argument("name", help="macro name")
    del_macro_p.add_argument("-f", "--force", action="store_true",
                             help="remove macro from all rules that use it")

    del_zone_p = del_sub.add_parser("zone", help="remove a zone", formatter_class=fmt,
        epilog="""\
removes zone from: zones, policy, log, macros.net
warns about remaining references in: forward, dnat, snat

examples:
  fwctl del zone zone1
""")
    del_zone_p.add_argument("name", help="zone name")

    del_forward_p = del_sub.add_parser("forward", help="remove a forward rule", formatter_class=fmt,
        epilog="""\
examples:
  fwctl del forward zone1 zone2
  fwctl del forward zone1 zone2 -    # bidirectional
""")
    del_forward_p.add_argument("src",   help="source zone")
    del_forward_p.add_argument("dst",   help="destination zone")
    del_forward_p.add_argument("bidir", nargs="?", metavar="-",
                               help="pass - for bidirectional rule")

    del_dnat_p = del_sub.add_parser("dnat", help="remove a DNAT rule", formatter_class=fmt,
        epilog="""\
examples:
  fwctl del dnat --hash a3f2c1b0
  fwctl del dnat host1 -s http https -i zone1 zone2 -f zone1
  fwctl del dnat host1 -s svc1 -i zone1 -p 8080 -f zone14
  fwctl del dnat pihole -s dns -i wan             # transparent rule (no -f)
""")
    del_dnat_p.add_argument("to",                   nargs="?", default=None,
                            help="destination IP/zone prefix")
    del_dnat_p.add_argument("-f", "--from", dest="frm", default=None, metavar="FROM",
                            help="source IP/zone prefix (omit for transparent rules)")
    del_dnat_p.add_argument("--hash",          dest="hash", default=None, metavar="HASH",
                            help="identify rule by hash (see: fwctl show dnat --hash)")
    del_dnat_p.add_argument("-s", "--service", dest="s", nargs="+", default=None, metavar="SVC")
    del_dnat_p.add_argument("-i", "--iif",     dest="i", nargs="+", default=None, metavar="IIF")
    del_dnat_p.add_argument("-p", "--port",    dest="p", default=None, metavar="PORT")

    del_snat_p = del_sub.add_parser("snat", help="remove a SNAT rule", formatter_class=fmt,
        epilog="""\
examples:
  fwctl del snat zone1              # remove all SNAT rules for zone1
  fwctl del snat zone1 --via wan    # remove only the zone1 -> wan rule
""")
    del_snat_p.add_argument("zone", help="zone name")
    del_snat_p.add_argument("-v", "--via", default=None, help="filter by outbound zone")

    del_hairpin_p = del_sub.add_parser("hairpin", help="unmark a zone as hairpin", formatter_class=fmt,
        epilog="""\
examples:
  fwctl del hairpin awg
""")
    del_hairpin_p.add_argument("zone", help="zone name")

    del_redirect_p = del_sub.add_parser("redirect", help="remove a redirect rule", formatter_class=fmt,
        epilog="""\
examples:
  fwctl del redirect --hash a3f2c1b0
  fwctl del redirect nat4 -s http
  fwctl del redirect nat4 -s http -p 8080
""")
    del_redirect_p.add_argument("to", nargs="?", default=None,
                                help="destination IP macro or direct IP")
    del_redirect_p.add_argument("--hash", dest="hash", default=None, metavar="HASH",
                                help="identify rule by hash (see: fwctl show redirect --hash)")
    del_redirect_p.add_argument("-s", "--service", dest="s", nargs="+", default=None, metavar="SVC")
    del_redirect_p.add_argument("-p", "--port", dest="p", default=None, metavar="PORT")

    # --- set ---
    _w = max(len(k) for k in list(FWCTL_DEFAULTS) + list(FOOMUURI_SETTINGS_DEFAULTS))
    _fwctl_keys   = "\n".join("  {:<{}}  (default: {})".format(k, _w, v) for k, v in FWCTL_DEFAULTS.items())
    _settings_keys = "\n".join("  {:<{}}  (default: {})".format(k, _w, v) for k, v in FOOMUURI_SETTINGS_DEFAULTS.items())
    set_p = sub.add_parser("set", help="set a value in fwctl: or settings:",
        description="Set a value in fw.yaml (fwctl: or settings: section, auto-detected by key).",
        formatter_class=fmt,
        epilog="""\
fwctl: keys:
{fwctl_keys}

settings: keys:
{settings_keys}

examples:
  fwctl set cmd "systemctl restart foomuuri"
  fwctl set log_input yes
  fwctl set log_rate "1/second burst 5"
""".format(fwctl_keys=_fwctl_keys, settings_keys=_settings_keys))
    set_p.add_argument("key",   help="setting key (e.g. log_input)")
    set_p.add_argument("value", nargs="+", help="value")

    # --- unset ---
    unset_p = sub.add_parser("unset", help="remove a custom setting (restore foomuuri default)",
        description="Remove a setting from fw.yaml, restoring the foomuuri default.",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl unset log_input
  fwctl unset log_rate
""")
    unset_p.add_argument("key", help="setting key to remove (e.g. log_input)")

    # --- edit ---
    edit_p = sub.add_parser("edit", help="edit fw.yaml or modify a resource",
        description="Without arguments: open fw.yaml in $EDITOR with validation.\nWith subcommand: modify a specific resource.",
        formatter_class=fmt)
    edit_sub = edit_p.add_subparsers(dest="edit_command")

    edit_zone_p = edit_sub.add_parser("zone", formatter_class=fmt)
    edit_zone_p.add_argument("name",  help="zone name")
    edit_zone_p.add_argument("iface", help="new interface")

    edit_policy_p = edit_sub.add_parser("policy", formatter_class=fmt)
    edit_policy_p.add_argument("zone", help="zone name")
    edit_policy_p.add_argument("-i", "--in",  nargs="*", dest="i", metavar="SVC")
    edit_policy_p.add_argument("-o", "--out", nargs="*", dest="o", metavar="SVC")

    edit_pod_p = edit_sub.add_parser("pod", formatter_class=fmt)
    edit_pod_p.add_argument("name", help="pod name")
    edit_pod_p.add_argument("--ip",          nargs="+", metavar="REF")
    edit_pod_p.add_argument("-i", "--in",    nargs="+", dest="in_svcs", metavar="SVC")
    edit_pod_p.add_argument("-o", "--out",   nargs="*", dest="out", metavar="SVC")
    edit_pod_p.add_argument("-d", "--dnat",  nargs="+", metavar="SVC:PORT")
    edit_pod_p.add_argument("-z", "--zones", nargs="+", metavar="ZONE")

    edit_forward_p = edit_sub.add_parser("forward", formatter_class=fmt)
    edit_forward_p.add_argument("src")
    edit_forward_p.add_argument("dst")
    edit_forward_p.add_argument("rules", nargs="+", metavar="SVC")

    edit_log_p = edit_sub.add_parser("log", formatter_class=fmt)
    edit_log_p.add_argument("zone", help="zone name")
    edit_log_p.add_argument("-i", "--in",  dest="i", metavar="ACTION")
    edit_log_p.add_argument("-o", "--out", dest="o", metavar="ACTION")

    edit_macro_p = edit_sub.add_parser("macro", formatter_class=fmt)
    edit_macro_p.add_argument("mtype", choices=["proto", "ip", "pod", "net", "log"])
    edit_macro_p.add_argument("name")
    edit_macro_p.add_argument("values", nargs="+")

    edit_dnat_p = edit_sub.add_parser("dnat", formatter_class=fmt)
    edit_dnat_p.add_argument("hash", nargs="?", default=None, metavar="HASH")
    edit_dnat_p.add_argument("-f", "--from",     dest="frm", metavar="FROM")
    edit_dnat_p.add_argument("-s", "--service",  dest="s", nargs="+", metavar="SVC")
    edit_dnat_p.add_argument("-i", "--iif",      dest="i", nargs="+", metavar="IIF")
    edit_dnat_p.add_argument("-t", "--to",       metavar="TO")
    edit_dnat_p.add_argument("-p", "--port",     dest="p", default=None, metavar="PORT")

    edit_snat_p = edit_sub.add_parser("snat", formatter_class=fmt)
    edit_snat_p.add_argument("zone", help="zone name")
    edit_snat_p.add_argument("-v", "--via", metavar="ZONE", help="WAN zone")
    edit_snat_p.add_argument("-i", "--ip",  metavar="4|6", help="restrict to IPv4 or IPv6 (empty to remove)")

    edit_redirect_p = edit_sub.add_parser("redirect", formatter_class=fmt)
    edit_redirect_p.add_argument("hash", nargs="?", default=None, metavar="HASH")
    edit_redirect_p.add_argument("-s", "--service", dest="s", nargs="+", metavar="SVC")
    edit_redirect_p.add_argument("-t", "--to", metavar="TO")
    edit_redirect_p.add_argument("-p", "--port", dest="p", default=None, metavar="PORT")

    init_p = sub.add_parser("init", help="generate a minimal default fw.yaml",
        description="Generate a minimal fw.yaml with WAN zone only.",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl init
  fwctl init -o /path/to/fw.yaml
  fwctl init -i eth0
  fwctl init -f
""")
    init_p.add_argument("--output", "-o", metavar="FILE", help="output path (default: fw.yaml)")
    init_p.add_argument("--iface",  "-i", metavar="IFACE", help="WAN interface (default: eth0)")
    init_p.add_argument("--force",  "-f", action="store_true", help="overwrite if file exists")

    # --- check ---
    sub.add_parser("check", help="validate fw.yaml",
        description="Validate fw.yaml and report errors/warnings. Exit 1 on error.",
        formatter_class=fmt)

    # --- gen ---
    gen_p = sub.add_parser("gen", help="generate fw.conf",
        description="Validate fw.yaml and write /etc/foomuuri/fw.conf.",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl gen
  fwctl gen -p
  fwctl -c /path/to/fw.yaml gen -p
""")
    gen_p.add_argument("-p", "--print",  action="store_true", help="print fw.conf to stdout instead of writing")
    gen_p.add_argument("-u", "--ulogd", action="store_true", help="also generate ulogd.conf")

    # --- apply ---
    apply_p = sub.add_parser("apply", help="generate fw.conf and reload foomuuri",
        description="Validate fw.yaml, write fw.conf, and reload foomuuri.",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl apply
  fwctl apply --dry-run
""")
    apply_p.add_argument("-n", "--dry-run", action="store_true",
                         help="show unified diff vs current fw.conf without writing")

    # --- log ---
    log_p = sub.add_parser("log", help="view firewall log",
        description="Filter and display entries from the firewall JSON log (/var/log/fw.json).",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl log                     # all entries (minus noise)
  fwctl log -a                  # all JSON fields
  fwctl log -i                  # inbound only
  fwctl log -o                  # outbound only
  fwctl log -i -s svc1          # inbound by service
  fwctl log -p tcp              # by protocol
  fwctl log -n eth0             # by interface
  fwctl log --ip 1.2.3.4        # by IP address
  fwctl log --port 8080         # by destination port
  fwctl log --hash <hash>       # show raw entry for specific hash
""")
    log_p.add_argument("-a", "--all",     action="store_true",        help="show all JSON fields")
    log_p.add_argument("-i", "--in",      action="store_true", dest="filter_in",  help="inbound only")
    log_p.add_argument("-o", "--out",     action="store_true", dest="filter_out", help="outbound only")
    log_p.add_argument("-n", "--network", metavar="NET",   help="filter by network interface")
    log_p.add_argument("-s", "--service", metavar="SVC",   help="filter by service name")
    log_p.add_argument("-p", "--proto",   metavar="PROTO", help="filter by protocol (tcp/udp/icmp/icmpv6)")
    log_p.add_argument("--ip",            metavar="IP",    help="filter by IP address")
    log_p.add_argument("--port",          metavar="PORT",  type=int, help="filter by destination port")
    log_p.add_argument("--hash",          metavar="HASH",  help="show raw entry for specific hash")

    # --- show ---
    show_p = sub.add_parser("show", help="show parsed config",
        description="Show the parsed fw.yaml in a readable format.",
        formatter_class=fmt,
        epilog="""\
sections:
  settings          foomuuri global settings
  zone              network zones
  macro             all macros
  macro proto       protocol/port macros
  macro ip          IP address macros
  macro pod         pod IP macros
  macro net         network/subnet macros
  macro log         log macros
  macro builtin     foomuuri built-in services
  macro <query>     search by name, port or protocol
  policy            all zone policies
  policy <zone>     policy for a specific zone
  forward           forward rules
  dnat              DNAT rules
  dnat --hash       DNAT rules with identifying hashes
  snat              SNAT rules

examples:
  fwctl show
  fwctl show zone
  fwctl show policy zone1
  fwctl show policy zone1 -c
  fwctl show policy zone1 -i -c
  fwctl show policy zone1 -o -c
  fwctl show macro ip
  fwctl show macro svc1
  fwctl show macro 443
  fwctl show macro tcp
  fwctl show dnat --hash
""")
    show_p.add_argument("section", nargs="?", default=None,
                        choices=["settings", "zone", "macro", "policy", "forward", "dnat", "snat", "redirect", "pod", "hairpin"],
                        help="section to show (default: all)")
    show_p.add_argument("sub", nargs="?", default=None,
                        help="subsection or zone name (e.g. 'proto', 'ip')")
    show_p.add_argument("-d", "--default",  action="store_true",
                        help="show foomuuri built-in defaults (only for settings)")
    show_p.add_argument("-i", "--in",       action="store_true", dest="filter_in",
                        help="show inbound policy only (with -c)")
    show_p.add_argument("-o", "--out",      action="store_true", dest="filter_out",
                        help="show outbound policy only (with -c)")
    show_p.add_argument("-c", "--compact",  action="store_true", dest="compact",
                        help="print services space-separated (ready to paste into edit policy)")
    show_p.add_argument("--hash",           action="store_true", dest="show_hash",
                        help="show DNAT rules with identifying hashes (only with 'dnat')")

    # --- completion ---
    completion_p = sub.add_parser("completion", help="print shell completion script",
        description="Print the shell completion script for fwctl.",
        formatter_class=fmt,
        epilog="""\
examples:
  fwctl completion bash > /etc/bash_completion.d/fwctl
  fwctl completion zsh  > /etc/zsh/zcompletion/_fwctl
""")
    completion_p.add_argument("shell", choices=["bash", "zsh"], help="target shell")

    args = parser.parse_args()

    if not args.command:
        parser.print_help()
        sys.exit(0)

    WRITE_COMMANDS = {"check", "gen", "apply", "edit", "add", "del", "set", "unset", "init"}
    if args.command in WRITE_COMMANDS and os.getuid() != 0:
        cfg_path = args.config if args.config else None
        if cfg_path is None:
            for p in DEFAULT_CONFIG_PATHS:
                if os.path.exists(p):
                    cfg_path = p
                    break
            else:
                cfg_path = DEFAULT_CONFIG_PATHS[0]
        check_path = cfg_path if os.path.exists(cfg_path) else os.path.dirname(os.path.abspath(cfg_path))
        if not os.access(check_path, os.W_OK):
            print("Error: '{}' requires root. Use sudo.".format(args.command), file=sys.stderr)
            sys.exit(1)

    EDIT_DISPATCH = {
        "zone":    cmd_edit_zone,
        "policy":  cmd_edit_policy,
        "pod":     cmd_edit_pod,
        "forward": cmd_edit_forward,
        "log":     cmd_edit_log,
        "macro":   cmd_edit_macro,
        "dnat":     cmd_edit_dnat,
        "snat":     cmd_edit_snat,
        "redirect": cmd_edit_redirect,
    }
    ADD_DISPATCH = {
        "macro":    cmd_add_macro,
        "zone":     cmd_add_zone,
        "policy":   cmd_add_policy,
        "forward":  cmd_add_forward,
        "dnat":     cmd_add_dnat,
        "snat":     cmd_add_snat,
        "redirect": cmd_add_redirect,
        "pod":      cmd_add_pod,
        "hairpin":  cmd_add_hairpin,
    }
    DEL_DISPATCH = {
        "macro":    cmd_del_macro,
        "zone":     cmd_del_zone,
        "forward":  cmd_del_forward,
        "dnat":     cmd_del_dnat,
        "snat":     cmd_del_snat,
        "redirect": cmd_del_redirect,
        "pod":      cmd_del_pod,
        "hairpin":  cmd_del_hairpin,
    }

    def _run_sub_dispatch(table, subcmd_attr):
        builtins, _ = load_builtin_conf()
        subcmd = getattr(args, subcmd_attr, None)
        if subcmd in table:
            table[subcmd](None, builtins, args)
        else:
            keys = ",".join(table.keys())
            print("Usage: fwctl {} {{{}}} ...".format(args.command, keys), file=sys.stderr)
            sys.exit(1)

    if args.command == "completion":
        cmd_completion(args)
    elif args.command == "edit":
        edit_cmd = getattr(args, "edit_command", None)
        if edit_cmd and edit_cmd in EDIT_DISPATCH:
            EDIT_DISPATCH[edit_cmd](args)
        else:
            cmd_edit(None, None, args)
    elif args.command == "init":
        cmd_init(None, None, args)
    elif args.command == "add":
        _run_sub_dispatch(ADD_DISPATCH, "add_command")
    elif args.command == "del":
        _run_sub_dispatch(DEL_DISPATCH, "del_command")
    elif args.command in ("set", "unset"):
        (cmd_set if args.command == "set" else cmd_unset)(None, None, args)
    elif args.command == "log":
        cfg = None
        for p in ([args.config] if args.config else DEFAULT_CONFIG_PATHS):
            if os.path.exists(p):
                try:
                    cfg = normalize_config(load_config(p))
                except Exception:
                    pass
                break
        cmd_log(cfg, None, args)
    else:
        config_path = find_config(args.config)
        builtins, foomuuri_settings = load_builtin_conf()
        try:
            raw = load_config(config_path)
        except yaml.YAMLError as e:
            print("Error parsing YAML: {}".format(e), file=sys.stderr)
            sys.exit(1)
        except OSError as e:
            print("Error reading config: {}".format(e), file=sys.stderr)
            sys.exit(1)
        raw = sync_ips(raw, config_path)
        raw = auto_fix_rules(raw, config_path, builtins)
        cfg = normalize_config(raw)
        full_dispatch = {
            "check": cmd_check,
            "gen":   cmd_gen,
            "apply": cmd_apply,
            "show":  lambda c, bi, a: cmd_show(c, bi, a, foomuuri_settings),
        }
        full_dispatch[args.command](cfg, builtins, args)


if __name__ == "__main__":
    main()
