#!/usr/bin/python3

import subprocess
import sys

WARN_TEMP = 80
CRIT_TEMP = 90

def main():
    try:
        out = subprocess.check_output(
            [
                "nvidia-smi",
                "--query-gpu=temperature.gpu,memory.used,memory.total,utilization.gpu,power.draw",
                "--format=csv,noheader,nounits",
            ],
            text=True,
            timeout=5,
        ).strip()

        lines = [line.strip() for line in out.splitlines() if line.strip()]
        if not lines:
            print("UNKNOWN - no NVIDIA GPUs found")
            sys.exit(3)

        overall_code = 0
        state = "OK"

        status_parts = []
        perf_parts = []

        for i, line in enumerate(lines, start=1):
            temp, mem_used, mem_total, gpu_util, power_draw = \
                [x.strip() for x in line.split(",")]

            temp = int(temp)
            mem_used = int(mem_used)
            mem_total = int(mem_total)
            gpu_util = int(gpu_util)
            power_draw = float(power_draw)

            mem_pct = (mem_used / mem_total * 100) if mem_total > 0 else 0

            if temp >= CRIT_TEMP:
                overall_code = max(overall_code, 2)
            elif temp >= WARN_TEMP:
                overall_code = max(overall_code, 1)

            status_parts.append(
                f"gpu{i}:"
                f" temp={temp}C"
                f" mem={mem_used}/{mem_total}MiB ({mem_pct:.1f}%)"
                f" util={gpu_util}%"
                f" power={power_draw}"
            )

            perf_parts.extend([
                f"temp{i}={temp};{WARN_TEMP};{CRIT_TEMP}",
                f"mem_used{i}={mem_used}MiB;;;0;{mem_total}",
                #f"mem_pct{i}={mem_pct:.1f}%;;;0;100",
                f"gpu_util{i}={gpu_util}%;;;0;100",
                f"power_draw{i}={power_draw}W",
            ])

        if overall_code == 2:
            state = "CRITICAL"
        elif overall_code == 1:
            state = "WARNING"

        print(f"{state} - {'; '.join(status_parts)} | {' '.join(perf_parts)}")
        sys.exit(overall_code)

    except Exception as e:
        print(f"UNKNOWN - failed to query NVIDIA GPU: {e}")
        sys.exit(3)

if __name__ == "__main__":
    main()
