Kunst verkennen in cultuur en gemiddeld met snelle, voorwaardelijke, k-dichtstbijzijnde buren
Dit artikel fungeert als richtlijn voor match-finding via k-dichtstbijzijnde buren. U stelt code in waarmee query's met betrekking tot culturen en mediums kunst uit het Metropolitan Museum of Art in NYC en het Rijksmuseum in Amsterdam mogelijk zijn.
Vereisten
- Koppel uw notitieblok aan een lakehouse. Selecteer aan de linkerkant Toevoegen om een bestaand lakehouse toe te voegen of een lakehouse te maken.
Overzicht van de BallTree
De structuur die achter het KNN-model functioneert, is een BallTree, een recursieve binaire structuur waarin elk knooppunt (of "bal") een partitie bevat van de punten van gegevens waarop een query moet worden uitgevoerd. Het bouwen van een BallTree omvat het toewijzen van gegevenspunten aan de "bal" waarvan het midden het dichtst bij hen ligt (met betrekking tot een bepaalde opgegeven functie), wat resulteert in een structuur die binair-boomachtige traversal toestaat en zichzelf leent voor het vinden van k-dichtstbijzijnde buren op een BallTree-blad.
Instellingen
Importeer de benodigde Python-bibliotheken en bereid de gegevensset voor.
from synapse.ml.core.platform import *
if running_on_binder():
from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO
import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()
Onze gegevensset is afkomstig uit een tabel met kunstwerkinformatie van zowel de musea Met als Rijks. Het -schema is als volgt:
- id: Een unieke id voor een kunstwerk
- Voorbeeld van met-id: 388395
- Voorbeeld rijks-id: SK-A-2344
- Titel: Titel van kunstwerk, zoals geschreven in de database van het museum
- Kunstenaar: Kunststukkunstenaar, zoals geschreven in de database van het museum
- Thumbnail_Url: Locatie van een JPEG-miniatuur van het kunstwerk
- Image_Url Locatie van een afbeelding van het kunstwerk dat wordt gehost op de website van Met/Rijks
- Cultuur: Categorie cultuur onder het kunststuk
- Voorbeeldcultuurcategorieën: Latijns-Amerikaans, Egyptisch, enz.
- Classificatie: Categorie van medium waarvan het kunststuk onder valt
- Voorbeeld medium categorieën: houtwerk, schilderijen, enz.
- Museum_Page: Link naar het kunstwerk op de website van Met/Rijks
- Norm_Features: Insluiten van de afbeelding van het kunstwerk
- Museum: Geeft aan uit welk museum het stuk afkomstig is
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
"wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))
Categorieën definiëren waarop een query moet worden uitgevoerd
Er worden twee KNN-modellen gebruikt: één voor cultuur en één voor medium.
# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs', "metalwork",
# "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]
mediums = ["paintings", "glass", "ceramics"]
# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
# 'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
# 'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
# 'spanish', 'swiss', 'various']
cultures = ["japanese", "american", "african (general)"]
# Uncomment the above for more robust and large scale searches!
classes = cultures + mediums
medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}
small_df = df.where(
udf(
lambda medium, culture, id_val: (medium in medium_set)
or (culture in culture_set)
or (id_val in selected_ids),
BooleanType(),
)("Classification", "Culture", "id")
)
small_df.count()
ConditionalKNN-modellen definiëren en aanpassen
VoorwaardelijkeKNN-modellen maken voor zowel de medium- als cultuurkolommen; elk model neemt een uitvoerkolom, kenmerken kolom (functievector), waardenkolom (celwaarden onder de uitvoerkolom) en labelkolom (de kwaliteit waarop de respectieve KNN is geconditioneerd).
medium_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Classification")
.fit(small_df)
)
culture_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Culture")
.fit(small_df)
)
Overeenkomende methoden definiëren en visualiseren
Nadat de eerste gegevensset en categorie zijn ingesteld, bereidt u methoden voor waarmee de resultaten van de voorwaardelijke KNN worden opgevraagd en gevisualiseerd.
addMatches()
maakt een Dataframe met een handvol overeenkomsten per categorie.
def add_matches(classes, cknn, df):
results = df
for label in classes:
results = cknn.transform(
results.withColumn("conditioner", array(lit(label)))
).withColumnRenamed("Matches", "Matches_{}".format(label))
return results
plot_urls()
aanroepen plot_img
om de belangrijkste overeenkomsten voor elke categorie in een raster te visualiseren.
def plot_img(axis, url, title):
try:
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
axis.imshow(img, aspect="equal")
except:
pass
if title is not None:
axis.set_title(title, fontsize=4)
axis.axis("off")
def plot_urls(url_arr, titles, filename):
nx, ny = url_arr.shape
plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
fig, axes = plt.subplots(ny, nx)
# reshape required in the case of 1 image query
if len(axes.shape) == 1:
axes = axes.reshape(1, -1)
for i in range(nx):
for j in range(ny):
if j == 0:
plot_img(axes[j, i], url_arr[i, j], titles[i])
else:
plot_img(axes[j, i], url_arr[i, j], None)
plt.savefig(filename, dpi=1600) # saves the results as a PNG
display(plt.show())
Alles samenvoegen
Definieer test_all()
om de gegevens, CKNN-modellen, de waarden voor de kunst-id op te vragen en het bestandspad om de uitvoervisualisatie op te slaan. De medium- en cultuurmodellen werden eerder getraind en geladen.
# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png
def test_all(data, cknn_medium, cknn_culture, test_ids, root):
is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
test_df = data.where(is_nice_obj("id"))
results_df_medium = add_matches(mediums, cknn_medium, test_df)
results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)
results = results_df_culture.collect()
original_urls = [row["Thumbnail_Url"] for row in results]
culture_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in cultures
]
culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")
medium_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in mediums
]
medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")
return results_df_culture
Demo
De volgende cel voert batchquery's uit op de gewenste afbeeldings-id's en een bestandsnaam om de visualisatie op te slaan.
# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")