RubixML: основы

Продолжаю тему машинного обучения, начатую здесь и здесь. Поскольку проект PHP-ML, похоже, почил в бозе, а иметь технологии машинного обучения на php иногда хочется, я оглянулся вокруг в поисках альтернативы. И нашёл чудесный проект RubixML. Будем осваивать.

Задача та же самая: научиться классифицировать части имени (отличить друг от друга фамилию, имя и отчество). В качестве датасета - текстовый файл такого формата:

Миклухо-Маклай:surname
Ломоносов:surname
Менделеев:surname
Николай:name
Михаил:name
Дмитрий:name
Николаевич:patronymic
Васильевич:patronymic
Иванович:patronymic

Разделителем значения и метки служит двоеточие. Обучение модели лучше проводить на локальном компьютере, мне на Debian 12 было достаточно установить php-cli:

sudo apt install php8.2-cli

Теперь нужно скачать RubixML через Composer:

php composer.phar require rubix/ml

Приступаем к обучению. Код для обучения предиктивной модели:

train.php

<?php

require_once 'vendor/autoload.php';

use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Extractors\CSV;
use Rubix\ML\PersistentModel;
use Rubix\ML\Pipeline;
use Rubix\ML\Transformers\WordCountVectorizer;
use Rubix\ML\Tokenizers\NGram;
use Rubix\ML\Transformers\TfIdfTransformer;
use Rubix\ML\Classifiers\SoftmaxClassifier;
use Rubix\ML\Persisters\Filesystem;
use Rubix\ML\CrossValidation\Metrics\Accuracy;

//Загрузка датасета
$dataset = Labeled::fromIterator(new CSV(path: 'names.txt', delimiter: ':'));

//Разделение датасета на обучающий и тестовый
[$training, $testing] = $dataset->randomize()->stratifiedSplit(0.85);

//Проектирование модели
$estimator = new PersistentModel(
new Pipeline([
new WordCountVectorizer(tokenizer: new NGram(1, 3)),
new TfIdfTransformer()
], new SoftmaxClassifier(epochs: 1000000, batchSize: 1)),
new Filesystem('model.rbx')
);

//Обучение модели
$estimator->train($training);

//Сохранение модели
$estimator->save();

//Проверка модели
$predictions = $estimator->predict($testing);

//Создание метрики
$metric = new Accuracy();

//Вычисление оценки
$score = $metric->score($predictions, $testing->labels());

//Печать оценки
print $score . "\n";
?>

Я перебрал все классификаторы, предлагаемые проектом RubixML и остановился на SoftmaxClassifier как на дающем наилучшие результаты в данном случае. В принципе, я запускал его и со стандартными параметрами (batchSize = 256 и epochs = 1000), но выбранные мною в итоге параметры дают чуть более точный результат предсказаний.

Чтобы воспользоваться обученной моделью, можно использовать такой код:

predict.php

<?php

require_once 'vendor/autoload.php';

use Rubix\ML\PersistentModel;
use Rubix\ML\Persisters\Filesystem;
use Rubix\ML\Datasets\Unlabeled;

//Загрузка обученной модели
$model = PersistentModel::load(new Filesystem('model.rbx'));

//Подготовка термов для классификации
$dataset_1 = new Unlabeled(['Владимир', 'Михайлович', 'Бехтерев']);

//Простое предсказание
print_r($model->predict($dataset_1));

//Подготовка термов для классификации
$dataset_2 = new Unlabeled(['Александр', 'Степанович', 'Попов']);

//Предсказание с оценкой точности
print_r($model->proba($dataset_2));
?>

Результат выполнения:

Array
(
[0] => name
[1] => patronymic
[2] => surname
)
Array
(
[0] => Array
(
[patronymic] => 0.00030342071279194
[name] => 0.99907243550825
[surname] => 0.00062414377895908
)

[1] => Array
(
[patronymic] => 0.99262299047189
[name] => 0.0019626332334701
[surname] => 0.0054143762946424
)

[2] => Array
(
[patronymic] => 0.0013845917377264
[name] => 0.0012370582966856
[surname] => 0.99737834996559
)

)

В первом случае модель выдаст наиболее вероятные метки, во втором - набор всех меток с оценкой вероятности каждой из них для данного терма.

RubixML ощутимо мощнее PHP-ML, но есть и нюансы. Например, логистическая регрессия, которую я применял в предыдущих статьях, здесь требует только двух классов, а у нас три класса: name, surname, patronymic. Поэтому пришлось применить SoftmaxClassifier, способный работать с большим числом меток. Кроме того, как я ни старался, получить модель с точностью выше 95% мне не удалось, тогда как модель, обученная в PHP-ML, имела точность 98%. И ещё одно: похоже, функции predict() и proba() модифицируют входящий массив термов, так что использовать один и тот же $dataset последовательно в predict() и proba() не получится, именно поэтому в моём примере есть $dataset_1 и $dataset_2.

А в целом, инструмент хороший.
2025-01-18