Esplorazione dell'arte attraverso la cultura e media con vicini veloci, condizionali, k-nearest
Questo articolo funge da linea guida per la ricerca di corrispondenze tramite k-near-neighbors. Si configura il codice che consente di eseguire query che coinvolgono culture e mezzi d'arte accumulati dal Metropolitan Museum of Art di Nyc e dal Rijks museum di Amsterdam.
Prerequisiti
- Collegare il notebook a un lakehouse. Sul lato sinistro, selezionare Aggiungi per aggiungere un lakehouse esistente o creare un lakehouse.
Panoramica di BallTree
La struttura che funziona dietro il modello KNN è un BallTree, che è un albero binario ricorsivo in cui ogni nodo (o "palla") contiene una partizione dei punti di dati su cui eseguire una query. La creazione di un BallTree comporta l'assegnazione di punti dati alla "palla" il cui centro è più vicino (rispetto a una determinata caratteristica specificata), con conseguente struttura che consente l'attraversamento binario simile all'albero binario e si presta a trovare i vicini k più vicini a una foglia di BallTree.
Attrezzaggio
Importare le librerie Python necessarie e preparare il set di dati.
from synapse.ml.core.platform import *
if running_on_binder():
from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO
import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()
Il set di dati proviene da una tabella contenente informazioni sull'opera d'arte provenienti dai musei Met e Rijks. Lo schema è il seguente:
- id: identificatore univoco per un pezzo d'arte
- Id met di esempio: 388395
- Id Rijks di esempio: SK-A-2344
- Titolo: Titolo dell’opera d’arte, come scritto nel database del museo
- Artista: artista del pezzo d’arte, come scritto nel database del museo
- Thumbnail_Url: posizione di un'anteprima JPEG del pezzo d'arte
- Image_Url Posizione di un'immagine del pezzo d'arte ospitato sul sito Web Met/Rijks
- Cultura: categoria di cultura in cui rientra l’opera d’arte
- Categorie cultura di esempio: latino americano, egiziano e così via.
- Classificazione: categoria del medium in cui rientra l’opera d'arte
- Categorie medie di esempio: opere in legno, dipinti e così via.
- Museum_Page: Collegamento all'opera d'arte sul sito Web Met/Rijks
- Norm_Features: Incorporamento dell'immagine dell'arte
- Museo: specifica il museo da cui proviene il pezzo
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
"wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))
Definire le categorie su cui eseguire query
Vengono usati due modelli KNN: uno per le impostazioni cultura e uno per il supporto.
# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs', "metalwork",
# "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]
mediums = ["paintings", "glass", "ceramics"]
# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
# 'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
# 'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
# 'spanish', 'swiss', 'various']
cultures = ["japanese", "american", "african (general)"]
# Uncomment the above for more robust and large scale searches!
classes = cultures + mediums
medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}
small_df = df.where(
udf(
lambda medium, culture, id_val: (medium in medium_set)
or (culture in culture_set)
or (id_val in selected_ids),
BooleanType(),
)("Classification", "Culture", "id")
)
small_df.count()
Definire e adattare i modelli ConditionalKNN
Creare modelli ConditionalKNN per le colonne medie e cultura; ogni modello accetta una colonna di output, una colonna feature (vettore di funzionalità), una colonna di valori (valori di cella nella colonna di output) e una colonna etichetta (la qualità su cui è condizionata la rispettiva chiave KNN).
medium_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Classification")
.fit(small_df)
)
culture_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Culture")
.fit(small_df)
)
Definire metodi di corrispondenza e visualizzazione
Dopo la configurazione iniziale del set di dati e della categoria, preparare i metodi che eseguiranno query e visualizzeranno i risultati del knn condizionale.
addMatches()
crea un dataframe con una manciata di corrispondenze per categoria.
def add_matches(classes, cknn, df):
results = df
for label in classes:
results = cknn.transform(
results.withColumn("conditioner", array(lit(label)))
).withColumnRenamed("Matches", "Matches_{}".format(label))
return results
plot_urls()
chiama plot_img
per visualizzare le corrispondenze principali per ogni categoria in una griglia.
def plot_img(axis, url, title):
try:
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
axis.imshow(img, aspect="equal")
except:
pass
if title is not None:
axis.set_title(title, fontsize=4)
axis.axis("off")
def plot_urls(url_arr, titles, filename):
nx, ny = url_arr.shape
plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
fig, axes = plt.subplots(ny, nx)
# reshape required in the case of 1 image query
if len(axes.shape) == 1:
axes = axes.reshape(1, -1)
for i in range(nx):
for j in range(ny):
if j == 0:
plot_img(axes[j, i], url_arr[i, j], titles[i])
else:
plot_img(axes[j, i], url_arr[i, j], None)
plt.savefig(filename, dpi=1600) # saves the results as a PNG
display(plt.show())
Combinazione delle funzionalità
Definire test_all()
per accettare i dati, i modelli CKNN, i valori di ID immagine su cui eseguire query e il percorso del file in cui salvare la visualizzazione di output. I modelli di media e cultura sono stati precedentemente sottoposti a training e caricati.
# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png
def test_all(data, cknn_medium, cknn_culture, test_ids, root):
is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
test_df = data.where(is_nice_obj("id"))
results_df_medium = add_matches(mediums, cknn_medium, test_df)
results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)
results = results_df_culture.collect()
original_urls = [row["Thumbnail_Url"] for row in results]
culture_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in cultures
]
culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")
medium_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in mediums
]
medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")
return results_df_culture
Demo
La cella seguente esegue query in batch in base agli ID immagine desiderati e un nome file per salvare la visualizzazione.
# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")