V4.2.1a - added standard deviation threshold to ML Algorithm
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
package uk.org.openseizuredetector;
|
||||
|
||||
import android.content.Context;
|
||||
import android.content.SharedPreferences;
|
||||
import android.preference.PreferenceManager;
|
||||
import android.util.Log;
|
||||
import android.widget.Toast;
|
||||
|
||||
import com.android.volley.AuthFailureError;
|
||||
import com.android.volley.Request;
|
||||
@@ -31,28 +34,50 @@ public class SdAlgNn {
|
||||
private Context mContext;
|
||||
RequestQueue mQueue;
|
||||
|
||||
private double mSdThresh; // Acceleration Standard Deviation Threshold required to activate analysis (%)
|
||||
private int mModelId; // ID of ML Model to be used (refers to information in MlModels.json for details).
|
||||
private int mInputFormat; // ID of input format required for model (populated from MlModels.json).
|
||||
|
||||
|
||||
public SdAlgNn(Context context) {
|
||||
Log.d(TAG, "SdAlgNn Constructor");
|
||||
mContext = context;
|
||||
|
||||
SharedPreferences SP = PreferenceManager
|
||||
.getDefaultSharedPreferences(mContext);
|
||||
try {
|
||||
String threshStr = SP.getString("CnnAlarmThreshold", "5");
|
||||
mSdThresh = Double.parseDouble(threshStr);
|
||||
Log.v(TAG, "SdAlgNn Constructor mSdThresh = " + mSdThresh);
|
||||
threshStr = SP.getString("CnnModelId", "1");
|
||||
mModelId = Integer.parseInt(threshStr);
|
||||
Log.v(TAG, "SdAlgNn Constructor mModelId = " + mModelId);
|
||||
} catch (Exception ex) {
|
||||
Log.v(TAG, "SdAlgNn Constructor - problem parsing preferences. " + ex.toString());
|
||||
Toast toast = Toast.makeText(mContext, "Problem Parsing ML Algorithm Preferences", Toast.LENGTH_SHORT);
|
||||
toast.show();
|
||||
}
|
||||
|
||||
mInputFormat = 1; // FIXME - this needs to be determined from the model ID specified by retrieving a configuration file.
|
||||
|
||||
Task<Void> initializeTask = TfLite.initialize(mContext);
|
||||
|
||||
initializeTask.addOnSuccessListener(a -> {
|
||||
MappedByteBuffer modelBuffer;
|
||||
try {
|
||||
Log.d(TAG, "onSuccessListener - loading model");
|
||||
modelBuffer = FileUtil.loadMappedFile(context, MODEL_PATH);
|
||||
Log.d(TAG, "onSuccessListener - model loaded");
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Error Loading Model File");
|
||||
return;
|
||||
}
|
||||
Log.d(TAG, "onSuccessListener - creating interpreter");
|
||||
interpreter = InterpreterApi.create(modelBuffer,
|
||||
new InterpreterApi.Options().setRuntime(
|
||||
InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY));
|
||||
Log.d(TAG, "onSuccessListener - interpreter created ok");
|
||||
})
|
||||
MappedByteBuffer modelBuffer;
|
||||
try {
|
||||
Log.d(TAG, "onSuccessListener - loading model");
|
||||
modelBuffer = FileUtil.loadMappedFile(context, MODEL_PATH);
|
||||
Log.d(TAG, "onSuccessListener - model loaded");
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Error Loading Model File");
|
||||
return;
|
||||
}
|
||||
Log.d(TAG, "onSuccessListener - creating interpreter");
|
||||
interpreter = InterpreterApi.create(modelBuffer,
|
||||
new InterpreterApi.Options().setRuntime(
|
||||
InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY));
|
||||
Log.d(TAG, "onSuccessListener - interpreter created ok");
|
||||
})
|
||||
.addOnFailureListener(e -> {
|
||||
Log.e(TAG, String.format("Cannot initialize interpreter: %s",
|
||||
e.getMessage()));
|
||||
@@ -62,25 +87,81 @@ public class SdAlgNn {
|
||||
}
|
||||
|
||||
public void close() {
|
||||
Log.d(TAG,"close()");
|
||||
Log.d(TAG, "close()");
|
||||
if (interpreter != null) {
|
||||
interpreter.close();
|
||||
}
|
||||
}
|
||||
|
||||
public float getPseizure(SdData sdData) {
|
||||
/**
|
||||
* getPseizureFmt1 - calculate probability of sdData representing seizure-like movement
|
||||
* using a model with input format #1, which is a simple vector of 125 accelerometer vector
|
||||
* magnitude readings.
|
||||
* @param sdData - seizure detector data as input to the model
|
||||
* @return probability of data representing seizure-like movement.
|
||||
*/
|
||||
private float getPseizureFmt1(SdData sdData) {
|
||||
int i;
|
||||
float[][][] modelInput = new float[1][125][1];
|
||||
float[][] modelOutput = new float[1][2];
|
||||
for (int j = 0; j < 125; j++) {
|
||||
modelInput[0][j][0] = (float)sdData.rawData[j];
|
||||
modelInput[0][j][0] = (float) sdData.rawData[j];
|
||||
}
|
||||
if (interpreter == null) {
|
||||
Log.d(TAG,"getPSeizure() - interpreter is null - returning zero seizure probability");
|
||||
Log.d(TAG, "getPSeizure() - interpreter is null - returning zero seizure probability");
|
||||
return (0.0f);
|
||||
}
|
||||
interpreter.run(modelInput, modelOutput);
|
||||
Log.d(TAG,"run - pSeizure="+modelOutput[0][1]);
|
||||
return(modelOutput[0][1]);
|
||||
Log.d(TAG, "run - pSeizure=" + modelOutput[0][1]);
|
||||
return (modelOutput[0][1]);
|
||||
}
|
||||
|
||||
public float getPseizure(SdData sdData) {
|
||||
int i;
|
||||
|
||||
// First check that we have enough movement to analyse by comparing the acceleration standard deviation to a threshold.
|
||||
double stdDev;
|
||||
stdDev = calcRawDataStd(sdData);
|
||||
if (stdDev < mSdThresh) {
|
||||
Log.d(TAG, "getPseizure - acceleration stdev below movement threshold: std="+stdDev+", thresh="+mSdThresh);
|
||||
return (0);
|
||||
}
|
||||
|
||||
float pSeizure;
|
||||
switch (mModelId) {
|
||||
case 1:
|
||||
pSeizure = getPseizureFmt1(sdData);
|
||||
break;
|
||||
default:
|
||||
Log.e(TAG,"getPSeizure - invalid model ID "+mModelId);
|
||||
pSeizure = 0;
|
||||
}
|
||||
|
||||
return(pSeizure);
|
||||
}
|
||||
|
||||
private double calcRawDataStd(SdData sdData) {
|
||||
/**
|
||||
* Calculate the standard deviation in % of the rawData array in the SdData instance provided.
|
||||
* It assumes that rawdata will contain 125 samples.
|
||||
* Returns the standard deviation in %.
|
||||
*/
|
||||
// FIXME - assumes length of rawdata array is 125 data points
|
||||
int j;
|
||||
double sum = 0.0;
|
||||
for (j = 0; j < 125; j++) { // FIXME - assumed length!
|
||||
sum += sdData.rawData[j];
|
||||
}
|
||||
double mean = sum / 125;
|
||||
|
||||
double standardDeviation = 0.0;
|
||||
for (j = 0; j < 125; j++) { // FIXME - assumed length!
|
||||
standardDeviation += Math.pow(sdData.rawData[j] - mean, 2);
|
||||
}
|
||||
standardDeviation = Math.sqrt(standardDeviation / 125); // FIXME - assumed length!
|
||||
|
||||
// Convert standard deviation from milli-g to %
|
||||
standardDeviation = 100. * standardDeviation / mean;
|
||||
return (standardDeviation);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user