V4.1.0a - neural network model runs and prints pSeizure to logcat.

This commit is contained in:
Graham Jones
2022-09-20 11:55:34 +01:00
parent ec704029bd
commit 34eb763ccd
8 changed files with 108 additions and 3 deletions

View File

@@ -57,6 +57,7 @@ import com.github.mikephil.charting.components.YAxis;
import com.github.mikephil.charting.data.BarData;
import com.github.mikephil.charting.data.BarDataSet;
import com.github.mikephil.charting.data.BarEntry;
import com.github.mikephil.charting.utils.ValueFormatter;
import com.rohitss.uceh.UCEHandler;

View File

@@ -0,0 +1,83 @@
package uk.org.openseizuredetector;
import android.content.Context;
import android.util.Log;
import com.android.volley.AuthFailureError;
import com.android.volley.Request;
import com.android.volley.RequestQueue;
import com.android.volley.Response;
import com.android.volley.VolleyError;
import com.android.volley.toolbox.StringRequest;
import com.google.android.gms.tasks.Task;
import com.google.android.gms.tflite.java.TfLite;
//import com.google.android.gms.tflite.java.TfLite;
import org.json.JSONException;
import org.json.JSONObject;
import org.tensorflow.lite.InterpreterApi;
import org.tensorflow.lite.support.common.FileUtil;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.util.HashMap;
import java.util.Map;
public class SdAlgNn {
private final static String TAG = "SdAlgNn";
private final static String MODEL_PATH = "best_model_v0.02.tflite";
private String mUrlBase = "https://osdApi.ddns.net";
private InterpreterApi interpreter;
private Context mContext;
RequestQueue mQueue;
public SdAlgNn(Context context) {
Log.d(TAG, "SdAlgNn Constructor");
mContext = context;
Task<Void> initializeTask = TfLite.initialize(mContext);
initializeTask.addOnSuccessListener(a -> {
MappedByteBuffer modelBuffer;
try {
Log.d(TAG, "onSuccessListener - loading model");
modelBuffer = FileUtil.loadMappedFile(context, MODEL_PATH);
Log.d(TAG, "onSuccessListener - model loaded");
} catch (IOException e) {
Log.e(TAG, "Error Loading Model File");
return;
}
Log.d(TAG, "onSuccessListener - creating interpreter");
interpreter = InterpreterApi.create(modelBuffer,
new InterpreterApi.Options().setRuntime(
InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY));
Log.d(TAG, "onSuccessListener - interpreter created ok");
})
.addOnFailureListener(e -> {
Log.e(TAG, String.format("Cannot initialize interpreter: %s",
e.getMessage()));
});
// FIXME - Add hardware acceleration - see https://www.tensorflow.org/lite/android/play_services
Log.d(TAG, "constructor finished - Note, NOT using hardware acceleration yet!!!!");
}
public void close() {
Log.d(TAG,"close()");
interpreter.close();
}
public int run(SdData sdData) {
int i;
float[][][] modelInput = new float[1][125][1];
float[][] modelOutput = new float[1][2];
for (int j = 0; j < 125; j++) {
modelInput[0][j][0] = (float)sdData.rawData[j];
}
interpreter.run(modelInput, modelOutput);
Log.d(TAG,"run - pSeizure="+modelOutput[0][1]);
if (modelOutput[0][1]>0.5)
return 1;
else
return 0;
}
}

View File

@@ -91,6 +91,7 @@ public abstract class SdDataSource {
private short mFallThreshMax;
private short mFallWindow;
private int mMute; // !=0 means muted by keypress on watch.
private SdAlgNn mSdAlgNn;
// Values for SD_MODE
private int SIMPLE_SPEC_FMAX = 10;
@@ -110,6 +111,8 @@ public abstract class SdDataSource {
mUtil = new OsdUtil(mContext, mHandler);
mSdDataReceiver = sdDataReceiver;
mSdData = new SdData();
mSdAlgNn = new SdAlgNn(mContext);
}
/**
@@ -484,6 +487,7 @@ public abstract class SdDataSource {
// Check this data to see if it represents an alarm state.
alarmCheck();
nnCheck();
hrCheck();
o2SatCheck();
fallCheck();
@@ -724,6 +728,13 @@ public abstract class SdDataSource {
}
}
void nnCheck() {
//Check the current set of data using the neural network model to look for alarms.
Log.d(TAG,"nnCheck");
int nnResult = mSdAlgNn.run(mSdData);
Log.d(TAG,"nnCheck - nnResult="+nnResult);
}
/**
* updatePrefs() - update basic settings from the SharedPreferences
* - defined in res/xml/SdDataSourceNetworkPassivePrefs.xml