analyze.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import math
  2. from dataclasses import dataclass
  3. from itertools import combinations
  4. import numpy as np
  5. from PIL import Image
  6. from scipy.cluster import vq
  7. # https://en.wikipedia.org/wiki/SRGB#Transformation
  8. linearize_srgb = np.vectorize(
  9. lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
  10. )
  11. delinearize_lrgb = np.vectorize(
  12. lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
  13. )
  14. # https://mina86.com/2019/srgb-xyz-matrix/
  15. RGB_TO_XYZ = np.array([
  16. [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
  17. [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
  18. [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
  19. ])
  20. XYZ_TO_RGB = [
  21. [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
  22. [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
  23. [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
  24. ]
  25. # https://bottosson.github.io/posts/oklab/
  26. XYZ_TO_LMS = np.array([
  27. [0.8189330101, 0.3618667424, -0.1288597137],
  28. [0.0329845436, 0.9293118715, 0.0361456387],
  29. [0.0482003018, 0.2643662691, 0.6338517070],
  30. ])
  31. RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
  32. LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
  33. LMS_TO_OKLAB = np.array([
  34. [0.2104542553, 0.7936177850, -0.0040720468],
  35. [1.9779984951, -2.4285922050, 0.4505937099],
  36. [0.0259040371, 0.7827717662, -0.8086757660],
  37. ])
  38. OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
  39. # round output to this many decimals
  40. OUTPUT_PRECISION = 6
  41. def oklab2hex(pixel: np.array) -> str:
  42. # no need for a vectorized version, this is only for providing the mean hex
  43. return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
  44. def srgb2oklab(pixels: np.array) -> np.array:
  45. return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
  46. @dataclass
  47. class Stats:
  48. # points (L, a, b)
  49. centroid: list[float]
  50. tilt: list[float]
  51. # scalar statistics
  52. size: int
  53. variance: float
  54. chroma: float
  55. hue: float
  56. # sRGB hex code of the centroid
  57. hex: str
  58. def calc_statistics(pixels: np.array) -> Stats:
  59. # Euclidean norm squared by summing squared components
  60. sqnorms = (pixels ** 2).sum(axis=1)
  61. # centroid, the arithmetic mean pixel of the image
  62. centroid = pixels.mean(axis=0)
  63. # tilt, the arithmetic mean pixel of normalized image
  64. tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0)
  65. # variance = mean(||p||^2) - ||mean(p)||^2
  66. variance = sqnorms.mean(axis=0) - sum(centroid ** 2)
  67. # chroma^2 = a^2 + b^2
  68. chroma = np.hypot(pixels[:, 1], pixels[:, 2])
  69. # hue = atan2(b, a), but we need a circular mean
  70. # https://en.wikipedia.org/wiki/Circular_mean#Definition
  71. # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma
  72. # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma
  73. # and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean
  74. hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).sum(axis=0))
  75. return Stats(
  76. centroid=list(np.round(centroid, OUTPUT_PRECISION)),
  77. tilt=list(np.round(tilt, OUTPUT_PRECISION)),
  78. size=len(pixels),
  79. variance=math.round(variance, OUTPUT_PRECISION),
  80. chroma=math.round(chroma.mean(axis=0), OUTPUT_PRECISION),
  81. hue=math.round(hue % (2 * math.pi), OUTPUT_PRECISION),
  82. hex=oklab2hex(centroid),
  83. )
  84. def find_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
  85. means, labels = max(
  86. (
  87. # Try k = 2, 3, and 4, and try a few times for each
  88. vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
  89. for k in (2, 3, 4)
  90. for i in range(cluster_attempts)
  91. ),
  92. key=lambda c:
  93. # Evaluate clustering by seeing the average distance in the ab-plane
  94. # between the centers. Maximizing this means the clusters are highly
  95. # distinct, which gives a sense of which k was best.
  96. # A different clustering algorithm may be more suited here, but this
  97. # is comparatively cheap while still producing reasonable results.
  98. (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2)
  99. .sum(axis=1)
  100. .mean(axis=0)
  101. )
  102. return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
  103. def get_pixels(img: Image.Image) -> np.array:
  104. rgb = []
  105. for fr in range(getattr(img, "n_frames", 1)):
  106. img.seek(fr)
  107. rgb += [
  108. [r, g, b]
  109. for r, g, b, a in img.convert("RGBA").getdata()
  110. if a > 0 and (r, g, b) != (0, 0, 0)
  111. ]
  112. return srgb2oklab(np.array(rgb))
  113. if __name__ == "__main__":
  114. print("TODO")