Prechádzať zdrojové kódy

Update nearest python script for feature parity

Kirk Trombley 3 rokov pred
rodič
commit
34ae7f6978
1 zmenil súbory, kde vykonal 29 pridanie a 6 odobranie
  1. 29 6
      nearest.py

+ 29 - 6
nearest.py

@@ -1,14 +1,24 @@
 #!/usr/bin/env python3
 import csv
-import colorsys
+import math
 import random
 from argparse import ArgumentParser
 
+from convert import rgb_to_cieluv
+
+
+def norm(r, g, b):
+    return math.sqrt(r * r + g * g + b * b)
+
+
 parser = ArgumentParser()
 parser.add_argument("color", nargs="?", default=None, help="Target color, randomized if not provided")
 parser.add_argument("-n", "--number", default=1, type=int, help="Number of Pokemon to find")
 parser.add_argument("-c", "--closeness", default=2, type=int, help="Closeness coefficient")
-parser.add_argument("--database", default="database.csv", help="Database file")
+parser.add_argument("-d", "--database", default="database.csv", help="Database file")
+parser.add_argument("-x", "--exclude-x", action="store_true", help="Exclude X")
+parser.add_argument("-z", "--normalize", action="store_true", help="Normalize q and Y")
+parser.add_argument("-l", "--convert-luv", action="store_true", help="Convert input color to CIELUV before calculation")
 parser.add_argument("-v", "--verbose", action="store_true", help="Print raw scores")
 args = parser.parse_args()
 
@@ -16,17 +26,30 @@ if args.number <= 0:
     raise ValueError("Must request a number greater than 0")
 
 if args.color is not None:
-    if len(args.color) != 6:
+    cleaned_color = args.color.strip("#")
+    if len(cleaned_color) != 6:
         raise ValueError("Color must be a 6 digit hex")
-    color = (int(args.color[0:2], base=16), int(args.color[2:4], base=16), int(args.color[4:6], base=16))
+    color = (int(cleaned_color[0:2], base=16), int(cleaned_color[2:4], base=16), int(cleaned_color[4:6], base=16))
 else:
-    color = tuple(int(c * 255) for c in colorsys.hsv_to_rgb(random.random(), 0.9, 0.9))
+    color = tuple(int(random.random() * 255) for _ in range(3))
     print(f"Generated random color: #{''.join(hex(c)[2:] for c in color)} / {color}")
 
+if args.convert_luv:
+    color = rgb_to_cieluv(*color)
+
+yfactor = args.closeness
+if args.normalize:
+    yfactor /= norm(*color)
+
 results = []
 with open(args.database) as infile:
     for name, x, *y in csv.reader(infile, delimiter=",", quotechar="'"):
-        score = float(x) - args.closeness * sum(float(y_c) * c for y_c, c in zip(y, color))
+        xval = 0 if args.exclude_x else float(x)
+        yvec = [float(y_c) for y_c in y]
+        yval = sum(y_c * c for y_c, c in zip(yvec, color))
+        if args.normalize:
+            yval /= norm(*yvec)
+        score = xval - yfactor * yval
         results.append((score, name))
 
 if args.number > 1: