Delen via


Kunst verkennen in cultuur en gemiddeld met snelle, voorwaardelijke, k-dichtstbijzijnde buren

Dit artikel fungeert als richtlijn voor match-finding via k-dichtstbijzijnde buren. U stelt code in waarmee query's met betrekking tot culturen en mediums kunst uit het Metropolitan Museum of Art in NYC en het Rijksmuseum in Amsterdam mogelijk zijn.

Vereisten

  • Koppel uw notitieblok aan een lakehouse. Selecteer aan de linkerkant Toevoegen om een bestaand lakehouse toe te voegen of een lakehouse te maken.

Overzicht van de BallTree

De structuur die achter het KNN-model functioneert, is een BallTree, een recursieve binaire structuur waarin elk knooppunt (of "bal") een partitie bevat van de punten van gegevens waarop een query moet worden uitgevoerd. Het bouwen van een BallTree omvat het toewijzen van gegevenspunten aan de "bal" waarvan het midden het dichtst bij hen ligt (met betrekking tot een bepaalde opgegeven functie), wat resulteert in een structuur die binair-boomachtige traversal toestaat en zichzelf leent voor het vinden van k-dichtstbijzijnde buren op een BallTree-blad.

Instellingen

Importeer de benodigde Python-bibliotheken en bereid de gegevensset voor.

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Onze gegevensset is afkomstig uit een tabel met kunstwerkinformatie van zowel de musea Met als Rijks. Het -schema is als volgt:

  • id: Een unieke id voor een kunstwerk
    • Voorbeeld van met-id: 388395
    • Voorbeeld rijks-id: SK-A-2344
  • Titel: Titel van kunstwerk, zoals geschreven in de database van het museum
  • Kunstenaar: Kunststukkunstenaar, zoals geschreven in de database van het museum
  • Thumbnail_Url: Locatie van een JPEG-miniatuur van het kunstwerk
  • Image_Url Locatie van een afbeelding van het kunstwerk dat wordt gehost op de website van Met/Rijks
  • Cultuur: Categorie cultuur onder het kunststuk
    • Voorbeeldcultuurcategorieën: Latijns-Amerikaans, Egyptisch, enz.
  • Classificatie: Categorie van medium waarvan het kunststuk onder valt
    • Voorbeeld medium categorieën: houtwerk, schilderijen, enz.
  • Museum_Page: Link naar het kunstwerk op de website van Met/Rijks
  • Norm_Features: Insluiten van de afbeelding van het kunstwerk
  • Museum: Geeft aan uit welk museum het stuk afkomstig is
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Categorieën definiëren waarop een query moet worden uitgevoerd

Er worden twee KNN-modellen gebruikt: één voor cultuur en één voor medium.

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

ConditionalKNN-modellen definiëren en aanpassen

VoorwaardelijkeKNN-modellen maken voor zowel de medium- als cultuurkolommen; elk model neemt een uitvoerkolom, kenmerken kolom (functievector), waardenkolom (celwaarden onder de uitvoerkolom) en labelkolom (de kwaliteit waarop de respectieve KNN is geconditioneerd).

medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Overeenkomende methoden definiëren en visualiseren

Nadat de eerste gegevensset en categorie zijn ingesteld, bereidt u methoden voor waarmee de resultaten van de voorwaardelijke KNN worden opgevraagd en gevisualiseerd.

addMatches() maakt een Dataframe met een handvol overeenkomsten per categorie.

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() aanroepen plot_img om de belangrijkste overeenkomsten voor elke categorie in een raster te visualiseren.

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Alles samenvoegen

Definieer test_all() om de gegevens, CKNN-modellen, de waarden voor de kunst-id op te vragen en het bestandspad om de uitvoervisualisatie op te slaan. De medium- en cultuurmodellen werden eerder getraind en geladen.

# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png


def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Demo

De volgende cel voert batchquery's uit op de gewenste afbeeldings-id's en een bestandsnaam om de visualisatie op te slaan.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")