浏览代码

wip, just commenting out everything for now

Kirk Trombley 2 年之前
父节点
当前提交
0002890d93
共有 1 个文件被更改,包括 104 次插入308 次删除
  1. 104 308
      ingest2.py

+ 104 - 308
ingest2.py

@@ -1,20 +1,5 @@
-import math
-import asyncio
-import multiprocessing
-import json
-from collections import defaultdict
-from io import BytesIO
-from typing import NamedTuple, Generator
-from itertools import combinations
-
-import numpy as np
-from PIL import Image
-from aiohttp import ClientSession
-from scipy.cluster import vq
-
 """
 Goals:
- + Single module
  + Use OKLab
  + Improved clustering logic
  + Parallel, in the same way as anim-ingest
@@ -22,301 +7,112 @@ Goals:
  + Include more info about the pokemon (form, display name, icon sprite source)
  + Include megas/gmax/etc, tagged so the UI can filter them
  * Include more images (get more stills from pokemondb + serebii)
- * Include shinies 
+ * Include shinies
  * Fallback automatically (try showdown animated, then showdown gen5, then pdb)
  * Filtering system more explicit and easier to work around
  * Output a record of ingest for auditing
  * Automatic retry of a partially failed ingest, using record
 """
-
-# https://en.wikipedia.org/wiki/SRGB#Transformation
-linearize_srgb = np.vectorize(
-  lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
-)
-delinearize_lrgb = np.vectorize(
-  lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
-)
-# https://mina86.com/2019/srgb-xyz-matrix/
-RGB_TO_XYZ = np.array([
-  [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
-  [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
-  [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
-])
-XYZ_TO_RGB = [
-  [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
-  [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
-  [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
-]
-
-# https://bottosson.github.io/posts/oklab/
-XYZ_TO_LMS = np.array([
-  [0.8189330101, 0.3618667424, -0.1288597137],
-  [0.0329845436, 0.9293118715, 0.0361456387],
-  [0.0482003018, 0.2643662691, 0.6338517070],
-])
-RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
-LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
-LMS_TO_OKLAB = np.array([
-  [0.2104542553, 0.7936177850, -0.0040720468],
-  [1.9779984951, -2.4285922050, 0.4505937099],
-  [0.0259040371, 0.7827717662, -0.8086757660],
-])
-OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
-
-
-def oklab2hex(pixel: np.array) -> str:
-  # no need for a vectorized version, this is only for providing the mean hex
-  return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
-
-
-def srgb2oklab(pixels: np.array) -> np.array:
-  return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
-
-
-Stats = NamedTuple("Stats", [
-  ("size", int),
-  ("variance", float),
-  ("stddev", float),
-  ("hex", str),
-  ("Lbar", float),
-  ("abar", float),
-  ("bbar", float),
-  ("Cbar", float),
-  ("hbar", float),
-  ("Lhat", float),
-  ("ahat", float),
-  ("bhat", float),
-])
-
-Data = NamedTuple("Data", [
-  ("total", Stats),
-  ("clusters", list[Stats]),
-])
-
-FormInfo = NamedTuple("FormData", [
-  ("name", str),
-  ("traits", list[str]),
-  ("types", list[str]),
-  ("color", str),
-  ("data", Data | None),
-])
-
-Pokemon = NamedTuple("Pokemon", [
-  ("num", int),
-  ("species", str),
-  ("sprite", str | None),
-  ("forms", list[FormInfo]),
-])
-
-
-def calc_statistics(pixels: np.array) -> Stats:
-  # mean pixel of the image, (L-bar, a-bar, b-bar)
-  mean = pixels.mean(axis=0)
-  # square each component
-  squared = pixels ** 2
-  # Euclidean norm squared by summing squared components
-  sqnorms = squared.sum(axis=1)
-  # mean pixel of normalized image, (L-hat, a-hat, b-hat)
-  tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0)
-  # variance = mean(||p||^2) - ||mean(p)||^2
-  variance = sqnorms.mean(axis=0) - sum(mean ** 2)
-  # chroma^2 = a^2 + b^2
-  chroma = np.sqrt(squared[:, 1:].sum(axis=1))
-  # hue = atan2(b, a), but we need a circular mean
-  # https://en.wikipedia.org/wiki/Circular_mean#Definition
-  # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma
-  # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma
-  hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).mean(axis=0))
-  return Stats(
-    size=len(pixels),
-    variance=variance,
-    stddev=math.sqrt(variance),
-    hex=oklab2hex(mean),
-    Lbar=mean[0],
-    abar=mean[1],
-    bbar=mean[2],
-    Cbar=chroma.mean(axis=0),
-    hbar=hue * 180 / math.pi,
-    Lhat=tilt[0],
-    ahat=tilt[1],
-    bhat=tilt[2],
-  )
-
-
-def find_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
-  means, labels = max(
-    (
-      # Try k = 2, 3, and 4, and try a few times for each
-      vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
-      for k in (2, 3, 4)
-      for i in range(cluster_attempts)
-    ),
-    key=lambda c:
-      # Evaluate clustering by seeing the average distance in the ab-plane
-      # between the centers. Maximizing this means the clusters are highly
-      # distinct, which gives a sense of which k was best.
-      (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2)
-        .sum(axis=1)
-        .mean(axis=0)
-  )
-  return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
-
-
-def get_pixels(img: Image) -> np.array:
-  rgb = []
-  for fr in range(getattr(img, "n_frames", 1)):
-    img.seek(fr)
-    rgb += [
-      [r, g, b]
-      for r, g, b, a in img.convert("RGBA").getdata()
-      if a > 0 and (r, g, b) != (0, 0, 0)
-    ]
-  return srgb2oklab(np.array(rgb))
-
-
-async def load_image(session: ClientSession, url: str) -> Image.Image:
-  async with session.get(url) as res:
-    return Image.open(BytesIO(await res.read()))
-
-
-async def load_all_images(urls: list[str]) -> list[Image.Image]:
-  async with ClientSession() as session:
-    # TODO error handling
-    return await asyncio.gather(*(load_image(session, url) for url in urls))
-
-
-def get_data(urls: list[str], seed=0) -> Data:
-  images = asyncio.get_event_loop().run_until_complete(load_all_images(urls))
-  # TODO error handling
-  pixels = np.concatenate([get_pixels(img) for img in images])
-  return Data(
-    total=calc_statistics(pixels),
-    clusters=find_clusters(pixels, seed=seed),
-  )
-
-
-def get_traits(species: str, form: dict) -> list[str]:
-  kind = form["formeKind"]
-  traits = []
-  if kind in ("mega", "mega-x", "mega-y", "primal"):
-    traits.extend(("mega", "nostart"))
-  if kind in ("gmax", "eternamax", "rapid-strike-gmax"):
-    traits.extend(("gmax", "nostart"))
-  if kind in ("alola", "galar", "hisui", "galar", "paldea"):
-    traits.extend(("regional", kind))
-
-  # special cases
-  if species == "Tauros" and "-paldea" in kind:
-    # paldean tauros has dumb names
-    traits.extend(("regional", "paldea"))
-  if species == "Minior" and kind != "meteor":
-    # minior can only start the battle in meteor form
-    traits.append("nostart")
-  if species == "Darmanitan" and "zen" in kind:
-    # darmanitan cannot start in zen form
-    traits.append("nostart")
-    if "galar" in kind:
-      # also there's a galar-zen form to handle
-      traits.extend(("regional", "galar"))
-  if species == "Palafin" and kind == "hero":
-    # palafin can only start in zero form
-    traits.append("nostart")
-  if species == "Gimmighoul" and kind == "roaming":
-    # gimmighoul roaming is only in PGO
-    traits.append("nostart")
-
-  return list(set(traits))
-
-
-# https://bulbapedia.bulbagarden.net/wiki/List_of_Pok%C3%A9mon_with_gender_differences
-# there are some pokemon with notable gender diffs that the dex doesn't cover
-# judgement calls made arbitrarily
-GENDER_DIFFS = (
-  "hippopotas", "hippowdon", 
-  "unfezant", "frillish", "jellicent",
-  "pyroar",
-  # meowstic, indeedee, basculegion, oinkologne are already handled in the dex
-)
-
-
-def load_pokedex(path: str) -> Generator[Pokemon, None, None]:
-  with open(path) as infile:
-    pkdx_raw = json.load(infile)
-
-  pkdx = defaultdict(list)
-
-  for key, entry in pkdx_raw.items():
-    num = entry["num"]
-    # non-cosmetic forms get separate entries automatically
-    # but keeping the separate unown forms would be ridiculous
-    if key != "unown" and len(cosmetic := entry.get("cosmeticFormes", [])) > 0:
-      cosmetic.append(f'{entry["name"]}-{entry["baseForme"]}')
-      if key == "alcremie":
-        # oh god this thing
-        cosmetic = [
-          f"{cf}-{sweet}"
-          for cf in cosmetic
-          for sweet in [
-            "Strawberry", "Berry", "Love", "Star",
-            "Clover", "Flower", "Ribbon",
-          ]
-        ]
-      pkdx[num].extend({
-        **entry,
-        "forme": cf.replace(" ", "-"),
-        "formeKind": "cosmetic",
-      } for cf in cosmetic)
-    elif key in GENDER_DIFFS:
-      pkdx[num].append({
-        **entry,
-        "forme": f'{entry["name"]}-M',
-        "formeKind": "cosmetic",
-      })
-      pkdx[num].append({
-        **entry,
-        "forme": f'{entry["name"]}-F',
-        "formeKind": "cosmetic",
-      })
-    else:
-      pkdx[num].append({
-        **entry,
-        "forme": entry["name"],
-        "formeKind": entry.get("forme", "base").lower(),
-      })
-
-  for i in range(1, max(pkdx.keys()) + 1):
-    forms = pkdx[i]
-    # double check there's no skipped entries
-    assert len(forms) > 0
-    # yield forms
-    species = forms[0].get("baseSpecies", forms[0]["name"])
-    yield Pokemon(
-      num=i,
-      species=species,
-      sprite=None,  # found later
-      forms=[
-        FormInfo(
-          name=f.get("forme", f["name"]),
-          traits=get_traits(species, f),
-          types=f["types"],
-          color=f["color"],
-          data=None,  # found later
-        ) for f in forms
-      ]
-    )
-
-
-if __name__ == "__main__":
-  from sys import argv
-  dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
-  out_file = argv[2] if len(argv) > 2 else "data/database-latest.js"
-  log_file = argv[3] if len(argv) > 2 else "ingest.log"
-
-  pkdx = list(load_pokedex())
-
-  print(json.dumps(pkdx[5], indent=2))
-  print(json.dumps(pkdx[285], indent=2))
-  print(json.dumps(pkdx[773], indent=2))
-
-  # with multiprocessing.Pool(4) as pool:
-  #   yield from pool.imap_unordered(lambda n: get_data(n, seed=seed), pokemon, 100)
+# async def load_image(session: ClientSession, url: str) -> Image.Image:
+#   async with session.get(url) as res:
+#     res.raise_for_status()
+#     return Image.open(BytesIO(await res.read()))
+
+
+# async def load_all_images(urls: list[str]) -> tuple[list[Image.Image], list[Exception]]:
+#   async with ClientSession() as session:
+#     results = await asyncio.gather(
+#       *(load_image(session, url) for url in urls),
+#       return_exceptions=True
+#     )
+#   success = []
+#   errors = []
+#   for r in results:
+#     (success if isinstance(r, Image.Image) else errors).append(r)
+#   return success, errors
+
+
+# def get_urls(target: Pokemon, form: FormInfo) -> list[str]:
+#   lower_name = form.name.lower()
+#   return [
+#     f"https://play.pokemonshowdown.com/sprites/ani/{lower_name}.gif",
+#     f"https://play.pokemonshowdown.com/sprites/ani-back/{lower_name}.gif",
+#     f"https://play.pokemonshowdown.com/sprites/gen5/{lower_name}.png",
+#     f"https://play.pokemonshowdown.com/sprites/gen5-back/{lower_name}.png",
+#     f"https://img.pokemondb.net/sprites/home/normal/{lower_name}.png",
+#     # TODO other sources - want to make sure we never cross contaminate though...
+#     # if we pull the wrong form for something it will be a nightmare to debug
+#     # f"https://www.serebii.net/scarletviolet/pokemon/new/{target.num}-{???}.png"
+#     # f"https://www.serebii.net/pokemon/art/{target.num}-{???}.png"
+#   ]
+
+
+# async def set_data(target: Pokemon, seed=0) -> list[Exception]:
+#   all_errors = []
+#   for form in target.forms:
+#     print(f" #{target.num} - Ingesting Form: {form.name}")
+#     urls = get_urls(target, form)
+#     print(f"  #{target.num} - Attempting {len(urls)} potential sources")
+#     images, errors = await load_all_images(urls)
+#     all_errors.extend(errors)
+#     print(f"  #{target.num} - Loaded {len(images)} sources")
+#     try:
+#       pixels = np.concatenate([get_pixels(img) for img in images])
+#       print(f"  #{target.num} - Summarizing {len(pixels)} total pixels")
+#       total = calc_statistics(pixels)
+#       print(f"  #{target.num} - Begin clustering")
+#       clusters = find_clusters(pixels, seed=seed)
+#       print(f"  #{target.num} - End clustering, chose k={len(clusters)}")
+#       form.data = Data(total=total, clusters=clusters)
+#     except Exception as e:
+#       all_errors.append(e)
+#   return all_errors
+
+
+
+# async def ingest(pool_size: int, seed: int) -> tuple[list[str], list[str]]:
+#   computed = []
+#   errors = []
+#   loop = asyncio.get_event_loop()
+#   with ProcessPoolExecutor(pool_size) as exec:
+#     print(f"Ingesting #{start} - #{end}")
+#     for pkmn in pkdx[start - 1:end]:
+#       print(f"Ingesting #{pkmn.num}: {pkmn.species}...")
+#       new_errors = await set_data(pkmn, seed)
+#       loop.run_in_executor(exec, set_data, pkmn, seed)
+
+    
+
+#       computed.append(loop.run_in_executor(pool, ingest(p)))
+
+#   try:
+#     errors.extend(new_errors)
+#     print(f"Finished #{pkmn.num}: {len(new_errors)} error(s)")
+#     return json.dumps(asdict(pkmn))
+#   except Exception as e:
+#     print(e)
+#     errors.append(e)
+
+# if __name__ == "__main__":
+  # from sys import argv
+  # dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
+  # out_file = argv[2] if len(argv) > 2 else "data/database-latest.db"
+  # dex_span = argv[3] if len(argv) > 3 else "1-151"
+  # log_file = argv[4] if len(argv) > 4 else "errors-latest.log"
+  # set_seed = argv[5] if len(argv) > 5 else "20230304"
+
+  # start, end = map(int, dex_span.split("-", maxsplit=1))
+  # seed = int(set_seed)
+  # errors = []
+
+  # pkdx = list(load_pokedex(dex_file))
+  # loop = asyncio.new_event_loop()
+
+  # with open(log_file, "w") as log:
+  #   # TODO better logging
+  #   log.writelines(str(e) for e in errors)
+
+  # with open(out_file, "a") as db:
+  #   for _, line in computed:
+  #     db.write(line)
+  #     db.write("\n")