Successfully training models based on a collection of images

This commit is contained in:
watsonb8
2020-11-08 20:57:57 -05:00
parent e1902a362e
commit fda68e7144
7 changed files with 176 additions and 113 deletions

View File

@ -1,20 +1,24 @@
import * as faceapi from "@vladmandic/face-api";
import canvas from "canvas";
import fs from "fs";
import fs, { lstatSync } from "fs";
import * as path from "path";
import { TNetInput } from "@vladmandic/face-api";
import { LabeledFaceDescriptors, TNetInput } from "@vladmandic/face-api";
import * as mime from "mime-types";
import dotenv from "dotenv-extended";
require("@tensorflow/tfjs-node");
const { Canvas, Image, ImageData } = canvas;
//@ts-ignore
faceapi.env.monkeyPatch({ Canvas, Image, ImageData });
const REFERENCE_IMAGE =
"/Users/brandonwatson/Documents/Git/Gitea/homebridge-face-location/images/brandon/IMG_1958.jpg";
const QUERY_IMAGE =
"/Users/brandonwatson/Documents/Git/Gitea/homebridge-face-location/images/brandon/IMG_0001.JPG";
const main = async () => {
dotenv.load({
silent: false,
errorOnMissing: true,
});
const inputDir = process.env.REF_IMAGE_DIR as string;
const outDir = process.env.TRAINED_MODEL_DIR as string;
const faceDetectionNet = faceapi.nets.ssdMobilenetv1;
await faceDetectionNet.loadFromDisk(path.join(__dirname, "../weights"));
await faceapi.nets.faceLandmark68Net.loadFromDisk(
@ -24,41 +28,66 @@ const main = async () => {
path.join(__dirname, "../weights")
);
const referenceImage = (await canvas.loadImage(REFERENCE_IMAGE)) as unknown;
const queryImage = (await canvas.loadImage(QUERY_IMAGE)) as unknown;
const options = getFaceDetectorOptions(faceDetectionNet);
const resultsRef = await faceapi
.detectAllFaces(referenceImage as TNetInput, options)
.withFaceLandmarks()
.withFaceDescriptors();
const dirs = fs.readdirSync(inputDir);
const resultsQuery = await faceapi
.detectAllFaces(queryImage as TNetInput, options)
.withFaceLandmarks()
.withFaceDescriptors();
for (const dir of dirs) {
if (!lstatSync(path.join(inputDir, dir)).isDirectory()) {
continue;
}
const files = fs.readdirSync(path.join(inputDir, dir));
let referenceResults = await Promise.all(
files.map(async (file: string) => {
const mimeType = mime.contentType(
path.extname(path.join(inputDir, dir, file))
);
if (!mimeType || !mimeType.startsWith("image")) {
return;
}
console.log(path.join(inputDir, dir, file));
const faceMatcher = new faceapi.FaceMatcher(resultsRef);
try {
const referenceImage = (await canvas.loadImage(
path.join(inputDir, dir, file)
)) as unknown;
const labels = faceMatcher.labeledDescriptors.map((ld) => ld.label);
const refDrawBoxes = resultsRef
.map((res) => res.detection.box)
.map((box, i) => new faceapi.draw.DrawBox(box, { label: labels[i] }));
const outRef = faceapi.createCanvasFromMedia(referenceImage as ImageData);
refDrawBoxes.forEach((drawBox) => drawBox.draw(outRef));
const descriptor = await faceapi
.detectAllFaces(referenceImage as TNetInput, options)
.withFaceLandmarks()
.withFaceDescriptors();
saveFile("referenceImage.jpg", (outRef as any).toBuffer("image/jpeg"));
return descriptor.length > 0 ? descriptor : undefined;
} catch (err) {
console.log(
"An error occurred loading image at path: " +
path.join(inputDir, dir, file)
);
}
return undefined;
})
);
const queryDrawBoxes = resultsQuery.map((res) => {
const bestMatch = faceMatcher.findBestMatch(res.descriptor);
return new faceapi.draw.DrawBox(res.detection.box, {
label: bestMatch.toString(),
});
});
const outQuery = faceapi.createCanvasFromMedia(queryImage as ImageData);
queryDrawBoxes.forEach((drawBox) => drawBox.draw(outQuery));
saveFile("queryImage.jpg", (outQuery as any).toBuffer("image/jpeg"));
console.log("done, saved results to out/queryImage.jpg");
const items = [];
for (const item of referenceResults) {
if (item) {
items.push(...item);
}
}
const faceMatcher = new faceapi.FaceMatcher(items);
fs.writeFile(
path.join(outDir, dir + ".json"),
JSON.stringify(faceMatcher.toJSON()),
"utf8",
(err) => {
if (err) {
console.log(`An error occurred while writing ${dir} model to file`);
}
console.log(`Successfully wrote ${dir} model to file`);
}
);
}
};
// SsdMobilenetv1Options
@ -76,12 +105,4 @@ function getFaceDetectorOptions(net: faceapi.NeuralNetwork<any>) {
const baseDir = path.resolve(__dirname, "../out");
function saveFile(fileName: string, buf: Buffer) {
if (!fs.existsSync(baseDir)) {
fs.mkdirSync(baseDir);
}
fs.writeFileSync(path.resolve(baseDir, fileName), buf);
}
main();