#!/usr/bin/python3

'''
Quick Network Settings
(c) 2022-2024 Jan ONDREJ (SAL) <ondrejj(at)salstar.sk>
Licensed under GPLv2+.

Usage:
	nw show [ip|addr|route|device|connection|bridge|team|bond|lacp] [DEVICE]
	nw [addr|a] [DEVICE]  # show addr
	nw [connection|bridge|team] show
	# permanent commands (nmcli):
	nw add ethernet DEVICE
	nw add vlan DEVICE.X [vlanNAME]
	nw add bridge NAME ports DEVICE DEVICE ...
	nw add bridge NAME vlan DEVICE.vlanid  # bridge from VLAN
	nw add iface BRIDGE_NAME DEVICE.vlanid  # add VLAN to bridge
	nw add team NAME ports DEVICE DEVICE ...
	nw add bond NAME ports DEVICE DEVICE ...
	nw modify DEVICE addr auto|disabled|IP/MASK gw GATEWAY dns IP
	nw modify DEVICE addr6 auto|disabled|IP/mask gw6 GATEWAY dns IP
	nw up|down|restart CONNECTION
	nw delete CONNECTION
	# temporary commands (ip link):
	nw init vlan DEVICE.X [vlanNAME]
	nw init bridge NAME ports DEVICE DEVICE ...

Examples:
	nw add bond bond0 ports ens4 ens5	# create bond device
	nw add bridge wan0 ports bond0		# create bridge and add slave
	nw modify wan0 addr 10.0.0.1/24		# add IP address to interface
'''

import sys, os, json, re

def nmcli(cmd):
    print("nmcli "+cmd)
    os.system("nmcli "+cmd)

def ip_link(cmd):
    print("ip link "+cmd)
    os.system("ip link "+cmd)

def eprint(*args, **kw):
    print(*args, **kw, file=sys.stderr)

def device_exists(dev):
    return os.path.exists("/sys/class/net/"+dev)

def show_addrs(device=""):
    if device:
        device = " dev "+device
    ipa = json.load(os.popen("ip --json addr show" + device))
    for eth in ipa:
        ifname = eth["ifname"]
        if "altnames" in eth:
            ifname = ifname+"("+",".join(eth.get("altnames", ""))+")"
        print("%s: <%s> state %s" \
              % (ifname, ",".join(eth["flags"]), eth["operstate"]))
        if "address" in eth:
            print("%10s: %s" % ("ether", eth["address"]))
        for addr in eth["addr_info"]:
            if addr["family"]=="inet6":
                print("%10s: %s/%s" % (
                    addr["family"],
                    addr["local"], addr["prefixlen"]
                ))
            else:
                print("%10s: %s (%s/%s)" % (
                    addr["family"],
                    addr["local"],
                    addr["local"], addr["prefixlen"]
                ))

def show_route(device=""):
    if device:
        device = " dev "+device
    for ipv in [4, 6]:
        ipr = json.load(os.popen("ip --json -%d route show%s" % (ipv, device)))
        print("%-10s%20s  %-30s %8s %6s" % (
            "IPv%d" % ipv, "Address",
            "Gateway",
            "Device",
            "Metric"
        ))
        for route in ipr:
            print("%30s  %-30s %8s %6d" % (
                route.get("dst", ""),
                route.get("gateway", ""),
                route.get("dev", ""),
                route.get("metric", 0)
            ))

def show_bridge(device=""):
    bridges = json.load(os.popen("bridge --json link show"))
    masters = [x["master"] for x in bridges if "master" in x]
    for master in set(masters):
        if device and master!=device:
            continue
        print(master+":")
        for slave in bridges:
            if "master" in slave and slave["master"]==master:
                if "hwmode" in slave:
                    print("\t%-15s\t\t\thwmode=%s" % (
                        slave["ifname"], slave.get("hwmode")
                    ))
                else:
                    print("\t%-15s\t%s\tprio=%s\tcost=%s" % (
                        slave["ifname"], slave.get("state"),
                        slave.get("priority"), slave.get("cost")
                    ))

def show_team(device=""):
    if device:
        teamdevs = [device]
    else:
        teamdevs = [
            x.split(":", 1)[0].strip()
            for x in open("/proc/net/dev").readlines()[2:]
            if x.strip().startswith("team")
               and "." not in x # ignore vlan devices
        ]
    for teamdev in teamdevs:
        state = json.load(os.popen("teamdctl %s state dump" % teamdev))
        ifinfo = state["team_device"]["ifinfo"]
        print("%s: %s [%s]" % (
            ifinfo["ifname"],
            ifinfo["dev_addr"],
            state["setup"]["runner_name"]
        ))
        for port in state["ports"].values():
            ifinfo = port["ifinfo"]
            link = port["link"]
            print("\t%s: %s" % (
                ifinfo["ifname"],
                ifinfo["dev_addr"]
            ))
            updown = "UP" if link["up"] else "DOWN"
            print("\t\t%s, %s Mbit/s, %s duplex" % (
                "UP" if link["up"] else "DOWN",
                link["speed"],
                link["duplex"]
            ))

def show_bond(device=""):
    proc = "/proc/net/bonding"
    if not os.path.exists(proc):
        print("No bonding devices!")
        return
    if device:
        bonddevs = [device]
    else:
        bonddevs = os.listdir(proc)
    for bonddev in bonddevs:
        state_lines = open(proc+"/"+bonddev).readlines()
        ifinfo = {"": {}}
        iface = ""
        for row in state_lines:
            if row.startswith("Slave Interface: "):
                iface = row.split(":", 1)[-1].strip()
                ifinfo[iface] = {}
                continue
            if ":" in row:
                key, value = row.split(":", 1)
                ifinfo[iface][key] = value.strip()
        print("%s: %s [%s]" % (
            bonddev,
            ifinfo[""]["System MAC address"],
            "LACP" if ifinfo[""].get("LACP active")=="on" else "UNKNOWN"
        ))
        for port in ifinfo.keys():
            if port=="":
                continue
            link = ifinfo[port]
            print("\t%s: %s %s" % (
                port,
                link["MII Status"].upper(),
                link["Permanent HW addr"]
            ))
            print("\t\tstate: %s/%s, %s Mbit/s, %s duplex" % (
                link["Actor Churn State"],
                link["Partner Churn State"],
                link["Speed"],
                link["Duplex"]
            ))

def show(what="", device=""):
    if what=="ip" or what=="addr" or what=="device":
        show_addrs(device)
    elif what=="route":
        show_route(device)
    elif what=="bridge":
        show_bridge(device)
    elif what=="team":
        show_team(device)
    elif what=="bond":
        show_bond(device)
    elif what=="lacp":
        show_bond(device)
        show_team(device)
    elif what=="connection" or what=="":
        nmcli("con")


#ip_disabled = " ipv4.method disabled ipv6.method disabled"
IP_DISABLED = ["addr4", "disabled", "addr6", "disabled"]

def set_ip(args):
    ipaddr, ipaddr6 = [], []
    gateway = gateway6 = None
    dns, dns6 = [], []
    while args:
        if args[0]=="addr" or args[0]=="addr4" or args[0]=="address":
            args.pop(0)
            if args[0]=="auto" or args[0]=="disabled":
                ipaddr = args.pop(0)
            else:
                ipaddr.append(args.pop(0))
        elif args[0]=="addr6":
            args.pop(0)
            if args[0]=="auto" or args[0]=="disabled":
                ipaddr6 = args.pop(0)
            else:
                ipaddr6.append(args.pop(0))
        elif args[0]=="gw" or args[0]=="gw4":
            args.pop(0)
            gateway = args.pop(0)
        elif args[0]=="gw6":
            args.pop(0)
            gateway6 = args.pop(0)
        elif args[0]=="dns" or args[0]=="dns4":
            args.pop(0)
            dns.append(args.pop(0))
        elif args[0]=="dns6":
            args.pop(0)
            dns6.append(args.pop(0))
    cmd = ""
    # IPv4
    if ipaddr=="auto" or ipaddr=="disabled":
        cmd += " ipv4.method " + ipaddr
    elif ipaddr:
        cmd += " ipv4.method manual ipv4.addresses " + ",".join(ipaddr)
        if gateway:
            cmd += " ipv4.gateway " + gateway
        else:
            cmd += " ipv4.never-default yes"
        # ignore auto routes and DNS for manual settings
        cmd += " ipv4.ignore-auto-routes yes"
        cmd += " ipv4.ignore-auto-dns yes"
        cmd += " ipv4.may-fail no"
    if dns:
        cmd += " ipv4.dns " + ",".join(dns)
    # IPv6
    if ipaddr6=="auto" or ipaddr6=="disabled":
        cmd += " ipv6.method " + ipaddr6
    elif ipaddr6:
        cmd += " ipv6.method manual ipv6.addresses " + ",".join(ipaddr6)
        if gateway6:
            cmd += " ipv6.gateway " + gateway6
        else:
            cmd += " ipv6.never-default yes"
        # ignore auto routes and DNS for manual settings
        cmd += " ipv6.ignore-auto-routes yes"
        cmd += " ipv6.ignore-auto-dns yes"
        cmd += " ipv6.may-fail no"
    if dns6:
        cmd += " ipv6.dns " + ",".join(dns6)
    return cmd

def add_ethernet(ifname, ip_args=[]):
    cmd = "con add con-name %s type ethernet ifname %s" % (ifname, ifname)
    cmd += set_ip(ip_args)
    nmcli(cmd)

def add_vlan(vlan_dev, vlan_name="", bridge="", ip_args=[]):
    device, vlan_id = vlan_dev.split(".", 1)
    name = vlan_name or "%s.%s" % (device, vlan_id)
    cmd = "con add con-name %s type vlan ifname %s" % (name, name)
    cmd += " dev %s id %s" % (device, vlan_id)
    if bridge:
        cmd += " master %s" % bridge
    cmd += set_ip(ip_args)
    nmcli(cmd)

def add_bridge(master, type_verb, *slaves, ip_args=[]):
    if type_verb not in ["ports", "vlan"]:
        print("Wrong parameter: %s. Must be ports or vlan!" % type_verb)
        return
    if not ip_args:
        ip_args = IP_DISABLED
    nmcli(
        "con add con-name %s type bridge ifname %s %s"
        % (master, master, set_ip(ip_args))
    )
    for slave in slaves:
        if type_verb=="vlan":
            add_vlan(slave, bridge=master, ip_args=IP_DISABLED)
        elif type_verb=="ports":
            if device_exists(slave):
                nmcli(
                    "con modify %s master %s"
                    % (slave, master)
                )
            else:
                nmcli(
                    "con add type bridge-slave ifname %s con-name %s-%s master %s"
                    % (slave, master, slave, master)
                )

def add_iface(master, *slaves):
    for slave in slaves:
        nmcli(
            "con add type bridge-slave ifname %s con-name %s-%s master %s"
            % (slave, master, slave, master)
        )

def add_team(master, ports_verb, *slaves, ip_args=[]):
    json_config = '{"runner": {"name": "lacp"}}'
    if not ip_args:
        ip_args = IP_DISABLED
    nmcli(
        "con add con-name %s type team ifname %s team.config '%s' %s"
        % (master, master, json_config, set_ip(ip_args))
    )
    for slave in slaves:
        nmcli(
            "con add type ethernet con-name %s-%s ifname %s slave-type team master %s"
            % (master, slave, slave, master)
        )

def add_bond(master, ports_verb, *slaves, ip_args=[]):
    options = 'mode=802.3ad,miimon=100,lacp_rate=1,xmit_hash_policy=1'
    if not ip_args:
        ip_args = IP_DISABLED
    nmcli(
        "con add con-name %s type bond ifname %s bond.options '%s' %s"
        % (master, master, options, set_ip(ip_args))
    )
    for slave in slaves:
        nmcli(
            "con add type ethernet con-name %s-%s ifname %s slave-type bond master %s"
            % (master, slave, slave, master)
        )

def init_vlan(vlan_dev, vlan_name="", ip_args=[]):
    '''Create temporary VLAN'''
    device, vlan_id = vlan_dev.split(".", 1)
    name = vlan_name or "%s.%s" % (device, vlan_id)
    ip_link("add link %s name %s type vlan id %s" % (
      device, vlan_dev, vlan_id
    ))

def init_bridge(master, ports_verb, *slaves, ip_args=[]):
    '''Create temporary bridge'''
    ip_link("add name %s type bridge" % master)
    ip_link("set dev %s up" % master)
    for slave in slaves:
        ip_link("set dev %s master %s" % (slave, master))

def modify_ip(device, ip_args):
    nmcli("con modify %s%s" % (device, set_ip(ip_args)))
    nmcli("con up %s" % device) # apply

def connection_up(devices):
    for device in devices:
        nmcli("con up %s" % device)

def connection_down(devices):
    for device in devices:
        nmcli("con down %s" % device)

def connection_restart(devices):
    for device in devices:
        nmcli("con down %s" % device)
    for device in devices[::-1]: # start in reverse order
        nmcli("con up %s" % device)

def delete_connections(devices):
    for device in devices:
        nmcli("con delete %s" % device)

def complete(cword, words):
    if len(words)>cword-1:
        cur = words[cword-1]
    else:
        cur = ""

    class connections(list):
      def __init__(self, **kw):
          super().__init__(self.values(**kw))
          self.used = set()
      def values(self):
          if not os.path.exists("/usr/bin/nmcli"):
              return []
          return os.popen(
                     "nmcli --get NAME --color no con show"
                 ).read().split("\n")
      def use(self, word):
          if word in self:
              #eprint("USE:", word)
              self.remove(word)

    class ifaces(connections):
      def values(self, ext="", exclude="^$"):
          return [
              x.split(":", 1)[0].strip()+ext
              for x in open("/proc/net/dev").readlines()[2:]
              if x.split(":", 1)[0].strip() not in ["", "lo"]
                 and not re.search(exclude, x.strip())
          ]

    def cprint(words):
        if type(words)==str:
            words = words.split(" ")
        elif type(words)==arg_dict:
            words = words.keys()
        #eprint("CPRINT:", words, type(words), cur)
        for word in words:
            if word.startswith(cur):
                print(word)

    def next_device(prefix):
        cons = connections()
        for i in range(999):
            if "%s%d" % (prefix, i) not in cons:
                return "%s%d" % (prefix, i)
        return prefix+"X"

    class arg_dict(dict):
        def __init__(self, **kw):
            super().__init__(**kw)
        def __call__(self, key):
            return self.keys(), self.get(key)

    class arg_list(list):
        def __init__(self, *args, repeat=None):
            super().__init__(args)
            self.repeat = repeat
        def __call__(self, key):
            #eprint(len(self), self, self.repeat)
            if not self:
                if self.repeat:
                    return self.repeat[0], arg_list(*self.repeat[1:], repeat=self.repeat)
                else:
                    return [], arg_list()
            return self[0], arg_list(*self[1:], repeat=self.repeat)

    arg_ip_set = [
        ["addr", "addr4", "addr6"],
            ["auto", "disabled",
             "192.168.0.1/24", "172.16.0.1/16", "10.0.0.1/8"],
        ["gw", "gw4", "gw6"],
            ["192.168.0.1", "172.16.0.1", "10.0.0.1"],
        ["dns", "dns4", "dns6"],
            ["192.168.0.1", "172.16.0.1", "10.0.0.1"]
    ]

    arg_values = arg_dict(
        show = arg_dict(
            ip = arg_list(ifaces()),
            addr = arg_list(ifaces()),
            route = arg_list(),
            device = arg_list(ifaces()),
            bridge = arg_list(ifaces()),
            team = arg_list(ifaces()),
            bond = arg_list(ifaces()),
            lacp = arg_list(ifaces()),
            connections = arg_list()
        ),
        addr = arg_list(ifaces()),
        connection = arg_list(["show"]),
        bridge = arg_list(["show"]),
        team = arg_list(["show"]),
        bond = arg_list(["show"]),
        lacp = arg_list(["show"]),
        add = arg_dict(
            ethernet = arg_list(ifaces(exclude=r"\.")),
            vlan = arg_list(ifaces(ext=".", exclude="^vlan")),
            bridge = arg_list(next_device("br"), ["ports", "vlan"],
                              repeat=[ifaces(exclude="^br")]),
            iface = arg_list(ifaces(), repeat=[ifaces()]),
            team = arg_list(next_device("team"), "ports",
                            repeat=[ifaces(exclude="^team")]),
            bond = arg_list(next_device("bond"), "ports",
                            repeat=[ifaces(exclude="^bond")])
        ),
        init = arg_dict(
            vlan = arg_list(ifaces(ext=".", exclude="^vlan")),
            bridge = arg_list(next_device("br"), "ports",
                              repeat=[ifaces(exclude="^br")])
        ),
        modify = arg_list(ifaces(), repeat=arg_ip_set),
        up = arg_list(repeat=[connections()]),
        down = arg_list(repeat=[connections()]),
        restart = arg_list(repeat=[connections()]),
        delete = arg_list(repeat=[connections()])
    )

    #eprint(words)
    arg_opts = arg_values
    for i in range(cword):
        if i>=len(words):
            arg_opts, arg_values = arg_values("")
            break
        arg_opts, arg_values = arg_values(words[i])
        #eprint(i, type(arg_opts), arg_opts)
        #eprint(" ", type(arg_values), arg_values)
        if isinstance(arg_opts, connections):
            arg_opts.use(words[i])
    cprint(arg_opts)

if __name__=="__main__":
    arg1 = ""
    arg2 = ""
    if sys.argv[1:]:
        arg1 = sys.argv[1]
        if sys.argv[2:]:
            arg2 = sys.argv[2]
    if not arg1 or arg1=="help":
        print(__doc__.strip())
    elif arg1=="show":
        show(*sys.argv[2:])
    elif arg1 in ("addr", "a"):
        show("addr", *sys.argv[2:])
    elif arg1 in ("connection", "bridge", "team", "bond", "lacp") and arg2=="show":
        show(arg1)
    elif arg1=="add":
        if arg2=="ethernet":
            add_ethernet(sys.argv[3], ip_args=sys.argv[4:])
        elif arg2=="vlan":
            add_vlan(*sys.argv[3:])
        elif arg2=="bridge":
            add_bridge(*sys.argv[3:])
        elif arg2=="iface":
            add_iface(*sys.argv[3:])
        elif arg2=="team":
            add_team(*sys.argv[3:])
        elif arg2=="bond":
            add_bond(*sys.argv[3:])
        else:
            print(f"Unknown command: {arg2}")
            sys.exit(1)
    elif arg1=="init":
        if arg2=="vlan":
            init_vlan(*sys.argv[3:])
        elif arg2=="bridge":
            init_bridge(*sys.argv[3:])
        else:
            print(f"Unknown command: {arg2}")
            sys.exit(1)
    elif arg1=="modify":
        modify_ip(arg2, sys.argv[3:])
    elif arg1=="up":
        connection_up(sys.argv[2:])
    elif arg1=="down":
        connection_down(sys.argv[2:])
    elif arg1=="restart":
        connection_restart(sys.argv[2:])
    elif arg1=="delete":
        delete_connections(sys.argv[2:])
    elif arg1=="complete":
        #eprint(sys.argv)
        complete(int(arg2), sys.argv[4:])
    else:
        print("Parse error:", " ".join(sys.argv[1:]))
