analyze.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import math
  2. import os
  3. import time
  4. import json
  5. from collections import defaultdict
  6. from concurrent.futures import Executor
  7. from pathlib import Path
  8. from dataclasses import dataclass, asdict
  9. from itertools import combinations
  10. import numpy as np
  11. from PIL import Image
  12. from scipy.cluster import vq
  13. # https://en.wikipedia.org/wiki/SRGB#Transformation
  14. linearize_srgb = np.vectorize(
  15. lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
  16. )
  17. delinearize_lrgb = np.vectorize(
  18. lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
  19. )
  20. # https://mina86.com/2019/srgb-xyz-matrix/
  21. RGB_TO_XYZ = np.array([
  22. [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
  23. [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
  24. [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
  25. ])
  26. XYZ_TO_RGB = [
  27. [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
  28. [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
  29. [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
  30. ]
  31. # https://bottosson.github.io/posts/oklab/
  32. XYZ_TO_LMS = np.array([
  33. [0.8189330101, 0.3618667424, -0.1288597137],
  34. [0.0329845436, 0.9293118715, 0.0361456387],
  35. [0.0482003018, 0.2643662691, 0.6338517070],
  36. ])
  37. RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
  38. LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
  39. LMS_TO_OKLAB = np.array([
  40. [0.2104542553, 0.7936177850, -0.0040720468],
  41. [1.9779984951, -2.4285922050, 0.4505937099],
  42. [0.0259040371, 0.7827717662, -0.8086757660],
  43. ])
  44. OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
  45. # round output to this many decimals
  46. OUTPUT_PRECISION = 8
  47. def oklab2hex(pixel: np.array) -> str:
  48. # no need for a vectorized version, this is only for providing the mean hex
  49. return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
  50. def srgb2oklab(pixels: np.array) -> np.array:
  51. return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
  52. @dataclass
  53. class Stats:
  54. # vector statistics
  55. centroid: list[float] # (L, a, b)
  56. stddev: list[float] # (L, a, b)
  57. tilt: list[float] # (L, a, b)
  58. chroma: list[float] # (mean, stddev)
  59. # scalar statistics
  60. hue: float
  61. size: int
  62. # sRGB hex code of the centroid
  63. hex: str
  64. def calc_statistics(pixels: np.array, output_precision: int) -> Stats:
  65. # centroid, the arithmetic mean pixel of the image
  66. centroid = pixels.mean(axis=0)
  67. # raw second moment, for each channel of the pixels
  68. raw_second_moment = (pixels ** 2).mean(axis=0)
  69. # stddev, the sqrt of the variance of each channel of the image
  70. # variance_x = mean(p_x^2) - mean(p_x)^2 = rsm_x - centroid_x^2
  71. # note, summing those gives a total "variance" in color
  72. stddev = np.sqrt(raw_second_moment - centroid ** 2)
  73. # tilt, the arithmetic mean pixel of normalized image
  74. tilt = (pixels / np.linalg.norm(pixels, axis=1)[:, np.newaxis]).mean(axis=0)
  75. # chroma^2 = a^2 + b^2
  76. chromas = np.hypot(pixels[:, 1], pixels[:, 2])
  77. chroma_mean = chromas.mean(axis=0)
  78. # variance in chroma is E[a^2 + b^2] - E[sqrt(a^2 + b^2)]^2
  79. # max(0, x) is present to deal with floating point error for extremely low chroma images
  80. # a more robust solution could use log space instead but this is fine for this dataset
  81. chroma_dev = math.sqrt(
  82. max(0, raw_second_moment[1] + raw_second_moment[2] - (chroma_mean ** 2)))
  83. # hue = atan2(b, a), but we need a circular mean
  84. # https://en.wikipedia.org/wiki/Circular_mean#Definition
  85. # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma
  86. # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma
  87. # and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean
  88. hue = math.atan2(*(pixels[:, [2, 1]] / chromas[:, np.newaxis]).sum(axis=0))
  89. return Stats(
  90. centroid=list(np.round(centroid, output_precision)),
  91. stddev=list(np.round(stddev, output_precision)),
  92. tilt=list(np.round(tilt, output_precision)),
  93. chroma=[
  94. round(chroma_mean, output_precision),
  95. round(chroma_dev, output_precision)
  96. ],
  97. hue=round(hue % (2 * math.pi), output_precision),
  98. size=len(pixels),
  99. hex=oklab2hex(centroid),
  100. )
  101. def calc_clusters(pixels: np.array, output_precision: int, cluster_attempts=5, seed=0) -> list[Stats]:
  102. means, labels = min(
  103. (
  104. # Try k = 2, 3, and 4, and try a few times for each
  105. vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
  106. for k in (2, 3, 4)
  107. for i in range(cluster_attempts)
  108. ),
  109. key=lambda c:
  110. # Evaluate clustering by seeing the average difference in hue angle
  111. # between the centers. Maximizing this means the clusters are highly
  112. # distinct, which gives a sense of which k was best.
  113. # This is computed by normalizing the ab-plane projection of the means,
  114. # then applying a dot product to get the cosine of the angle
  115. # between them in that plane, which is the hue difference. Minimizing
  116. # this maximizes the differences in hues.
  117. # A different clustering algorithm may be more suited here, but this
  118. # is comparatively cheap while still producing reasonable results.
  119. (np.array([
  120. m1 @ m2
  121. for m1, m2 in combinations(
  122. c[0][:, 1:] / np.linalg.norm(c[0][:, 1:], axis=1)[:, np.newaxis], 2
  123. )
  124. ])).mean(axis=0)
  125. )
  126. return [calc_statistics(pixels[labels == i], output_precision) for i in range(len(means))]
  127. def get_srgb_pixels(img: Image.Image) -> np.array:
  128. rgb = []
  129. for fr in range(getattr(img, "n_frames", 1)):
  130. img.seek(fr)
  131. rgb += [
  132. [r, g, b]
  133. for r, g, b, a in img.convert("RGBA").getdata()
  134. if a > 0 and (r, g, b) != (0, 0, 0)
  135. ]
  136. return np.array(rgb)
  137. def search_files(image_dir: str) -> dict[str, list[str]]:
  138. files = defaultdict(list)
  139. for image_filename in os.listdir(image_dir):
  140. form_name = image_filename.rsplit("-", maxsplit=1)[0]
  141. files[form_name].append(Path(image_dir, image_filename))
  142. return files
  143. def log(*a, **kw):
  144. print(*a, **kw, flush=True)
  145. class Ingester:
  146. def __init__(self, dex: dict, output_precision: int, seed: int) -> None:
  147. self.lookup = {
  148. form["name"]: {
  149. "num": pkmn["num"],
  150. "species": pkmn["species"],
  151. **form,
  152. }
  153. for pkmn in dex.values()
  154. for form in pkmn["forms"]
  155. }
  156. self.seed = seed
  157. self.output_precision = output_precision
  158. def __call__(self, args: tuple[str, list[str]]) -> dict | Exception:
  159. form, filenames = args
  160. log(f"Ingesting {form}...")
  161. start_time = time.time()
  162. try:
  163. all_pixels = np.concatenate([
  164. get_srgb_pixels(Image.open(fn)) for fn in filenames
  165. ])
  166. except Exception as e:
  167. log(f"Error loading images for {form}: {e}")
  168. return e
  169. try:
  170. oklab = srgb2oklab(all_pixels)
  171. conv_time = time.time()
  172. total = calc_statistics(oklab, self.output_precision)
  173. calc_time = time.time()
  174. clusters = calc_clusters(oklab, self.output_precision, seed=self.seed)
  175. cluster_time = time.time()
  176. except Exception as e:
  177. log(f"Error calculating statistics for {form}: {e}")
  178. return e
  179. log(
  180. f"Completed {form}: ",
  181. f"{(cluster_time - start_time):.02f}s in total,",
  182. f"{(cluster_time - calc_time):.02f}s on clustering,",
  183. f"{(calc_time - conv_time):.02f}s on total calcs,",
  184. f"{(conv_time - start_time):.02f}s on read and conversion,",
  185. f"centroid {total.hex} and {len(clusters)} clusters"
  186. )
  187. return {
  188. **self.lookup[form],
  189. "total": asdict(total),
  190. "clusters": [asdict(c) for c in clusters],
  191. }
  192. def output_db(results: list[dict], db_file: str):
  193. if db_file == "-":
  194. log(json.dumps(results, indent=2))
  195. return
  196. with open(db_file, "w") as output:
  197. output.write("const database = [\n")
  198. for entry in results:
  199. output.write(" ")
  200. output.write(json.dumps(entry))
  201. output.write(",\n")
  202. output.write("]\n")
  203. def run_ingest(ingest: Ingester, filenames: list[Path], exec: Executor, db_file: str):
  204. to_process = defaultdict(list)
  205. missing = []
  206. for path in filenames:
  207. if path.is_file():
  208. form_name = path.name.rsplit("-", 1)[0]
  209. to_process[form_name].append(path)
  210. else:
  211. missing.append(path)
  212. log(f"Missing file: {path}")
  213. start = time.time()
  214. results = list(exec.map(ingest, to_process.items()))
  215. end = time.time()
  216. success = [r for r in results if not isinstance(r, Exception)]
  217. errors = [e for e in results if isinstance(e, Exception)]
  218. log(
  219. f"Finished ingest of {len(to_process)} forms",
  220. f"and {sum(len(fns) for fns in to_process.values())} files",
  221. f"in {(end - start):.2f}s",
  222. f"with {len(missing)} missing file(s)",
  223. f"and {len(errors)} error(s)"
  224. )
  225. for e in errors:
  226. log(f"Error: {e}")
  227. for m in missing:
  228. log(f"Missing: {m}")
  229. output_db(sorted(success, key=lambda e: (e["num"], e["name"])), db_file)
  230. log(f"Output {len(success)} entries to {db_file}")
  231. if __name__ == "__main__":
  232. from argparse import ArgumentParser
  233. parser = ArgumentParser(
  234. prog="Image Analyzer",
  235. description="Analyze and summarize images based on color",
  236. )
  237. parser.add_argument(
  238. "-p", "--precision", type=int, default=8, help="Round output to this many decimal places"
  239. )
  240. parser.add_argument(
  241. "-s", "--seed", type=int, default=230308, help="Clustering seed"
  242. )
  243. parser.add_argument(
  244. "-w", "--workers", type=int, default=4, help="Worker process count"
  245. )
  246. parser.add_argument(
  247. "-o", "--output", default="data/latest.db", help="Database file"
  248. )
  249. parser.add_argument(
  250. "-d", "--pokedex", default="data/pokedex.json", help="Pokedex file"
  251. )
  252. parser.add_argument(
  253. "--threading", action="store_true", help="Use threads instead of multiproc (slower but more stable on 3.10)"
  254. )
  255. parser.add_argument("images", metavar="file", type=Path, nargs="+")
  256. args = parser.parse_args()
  257. with open(args.pokedex) as infile:
  258. dex = json.load(infile)
  259. ingest = Ingester(dex, args.precision, args.seed)
  260. from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
  261. with (ThreadPoolExecutor if args.threading else ProcessPoolExecutor)(max_workers=args.workers) as pool:
  262. run_ingest(ingest, args.images, pool, args.output)