analyze.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import math
  2. from pathlib import Path
  3. from dataclasses import dataclass, asdict
  4. from itertools import combinations
  5. import numpy as np
  6. from PIL import Image
  7. from scipy.cluster import vq
  8. # https://en.wikipedia.org/wiki/SRGB#Transformation
  9. linearize_srgb = np.vectorize(
  10. lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
  11. )
  12. delinearize_lrgb = np.vectorize(
  13. lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
  14. )
  15. # https://mina86.com/2019/srgb-xyz-matrix/
  16. RGB_TO_XYZ = np.array([
  17. [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
  18. [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
  19. [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
  20. ])
  21. XYZ_TO_RGB = [
  22. [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
  23. [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
  24. [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
  25. ]
  26. # https://bottosson.github.io/posts/oklab/
  27. XYZ_TO_LMS = np.array([
  28. [0.8189330101, 0.3618667424, -0.1288597137],
  29. [0.0329845436, 0.9293118715, 0.0361456387],
  30. [0.0482003018, 0.2643662691, 0.6338517070],
  31. ])
  32. RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
  33. LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
  34. LMS_TO_OKLAB = np.array([
  35. [0.2104542553, 0.7936177850, -0.0040720468],
  36. [1.9779984951, -2.4285922050, 0.4505937099],
  37. [0.0259040371, 0.7827717662, -0.8086757660],
  38. ])
  39. OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
  40. # round output to this many decimals
  41. OUTPUT_PRECISION = 6
  42. def oklab2hex(pixel: np.array) -> str:
  43. # no need for a vectorized version, this is only for providing the mean hex
  44. return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
  45. def srgb2oklab(pixels: np.array) -> np.array:
  46. return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
  47. @dataclass
  48. class Stats:
  49. # points (L, a, b)
  50. centroid: list[float]
  51. tilt: list[float]
  52. # scalar statistics
  53. size: int
  54. variance: float
  55. chroma: float
  56. hue: float
  57. # sRGB hex code of the centroid
  58. hex: str
  59. def calc_statistics(pixels: np.array) -> Stats:
  60. # Euclidean norm squared by summing squared components
  61. sqnorms = (pixels ** 2).sum(axis=1)
  62. # centroid, the arithmetic mean pixel of the image
  63. centroid = pixels.mean(axis=0)
  64. # tilt, the arithmetic mean pixel of normalized image
  65. tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0)
  66. # variance = mean(||p||^2) - ||mean(p)||^2
  67. variance = sqnorms.mean(axis=0) - sum(centroid ** 2)
  68. # chroma^2 = a^2 + b^2
  69. chroma = np.hypot(pixels[:, 1], pixels[:, 2])
  70. # hue = atan2(b, a), but we need a circular mean
  71. # https://en.wikipedia.org/wiki/Circular_mean#Definition
  72. # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma
  73. # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma
  74. # and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean
  75. hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).sum(axis=0))
  76. return Stats(
  77. centroid=list(np.round(centroid, OUTPUT_PRECISION)),
  78. tilt=list(np.round(tilt, OUTPUT_PRECISION)),
  79. size=len(pixels),
  80. variance=round(variance, OUTPUT_PRECISION),
  81. chroma=round(chroma.mean(axis=0), OUTPUT_PRECISION),
  82. hue=round(hue % (2 * math.pi), OUTPUT_PRECISION),
  83. hex=oklab2hex(centroid),
  84. )
  85. def calc_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
  86. means, labels = max(
  87. (
  88. # Try k = 2, 3, and 4, and try a few times for each
  89. vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
  90. for k in (2, 3, 4)
  91. for i in range(cluster_attempts)
  92. ),
  93. key=lambda c:
  94. # Evaluate clustering by seeing the average distance in the ab-plane
  95. # between the centers. Maximizing this means the clusters are highly
  96. # distinct, which gives a sense of which k was best.
  97. # A different clustering algorithm may be more suited here, but this
  98. # is comparatively cheap while still producing reasonable results.
  99. (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2)
  100. .sum(axis=1)
  101. .mean(axis=0)
  102. )
  103. return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
  104. def get_srgb_pixels(img: Image.Image) -> np.array:
  105. rgb = []
  106. for fr in range(getattr(img, "n_frames", 1)):
  107. img.seek(fr)
  108. rgb += [
  109. [r, g, b]
  110. for r, g, b, a in img.convert("RGBA").getdata()
  111. if a > 0 and (r, g, b) != (0, 0, 0)
  112. ]
  113. return np.array(rgb)
  114. if __name__ == "__main__":
  115. from sys import argv
  116. dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
  117. image_dir = argv[2] if len(argv) > 2 else "images"
  118. seed = int(argv[3]) if len(argv) > 3 else 230308
  119. import os
  120. from collections import defaultdict
  121. to_process = defaultdict(list)
  122. for image_filename in os.listdir(image_dir):
  123. form_name = image_filename.rsplit("-", maxsplit=1)[0]
  124. to_process[form_name].append(Path(image_dir, image_filename))
  125. # TODO multiproc
  126. database = []
  127. for form, image_files in to_process.items():
  128. all_pixels = np.concatenate([
  129. get_srgb_pixels(Image.open(fn)) for fn in image_files
  130. ])
  131. oklab = srgb2oklab(all_pixels)
  132. database.append({
  133. "name": form,
  134. # TODO also get dex info - species, color, etc.
  135. "total": asdict(calc_statistics(oklab)),
  136. "clusters": [asdict(c) for c in calc_clusters(oklab, seed=seed)],
  137. })
  138. # TODO real output
  139. import json
  140. print(json.dumps(database, indent=2))