Model jako API (w R)

Budujemy model w R. Daje nawet dobre wyniki. Ale jak go użyć produkcyjnie? Gdyby tak pytać o wynik korzystając z API?
Dzisiaj o tym jak przygotować w R web-serwis, czyli narzędzie które zapytane (w odpowiedni sposób) przez internet zwróci nam jakiś wynik. W naszym przypadku odpowiedzią będzie wynik działania modelu.
Zadanie podzielimy na kilka etapów:
- przygotowanie modelu
- przygotowanie funkcji, które będą używać naszego modelu
- web-service korzystający z tych funkcji
Przygotowanie modelu
Na warsztat weźmiemy bardzo prosty przykład i prosty model. Będziemy przewidywać gatunek irysa na podstawie rozmiarów płatka (petal) i kielicha (sepal) – czyli wykorzystamy popularny w uczeniu maszynowym zbiór iris.
Model też nie będzie wydumany. Na tym zbiorze sprawdza się las losowy, więc bez większych kombinacji po prostu takiego modelu użyjemy. W R to banalne:
library(randomForest)
# budujemy model
model <- randomForest(Species ~ ., data = iris)
# zapisujemy go w postaci pliku
saveRDS(model, "model_rf.RDS")
I to wszystko – model jest gotowy. Oczywiście na inne potrzeby może bardziej skomplikowany model, z przetworzeniem danych wejściowych, szukaniem dodatkowych cech i tak dalej – ale nie o tym jest ten wpis.
Przygotowanie predykcji
Teraz czas na przygotowanie funkcji, która zwróci wynik działania modelu. Jako parametry podamy rozmiary płatka i kielicha, w odpowiedzi oczekujemy gatunku irysa.
# wczytujemy wytrenowany model
model <- readRDS("model_rf.RDS")
predict_class <- function(sl, sw, pl, pw) {
# przygotowanie danych na które odpowie model
new_df <- data.frame(
Sepal.Length = sl,
Sepal.Width = sw,
Petal.Length = pl,
Petal.Width = pw
)
# predykcja
pred <- as.character(predict(model, new_df))
return(pred)
}
Sprawdźmy jak nasz model odpowiada na przykładowe zapytanie:
predict_class(5.1, 3.5, 1.4, 0.2)
## [1] "setosa"
I znowu – to wszystko. Teraz wystarczy to ubrać w API.
Przygotowanie web-service'u
Z pomocą przychodzi pakiet plumber, który właściwie całą pracę wykona za nas. O ile mu pomożemy.W pierwszej kolejności budujemy skrypt (w R), w którym zdefiniujemy funkcje, jakie będą dostępne przez API.
Zacznijmy od prostego przykładu: przepisanie podanego jako argument tekstu. Odpowiedni kod rozumiany przez plumber to:
#* Wypisanie tekstu
#* @param msg Tekst do wypisania
#* @get /echo
function(msg = "") {
list(msg = paste0("Podany tekst to: '", msg, "'"))
}
W pierwszej linii opisujemy krótko co funkcja (metoda API) robi.
W drugiej (i ewentualnie kolejnych) – jakie są parametry (ich nazwy) i do czego służą.
Trzecia linia mówi o sposobie pobrania parametrów (metoda HTTP: w tym przypadku GET, ale może by POST) oraz pod jakim adresem dostępna jest metoda.
Kolejne linie to już sama funkcja. Zwróćcie uwagę, że nie musimy już podawać jej nazwy.
W powyższym przykładzie, mając API uruchomione na adresie http://serwer.api możemy wywołać naszą funkcję poprzez zapytanie:
http://serwer.api/echo?msg=tekst bedacy parametrem
co w odpowiedzi da nam:
{"msg":["Podany tekst to: 'tekst bedacy parametrem'"]}
Jak widać jest to JSON.
Możemy teraz zbudować funkcję, która odpowie nam wynikiem z modelu. Nieco inaczej zdefiniujemy sposób podania jej parametrów poprzez API:
#* Predict Class
#* @get /irisclass/<sl>/<sw>/<pl>/<pw>
function(sl, sw, pl, pw) {
new_df <- tibble(
Sepal.Length = sl,
Sepal.Width = sw,
Petal.Length = pl,
Petal.Width = pw
)
pred <- as.character(predict(model, new_df))
return(list(Sepal.Length = sl,
Sepal.Width = sw,
Petal.Length = pl,
Petal.Width = pw,
Spices = pred))
}
Zwróćcie uwagę, że nie wczytujemy w funkcji zapisanego modelu – zrobimy to globalnie, aby zrobić to tylko raz przy uruchomieniu web service’u, a nie za każdym wywołaniem funkcji z API. Bo przecież model może by duży, a API powinno odpowiadać szybko.
W powyższym przypadku podanie parametrów wygląda inaczej niż wcześniej – podajemy je jako element ścieżki (URLa) zapytania. Wywołanie wyglądać będzie zatem:
http://serwer.api/irisclass/5.1/3.5/1.4/0.2
a odpowiedź:
{"Sepal.Length":["5.1"],"Sepal.Width":["3.5"],"Petal.Length":["1.4"],"Petal.Width":["0.2"],"Spices":["setosa"]}
Dostajemy więc ten sam wynik, jaki dostaliśmy w przypadku użycia funkcji predict_class()
(podając oczywiście te same parametry). Wszystko zatem działa.
Uruchomienie web-service'u
Czas na uruchomienie web service’u. W tym właśnie miejscu pojawia się plumber. Wszystkie przygotowane funkcje pakujemy do jednego skryptu, niech to będzie plik plumbel.R:
# plumber.R
library(tidyverse)
library(randomForest)
# wczytanie modelu
model <- readRDS("model_rf.RDS")
#* Wypisanie tekstu
#* @param msg Tekst do wypisania
#* @get /echo
function(msg = "") {
list(msg = paste0("Podany tekst to: '", msg, "'"))
}
#* Predict Class
#* @get /irisclass/<sl>/<sw>/<pl>/<pw>
function(sl, sw, pl, pw) {
new_df <- tibble(
Sepal.Length = sl,
Sepal.Width = sw,
Petal.Length = pl,
Petal.Width = pw
)
pred <- as.character(predict(model, new_df))
return(list(Sepal.Length = sl,
Sepal.Width = sw,
Petal.Length = pl,
Petal.Width = pw,
Spices = pred))
}
#* Plot out data from the iris dataset
#* @serializer contentType list(type='image/png')
#* @param spec If provided, filter the data to only this species (e.g. 'setosa')
#* @get /irisplot/<spec>
function(spec) {
myData <- iris
title <- "All Species"
# Filter if the species was specified
if (!missing(spec) ) {
title <- paste0("Only the '", spec, "' Species")
myData <- iris %>%
filter(Species == spec)
}
p <- ggplot(myData) +
geom_point(aes(Sepal.Length, Petal.Length)) +
labs(title = title, x = "Sepal Length", y = "Petal Length")
tmp <- tempfile()
ggsave(filename = tmp, plot = p, device = "png")
readBin(tmp, 'raw', n = file.info(tmp)$size)
}
Powyższy skrypt definiuje nam metody dostępne przez API. W tym przypadku mamy trzy:
- /echo
- /irisclass
- /irisplot(o niej za chwilę)
Drugi element to uruchomienie web service’u. Wystarczą trzy linijki kodu:
library(plumber)
r <- plumb("plumber.R")
r$run(host = "0.0.0.0", port = 8080, swagger = FALSE) # domyślnie host = "127.0.0.1"
Host “0.0.0.0” pozwala nam na obsługę zapytań spływających z internetu. Domyślnie (bez tego parametru) obsługiwane są tylko zapytania z lokalnej maszyny. Możesz więc pominą ten parametr, uruchomić powyższy kod i w przeglądarce wpisać:
http://localhost:8080/irisplot/setosa
powinieneś w odpowiedzi dostać wykres:
Te trzy linijki kodu możemy oczywiście zapisać do pliku (np. “run_api.R”) i uruchomić taki skrypt przez
Rscript run_api.R
Dopóki nie przerwiemy wykonywania skryptu nasze API będzie działało.
Wróćmy do /irisplot – co to robi? Rysuje wykres i zwraca go jako obrazek. Czytelnicy widzą zapewne, że jest to fragment zbioru iris (konkretny gatunek irysów). Różnica tej metody to inny serializer zdefiniowany w linii:
#* @serializer contentType list(type='image/png')
Funkcja przygotowuje obrazek, zapisuje go do tymczasowego pliku, a później zwraca ten plik jako ciąg bajtów (poprzez readBin()), bez żadnych dodatkowych informacji jednocześnie mówiąc przeglądarce (w nagłówku), że to obrazek w formacie png.
Na początek budowania API w R to powinno wystarczyć. Więcej informacji znajdziecie w dokumentacji pakietu plumber.