analyze.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import math
  2. import os
  3. import time
  4. from collections import defaultdict
  5. from concurrent.futures import ProcessPoolExecutor
  6. from pathlib import Path
  7. from dataclasses import dataclass, asdict
  8. from itertools import combinations
  9. from typing import Callable
  10. import numpy as np
  11. from PIL import Image
  12. from scipy.cluster import vq
  13. # https://en.wikipedia.org/wiki/SRGB#Transformation
  14. linearize_srgb = np.vectorize(
  15. lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
  16. )
  17. delinearize_lrgb = np.vectorize(
  18. lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
  19. )
  20. # https://mina86.com/2019/srgb-xyz-matrix/
  21. RGB_TO_XYZ = np.array([
  22. [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
  23. [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
  24. [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
  25. ])
  26. XYZ_TO_RGB = [
  27. [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
  28. [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
  29. [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
  30. ]
  31. # https://bottosson.github.io/posts/oklab/
  32. XYZ_TO_LMS = np.array([
  33. [0.8189330101, 0.3618667424, -0.1288597137],
  34. [0.0329845436, 0.9293118715, 0.0361456387],
  35. [0.0482003018, 0.2643662691, 0.6338517070],
  36. ])
  37. RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
  38. LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
  39. LMS_TO_OKLAB = np.array([
  40. [0.2104542553, 0.7936177850, -0.0040720468],
  41. [1.9779984951, -2.4285922050, 0.4505937099],
  42. [0.0259040371, 0.7827717662, -0.8086757660],
  43. ])
  44. OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
  45. # round output to this many decimals
  46. OUTPUT_PRECISION = 6
  47. def oklab2hex(pixel: np.array) -> str:
  48. # no need for a vectorized version, this is only for providing the mean hex
  49. return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
  50. def srgb2oklab(pixels: np.array) -> np.array:
  51. return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
  52. @dataclass
  53. class Stats:
  54. # points (L, a, b)
  55. centroid: list[float]
  56. tilt: list[float]
  57. # scalar statistics
  58. variance: float
  59. chroma: float
  60. hue: float
  61. size: int
  62. # sRGB hex code of the centroid
  63. hex: str
  64. def calc_statistics(pixels: np.array) -> Stats:
  65. # Euclidean norm squared by summing squared components
  66. sqnorms = (pixels ** 2).sum(axis=1)
  67. # centroid, the arithmetic mean pixel of the image
  68. centroid = pixels.mean(axis=0)
  69. # tilt, the arithmetic mean pixel of normalized image
  70. tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0)
  71. # variance = mean(||p||^2) - ||mean(p)||^2
  72. variance = sqnorms.mean(axis=0) - sum(centroid ** 2)
  73. # chroma^2 = a^2 + b^2
  74. chroma = np.hypot(pixels[:, 1], pixels[:, 2])
  75. # hue = atan2(b, a), but we need a circular mean
  76. # https://en.wikipedia.org/wiki/Circular_mean#Definition
  77. # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma
  78. # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma
  79. # and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean
  80. hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).sum(axis=0))
  81. return Stats(
  82. centroid=list(np.round(centroid, OUTPUT_PRECISION)),
  83. tilt=list(np.round(tilt, OUTPUT_PRECISION)),
  84. variance=round(variance, OUTPUT_PRECISION),
  85. chroma=round(chroma.mean(axis=0), OUTPUT_PRECISION),
  86. hue=round(hue % (2 * math.pi), OUTPUT_PRECISION),
  87. size=len(pixels),
  88. hex=oklab2hex(centroid),
  89. )
  90. def calc_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
  91. means, labels = max(
  92. (
  93. # Try k = 2, 3, and 4, and try a few times for each
  94. vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
  95. for k in (2, 3, 4)
  96. for i in range(cluster_attempts)
  97. ),
  98. key=lambda c:
  99. # Evaluate clustering by seeing the average distance in the ab-plane
  100. # between the centers. Maximizing this means the clusters are highly
  101. # distinct, which gives a sense of which k was best.
  102. # A different clustering algorithm may be more suited here, but this
  103. # is comparatively cheap while still producing reasonable results.
  104. (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2)
  105. .sum(axis=1)
  106. .mean(axis=0)
  107. )
  108. return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
  109. def get_srgb_pixels(img: Image.Image) -> np.array:
  110. rgb = []
  111. for fr in range(getattr(img, "n_frames", 1)):
  112. img.seek(fr)
  113. rgb += [
  114. [r, g, b]
  115. for r, g, b, a in img.convert("RGBA").getdata()
  116. if a > 0 and (r, g, b) != (0, 0, 0)
  117. ]
  118. return np.array(rgb)
  119. def search_files(image_dir: str) -> dict[str, list[str]]:
  120. files = defaultdict(list)
  121. for image_filename in os.listdir(image_dir):
  122. form_name = image_filename.rsplit("-", maxsplit=1)[0]
  123. files[form_name].append(Path(image_dir, image_filename))
  124. return files
  125. class Ingester:
  126. def __init__(self, dex: dict) -> None:
  127. self.lookup = {
  128. form["name"]: {
  129. "num": pkmn["num"],
  130. "species": pkmn["species"],
  131. **form,
  132. }
  133. for pkmn in dex.values()
  134. for form in pkmn["forms"]
  135. }
  136. def __call__(self, args: tuple[str, list[str]]) -> tuple[Stats, list[Stats]]:
  137. form, filenames = args
  138. print(f"Ingesting {form}...")
  139. start_time = time.time()
  140. all_pixels = np.concatenate([
  141. get_srgb_pixels(Image.open(fn)) for fn in filenames
  142. ])
  143. oklab = srgb2oklab(all_pixels)
  144. conv_time = time.time()
  145. total = calc_statistics(oklab)
  146. calc_time = time.time()
  147. clusters = [asdict(c) for c in calc_clusters(oklab, seed=seed)]
  148. cluster_time = time.time()
  149. print(
  150. f"Completed {form}: ",
  151. f"{(cluster_time - start_time):.02f}s in total,",
  152. f"{(cluster_time - calc_time):.02f}s on clustering,",
  153. f"{(calc_time - conv_time):.02f}s on total calcs,",
  154. f"{(conv_time - start_time):.02f}s on read and conversion,",
  155. f"centroid {total.hex} and {len(clusters)} clusters"
  156. )
  157. return {
  158. **self.lookup[form],
  159. "total": asdict(total),
  160. "clusters": clusters,
  161. }
  162. def output_db(results: list[dict], db_file: str):
  163. with open(db_file, "w") as output:
  164. output.write("const database = [\n")
  165. for entry in results:
  166. output.write(" ")
  167. output.write(json.dumps(entry))
  168. output.write(",\n")
  169. output.write("]\n")
  170. if __name__ == "__main__":
  171. from sys import argv
  172. dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
  173. image_dir = argv[2] if len(argv) > 2 else "images"
  174. seed = int(argv[3]) if len(argv) > 3 else 230308
  175. db_file = argv[4] if len(argv) > 4 else "data/latest.db"
  176. import json
  177. with open(dex_file) as infile:
  178. dex = json.load(infile)
  179. ingest = Ingester(dex)
  180. to_process = search_files(image_dir)
  181. start = time.time()
  182. with ProcessPoolExecutor(4) as pool:
  183. results = list(pool.map(ingest, to_process.items()))
  184. end = time.time()
  185. print(
  186. f"Finished ingest of {len(to_process)} forms",
  187. f"and {sum(len(fns) for fns in to_process.values())} files",
  188. f"in {(end - start):.2f}s"
  189. )
  190. output_db(sorted(results, key=lambda e: (e["num"], e["name"])), db_file)
  191. print(f"Output {len(results)} entries to {db_file}")