104 lines
3.1 KiB
TypeScript

import * as faceapi from "@vladmandic/face-api";
import canvas from "canvas";
import fs, { lstatSync } from "fs";
import * as path from "path";
import { LabeledFaceDescriptors, TNetInput } from "@vladmandic/face-api";
import * as mime from "mime-types";
import { getFaceDetectorOptions } from "./common";
require("@tensorflow/tfjs-node");
const { Canvas, Image, ImageData } = canvas;
//@ts-ignore
faceapi.env.monkeyPatch({ Canvas, Image, ImageData });
export class Trainer {
constructor(private _refImageDir: string, private _trainedModelDir: string) {}
public async train(writeToDisk: boolean): Promise<faceapi.FaceMatcher> {
const faceDetectionNet = faceapi.nets.ssdMobilenetv1;
await faceDetectionNet.loadFromDisk(path.join(__dirname, "../weights"));
await faceapi.nets.faceLandmark68Net.loadFromDisk(
path.join(__dirname, "../weights")
);
await faceapi.nets.faceRecognitionNet.loadFromDisk(
path.join(__dirname, "../weights")
);
const options = getFaceDetectorOptions(faceDetectionNet);
const dirs = fs.readdirSync(this._refImageDir);
const refs = [];
for (const dir of dirs) {
const descriptor = new LabeledFaceDescriptors(dir, []);
await this.getLabeledFaceDescriptorFromDir(
path.join(this._refImageDir, dir),
descriptor,
options
);
if (descriptor) {
refs.push(descriptor);
}
}
const faceMatcher = new faceapi.FaceMatcher(refs);
if (writeToDisk) {
fs.writeFile(
path.join(this._trainedModelDir, "data.json"),
JSON.stringify(faceMatcher.toJSON()),
"utf8",
(err) => {
if (err) {
console.log(`An error occurred while writing data model to file`);
}
console.log(`Successfully wrote data model to file`);
}
);
}
return faceMatcher;
}
private getLabeledFaceDescriptorFromDir = async (
dir: string,
labeldFaceDescriptors: LabeledFaceDescriptors,
options: faceapi.TinyFaceDetectorOptions | faceapi.SsdMobilenetv1Options
): Promise<void> => {
if (!lstatSync(dir).isDirectory()) {
return;
}
const files = fs.readdirSync(dir);
await Promise.all(
files.map(async (file: string) => {
const mimeType = mime.contentType(path.extname(path.join(dir, file)));
if (!mimeType || !mimeType.startsWith("image")) {
return;
}
console.log(path.join(dir, file));
try {
const referenceImage = (await canvas.loadImage(
path.join(dir, file)
)) as unknown;
const descriptor = await faceapi
.detectSingleFace(referenceImage as TNetInput, options)
.withFaceLandmarks()
.withFaceDescriptor();
if (!descriptor || !descriptor.descriptor) {
throw new Error("No face found");
}
labeldFaceDescriptors.descriptors.push(descriptor.descriptor);
} catch (err) {
throw new Error(
"An error occurred loading image at path: " + path.join(dir, file)
);
}
})
);
};
}