analyze.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import math
  2. import os
  3. import time
  4. import json
  5. from collections import defaultdict
  6. from concurrent.futures import Executor
  7. from pathlib import Path
  8. from dataclasses import dataclass, asdict
  9. from itertools import combinations
  10. import numpy as np
  11. from PIL import Image
  12. from scipy.cluster import vq
  13. from scipy.spatial.distance import cdist, euclidean
  14. # https://en.wikipedia.org/wiki/SRGB#Transformation
  15. linearize_srgb = np.vectorize(
  16. lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
  17. )
  18. delinearize_lrgb = np.vectorize(
  19. lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
  20. )
  21. # https://mina86.com/2019/srgb-xyz-matrix/
  22. RGB_TO_XYZ = np.array([
  23. [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
  24. [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
  25. [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
  26. ])
  27. XYZ_TO_RGB = [
  28. [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
  29. [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
  30. [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
  31. ]
  32. # https://bottosson.github.io/posts/oklab/
  33. XYZ_TO_LMS = np.array([
  34. [0.8189330101, 0.3618667424, -0.1288597137],
  35. [0.0329845436, 0.9293118715, 0.0361456387],
  36. [0.0482003018, 0.2643662691, 0.6338517070],
  37. ])
  38. RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
  39. LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
  40. LMS_TO_OKLAB = np.array([
  41. [0.2104542553, 0.7936177850, -0.0040720468],
  42. [1.9779984951, -2.4285922050, 0.4505937099],
  43. [0.0259040371, 0.7827717662, -0.8086757660],
  44. ])
  45. OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
  46. # round output to this many decimals
  47. OUTPUT_PRECISION = 8
  48. def oklab2hex(pixel: np.array) -> str:
  49. # no need for a vectorized version, this is only for providing the mean hex
  50. return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
  51. def srgb2oklab(pixels: np.array) -> np.array:
  52. return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
  53. # https://stackoverflow.com/a/30305181
  54. def geometric_median(X: np.array, eps=1e-5) -> np.array:
  55. y = np.mean(X, 0)
  56. while True:
  57. D = cdist(X, [y])
  58. nonzeros = (D != 0)[:, 0]
  59. Dinv = 1 / D[nonzeros]
  60. Dinvs = np.sum(Dinv)
  61. W = Dinv / Dinvs
  62. T = np.sum(W * X[nonzeros], 0)
  63. num_zeros = len(X) - np.sum(nonzeros)
  64. if num_zeros == 0:
  65. y1 = T
  66. elif num_zeros == len(X):
  67. return y
  68. else:
  69. R = (T - y) * Dinvs
  70. r = np.linalg.norm(R)
  71. rinv = 0 if r == 0 else num_zeros / r
  72. y1 = max(0, 1 - rinv) * T + min(1, rinv) * y
  73. if euclidean(y, y1) < eps:
  74. return y1
  75. y = y1
  76. @dataclass
  77. class Stats:
  78. # vector statistics
  79. centroid: list[float] # (L, a, b)
  80. median: list[float] # (L, a, b)
  81. stddev: list[float] # (L, a, b)
  82. tilt: list[float] # (L, a, b)
  83. chroma: list[float] # (mean, stddev)
  84. # scalar statistics
  85. hue: float
  86. size: int
  87. # sRGB hex code of the centroid and median
  88. centroidHex: str
  89. medianHex: str
  90. def calc_statistics(pixels: np.array, output_precision: int) -> Stats:
  91. # centroid, the arithmetic mean pixel of the image
  92. centroid = pixels.mean(axis=0)
  93. # raw second moment, for each channel of the pixels
  94. raw_second_moment = (pixels ** 2).mean(axis=0)
  95. # stddev, the sqrt of the variance of each channel of the image
  96. # variance_x = mean(p_x^2) - mean(p_x)^2 = rsm_x - centroid_x^2
  97. # note, summing those gives a total "variance" in color
  98. stddev = np.sqrt(raw_second_moment - centroid ** 2)
  99. # tilt, the arithmetic mean pixel of normalized image
  100. tilt = (pixels / np.linalg.norm(pixels, axis=1)[:, np.newaxis]).mean(axis=0)
  101. # chroma^2 = a^2 + b^2
  102. chromas = np.hypot(pixels[:, 1], pixels[:, 2])
  103. chroma_mean = chromas.mean(axis=0)
  104. # variance in chroma is E[a^2 + b^2] - E[sqrt(a^2 + b^2)]^2
  105. # max(0, x) is present to deal with floating point error for extremely low chroma images
  106. # a more robust solution could use log space instead but this is fine for this dataset
  107. chroma_dev = math.sqrt(
  108. max(0, raw_second_moment[1] + raw_second_moment[2] - (chroma_mean ** 2)))
  109. # hue = atan2(b, a), but we need a circular mean
  110. # https://en.wikipedia.org/wiki/Circular_mean#Definition
  111. # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma
  112. # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma
  113. # and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean
  114. hue = math.atan2(*(pixels[:, [2, 1]] / chromas[:, np.newaxis]).sum(axis=0))
  115. # approximation of geometric median, primarily for display purposes
  116. median = geometric_median(pixels)
  117. return Stats(
  118. centroid=list(np.round(centroid, output_precision)),
  119. median=list(np.round(median, output_precision)),
  120. stddev=list(np.round(stddev, output_precision)),
  121. tilt=list(np.round(tilt, output_precision)),
  122. chroma=[
  123. round(chroma_mean, output_precision),
  124. round(chroma_dev, output_precision)
  125. ],
  126. hue=round(hue % (2 * math.pi), output_precision),
  127. size=len(pixels),
  128. centroidHex=oklab2hex(centroid),
  129. medianHex=oklab2hex(median)
  130. )
  131. def calc_clusters(pixels: np.array, output_precision: int, cluster_attempts=5, seed=0) -> list[Stats]:
  132. means, labels = min(
  133. (
  134. # Try k = 2, 3, and 4, and try a few times for each
  135. vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
  136. for k in (2, 3, 4)
  137. for i in range(cluster_attempts)
  138. ),
  139. key=lambda c:
  140. # Evaluate clustering by seeing the average difference in hue angle
  141. # between the centers. Maximizing this means the clusters are highly
  142. # distinct, which gives a sense of which k was best.
  143. # This is computed by normalizing the ab-plane projection of the means,
  144. # then applying a dot product to get the cosine of the angle
  145. # between them in that plane, which is the hue difference. Minimizing
  146. # this maximizes the differences in hues.
  147. # A different clustering algorithm may be more suited here, but this
  148. # is comparatively cheap while still producing reasonable results.
  149. (np.array([
  150. m1 @ m2
  151. for m1, m2 in combinations(
  152. c[0][:, 1:] / np.linalg.norm(c[0][:, 1:], axis=1)[:, np.newaxis], 2
  153. )
  154. ])).mean(axis=0)
  155. )
  156. return [calc_statistics(pixels[labels == i], output_precision) for i in range(len(means))]
  157. def get_srgb_pixels(img: Image.Image) -> np.array:
  158. rgb = []
  159. for fr in range(getattr(img, "n_frames", 1)):
  160. img.seek(fr)
  161. rgb += [
  162. [r, g, b]
  163. for r, g, b, a in img.convert("RGBA").getdata()
  164. if a > 0 and (r, g, b) != (0, 0, 0)
  165. ]
  166. return np.array(rgb)
  167. def log(*a, **kw):
  168. print(*a, **kw, flush=True)
  169. class Ingester:
  170. def __init__(self, dex: dict, output_precision: int, seed: int) -> None:
  171. self.lookup = {
  172. form["name"]: {
  173. "num": pkmn["num"],
  174. "species": pkmn["species"],
  175. **form,
  176. }
  177. for pkmn in dex.values()
  178. for form in pkmn["forms"]
  179. }
  180. self.seed = seed
  181. self.output_precision = output_precision
  182. def __call__(self, args: tuple[str, list[str]]) -> dict | Exception:
  183. form, filenames = args
  184. log(f"Ingesting {form}...")
  185. start_time = time.time()
  186. try:
  187. all_pixels = np.concatenate([
  188. get_srgb_pixels(Image.open(fn)) for fn in filenames
  189. ])
  190. except Exception as e:
  191. log(f"Error loading images for {form}: {e}")
  192. return e
  193. try:
  194. oklab = srgb2oklab(all_pixels)
  195. conv_time = time.time()
  196. total = calc_statistics(oklab, self.output_precision)
  197. calc_time = time.time()
  198. clusters = calc_clusters(oklab, self.output_precision, seed=self.seed)
  199. cluster_time = time.time()
  200. except Exception as e:
  201. log(f"Error calculating statistics for {form}: {e}")
  202. return e
  203. log(
  204. f"Completed {form}: ",
  205. f"{(cluster_time - start_time):.02f}s in total,",
  206. f"{(cluster_time - calc_time):.02f}s on clustering,",
  207. f"{(calc_time - conv_time):.02f}s on total calcs,",
  208. f"{(conv_time - start_time):.02f}s on read and conversion,",
  209. f"median {total.medianHex} and {len(clusters)} clusters"
  210. )
  211. return {
  212. **self.lookup[form],
  213. "total": asdict(total),
  214. "clusters": [asdict(c) for c in clusters],
  215. }
  216. def output_db(results: list[dict], db_file: str):
  217. if db_file == "-":
  218. log(json.dumps(results, indent=2))
  219. return
  220. with open(db_file, "w") as output:
  221. output.write("const database = [\n")
  222. for entry in results:
  223. output.write(" ")
  224. output.write(json.dumps(entry))
  225. output.write(",\n")
  226. output.write("]\n")
  227. def run_ingest(ingest: Ingester, filenames: list[Path], exec: Executor, db_file: str):
  228. to_process = defaultdict(list)
  229. missing = []
  230. for path in filenames:
  231. if path.is_file():
  232. form_name = path.name.rsplit("-", 1)[0]
  233. to_process[form_name].append(path)
  234. else:
  235. missing.append(path)
  236. log(f"Missing file: {path}")
  237. start = time.time()
  238. results = list(exec.map(ingest, to_process.items()))
  239. end = time.time()
  240. success = [r for r in results if not isinstance(r, Exception)]
  241. errors = [e for e in results if isinstance(e, Exception)]
  242. log(
  243. f"Finished ingest of {len(to_process)} forms",
  244. f"and {sum(len(fns) for fns in to_process.values())} files",
  245. f"in {(end - start):.2f}s",
  246. f"with {len(missing)} missing file(s)",
  247. f"and {len(errors)} error(s)"
  248. )
  249. for e in errors:
  250. log(f"Error: {e}")
  251. for m in missing:
  252. log(f"Missing: {m}")
  253. output_db(sorted(success, key=lambda e: (e["num"], e["name"])), db_file)
  254. log(f"Output {len(success)} entries to {db_file}")
  255. if __name__ == "__main__":
  256. from argparse import ArgumentParser
  257. parser = ArgumentParser(
  258. prog="Image Analyzer",
  259. description="Analyze and summarize images based on color",
  260. )
  261. parser.add_argument(
  262. "-p", "--precision", type=int, default=8, help="Round output to this many decimal places"
  263. )
  264. parser.add_argument(
  265. "-s", "--seed", type=int, default=230308, help="Clustering seed"
  266. )
  267. parser.add_argument(
  268. "-w", "--workers", type=int, default=4, help="Worker process count"
  269. )
  270. parser.add_argument(
  271. "-o", "--output", default="data/latest.db", help="Database file"
  272. )
  273. parser.add_argument(
  274. "-d", "--pokedex", default="data/pokedex.json", help="Pokedex file"
  275. )
  276. parser.add_argument(
  277. "--threading", action="store_true", help="Use threads instead of multiproc (slower but more stable on 3.10)"
  278. )
  279. parser.add_argument("images", metavar="file", type=Path, nargs="+")
  280. args = parser.parse_args()
  281. with open(args.pokedex) as infile:
  282. dex = json.load(infile)
  283. ingest = Ingester(dex, args.precision, args.seed)
  284. from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
  285. with (ThreadPoolExecutor if args.threading else ProcessPoolExecutor)(max_workers=args.workers) as pool:
  286. run_ingest(ingest, args.images, pool, args.output)