Träna modeller med PyTorch
PyTorch är ett vanligt maskininlärningsramverk för att träna djupinlärningsmodeller. I Azure Databricks är PyTorch förinstallerat i ML-kluster .
Kommentar
Kodfragmenten i den här lektionen tillhandahålls som exempel för att framhäva viktiga punkter. Du får chansen att köra kod för ett fullständigt, fungerande exempel i övningen senare i den här modulen.
Definiera ett PyTorch-nätverk
I PyTorch baseras modellerna på ett nätverk som du definierar. Nätverket består av flera lager, var och en med angivna indata och utdata. Dessutom definierar arbetet en framåtriktad funktion som tillämpar funktioner på varje lager när data skickas via nätverket.
Följande exempelkod definierar ett nätverk.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.layer1 = nn.Linear(4, 5)
self.layer2 = nn.Linear(5, 5)
self.layer3 = nn.Linear(5, 3)
def forward(self, x):
layer1_output = torch.relu(self.layer1(x))
layer2_output = torch.relu(self.layer2(layer1_output))
y = self.layer3(layer2_output)
return y
Koden kan verka komplex till en början, men den här klassen definierar ett relativt enkelt nätverk med tre lager:
- Ett indatalager som accepterar fyra indatavärden och genererar fem utdatavärden för nästa lager.
- Ett lager som accepterar fem indata och genererar fem utdata.
- Ett slutligt utdatalager som accepterar fem indata och genererar tre utdata.
Funktionen forward tillämpar lagren på indata (x), skickar utdata från varje lager till nästa och returnerar slutligen utdata från det sista lagret (som innehåller etikettens förutsägelsevektor, y). En rektifierad aktiveringsfunktion för linjär enhet (ReLU) tillämpas på utdata från lager 1 och 2 för att begränsa utdatavärdena till positiva tal.
Kommentar
Beroende på vilken typ av förlustvillkor som används kan du välja att tillämpa en aktiveringsfunktion, till exempel en log_softmax på returvärdet för att tvinga det till intervallet 0 till 1. Vissa förlustkriterier (till exempel CrossEntropyLoss, som ofta används för multiklassklassificering) tillämpar dock automatiskt en lämplig funktion.
Om du vill skapa en modell för träning behöver du bara skapa en instans av nätverksklassen så här:
myModel = MyNet()
Förbereda data för modellering
PyTorch-lager fungerar på data som är formaterade som tensorer – matrisliknande strukturer. Det finns olika funktioner för att konvertera andra vanliga dataformat till tensorer, och du kan definiera en PyTorch-datainläsare för att läsa data tensorer till en modell för träning eller slutsatsdragning.
Precis som med de flesta övervakade maskininlärningstekniker bör du definiera separata datauppsättningar för träning och validering. Med den här separationen kan du verifiera att modellen förutsäger korrekt när den presenteras med data som den inte har tränats på.
Följande kod definierar två datainläsare. en för träning och den andra för testning. Källdata för varje inläsare i det här exemplet antas vara en Numpy-matris med funktionsvärden och en Numpy-matris med motsvarande etikettvärden.
# Create a dataset and loader for the training data and labels
train_x = torch.Tensor(x_train).float()
train_y = torch.Tensor(y_train).long()
train_ds = td.TensorDataset(train_x,train_y)
train_loader = td.DataLoader(train_ds, batch_size=20,
shuffle=False, num_workers=1)
# Create a dataset and loader for the test data and labels
test_x = torch.Tensor(x_test).float()
test_y = torch.Tensor(y_test).long()
test_ds = td.TensorDataset(test_x,test_y)
test_loader = td.DataLoader(test_ds, batch_size=20,
shuffle=False, num_workers=1)
Inläsarna i det här exemplet delar upp data i batchar med 30, som skickas till funktionen forward under träning eller slutsatsdragning.
Välj ett förlustvillkor och optimeringsalgoritm
Modellen tränas genom att mata in träningsdata i nätverket, mäta förlusten (den aggregerade skillnaden mellan förutsagda och faktiska värden) och optimera nätverket genom att justera vikter och saldon för att minimera förlusten. Den specifika informationen om hur förlust beräknas och minimeras styrs av det förlustvillkor och den optimeraralgoritm som du väljer.
Förlustvillkor
PyTorch stöder flera funktioner för förlustkriterier, inklusive (bland många andra):
- cross_entropy: En funktion som mäter den aggregerade skillnaden mellan förutsagda och faktiska värden för flera variabler (används vanligtvis för att mäta förlust för klassannolikheter i klassificering med flera klasser).
- binary_cross_entropy: En funktion som mäter skillnaden mellan förutsagda och faktiska sannolikheter (används vanligtvis för att mäta förlust för klassannolikheter i binär klassificering).
- mse_loss: En funktion som mäter den genomsnittliga kvadratfelförlusten för förutsagda och faktiska numeriska värden (används vanligtvis för regression).
Om du vill ange det förlustvillkor som du vill använda när du tränar din modell skapar du en instans av lämplig funktion. Gillar det här:
import torch.nn as nn
loss_criteria = nn.CrossEntropyLoss
Dricks
Mer information om tillgängliga förlustkriterier i PyTorch finns i Förlustfunktioner i PyTorch-dokumentationen.
Optimeraralgoritmer
Efter att ha beräknat förlusten används en optimerare för att avgöra hur du bäst justerar vikter och saldon för att minimera den. Optimerare är specifika implementeringar av en gradient descent-metod för att minimera en funktion. Tillgängliga optimerare i PyTorch inkluderar (bland annat):
- Adadelta: En optimering baserad på algoritmen för anpassningsbar inlärningsfrekvens .
- Adam: En beräkningseffektiv optimerare baserat på Adam-algoritmen.
- SGD: En optimerare baserat på algoritmen för stokastisk gradient.
Om du vill använda någon av dessa algoritmer för att träna en modell måste du skapa en instans av optimeraren och ange nödvändiga parametrar. De specifika parametrarna varierar beroende på vilken optimerare som valts, men de flesta kräver att du anger en inlärningsfrekvens som styr storleken på de justeringar som görs med varje optimering.
Följande kod skapar en instans av Adam-optimeraren .
import torch.optim as opt
learning_rate = 0.001
optimizer = opt.Adam(model.parameters(), lr=learning_rate)
Dricks
Mer information om tillgängliga optimerare i PyTorch finns i Algoritmer i PyTorch-dokumentationen.
Skapa tränings- och testfunktioner
När du har definierat ett nätverk och förberett data för det kan du använda data för att träna och testa en modell genom att skicka träningsdata via nätverket, beräkna förlusten, optimera nätverksvikter och fördomar och verifiera nätverkets prestanda med testdata. Det är vanligt att definiera en funktion som skickar data via nätverket för att träna modellen med träningsdata och en separat funktion för att testa modellen med testdata.
Skapa en träningsfunktion
I följande exempel visas en funktion för att träna en modell.
def train(model, data_loader, optimizer):
# Use GPU if available, otherwise CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Set the model to training mode (to enable backpropagation)
model.train()
train_loss = 0
# Feed the batches of data forward through the network
for batch, tensor in enumerate(data_loader):
data, target = tensor # Specify features and labels in a tensor
optimizer.zero_grad() # Reset optimizer state
out = model(data) # Pass the data through the network
loss = loss_criteria(out, target) # Calculate the loss
train_loss += loss.item() # Keep a running total of loss for each batch
# backpropagate adjustments to weights/bias
loss.backward()
optimizer.step()
#Return average loss for all batches
avg_loss = train_loss / (batch+1)
print('Training set: Average loss: {:.6f}'.format(avg_loss))
return avg_loss
I följande exempel visas en funktion för att testa modellen.
def test(model, data_loader):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Switch the model to evaluation mode (so we don't backpropagate)
model.eval()
test_loss = 0
correct = 0
# Pass the data through with no gradient computation
with torch.no_grad():
batch_count = 0
for batch, tensor in enumerate(data_loader):
batch_count += 1
data, target = tensor
# Get the predictions
out = model(data)
# calculate the loss
test_loss += loss_criteria(out, target).item()
# Calculate the accuracy
_, predicted = torch.max(out.data, 1)
correct += torch.sum(target==predicted).item()
# Calculate the average loss and total accuracy for all batches
avg_loss = test_loss/batch_count
print('Validation set: Average loss: {:.6f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avg_loss, correct, len(data_loader.dataset),
100. * correct / len(data_loader.dataset)))
return avg_loss
Träna modellen över flera epoker
För att träna en djupinlärningsmodell kör du vanligtvis träningsfunktionen flera gånger (kallas epoker), med målet att minska förlusten som beräknas från träningsdata varje epok. Du kan använda testfunktionen för att verifiera att förlusten från testdata (där modellen inte tränades) också minskar i linje med träningsförlusten , med andra ord att modellträningen inte producerar en modell som är överanpassad till träningsdata.
Dricks
Du behöver inte köra testfunktionen för varje epok. Du kan välja att köra den varje sekund, eller en gång i slutet. Att testa modellen som den tränas kan dock vara till hjälp när du ska avgöra hur många epoker en modell börjar bli överanpassad.
Följande kod tränar en modell över 50 epoker.
epochs = 50
for epoch in range(1, epochs + 1):
# print the epoch number
print('Epoch: {}'.format(epoch))
# Feed training data into the model to optimize the weights
train_loss = train(model, train_loader, optimizer)
print(train_loss)
# Feed the test data into the model to check its performance
test_loss = test(model, test_loader)
print(test_loss)
Spara det tränade modelltillståndet
När du har tränat en modell kan du spara dess vikter och fördomar så här:
model_file = '/dbfs/my_model.pkl'
torch.save(model.state_dict(), model_file)
Om du vill läsa in och använda modellen vid ett senare tillfälle skapar du en instans av nätverksklassen som modellen baseras på och läser in de sparade vikterna och fördomarna.
model = myNet()
model.load_state_dict(torch.load(model_file))