Нейронные сети в Android, Google ML Kit и не только

Итак, вы разработали и натренировали свою нейронную сеть, для выполнения какой-то задачи (например то же распознавание объектов через камеру) и хотите внедрить ее в свое приложение на андроид? Тогда добро пожаловать под кат!

Для начала, следует понять, что андроид в данный момент умеет работать только с сетями формата TensorFlowLite, а это значит, нам необходимо провести какие-то манипуляции с исходной сетью. Предположим, у вас есть уже обученная сеть на фреймворке Keras или Tensorflow. Необходимо сохранить сетку в формате pb.

Начнем со случая, когда вы пишите на Tensorflow, тогда все чуть проще.

saver = tf.train.Saver() tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False) tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True) saver.save(session,path_to_folder+"model.ckpt")

Если же вы пишете на Keras, нужно в начале файла, где вы обучаете сеть, создать новый объект сессии, сохранить ссылку на него, и передать в функцию set_session

import keras.backend as K  session = K.get_session()  K.set_session(session)

Отлично, вы сохранили сеть, теперь нужно перевести ее в формат tflite. Для этого нам нужно запустить два небольших скрипта, первый «заморозит» сеть, второй уже переведет в нужный формат. Суть «заморозки» в том, что tf не хранит веса слоев в сохраненном файле pb, а сохраняет их в специальных чекпоинтах. Для последующей конвертации в tflite нужно, чтобы вся информация о нейронной сети была в одном файле.

freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt 

Обратите внимание, что вам нужно знать имя выходного тензора. В tensorflow вы можете задавать его сами, в случае использования Keras — задаете имя в конструкторе слоя

model.add(Dense(10,activation="softmax",name="result")) 

В таком случае имя тензора обычно выглядит как «result/Softmax»

Если в вашем случае не так, можете найти имя следующим образом

[print(n.name) for n in session.graph.as_graph_def().node] 

Осталось запустить второй скрипт

toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784 

Ура! Теперь у вас в папке есть модель TensorFlowLite, дело за малым — правильно интегрировать ее в ваше андроид приложение. Вы можете сделать это с помощью новомодного Firebase ML Kit, но есть и другой способ, о нем чуть позже. Добавляем зависимость в наш файл gradle

dependencies {   // ...   implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0' } 

Теперь вам нужно решить, будете ли вы держать модель где-то у себя на сервере, либо поставлять с приложением.

Рассмотрим первый случай: модель на сервере. Первым делом не забываем добавить в манифест

<uses-permission android:name="android.permission.INTERNET" /> 

    // Создаем объект для задания специальных условий, требуемых для загрузки/обновления модели FirebaseModelDownloadConditions.Builder conditionsBuilder =         new FirebaseModelDownloadConditions.Builder().requireWifi(); if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {     conditionsBuilder = conditionsBuilder             .requireCharging(); } FirebaseModelDownloadConditions conditions = conditionsBuilder.build(); // Создаем объект FirebaseCloudModelSource , задаем имя (должно совпадать с именем модели, загруженной // в консоль Firebase) FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("my_cloud_model")         .enableModelUpdates(true)         .setInitialDownloadConditions(conditions)         .setUpdatesDownloadConditions(conditions)         .build(); FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource); 

Если вы используете модель, включенную в приложение локально, не забудьте добавить в build.gradle файл следующую запись для того, чтобы файл модели не сжимался

android {      // ...     aaptOptions {         noCompress "tflite"     } } 

После чего, по аналогии с моделью в облаке, нашу локальную нейронку нужно зарегистрировать.

FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model")         .setAssetFilePath("mymodel.tflite")         .build(); FirebaseModelManager.getInstance().registerLocalModelSource(localSource); 

В коде выше предполагается, что ваша модель лежит в папке assets, если это не так, вместо

        .setAssetFilePath("mymodel.tflite") 

используйте

        .seFilePath(filePath) 

После чего создаем новые объекты FirebaseModelOptions и FirebaseModelInterpreter

FirebaseModelOptions options = new FirebaseModelOptions.Builder()         .setCloudModelName("my_cloud_model")         .setLocalModelName("my_local_model")         .build(); FirebaseModelInterpreter firebaseInterpreter =         FirebaseModelInterpreter.getInstance(options);

Вы можете использовать одновременно как локальную, так и находящуюся на сервере модель. При этом, по умолчанию будет использоваться облачная, если она доступна, в противном случае локальная.

Почти все, осталось создать массивы для входных/выходных данных, и запустить!

FirebaseModelInputOutputOptions inputOutputOptions =     new FirebaseModelInputOutputOptions.Builder()         .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3})         .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 784})         .build();  byte[][][][] input = new byte[1][640][480][3]; input = getYourInputData(); FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()     .add(input)  // add() as many input arrays as your model requires     .build();  Task<FirebaseModelOutputs> result =     firebaseInterpreter.run(inputs, inputOutputOptions)         .addOnSuccessListener(           new OnSuccessListener<FirebaseModelOutputs>() {             @Override             public void onSuccess(FirebaseModelOutputs result) {               // ...             }           })         .addOnFailureListener(           new OnFailureListener() {             @Override             public void onFailure(@NonNull Exception e) {               // Task failed with an exception               // ...             }           });  float[][] output = result.<float[][]>getOutput(0); float[] probabilities = output[0]; 

Если же вы не хотите по каким-то причинам использовать Firebase, есть и другой способ, вызывать интерпретатор tflite и скармливать ему данные напрямую.

Добавляем в build/gradle строчку

    implementation 'org.tensorflow:tensorflow-lite:+' 

Создаем интерпретатор и массивы

          Interpreter  tflite = new Interpreter(loadModelFile(getContext(), "model.tflite")); // создаем массивы и заполняем inputs           tflite.run(inputs,outputs) 

Кода в этом случае значительно меньше, как вы видите.

Вот и все, что нужно для использования вашей нейронной сети в андроид.

Полезные ссылки:

Офф доки по ML Kit
Tensorflow Lite

FavoriteLoadingДобавить в избранное
Posted in Без рубрики

Добавить комментарий

Ваш e-mail не будет опубликован. Обязательные поля помечены *