anim_ingest.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import io
  2. import math
  3. import itertools
  4. import multiprocessing
  5. from typing import Callable, NamedTuple
  6. from PIL import Image
  7. from bs4 import BeautifulSoup
  8. from colorspacious import cspace_convert
  9. from scipy.cluster import vq
  10. import requests
  11. import numpy as np
  12. import ingest
  13. extension = ".gif"
  14. cluster_seed = 20220328
  15. cluster_attempts = 10
  16. base = "https://play.pokemonshowdown.com/sprites/ani/"
  17. back_base = "https://play.pokemonshowdown.com/sprites/ani-back/"
  18. # removing all forms of a pokemon, and also pokestars
  19. start_with_filters = [
  20. # no significant visual changes
  21. "arceus-", "silvally-", "genesect-", "pumpkaboo-", "gourgeist-", "unown-", "giratina-",
  22. # cannot start the battle in alternate form
  23. "castform-", "cherrim-", "aegislash-", "xerneas-", "wishiwashi-",
  24. "eiscue-", "mimikyu-", "cramorant-", "morpeko-",
  25. # weird event thing
  26. "greninja-", "eevee-", "pikachu-", "zarude-", "magearna-",
  27. # pokestars
  28. "pokestar",
  29. ]
  30. # removing all forms of a type
  31. end_with_filters = [
  32. "-mega", "-megax", "-megay", "-primal", "-ultra",
  33. "-gmax", "-eternamax", "-totem", "-f", "-b", "-old", "-shiny",
  34. "-eternalflower", "-rapidstrikegmax",
  35. ]
  36. # removing pokemon entirely
  37. full_filters = [
  38. # darmanitan zen forms (cannot start in zen)
  39. "darmanitan-galarzen", "darmanitan-zen",
  40. # minior core forms (cannot start in anything but -meteor, renamed below)
  41. "minior", "minior-blue", "minior-green", "minior-indigo",
  42. "minior-orange", "minior-red", "minior-violet", "minior-yellow",
  43. # gimmighoul roaming (cannot start roaming)
  44. "gimmighoul-roaming",
  45. # palafin hero (cannot start as hero)
  46. "palafin-hero",
  47. # because it is a create-a-pokemon
  48. "argalis", "arghonaut", "brattler", "breezi", "caimanoe", "cawdet",
  49. "colossoil", "coribalis", "cupra", "cyclohm", "dorsoil", "duohm",
  50. "electrelk", "embirch", "fawnifer", "flarelm", "floatoy", "krillowatt",
  51. "krolowatt", "miasmite", "monohm", "necturine", "nohface", "privatyke",
  52. "pyroak", "rebble", "revenankh", "saharaja", "snugglow", "solotl",
  53. "swirlpool", "syclant", "syclar", "tactite", "venomicon",
  54. "venomicon-epilogue", "volkritter", "voodoll",
  55. "astrolotl", "aurumoth", "caribolt", "cawmodore", "chromera", "crucibelle",
  56. "equilibra", "fidgit", "jumbao", "justyke", "kerfluffle", "kitsunoh",
  57. "krilowatt", "malaconda", "miasmaw", "mollux", "naviathan", "necturna",
  58. "pajantom", "plasmanta", "pluffle", "protowatt", "scratchet", "smogecko",
  59. "smoguana", "smokomodo", "snaelstrom", "stratagem", "tomohawk", "volkraken", "voodoom",
  60. # typos/duplicates
  61. "0", "arctovolt", "buffalant", "burmy-plant", "darmanitan-standard",
  62. "deerling-spring", "deoxys-rs", "gastrodon-west",
  63. "klinklang-back", "krikretot", "marenie", "marowak-alolan", "meloetta-aria",
  64. "pichu-spikyeared", "polteageist-chipped", "rattata-alolan", "regidragon",
  65. "sawsbuck-spring", "shaymin-land", "shellos-west", "sinistea-chipped", "wormadam-plant",
  66. "pumpkabo-super", "magcargo%20", "meowstic-female",
  67. "ratatta-a", "ratatta-alola", "raticate-a",
  68. "rotom-h", "rotom-m", "rotom-s", "rotom-w",
  69. # not a pokemon
  70. "substitute", "egg", "egg-manaphy", "missingno",
  71. ]
  72. # force certain pokemon to stay
  73. force_keep = [ "meowstic-f", "unfezant-f", "pyroar-f" ]
  74. # rename certain pokemon after the fact
  75. rename = {
  76. # dash consistency
  77. "nidoranm": "nidoran-m",
  78. "nidoranf": "nidoran-f",
  79. "porygonz": "porygon-z",
  80. "tapubulu": "tapu-bulu",
  81. "tapufini": "tapu-fini",
  82. "tapukoko": "tapu-koko",
  83. "tapulele": "tapu-lele",
  84. "hooh": "ho-oh",
  85. "mimejr": "mime-jr",
  86. "mrmime": "mr-mime",
  87. "mrmime-galar": "mr-mime-galar",
  88. "mrrime": "mr-rime",
  89. "jangmoo": "jangmo-o",
  90. "hakamoo": "hakamo-o",
  91. "kommoo": "kommo-o",
  92. "typenull": "type-null",
  93. "oricorio-pompom": "oricorio-pom-pom",
  94. "necrozma-duskmane": "necrozma-dusk-mane",
  95. "necrozma-dawnwings": "necrozma-dawn-wings",
  96. "toxtricity-lowkey": "toxtricity-low-key",
  97. # rename forms
  98. "shellos": "shellos-west",
  99. "shaymin": "shaymin-land",
  100. "meloetta": "meloetta-aria",
  101. "keldeo": "keldeo-ordinary",
  102. "hoopa": "hoopa-confined",
  103. "burmy": "burmy-plant",
  104. "wormadam": "wormadam-plant",
  105. "deerling": "deerling-spring",
  106. "sawsbuck": "sawsbuck-spring",
  107. "vivillon": "vivillon-meadow",
  108. "basculin": "basculin-redstriped",
  109. "meowstic": "meowstic-male",
  110. "meowstic-f": "meowstic-female",
  111. "pyroar-f": "pyroar-female",
  112. "flabebe": "flabebe-red",
  113. "floette": "floette-red",
  114. "florges": "florges-red",
  115. "minior-meteor": "minior",
  116. "sinistea": "sinistea-phony",
  117. "polteageist": "polteageist-phony",
  118. "gastrodon": "gastrodon-west",
  119. "furfrou": "furfrou-natural",
  120. "wishiwashi": "wishiwashi-school",
  121. "tornadus": "tornadus-incarnate",
  122. "landorus": "landorus-incarnate",
  123. "thundurus": "thundurus-incarnate",
  124. "calyrex-ice": "calyrex-ice-rider",
  125. "calyrex-shadow": "calyrex-shadow-rider",
  126. "urshifu-rapidstrike": "urshifu-rapid-strike",
  127. "zacian": "zacian-hero",
  128. "zamazenta": "zamazenta-hero",
  129. }
  130. def get_all_pokemon(url: str, ext: str = extension) -> list[str]:
  131. # TODO clean this up
  132. soup = BeautifulSoup(requests.get(url).text, "html.parser")
  133. imgs = [href for a in soup.find_all("a") if (href := a.get("href")).endswith(ext)]
  134. return [
  135. g[:-4]
  136. for g in imgs
  137. if g in [name + ext for name in force_keep] or (
  138. g not in [full + ext for full in full_filters]
  139. and not any(g.startswith(f) for f in start_with_filters)
  140. and not any(g.endswith(f) for f in [ending + ext for ending in end_with_filters])
  141. )
  142. ]
  143. def load_image(base: str, name: str, ext: str = extension) -> Image:
  144. return Image.open(io.BytesIO(requests.get(base + name + ext).content))
  145. def get_all_pixels(im: Image) -> list[tuple[int, int, int]]:
  146. rgb_pixels = []
  147. for fr in range(getattr(im, "n_frames", 1)):
  148. im.seek(fr)
  149. rgb_pixels += [
  150. (r, g, b)
  151. for r, g, b, a in im.convert("RGBA").getdata()
  152. if not ingest.is_outline(r, g, b, a)
  153. ]
  154. return rgb_pixels
  155. def merge_dist_jab(p: np.array, q: np.array) -> float:
  156. pj, pa, pb = p
  157. qj, qa, qb = q
  158. light_diff = abs(pj - qj)
  159. hue_angle = math.acos((pa * qa + pb * qb) / math.sqrt((pa ** 2 + pb ** 2) * (qa ** 2 + qb ** 2))) * 180 / math.pi
  160. return light_diff if hue_angle <= 10 and light_diff <= 20 else None
  161. def merge_dist_rgb(p: np.array, q: np.array) -> float:
  162. return merge_dist_jab(*cspace_convert(np.array([p, q]), "sRGB255", "CAM02-UCS"))
  163. def score_clustering_jab(means: list[np.array]) -> float:
  164. score = 0
  165. count = 0
  166. for p, q in itertools.combinations(means, 2):
  167. # squared dist in the a-b plane
  168. _, pa, pb = p
  169. _, qa, qb = q
  170. score += (pa - qa) ** 2 + (pb - qb) ** 2
  171. count += 1
  172. return score / count
  173. def score_clustering_rgb(means: list[np.array]) -> float:
  174. return score_clustering_jab(list(cspace_convert(np.array(means), "sRGB255", "CAM02-UCS")))
  175. Stats = NamedTuple("Stats", [("size", int), ("inertia", float), ("mu", np.array), ("nu", np.array)])
  176. def merge_stats(s1: Stats, s2: Stats) -> Stats:
  177. ts = s1.size + s2.size
  178. f1 = s1.size / ts
  179. f2 = s2.size / ts
  180. return Stats(
  181. size=ts,
  182. inertia=s1.inertia * f1 + s2.inertia * f2,
  183. mu=s1.mu * f1 + s2.mu * f2,
  184. nu=s1.nu * f1 + s2.nu * f2,
  185. )
  186. def flatten_stats(ss: list[Stats], target_len: int = 40) -> list[float]:
  187. to_return = []
  188. for s in ss:
  189. to_return += [s.size, s.inertia, *s.mu, *s.nu]
  190. return to_return + ([0] * (target_len - len(to_return)))
  191. def compute_stats(
  192. pixels: np.array,
  193. clustering_scorer: Callable[[list[np.array]], float],
  194. merge_dist: Callable[[np.array, np.array], float],
  195. ) -> list[Stats]:
  196. total_stats = Stats(
  197. size=len(pixels),
  198. inertia=ingest.inertia(pixels),
  199. mu=ingest.mu(pixels),
  200. nu=ingest.nu(pixels),
  201. )
  202. # run k-means multiple times, for multiple k's, trying to maximize the clustering_scorer
  203. best = None
  204. for k in (2, 3, 4):
  205. for i in range(cluster_attempts):
  206. means, labels = vq.kmeans2(pixels.astype(float), k, minit="++", seed=cluster_seed + i)
  207. score = clustering_scorer(means)
  208. if best is None or best[0] < score:
  209. best = (score, means, labels)
  210. _, best_means, best_labels = best
  211. cluster_stats = []
  212. for i in range(len(best_means)):
  213. cluster_pixels = pixels[best_labels == i]
  214. cluster_stats.append(Stats(
  215. size=len(cluster_pixels),
  216. inertia=ingest.inertia(cluster_pixels),
  217. mu=best_means[i],
  218. nu=ingest.nu(cluster_pixels),
  219. ))
  220. # assuming there are still more than two clusters,
  221. # attempt to merge the closest if they're close enough
  222. if len(cluster_stats) > 2:
  223. # first, find all the options
  224. options = []
  225. for i, j in itertools.combinations(range(len(cluster_stats)), 2):
  226. ci = cluster_stats[i]
  227. cj = cluster_stats[j]
  228. if (dist := merge_dist(ci.mu, cj.mu)) is not None:
  229. rest = [c for k, c in enumerate(cluster_stats) if k not in (i, j)]
  230. options.append((dist, [merge_stats(ci, cj), *rest]))
  231. # if there are multiple options, use the closest,
  232. # otherwise leaves cluster_stats the same
  233. if len(options) > 0:
  234. cluster_stats = min(options, key=lambda x: x[0])[1]
  235. return [total_stats, *cluster_stats]
  236. def get_stats(name: str) -> list[float]:
  237. front = get_all_pixels(load_image(base, name))
  238. back = get_all_pixels(load_image(back_base, name))
  239. rgb_pixels = np.array(front + back)
  240. jab_pixels = cspace_convert(rgb_pixels, "sRGB255", "CAM02-UCS")
  241. jab_stats = flatten_stats(compute_stats(
  242. jab_pixels,
  243. score_clustering_jab,
  244. merge_dist_jab,
  245. ))[1:]
  246. rgb_stats = flatten_stats(compute_stats(
  247. rgb_pixels,
  248. score_clustering_rgb,
  249. merge_dist_rgb,
  250. ))[1:]
  251. return [len(rgb_pixels), *jab_stats, *rgb_stats]
  252. if __name__ == "__main__":
  253. pkmn = get_all_pokemon(back_base)
  254. print("Found", len(pkmn), "sprites...")
  255. errors = []
  256. def ingest_and_format(pair: tuple[int, str]) -> str:
  257. index, name = pair
  258. try:
  259. print(f"Ingesting #{index+1}: {name}...")
  260. stats = get_stats(name)
  261. format_name = rename.get(name, name)
  262. print(f"Finished #{index+1}: {name}, saving under {format_name}")
  263. return f' [ "{format_name}", {", ".join(str(n) for n in stats)} ],\n'
  264. except Exception as e:
  265. print(e)
  266. errors.append((name, e))
  267. with multiprocessing.Pool(4) as pool:
  268. stats = sorted(res for res in pool.imap_unordered(ingest_and_format, enumerate(pkmn), 100) if res is not None)
  269. print(f"Calculated {len(stats)} statistics, writing...")
  270. with open("database-v3.js", "w") as outfile:
  271. outfile.write("const databaseV3 = [\n")
  272. for line in sorted(stats):
  273. outfile.write(line)
  274. outfile.write("];\n")
  275. print("Errors:", errors)