123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- import math
- import os
- import time
- import json
- from collections import defaultdict
- from concurrent.futures import Executor
- from pathlib import Path
- from dataclasses import dataclass, asdict
- from itertools import combinations
- import numpy as np
- from PIL import Image
- from scipy.cluster import vq
- from scipy.spatial.distance import cdist, euclidean
- # https://en.wikipedia.org/wiki/SRGB#Transformation
- linearize_srgb = np.vectorize(
- lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
- )
- delinearize_lrgb = np.vectorize(
- lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
- )
- # https://mina86.com/2019/srgb-xyz-matrix/
- RGB_TO_XYZ = np.array([
- [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
- [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
- [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
- ])
- XYZ_TO_RGB = [
- [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
- [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
- [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
- ]
- # https://bottosson.github.io/posts/oklab/
- XYZ_TO_LMS = np.array([
- [0.8189330101, 0.3618667424, -0.1288597137],
- [0.0329845436, 0.9293118715, 0.0361456387],
- [0.0482003018, 0.2643662691, 0.6338517070],
- ])
- RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
- LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
- LMS_TO_OKLAB = np.array([
- [0.2104542553, 0.7936177850, -0.0040720468],
- [1.9779984951, -2.4285922050, 0.4505937099],
- [0.0259040371, 0.7827717662, -0.8086757660],
- ])
- OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
- # round output to this many decimals
- OUTPUT_PRECISION = 8
- def oklab2hex(pixel: np.array) -> str:
- # no need for a vectorized version, this is only for providing the mean hex
- return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
- def srgb2oklab(pixels: np.array) -> np.array:
- return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
- # https://stackoverflow.com/a/30305181
- def geometric_median(X: np.array, eps=1e-5) -> np.array:
- y = np.mean(X, 0)
- while True:
- D = cdist(X, [y])
- nonzeros = (D != 0)[:, 0]
- Dinv = 1 / D[nonzeros]
- Dinvs = np.sum(Dinv)
- W = Dinv / Dinvs
- T = np.sum(W * X[nonzeros], 0)
- num_zeros = len(X) - np.sum(nonzeros)
- if num_zeros == 0:
- y1 = T
- elif num_zeros == len(X):
- return y
- else:
- R = (T - y) * Dinvs
- r = np.linalg.norm(R)
- rinv = 0 if r == 0 else num_zeros / r
- y1 = max(0, 1 - rinv) * T + min(1, rinv) * y
- if euclidean(y, y1) < eps:
- return y1
- y = y1
- @dataclass
- class Stats:
- # vector statistics
- centroid: list[float] # (L, a, b)
- median: list[float] # (L, a, b)
- stddev: list[float] # (L, a, b)
- tilt: list[float] # (L, a, b)
- chroma: list[float] # (mean, stddev)
- # scalar statistics
- hue: float
- size: int
- # sRGB hex code of the centroid and median
- centroidHex: str
- medianHex: str
- def calc_statistics(pixels: np.array, output_precision: int) -> Stats:
- # centroid, the arithmetic mean pixel of the image
- centroid = pixels.mean(axis=0)
- # raw second moment, for each channel of the pixels
- raw_second_moment = (pixels ** 2).mean(axis=0)
- # stddev, the sqrt of the variance of each channel of the image
- # variance_x = mean(p_x^2) - mean(p_x)^2 = rsm_x - centroid_x^2
- # note, summing those gives a total "variance" in color
- stddev = np.sqrt(raw_second_moment - centroid ** 2)
- # tilt, the arithmetic mean pixel of normalized image
- tilt = (pixels / np.linalg.norm(pixels, axis=1)[:, np.newaxis]).mean(axis=0)
- # chroma^2 = a^2 + b^2
- chromas = np.hypot(pixels[:, 1], pixels[:, 2])
- chroma_mean = chromas.mean(axis=0)
- # variance in chroma is E[a^2 + b^2] - E[sqrt(a^2 + b^2)]^2
- # max(0, x) is present to deal with floating point error for extremely low chroma images
- # a more robust solution could use log space instead but this is fine for this dataset
- chroma_dev = math.sqrt(
- max(0, raw_second_moment[1] + raw_second_moment[2] - (chroma_mean ** 2)))
- # hue = atan2(b, a), but we need a circular mean
- # https://en.wikipedia.org/wiki/Circular_mean#Definition
- # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma
- # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma
- # and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean
- hue = math.atan2(*(pixels[:, [2, 1]] / chromas[:, np.newaxis]).sum(axis=0))
- # approximation of geometric median, primarily for display purposes
- median = geometric_median(pixels)
- return Stats(
- centroid=list(np.round(centroid, output_precision)),
- median=list(np.round(median, output_precision)),
- stddev=list(np.round(stddev, output_precision)),
- tilt=list(np.round(tilt, output_precision)),
- chroma=[
- round(chroma_mean, output_precision),
- round(chroma_dev, output_precision)
- ],
- hue=round(hue % (2 * math.pi), output_precision),
- size=len(pixels),
- centroidHex=oklab2hex(centroid),
- medianHex=oklab2hex(median)
- )
- def calc_clusters(pixels: np.array, output_precision: int, cluster_attempts=5, seed=0) -> list[Stats]:
- means, labels = min(
- (
- # Try k = 2, 3, and 4, and try a few times for each
- vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
- for k in (2, 3, 4)
- for i in range(cluster_attempts)
- ),
- key=lambda c:
- # Evaluate clustering by seeing the average difference in hue angle
- # between the centers. Maximizing this means the clusters are highly
- # distinct, which gives a sense of which k was best.
- # This is computed by normalizing the ab-plane projection of the means,
- # then applying a dot product to get the cosine of the angle
- # between them in that plane, which is the hue difference. Minimizing
- # this maximizes the differences in hues.
- # A different clustering algorithm may be more suited here, but this
- # is comparatively cheap while still producing reasonable results.
- (np.array([
- m1 @ m2
- for m1, m2 in combinations(
- c[0][:, 1:] / np.linalg.norm(c[0][:, 1:], axis=1)[:, np.newaxis], 2
- )
- ])).mean(axis=0)
- )
- return [calc_statistics(pixels[labels == i], output_precision) for i in range(len(means))]
- def get_srgb_pixels(img: Image.Image) -> np.array:
- rgb = []
- for fr in range(getattr(img, "n_frames", 1)):
- img.seek(fr)
- rgb += [
- [r, g, b]
- for r, g, b, a in img.convert("RGBA").getdata()
- if a > 0 and (r, g, b) != (0, 0, 0)
- ]
- return np.array(rgb)
- def log(*a, **kw):
- print(*a, **kw, flush=True)
- class Ingester:
- def __init__(self, dex: dict, output_precision: int, seed: int) -> None:
- self.lookup = {
- form["name"]: {
- "num": pkmn["num"],
- "species": pkmn["species"],
- **form,
- }
- for pkmn in dex.values()
- for form in pkmn["forms"]
- }
- self.seed = seed
- self.output_precision = output_precision
- def __call__(self, args: tuple[str, list[str]]) -> dict | Exception:
- form, filenames = args
- log(f"Ingesting {form}...")
- start_time = time.time()
- try:
- all_pixels = np.concatenate([
- get_srgb_pixels(Image.open(fn)) for fn in filenames
- ])
- except Exception as e:
- log(f"Error loading images for {form}: {e}")
- return e
- try:
- oklab = srgb2oklab(all_pixels)
- conv_time = time.time()
- total = calc_statistics(oklab, self.output_precision)
- calc_time = time.time()
- clusters = calc_clusters(oklab, self.output_precision, seed=self.seed)
- cluster_time = time.time()
- except Exception as e:
- log(f"Error calculating statistics for {form}: {e}")
- return e
- log(
- f"Completed {form}: ",
- f"{(cluster_time - start_time):.02f}s in total,",
- f"{(cluster_time - calc_time):.02f}s on clustering,",
- f"{(calc_time - conv_time):.02f}s on total calcs,",
- f"{(conv_time - start_time):.02f}s on read and conversion,",
- f"median {total.medianHex} and {len(clusters)} clusters"
- )
- return {
- **self.lookup[form],
- "total": asdict(total),
- "clusters": [asdict(c) for c in clusters],
- }
- def output_db(results: list[dict], db_file: str):
- if db_file == "-":
- log(json.dumps(results, indent=2))
- return
- with open(db_file, "w") as output:
- output.write("const database = [\n")
- for entry in results:
- output.write(" ")
- output.write(json.dumps(entry))
- output.write(",\n")
- output.write("]\n")
- def run_ingest(ingest: Ingester, filenames: list[Path], exec: Executor, db_file: str):
- to_process = defaultdict(list)
- missing = []
- for path in filenames:
- if path.is_file():
- form_name = path.name.rsplit("-", 1)[0]
- to_process[form_name].append(path)
- else:
- missing.append(path)
- log(f"Missing file: {path}")
- start = time.time()
- results = list(exec.map(ingest, to_process.items()))
- end = time.time()
- success = [r for r in results if not isinstance(r, Exception)]
- errors = [e for e in results if isinstance(e, Exception)]
- log(
- f"Finished ingest of {len(to_process)} forms",
- f"and {sum(len(fns) for fns in to_process.values())} files",
- f"in {(end - start):.2f}s",
- f"with {len(missing)} missing file(s)",
- f"and {len(errors)} error(s)"
- )
- for e in errors:
- log(f"Error: {e}")
- for m in missing:
- log(f"Missing: {m}")
- output_db(sorted(success, key=lambda e: (e["num"], e["name"])), db_file)
- log(f"Output {len(success)} entries to {db_file}")
- if __name__ == "__main__":
- from argparse import ArgumentParser
- parser = ArgumentParser(
- prog="Image Analyzer",
- description="Analyze and summarize images based on color",
- )
- parser.add_argument(
- "-p", "--precision", type=int, default=8, help="Round output to this many decimal places"
- )
- parser.add_argument(
- "-s", "--seed", type=int, default=230308, help="Clustering seed"
- )
- parser.add_argument(
- "-w", "--workers", type=int, default=4, help="Worker process count"
- )
- parser.add_argument(
- "-o", "--output", default="data/latest.db", help="Database file"
- )
- parser.add_argument(
- "-d", "--pokedex", default="data/pokedex.json", help="Pokedex file"
- )
- parser.add_argument(
- "--threading", action="store_true", help="Use threads instead of multiproc (slower but more stable on 3.10)"
- )
- parser.add_argument("images", metavar="file", type=Path, nargs="+")
- args = parser.parse_args()
- with open(args.pokedex) as infile:
- dex = json.load(infile)
- ingest = Ingester(dex, args.precision, args.seed)
- from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
- with (ThreadPoolExecutor if args.threading else ProcessPoolExecutor)(max_workers=args.workers) as pool:
- run_ingest(ingest, args.images, pool, args.output)
|