analyze.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. import math
  2. import os
  3. import time
  4. import json
  5. from collections import defaultdict
  6. from concurrent.futures import Executor
  7. from pathlib import Path
  8. from dataclasses import dataclass, asdict
  9. from itertools import combinations
  10. import numpy as np
  11. from PIL import Image
  12. from scipy.cluster import vq
  13. # https://en.wikipedia.org/wiki/SRGB#Transformation
  14. linearize_srgb = np.vectorize(
  15. lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
  16. )
  17. delinearize_lrgb = np.vectorize(
  18. lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
  19. )
  20. # https://mina86.com/2019/srgb-xyz-matrix/
  21. RGB_TO_XYZ = np.array([
  22. [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
  23. [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
  24. [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
  25. ])
  26. XYZ_TO_RGB = [
  27. [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
  28. [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
  29. [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
  30. ]
  31. # https://bottosson.github.io/posts/oklab/
  32. XYZ_TO_LMS = np.array([
  33. [0.8189330101, 0.3618667424, -0.1288597137],
  34. [0.0329845436, 0.9293118715, 0.0361456387],
  35. [0.0482003018, 0.2643662691, 0.6338517070],
  36. ])
  37. RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
  38. LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
  39. LMS_TO_OKLAB = np.array([
  40. [0.2104542553, 0.7936177850, -0.0040720468],
  41. [1.9779984951, -2.4285922050, 0.4505937099],
  42. [0.0259040371, 0.7827717662, -0.8086757660],
  43. ])
  44. OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
  45. # round output to this many decimals
  46. OUTPUT_PRECISION = 8
  47. def oklab2hex(pixel: np.array) -> str:
  48. # no need for a vectorized version, this is only for providing the mean hex
  49. return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
  50. def srgb2oklab(pixels: np.array) -> np.array:
  51. return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
  52. @dataclass
  53. class Stats:
  54. # points (L, a, b)
  55. centroid: list[float]
  56. tilt: list[float]
  57. # scalar statistics
  58. variance: float
  59. chroma: float
  60. hue: float
  61. size: int
  62. # sRGB hex code of the centroid
  63. hex: str
  64. def calc_statistics(pixels: np.array, output_precision: int) -> Stats:
  65. # Euclidean norm squared by summing squared components
  66. sqnorms = (pixels ** 2).sum(axis=1)
  67. # centroid, the arithmetic mean pixel of the image
  68. centroid = pixels.mean(axis=0)
  69. # tilt, the arithmetic mean pixel of normalized image
  70. tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0)
  71. # variance = mean(||p||^2) - ||mean(p)||^2
  72. variance = sqnorms.mean(axis=0) - sum(centroid ** 2)
  73. # chroma^2 = a^2 + b^2
  74. chroma = np.hypot(pixels[:, 1], pixels[:, 2])
  75. # hue = atan2(b, a), but we need a circular mean
  76. # https://en.wikipedia.org/wiki/Circular_mean#Definition
  77. # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma
  78. # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma
  79. # and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean
  80. hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).sum(axis=0))
  81. return Stats(
  82. centroid=list(np.round(centroid, output_precision)),
  83. tilt=list(np.round(tilt, output_precision)),
  84. variance=round(variance, output_precision),
  85. chroma=round(chroma.mean(axis=0), output_precision),
  86. hue=round(hue % (2 * math.pi), output_precision),
  87. size=len(pixels),
  88. hex=oklab2hex(centroid),
  89. )
  90. def calc_clusters(pixels: np.array, output_precision: int, cluster_attempts=5, seed=0) -> list[Stats]:
  91. means, labels = min(
  92. (
  93. # Try k = 2, 3, and 4, and try a few times for each
  94. vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
  95. for k in (2, 3, 4)
  96. for i in range(cluster_attempts)
  97. ),
  98. key=lambda c:
  99. # Evaluate clustering by seeing the average difference in hue angle
  100. # between the centers. Maximizing this means the clusters are highly
  101. # distinct, which gives a sense of which k was best.
  102. # This is computed by normalizing the ab-plane projection of the means,
  103. # then applying a dot product to get the cosine of the angle
  104. # between them in that plane, which is the hue difference. Minimizing
  105. # this maximizes the differences in hues.
  106. # A different clustering algorithm may be more suited here, but this
  107. # is comparatively cheap while still producing reasonable results.
  108. (np.array([
  109. m1 @ m2
  110. for m1, m2 in combinations(
  111. c[0][:, 1:] / np.linalg.norm(c[0][:, 1:], axis=1)[:, np.newaxis], 2
  112. )
  113. ])).mean(axis=0)
  114. )
  115. return [calc_statistics(pixels[labels == i], output_precision) for i in range(len(means))]
  116. def get_srgb_pixels(img: Image.Image) -> np.array:
  117. rgb = []
  118. for fr in range(getattr(img, "n_frames", 1)):
  119. img.seek(fr)
  120. rgb += [
  121. [r, g, b]
  122. for r, g, b, a in img.convert("RGBA").getdata()
  123. if a > 0 and (r, g, b) != (0, 0, 0)
  124. ]
  125. return np.array(rgb)
  126. def search_files(image_dir: str) -> dict[str, list[str]]:
  127. files = defaultdict(list)
  128. for image_filename in os.listdir(image_dir):
  129. form_name = image_filename.rsplit("-", maxsplit=1)[0]
  130. files[form_name].append(Path(image_dir, image_filename))
  131. return files
  132. def log(*a, **kw):
  133. print(*a, **kw, flush=True)
  134. class Ingester:
  135. def __init__(self, dex: dict, output_precision: int, seed: int) -> None:
  136. self.lookup = {
  137. form["name"]: {
  138. "num": pkmn["num"],
  139. "species": pkmn["species"],
  140. **form,
  141. }
  142. for pkmn in dex.values()
  143. for form in pkmn["forms"]
  144. }
  145. self.seed = seed
  146. self.output_precision = output_precision
  147. def __call__(self, args: tuple[str, list[str]]) -> dict | Exception:
  148. form, filenames = args
  149. log(f"Ingesting {form}...")
  150. start_time = time.time()
  151. try:
  152. all_pixels = np.concatenate([
  153. get_srgb_pixels(Image.open(fn)) for fn in filenames
  154. ])
  155. except Exception as e:
  156. log(f"Error loading images for {form}: {e}")
  157. return e
  158. try:
  159. oklab = srgb2oklab(all_pixels)
  160. conv_time = time.time()
  161. total = calc_statistics(oklab, self.output_precision)
  162. calc_time = time.time()
  163. clusters = calc_clusters(oklab, self.output_precision, seed=self.seed)
  164. cluster_time = time.time()
  165. except Exception as e:
  166. log(f"Error calculating statistics for {form}: {e}")
  167. return e
  168. log(
  169. f"Completed {form}: ",
  170. f"{(cluster_time - start_time):.02f}s in total,",
  171. f"{(cluster_time - calc_time):.02f}s on clustering,",
  172. f"{(calc_time - conv_time):.02f}s on total calcs,",
  173. f"{(conv_time - start_time):.02f}s on read and conversion,",
  174. f"centroid {total.hex} and {len(clusters)} clusters"
  175. )
  176. return {
  177. **self.lookup[form],
  178. "total": asdict(total),
  179. "clusters": [asdict(c) for c in clusters],
  180. }
  181. def output_db(results: list[dict], db_file: str):
  182. if db_file == "-":
  183. log(json.dumps(results, indent=2))
  184. return
  185. with open(db_file, "w") as output:
  186. output.write("const database = [\n")
  187. for entry in results:
  188. output.write(" ")
  189. output.write(json.dumps(entry))
  190. output.write(",\n")
  191. output.write("]\n")
  192. def run_ingest(ingest: Ingester, filenames: list[Path], exec: Executor, db_file: str):
  193. to_process = defaultdict(list)
  194. missing = []
  195. for path in filenames:
  196. if path.is_file():
  197. form_name = path.name.rsplit("-", 1)[0]
  198. to_process[form_name].append(path)
  199. else:
  200. missing.append(path)
  201. log(f"Missing file: {path}")
  202. start = time.time()
  203. results = list(exec.map(ingest, to_process.items()))
  204. end = time.time()
  205. success = [r for r in results if not isinstance(r, Exception)]
  206. errors = [e for e in results if isinstance(e, Exception)]
  207. log(
  208. f"Finished ingest of {len(to_process)} forms",
  209. f"and {sum(len(fns) for fns in to_process.values())} files",
  210. f"in {(end - start):.2f}s",
  211. f"with {len(missing)} missing file(s)"
  212. f"and {len(errors)} error(s)"
  213. )
  214. for e in errors:
  215. log(f"Error: {e}")
  216. for m in missing:
  217. log(f"Missing: {e}")
  218. output_db(sorted(success, key=lambda e: (e["num"], e["name"])), db_file)
  219. log(f"Output {len(success)} entries to {db_file}")
  220. if __name__ == "__main__":
  221. from argparse import ArgumentParser
  222. parser = ArgumentParser(
  223. prog="Image Analyzer",
  224. description="Analyze and summarize images based on color",
  225. )
  226. parser.add_argument(
  227. "-p", "--precision", type=int, default=8, help="Round output to this many decimal places"
  228. )
  229. parser.add_argument(
  230. "-s", "--seed", type=int, default=230308, help="Clustering seed"
  231. )
  232. parser.add_argument(
  233. "-w", "--workers", type=int, default=4, help="Worker process count"
  234. )
  235. parser.add_argument(
  236. "-o", "--output", default="data/latest.db", help="Database file"
  237. )
  238. parser.add_argument(
  239. "-d", "--pokedex", default="data/pokedex.json", help="Pokedex file"
  240. )
  241. parser.add_argument("images", metavar="file", type=Path, nargs="+")
  242. args = parser.parse_args()
  243. with open(args.pokedex) as infile:
  244. dex = json.load(infile)
  245. ingest = Ingester(dex, args.precision, args.seed)
  246. from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
  247. with ThreadPoolExecutor(max_workers=args.workers) as pool:
  248. run_ingest(ingest, args.images, pool, args.output)