homebridge-face-location/scripts/train.ts
2020-11-08 15:33:07 -05:00

88 lines
3.0 KiB
TypeScript

import * as faceapi from "@vladmandic/face-api";
import canvas from "canvas";
import fs from "fs";
import * as path from "path";
import { TNetInput } from "@vladmandic/face-api";
require("@tensorflow/tfjs-node");
const { Canvas, Image, ImageData } = canvas;
//@ts-ignore
faceapi.env.monkeyPatch({ Canvas, Image, ImageData });
const REFERENCE_IMAGE =
"/Users/brandonwatson/Documents/Git/Gitea/homebridge-face-location/images/brandon/IMG_1958.jpg";
const QUERY_IMAGE =
"/Users/brandonwatson/Documents/Git/Gitea/homebridge-face-location/images/brandon/IMG_0001.JPG";
const main = async () => {
const faceDetectionNet = faceapi.nets.ssdMobilenetv1;
await faceDetectionNet.loadFromDisk(path.join(__dirname, "../weights"));
await faceapi.nets.faceLandmark68Net.loadFromDisk(
path.join(__dirname, "../weights")
);
await faceapi.nets.faceRecognitionNet.loadFromDisk(
path.join(__dirname, "../weights")
);
const referenceImage = (await canvas.loadImage(REFERENCE_IMAGE)) as unknown;
const queryImage = (await canvas.loadImage(QUERY_IMAGE)) as unknown;
const options = getFaceDetectorOptions(faceDetectionNet);
const resultsRef = await faceapi
.detectAllFaces(referenceImage as TNetInput, options)
.withFaceLandmarks()
.withFaceDescriptors();
const resultsQuery = await faceapi
.detectAllFaces(queryImage as TNetInput, options)
.withFaceLandmarks()
.withFaceDescriptors();
const faceMatcher = new faceapi.FaceMatcher(resultsRef);
const labels = faceMatcher.labeledDescriptors.map((ld) => ld.label);
const refDrawBoxes = resultsRef
.map((res) => res.detection.box)
.map((box, i) => new faceapi.draw.DrawBox(box, { label: labels[i] }));
const outRef = faceapi.createCanvasFromMedia(referenceImage as ImageData);
refDrawBoxes.forEach((drawBox) => drawBox.draw(outRef));
saveFile("referenceImage.jpg", (outRef as any).toBuffer("image/jpeg"));
const queryDrawBoxes = resultsQuery.map((res) => {
const bestMatch = faceMatcher.findBestMatch(res.descriptor);
return new faceapi.draw.DrawBox(res.detection.box, {
label: bestMatch.toString(),
});
});
const outQuery = faceapi.createCanvasFromMedia(queryImage as ImageData);
queryDrawBoxes.forEach((drawBox) => drawBox.draw(outQuery));
saveFile("queryImage.jpg", (outQuery as any).toBuffer("image/jpeg"));
console.log("done, saved results to out/queryImage.jpg");
};
// SsdMobilenetv1Options
const minConfidence = 0.5;
// TinyFaceDetectorOptions
const inputSize = 408;
const scoreThreshold = 0.5;
function getFaceDetectorOptions(net: faceapi.NeuralNetwork<any>) {
return net === faceapi.nets.ssdMobilenetv1
? new faceapi.SsdMobilenetv1Options({ minConfidence })
: new faceapi.TinyFaceDetectorOptions({ inputSize, scoreThreshold });
}
const baseDir = path.resolve(__dirname, "../out");
function saveFile(fileName: string, buf: Buffer) {
if (!fs.existsSync(baseDir)) {
fs.mkdirSync(baseDir);
}
fs.writeFileSync(path.resolve(baseDir, fileName), buf);
}
main();