Successfully training models based on a collection of images
This commit is contained in:
107
scripts/train.ts
107
scripts/train.ts
@ -1,20 +1,24 @@
|
||||
import * as faceapi from "@vladmandic/face-api";
|
||||
import canvas from "canvas";
|
||||
import fs from "fs";
|
||||
import fs, { lstatSync } from "fs";
|
||||
import * as path from "path";
|
||||
import { TNetInput } from "@vladmandic/face-api";
|
||||
import { LabeledFaceDescriptors, TNetInput } from "@vladmandic/face-api";
|
||||
import * as mime from "mime-types";
|
||||
import dotenv from "dotenv-extended";
|
||||
require("@tensorflow/tfjs-node");
|
||||
|
||||
const { Canvas, Image, ImageData } = canvas;
|
||||
//@ts-ignore
|
||||
faceapi.env.monkeyPatch({ Canvas, Image, ImageData });
|
||||
|
||||
const REFERENCE_IMAGE =
|
||||
"/Users/brandonwatson/Documents/Git/Gitea/homebridge-face-location/images/brandon/IMG_1958.jpg";
|
||||
const QUERY_IMAGE =
|
||||
"/Users/brandonwatson/Documents/Git/Gitea/homebridge-face-location/images/brandon/IMG_0001.JPG";
|
||||
|
||||
const main = async () => {
|
||||
dotenv.load({
|
||||
silent: false,
|
||||
errorOnMissing: true,
|
||||
});
|
||||
const inputDir = process.env.REF_IMAGE_DIR as string;
|
||||
const outDir = process.env.TRAINED_MODEL_DIR as string;
|
||||
|
||||
const faceDetectionNet = faceapi.nets.ssdMobilenetv1;
|
||||
await faceDetectionNet.loadFromDisk(path.join(__dirname, "../weights"));
|
||||
await faceapi.nets.faceLandmark68Net.loadFromDisk(
|
||||
@ -24,41 +28,66 @@ const main = async () => {
|
||||
path.join(__dirname, "../weights")
|
||||
);
|
||||
|
||||
const referenceImage = (await canvas.loadImage(REFERENCE_IMAGE)) as unknown;
|
||||
const queryImage = (await canvas.loadImage(QUERY_IMAGE)) as unknown;
|
||||
const options = getFaceDetectorOptions(faceDetectionNet);
|
||||
|
||||
const resultsRef = await faceapi
|
||||
.detectAllFaces(referenceImage as TNetInput, options)
|
||||
.withFaceLandmarks()
|
||||
.withFaceDescriptors();
|
||||
const dirs = fs.readdirSync(inputDir);
|
||||
|
||||
const resultsQuery = await faceapi
|
||||
.detectAllFaces(queryImage as TNetInput, options)
|
||||
.withFaceLandmarks()
|
||||
.withFaceDescriptors();
|
||||
for (const dir of dirs) {
|
||||
if (!lstatSync(path.join(inputDir, dir)).isDirectory()) {
|
||||
continue;
|
||||
}
|
||||
const files = fs.readdirSync(path.join(inputDir, dir));
|
||||
let referenceResults = await Promise.all(
|
||||
files.map(async (file: string) => {
|
||||
const mimeType = mime.contentType(
|
||||
path.extname(path.join(inputDir, dir, file))
|
||||
);
|
||||
if (!mimeType || !mimeType.startsWith("image")) {
|
||||
return;
|
||||
}
|
||||
console.log(path.join(inputDir, dir, file));
|
||||
|
||||
const faceMatcher = new faceapi.FaceMatcher(resultsRef);
|
||||
try {
|
||||
const referenceImage = (await canvas.loadImage(
|
||||
path.join(inputDir, dir, file)
|
||||
)) as unknown;
|
||||
|
||||
const labels = faceMatcher.labeledDescriptors.map((ld) => ld.label);
|
||||
const refDrawBoxes = resultsRef
|
||||
.map((res) => res.detection.box)
|
||||
.map((box, i) => new faceapi.draw.DrawBox(box, { label: labels[i] }));
|
||||
const outRef = faceapi.createCanvasFromMedia(referenceImage as ImageData);
|
||||
refDrawBoxes.forEach((drawBox) => drawBox.draw(outRef));
|
||||
const descriptor = await faceapi
|
||||
.detectAllFaces(referenceImage as TNetInput, options)
|
||||
.withFaceLandmarks()
|
||||
.withFaceDescriptors();
|
||||
|
||||
saveFile("referenceImage.jpg", (outRef as any).toBuffer("image/jpeg"));
|
||||
return descriptor.length > 0 ? descriptor : undefined;
|
||||
} catch (err) {
|
||||
console.log(
|
||||
"An error occurred loading image at path: " +
|
||||
path.join(inputDir, dir, file)
|
||||
);
|
||||
}
|
||||
return undefined;
|
||||
})
|
||||
);
|
||||
|
||||
const queryDrawBoxes = resultsQuery.map((res) => {
|
||||
const bestMatch = faceMatcher.findBestMatch(res.descriptor);
|
||||
return new faceapi.draw.DrawBox(res.detection.box, {
|
||||
label: bestMatch.toString(),
|
||||
});
|
||||
});
|
||||
const outQuery = faceapi.createCanvasFromMedia(queryImage as ImageData);
|
||||
queryDrawBoxes.forEach((drawBox) => drawBox.draw(outQuery));
|
||||
saveFile("queryImage.jpg", (outQuery as any).toBuffer("image/jpeg"));
|
||||
console.log("done, saved results to out/queryImage.jpg");
|
||||
const items = [];
|
||||
for (const item of referenceResults) {
|
||||
if (item) {
|
||||
items.push(...item);
|
||||
}
|
||||
}
|
||||
const faceMatcher = new faceapi.FaceMatcher(items);
|
||||
fs.writeFile(
|
||||
path.join(outDir, dir + ".json"),
|
||||
JSON.stringify(faceMatcher.toJSON()),
|
||||
"utf8",
|
||||
(err) => {
|
||||
if (err) {
|
||||
console.log(`An error occurred while writing ${dir} model to file`);
|
||||
}
|
||||
|
||||
console.log(`Successfully wrote ${dir} model to file`);
|
||||
}
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// SsdMobilenetv1Options
|
||||
@ -76,12 +105,4 @@ function getFaceDetectorOptions(net: faceapi.NeuralNetwork<any>) {
|
||||
|
||||
const baseDir = path.resolve(__dirname, "../out");
|
||||
|
||||
function saveFile(fileName: string, buf: Buffer) {
|
||||
if (!fs.existsSync(baseDir)) {
|
||||
fs.mkdirSync(baseDir);
|
||||
}
|
||||
|
||||
fs.writeFileSync(path.resolve(baseDir, fileName), buf);
|
||||
}
|
||||
|
||||
main();
|
||||
|
Reference in New Issue
Block a user