7 Commits

12 changed files with 480 additions and 386 deletions

23
deploy.sh Executable file
View File

@ -0,0 +1,23 @@
#!/bin/bash
remote_user="bmw"
remote_server="linuxhost.me"
deploy_location="/home/bmw/homebridge-face-location"
#build
tsc --build
#copy files to remote machine
scp -r bin $remote_user@$remote_server:$deploy_location
scp -r out $remote_user@$remote_server:$deploy_location
scp -r weights $remote_user@$remote_server:$deploy_location
scp -r trainedModels $remote_user@$remote_server:$deploy_location
scp package.json $remote_user@$remote_server:$deploy_location
#install package
ssh -t $remote_user@$remote_server "sudo npm install -g --unsafe-perm $deploy_location"
#restart service
ssh -t
ssh -t $remote_user@$remote_server "sudo systemctl restart homebridge.service"
echo done
exit

View File

@ -1,100 +1,16 @@
import * as faceapi from "@vladmandic/face-api";
import canvas from "canvas";
import fs, { lstatSync } from "fs";
import * as path from "path";
import { LabeledFaceDescriptors, TNetInput } from "@vladmandic/face-api";
import * as mime from "mime-types";
import dotenv from "dotenv-extended";
import { getFaceDetectorOptions } from "../src/common";
require("@tensorflow/tfjs-node");
const { Canvas, Image, ImageData } = canvas;
//@ts-ignore
faceapi.env.monkeyPatch({ Canvas, Image, ImageData });
import { Trainer } from "../src/trainer";
const main = async () => {
dotenv.load({
silent: false,
errorOnMissing: true,
});
const inputDir = process.env.REF_IMAGE_DIR as string;
const outDir = process.env.TRAINED_MODEL_DIR as string;
const faceDetectionNet = faceapi.nets.ssdMobilenetv1;
await faceDetectionNet.loadFromDisk(path.join(__dirname, "../weights"));
await faceapi.nets.faceLandmark68Net.loadFromDisk(
path.join(__dirname, "../weights")
);
await faceapi.nets.faceRecognitionNet.loadFromDisk(
path.join(__dirname, "../weights")
);
const options = getFaceDetectorOptions(faceDetectionNet);
const dirs = fs.readdirSync(inputDir);
const refs: Array<LabeledFaceDescriptors> = [];
for (const dir of dirs) {
if (!lstatSync(path.join(inputDir, dir)).isDirectory()) {
continue;
}
const files = fs.readdirSync(path.join(inputDir, dir));
let referenceResults = await Promise.all(
files.map(async (file: string) => {
const mimeType = mime.contentType(
path.extname(path.join(inputDir, dir, file))
);
if (!mimeType || !mimeType.startsWith("image")) {
return;
}
console.log(path.join(inputDir, dir, file));
try {
const referenceImage = (await canvas.loadImage(
path.join(inputDir, dir, file)
)) as unknown;
const descriptor = await faceapi
.detectSingleFace(referenceImage as TNetInput, options)
.withFaceLandmarks()
.withFaceDescriptor();
if (!descriptor || !descriptor.descriptor) {
throw new Error("No face found");
}
const faceDescriptors = [descriptor.descriptor];
return new faceapi.LabeledFaceDescriptors(dir, faceDescriptors);
} catch (err) {
console.log(
"An error occurred loading image at path: " +
path.join(inputDir, dir, file)
);
}
return undefined;
})
);
if (referenceResults) {
refs.push(
...(referenceResults.filter((e) => e) as LabeledFaceDescriptors[])
);
}
}
const faceMatcher = new faceapi.FaceMatcher(refs);
fs.writeFile(
path.join(outDir, "data.json"),
JSON.stringify(faceMatcher.toJSON()),
"utf8",
(err) => {
if (err) {
console.log(`An error occurred while writing data model to file`);
}
console.log(`Successfully wrote data model to file`);
}
const trainer = new Trainer(
process.env.REF_IMAGE_DIR as string,
process.env.TRAINED_MODEL_DIR as string
);
await trainer.train(true);
};
main();

View File

@ -3,10 +3,10 @@ import * as path from "path";
import fs from "fs";
// SsdMobilenetv1Options
export const minConfidence = 0.5;
export const minConfidence = 0.4;
// TinyFaceDetectorOptions
export const inputSize = 408;
export const inputSize = 416;
export const scoreThreshold = 0.5;
export const getFaceDetectorOptions = (net: faceapi.NeuralNetwork<any>) => {
@ -15,34 +15,48 @@ export const getFaceDetectorOptions = (net: faceapi.NeuralNetwork<any>) => {
: new faceapi.TinyFaceDetectorOptions({ inputSize, scoreThreshold });
};
export function saveFile(
export const saveFile = async (
basePath: string,
fileName: string,
buf: Buffer
): Promise<void> {
const writeFile = (): Promise<void> => {
return new Promise((resolve, reject) => {
fs.writeFile(path.resolve(basePath, fileName), buf, "base64", (err) => {
if (err) {
return reject(err);
}
resolve();
});
});
};
): Promise<void> => {
return new Promise(async (resolve, reject) => {
if (!fs.existsSync(basePath)) {
fs.mkdir(basePath, async (err) => {
try {
//Create directory if it does not exist
await makeDirectory(basePath);
} catch (err) {
return reject(err);
}
//Write file to directory
try {
const asdf = fs.writeFileSync(
path.join(basePath, fileName),
buf,
"base64"
);
} catch (err) {
return reject(err);
}
return resolve();
});
};
export const makeDirectory = (path: string): Promise<void> => {
return new Promise(async (resolve, reject) => {
if (!fs.existsSync(path)) {
fs.mkdir(path, async (err) => {
if (err) {
return reject(err);
}
resolve(await writeFile());
return resolve();
});
} else {
resolve(await writeFile());
}
return resolve();
});
}
};
export const delay = (ms: number): Promise<void> => {
return new Promise((resolve) => {

View File

@ -7,9 +7,11 @@ export interface IConfig extends PlatformConfig {
outputDirectory: string;
trainOnStartup: boolean;
rooms: Array<IRoom>;
detectionTimeout: number;
debug: boolean;
writeOutput: boolean;
detectionTimeout?: number;
watchdogTimeout?: number;
debug?: boolean;
writeOutput?: boolean;
rate?: number;
}
export interface IRoom {
@ -25,14 +27,13 @@ export const isConfig = (object: any): object is IConfig => {
const roomsOkay =
object["rooms"].filter((room: any) => isRoom(room)).length ===
object["rooms"].length;
return (
"refImageDirectory" in object &&
"trainedModelDirectory" in object &&
"weightDirectory" in object &&
"outputDirectory" in object &&
"trainOnStartup" in object &&
"detectionTimeout" in object &&
"writeOutput" in object &&
"rooms" in object &&
roomsOkay
);

View File

@ -10,17 +10,12 @@ import {
import { IConfig, isConfig } from "./config";
import * as faceapi from "@vladmandic/face-api";
import canvas from "canvas";
import fs, { lstatSync } from "fs";
import fs from "fs";
import * as path from "path";
import { nets } from "@vladmandic/face-api";
import {
LabeledFaceDescriptors,
TNetInput,
FaceMatcher,
} from "@vladmandic/face-api";
import * as mime from "mime-types";
import { Monitor } from "./monitor";
import { getFaceDetectorOptions } from "./common";
import { FaceMatcher } from "@vladmandic/face-api";
import { Monitor } from "./monitor/monitor";
import { Trainer } from "./trainer";
require("@tensorflow/tfjs-node");
const { Canvas, Image, ImageData } = canvas;
@ -82,17 +77,20 @@ export class HomeLocationPlatform implements DynamicPlatformPlugin {
* must not be registered again to prevent "duplicate UUID" errors.
*/
public async discoverDevices() {
const faceDetectionNet = nets.ssdMobilenetv1;
await faceDetectionNet.loadFromDisk(this.config.weightDirectory);
await nets.faceLandmark68Net.loadFromDisk(this.config.weightDirectory);
await nets.faceRecognitionNet.loadFromDisk(this.config.weightDirectory);
//Train facial recognition model
let faceMatcher: FaceMatcher;
if (this.config.trainOnStartup) {
faceMatcher = await this.trainModels();
const trainer = new Trainer(
this.config.refImageDirectory,
this.config.trainedModelDirectory
);
faceMatcher = await trainer.train(true);
} else {
const faceDetectionNet = nets.ssdMobilenetv1;
await faceDetectionNet.loadFromDisk(this.config.weightDirectory);
await nets.faceLandmark68Net.loadFromDisk(this.config.weightDirectory);
await nets.faceRecognitionNet.loadFromDisk(this.config.weightDirectory);
const raw = fs.readFileSync(
path.join(this.config.trainedModelDirectory, "data.json"),
"utf-8"
@ -142,88 +140,4 @@ export class HomeLocationPlatform implements DynamicPlatformPlugin {
}
}
}
private async trainModels(): Promise<FaceMatcher> {
const faceDetectionNet = faceapi.nets.ssdMobilenetv1;
await faceDetectionNet.loadFromDisk(this.config.weightDirectory);
await faceapi.nets.faceLandmark68Net.loadFromDisk(
this.config.weightDirectory
);
await faceapi.nets.faceRecognitionNet.loadFromDisk(
this.config.weightDirectory
);
const options = getFaceDetectorOptions(faceDetectionNet);
const dirs = fs.readdirSync(this.config.refImageDirectory);
const refs: Array<LabeledFaceDescriptors> = [];
for (const dir of dirs) {
if (
!lstatSync(path.join(this.config.refImageDirectory, dir)).isDirectory()
) {
continue;
}
const files = fs.readdirSync(
path.join(this.config.refImageDirectory, dir)
);
let referenceResults = await Promise.all(
files.map(async (file: string) => {
const mimeType = mime.contentType(
path.extname(path.join(this.config.refImageDirectory, dir, file))
);
if (!mimeType || !mimeType.startsWith("image")) {
return;
}
console.log(path.join(this.config.refImageDirectory, dir, file));
try {
const referenceImage = (await canvas.loadImage(
path.join(this.config.refImageDirectory, dir, file)
)) as unknown;
const descriptor = await faceapi
.detectSingleFace(referenceImage as TNetInput, options)
.withFaceLandmarks()
.withFaceDescriptor();
if (!descriptor || !descriptor.descriptor) {
throw new Error("No face found");
}
const faceDescriptors = [descriptor.descriptor];
return new faceapi.LabeledFaceDescriptors(dir, faceDescriptors);
} catch (err) {
console.log(
"An error occurred loading image at path: " +
path.join(this.config.refImageDirectory, dir, file)
);
}
return undefined;
})
);
if (referenceResults) {
refs.push(
...(referenceResults.filter((e) => e) as LabeledFaceDescriptors[])
);
}
}
const faceMatcher = new faceapi.FaceMatcher(refs);
fs.writeFile(
path.join(this.config.trainedModelDirectory, "data.json"),
JSON.stringify(faceMatcher.toJSON()),
"utf8",
(err) => {
if (err) {
console.log(`An error occurred while writing data model to file`);
}
console.log(`Successfully wrote data model to file`);
}
);
return faceMatcher;
}
}

View File

@ -3,17 +3,24 @@ import {
CharacteristicGetCallback,
PlatformAccessory,
} from "homebridge";
import { Monitor, IStateChangeEventArgs } from "./monitor";
import { Monitor, IStateChangeEventArgs } from "./monitor/monitor";
import { HomeLocationPlatform } from "./homeLocationPlatform";
import { IRoom } from "./config";
const defaultDetectionTimeout = 180000;
interface IMotionDetectionService {
service: Service;
detectionTimeout: NodeJS.Timeout | null;
}
/**
* Platform Accessory
* An instance of this class is created for each accessory your platform registers
* Each accessory may expose multiple services of different service types.
*/
export class LocationAccessory {
private _services: Array<Service>;
private _services: Array<IMotionDetectionService>;
constructor(
private readonly _platform: HomeLocationPlatform,
@ -54,7 +61,10 @@ export class LocationAccessory {
this.onMotionDetectedGet(label, callback)
);
this._services.push(newService);
this._services.push({
service: newService,
detectionTimeout: null,
});
}
//Register monitor state change events
@ -78,14 +88,31 @@ export class LocationAccessory {
sender: Monitor,
args: IStateChangeEventArgs
) => {
const service = this._services.find(
(service) => service.displayName == args.label
const motionService = this._services.find(
(motionService) => motionService.service.displayName == args.label
);
if (service) {
service.setCharacteristic(
if (motionService) {
//Set accessory state
motionService.service.setCharacteristic(
this._platform.Characteristic.MotionDetected,
args.new === this._room.name
);
//Reset detectionTimeout
clearTimeout(motionService.detectionTimeout!);
motionService.detectionTimeout = setTimeout(
() => this.onDetectionTimeout(motionService),
this._platform.config.detectionTimeout ?? defaultDetectionTimeout
);
}
};
private onDetectionTimeout = (motionService: IMotionDetectionService) => {
//Set accessory state
motionService.service.setCharacteristic(
this._platform.Characteristic.MotionDetected,
0
);
this._monitor.resetState(motionService.service.displayName);
};
}

View File

@ -1,161 +0,0 @@
import { FaceMatcher } from "@vladmandic/face-api";
import { IRoom } from "./config";
import {
Rtsp,
IStreamEventArgs,
ICloseEventArgs,
IErrorEventArgs,
IMessageEventArgs,
} from "./rtsp/rtsp";
import canvas from "canvas";
import * as faceapi from "@vladmandic/face-api";
import { getFaceDetectorOptions, saveFile } from "./common";
import { nets } from "@vladmandic/face-api";
import { Logger } from "homebridge";
import { Event } from "./events";
import { IConfig } from "./config";
const { Canvas, Image, ImageData } = canvas;
export type MonitorState = { [label: string]: string | null };
export interface IStateChangeEventArgs {
label: string;
old: string | null;
new: string;
}
export class Monitor {
private _state: MonitorState = {};
private _streamsByRoom: { [roomName: string]: Array<Rtsp> } = {};
private _faceDetectionNet = nets.ssdMobilenetv1;
private _stateChangedEvent: Event<this, IStateChangeEventArgs>;
constructor(
private _rooms: Array<IRoom>,
private _matcher: FaceMatcher,
private _logger: Logger,
private _config: IConfig
) {
this._stateChangedEvent = new Event();
//Initialize state
for (const room of this._rooms) {
this._streamsByRoom[room.name] = [
...room.rtspConnectionStrings.map((connectionString) => {
const rtsp = new Rtsp(connectionString, {
rate: 0.7,
image: true,
});
rtsp.dataEvent.push((sender: Rtsp, args: IStreamEventArgs) =>
this.onData(room.name, args)
);
rtsp.closeEvent.push((sender: Rtsp, args: ICloseEventArgs) =>
this.onExit(connectionString, args)
);
rtsp.errorEvent.push((sender: Rtsp, args: IErrorEventArgs) =>
this.onError(args, connectionString)
);
if (this._config.debug) {
rtsp.messageEvent.push((sender: Rtsp, args: IMessageEventArgs) => {
this._logger.info(`[${connectionString}] ${args.message}`);
});
}
return rtsp;
}),
];
_matcher.labeledDescriptors.forEach((descriptor) => {
this._state[descriptor.label] = null;
});
}
}
/**
* @method getState
*
* @param label The name of the label to retrieve state for
*
* The last known room of the requested label
*/
public getState(label: string): string | null {
return this._state[label];
}
/**
* @property labels
*
* Gets the list of labels associated with the monitor
*/
public get labels(): Array<string> {
return this._matcher.labeledDescriptors
.map((descriptor) => descriptor.label)
.filter(
(label: string, index: number, array: Array<string>) =>
array.indexOf(label) === index
);
}
public get stateChangedEvent(): Event<this, IStateChangeEventArgs> {
return this._stateChangedEvent;
}
/**
* @method startStreams
*
* Starts monitoring rtsp streams
*/
public startStreams() {
for (const key in this._streamsByRoom) {
for (const stream of this._streamsByRoom[key]) {
stream.start();
}
}
}
/**
* @method closeStreams
*
* Stops monitoring rtsp streams
*/
public closeStreams() {
for (const key in this._streamsByRoom) {
for (const stream of this._streamsByRoom[key]) {
stream.close();
}
}
}
private onData = async (room: string, args: IStreamEventArgs) => {
const input = ((await canvas.loadImage(args.data)) as unknown) as ImageData;
const out = faceapi.createCanvasFromMedia(input);
const resultsQuery = await faceapi
.detectAllFaces(out, getFaceDetectorOptions(this._faceDetectionNet))
.withFaceLandmarks()
.withFaceDescriptors();
//Write to output image
if (this._config.writeOutput) {
await saveFile(this._config.outputDirectory, room + ".jpg", args.data);
}
for (const res of resultsQuery) {
const bestMatch = this._matcher.matchDescriptor(res.descriptor);
const old = this._state[bestMatch.label];
this._state[bestMatch.label] = room;
this._stateChangedEvent.fire(this, {
old: old,
new: room,
label: bestMatch.label,
});
this._logger.info(`Face Detected: ${bestMatch.label} in room ${room}`);
}
};
private onError = (args: IErrorEventArgs, streamName: string) => {
this._logger.info(`[${streamName}] ${args.message}`);
};
private onExit = (streamName: string, args: ICloseEventArgs) => {
this._logger.info(`[${streamName}] Stream has exited: ${args.message}`);
};
}

238
src/monitor/monitor.ts Normal file
View File

@ -0,0 +1,238 @@
import { FaceMatcher } from "@vladmandic/face-api";
import { IRoom } from "../config";
import {
Rtsp,
IStreamEventArgs,
ICloseEventArgs,
IErrorEventArgs,
IMessageEventArgs,
} from "../rtsp/rtsp";
import canvas from "canvas";
import * as faceapi from "@vladmandic/face-api";
import { getFaceDetectorOptions, saveFile } from "../common";
import { nets } from "@vladmandic/face-api";
import { Logger } from "homebridge";
import { Event } from "../events";
import { IConfig } from "../config";
import { MonitorState } from "./monitorState";
import { IStream } from "./stream";
const { Canvas, Image, ImageData } = canvas;
const defaultWatchDog = 30000;
const defaultRate = 0.7;
export interface IStateChangeEventArgs {
label: string;
old: string | null;
new: string;
}
export class Monitor {
private _state: MonitorState = {};
private _streamsByRoom: { [roomName: string]: Array<IStream> } = {};
private _faceDetectionNet = nets.ssdMobilenetv1;
private _stateChangedEvent: Event<this, IStateChangeEventArgs>;
constructor(
rooms: Array<IRoom>,
private _matcher: FaceMatcher,
private _logger: Logger,
private _config: IConfig
) {
this._stateChangedEvent = new Event();
//Initialize state
for (const room of rooms) {
this._streamsByRoom[room.name] = [
...room.rtspConnectionStrings.map((connectionString) => {
return this.getNewStream(connectionString, room.name);
}),
];
_matcher.labeledDescriptors.forEach((descriptor) => {
this._state[descriptor.label] = null;
});
}
}
/**
* @method getState
*
* @param label The name of the label to retrieve state for
*
* The last known room of the requested label
*/
public getState(label: string): string | null {
return this._state[label];
}
public resetState(label: string): Monitor {
this._state[label] = null;
return this;
}
/**
* @property labels
*
* Gets the list of labels associated with the monitor
*/
public get labels(): Array<string> {
return this._matcher.labeledDescriptors
.map((descriptor) => descriptor.label)
.filter(
(label: string, index: number, array: Array<string>) =>
array.indexOf(label) === index
);
}
public get stateChangedEvent(): Event<this, IStateChangeEventArgs> {
return this._stateChangedEvent;
}
/**
* @method startStreams
*
* Starts monitoring rtsp streams
*/
public startStreams(): Monitor {
for (const key in this._streamsByRoom) {
for (const stream of this._streamsByRoom[key]) {
//Start stream
stream.rtsp.start();
//Start watchdog timer
stream.watchdogTimer = setTimeout(
() => this.onWatchdogTimeout(stream, key),
this._config.watchdogTimeout ?? defaultWatchDog
);
}
}
return this;
}
/**
* @method closeStreams
*
* Stops monitoring rtsp streams
*/
public closeStreams(): Monitor {
for (const key in this._streamsByRoom) {
for (const stream of this._streamsByRoom[key]) {
stream.rtsp.close();
//Stop watchdog timer
if (stream.watchdogTimer) {
clearTimeout(stream.watchdogTimer);
}
}
}
return this;
}
private onData = async (
room: string,
stream: IStream,
args: IStreamEventArgs
) => {
//Reset watchdog timer for the stream
clearTimeout(stream.watchdogTimer!);
stream.watchdogTimer = setTimeout(
() => this.onWatchdogTimeout(stream, room),
this._config.watchdogTimeout ?? 30000
);
//Detect faces in image
const input = ((await canvas.loadImage(args.data)) as unknown) as ImageData;
const out = faceapi.createCanvasFromMedia(input);
const resultsQuery = await faceapi
.detectAllFaces(out, getFaceDetectorOptions(this._faceDetectionNet))
.withFaceLandmarks()
.withFaceDescriptors();
//Write to output image
if (this._config.writeOutput) {
await saveFile(this._config.outputDirectory, room + ".jpg", args.data);
}
for (const res of resultsQuery) {
const bestMatch = this._matcher.matchDescriptor(res.descriptor);
const old = this._state[bestMatch.label];
this._state[bestMatch.label] = room;
this._stateChangedEvent.fire(this, {
old: old,
new: room,
label: bestMatch.label,
});
if (this._config.debug) {
this._logger.info(`Face Detected: ${bestMatch.label} in room ${room}`);
}
}
};
private getNewStream(connectionString: string, roomName: string): IStream {
const stream = {
rtsp: new Rtsp(connectionString, {
rate: this._config.rate ?? defaultRate,
image: true,
}),
watchdogTimer: null,
detectionTimer: null,
connectionString: connectionString,
};
connectionString = this.getRedactedConnectionString(connectionString);
//Subscribe to rtsp events
stream.rtsp.dataEvent.push((sender: Rtsp, args: IStreamEventArgs) =>
this.onData(roomName, stream, args)
);
//Only subscribe to these events if debug
if (this._config.debug) {
stream.rtsp.messageEvent.push((sender: Rtsp, args: IMessageEventArgs) => {
this._logger.info(`[${connectionString}] ${args.message}`);
});
stream.rtsp.errorEvent.push((sender: Rtsp, args: IErrorEventArgs) => {
this._logger.info(`[${connectionString}] ${args.message}`);
});
stream.rtsp.closeEvent.push((sender: Rtsp, args: ICloseEventArgs) => {
this._logger.info(
`[${connectionString}] Stream has exited: ${args.message}`
);
});
}
return stream;
}
private onWatchdogTimeout = async (stream: IStream, roomName: string) => {
this._logger.info(
`[${this.getRedactedConnectionString(
stream.connectionString
)}] Watchdog timeout: restarting stream`
);
//Close and remove old stream
stream.rtsp.close();
this._streamsByRoom[roomName].splice(
this._streamsByRoom[roomName].indexOf(stream),
1
);
//Create and add new stream
this._streamsByRoom[roomName].push(
this.getNewStream(stream.connectionString, roomName)
);
stream.rtsp.start();
};
private getRedactedConnectionString(connectionString: string) {
const pwSepIdx = connectionString.lastIndexOf(":") + 1;
const pwEndIdx = connectionString.indexOf("@");
return (
connectionString.substring(0, pwSepIdx) +
connectionString.substring(pwEndIdx)
);
}
}

View File

@ -0,0 +1 @@
export type MonitorState = { [label: string]: string | null };

8
src/monitor/stream.ts Normal file
View File

@ -0,0 +1,8 @@
import { Rtsp } from "../rtsp/rtsp";
export interface IStream {
rtsp: Rtsp;
connectionString: string;
watchdogTimer: NodeJS.Timeout | null;
detectionTimer: NodeJS.Timeout | null;
}

View File

@ -76,6 +76,7 @@ export class Rtsp {
public start(): void {
const argStrings = [
`-rtsp_transport tcp`,
`-i ${this._connecteionString}`,
`-r ${this._options.rate ?? 10}`,
`-vf mpdecimate,setpts=N/FRAME_RATE/TB`,
@ -92,11 +93,20 @@ export class Rtsp {
}
this._childProcess.stdout?.on("data", this.onData);
this._childProcess.stdout?.on("error", (err) =>
console.log("And error occurred" + err)
this._childProcess.stdout?.on("error", (error: Error) =>
this._errorEvent.fire(this, { err: error })
);
this._childProcess.stdout?.on("close", () =>
this._closeEvent.fire(this, {
message: "Stream closed",
})
);
this._childProcess.stdout?.on("end", () =>
this._closeEvent.fire(this, {
message: "Stream ended",
})
);
this._childProcess.stdout?.on("close", () => console.log("Stream closed"));
this._childProcess.stdout?.on("end", () => console.log("Stream ended"));
//Only register this event if there are subscribers
if (this._childProcess.stderr && this._messageEvent.length > 0) {

103
src/trainer.ts Normal file
View File

@ -0,0 +1,103 @@
import * as faceapi from "@vladmandic/face-api";
import canvas from "canvas";
import fs, { lstatSync } from "fs";
import * as path from "path";
import { LabeledFaceDescriptors, TNetInput } from "@vladmandic/face-api";
import * as mime from "mime-types";
import { getFaceDetectorOptions } from "./common";
require("@tensorflow/tfjs-node");
const { Canvas, Image, ImageData } = canvas;
//@ts-ignore
faceapi.env.monkeyPatch({ Canvas, Image, ImageData });
export class Trainer {
constructor(private _refImageDir: string, private _trainedModelDir: string) {}
public async train(writeToDisk: boolean): Promise<faceapi.FaceMatcher> {
const faceDetectionNet = faceapi.nets.ssdMobilenetv1;
await faceDetectionNet.loadFromDisk(path.join(__dirname, "../weights"));
await faceapi.nets.faceLandmark68Net.loadFromDisk(
path.join(__dirname, "../weights")
);
await faceapi.nets.faceRecognitionNet.loadFromDisk(
path.join(__dirname, "../weights")
);
const options = getFaceDetectorOptions(faceDetectionNet);
const dirs = fs.readdirSync(this._refImageDir);
const refs = [];
for (const dir of dirs) {
const descriptor = new LabeledFaceDescriptors(dir, []);
await this.getLabeledFaceDescriptorFromDir(
path.join(this._refImageDir, dir),
descriptor,
options
);
if (descriptor) {
refs.push(descriptor);
}
}
const faceMatcher = new faceapi.FaceMatcher(refs);
if (writeToDisk) {
fs.writeFile(
path.join(this._trainedModelDir, "data.json"),
JSON.stringify(faceMatcher.toJSON()),
"utf8",
(err) => {
if (err) {
console.log(`An error occurred while writing data model to file`);
}
console.log(`Successfully wrote data model to file`);
}
);
}
return faceMatcher;
}
private getLabeledFaceDescriptorFromDir = async (
dir: string,
labeldFaceDescriptors: LabeledFaceDescriptors,
options: faceapi.TinyFaceDetectorOptions | faceapi.SsdMobilenetv1Options
): Promise<void> => {
if (!lstatSync(dir).isDirectory()) {
return;
}
const files = fs.readdirSync(dir);
await Promise.all(
files.map(async (file: string) => {
const mimeType = mime.contentType(path.extname(path.join(dir, file)));
if (!mimeType || !mimeType.startsWith("image")) {
return;
}
console.log(path.join(dir, file));
try {
const referenceImage = (await canvas.loadImage(
path.join(dir, file)
)) as unknown;
const descriptor = await faceapi
.detectSingleFace(referenceImage as TNetInput, options)
.withFaceLandmarks()
.withFaceDescriptor();
if (!descriptor || !descriptor.descriptor) {
throw new Error("No face found");
}
labeldFaceDescriptors.descriptors.push(descriptor.descriptor);
} catch (err) {
throw new Error(
"An error occurred loading image at path: " + path.join(dir, file)
);
}
})
);
};
}