diff --git a/discojs/src/index.ts b/discojs/src/index.ts index 04537e7c9..26fa556ec 100644 --- a/discojs/src/index.ts +++ b/discojs/src/index.ts @@ -17,6 +17,7 @@ export { EpochLogs, Tokenizer, ValidationMetrics, + ModelMetadata, } from "./models/index.js"; export * as models from './models/index.js' diff --git a/discojs/src/models/index.ts b/discojs/src/models/index.ts index 96d253f0b..d2497e894 100644 --- a/discojs/src/models/index.ts +++ b/discojs/src/models/index.ts @@ -1,4 +1,4 @@ -export { Model } from './model.js' +export { Model, ModelMetadata } from './model.js' export { BatchLogs, EpochLogs, ValidationMetrics } from "./logs.js"; export { Tokenizer } from "./tokenizer.js"; diff --git a/discojs/src/models/model.ts b/discojs/src/models/model.ts index dd7c0477c..9d617c621 100644 --- a/discojs/src/models/model.ts +++ b/discojs/src/models/model.ts @@ -7,6 +7,11 @@ import type { } from "../index.js"; import type { BatchLogs, EpochLogs } from "./logs.js"; +import type { StandardizationStats } from "../processing/tabular.js"; + +export type ModelMetadata = { + tabularStandardization?: StandardizationStats; +}; /** * Trainable predictor @@ -21,6 +26,9 @@ export abstract class Model implements Disposable { /** Set training state */ abstract set weights(ws: WeightsContainer); + /** Optional metadata for tabular task data standardization */ + metadata?: ModelMetadata; + /** * Improve predictor * diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts index b60060f49..8a62dd5ca 100644 --- a/discojs/src/models/tfjs.ts +++ b/discojs/src/models/tfjs.ts @@ -12,17 +12,20 @@ import { import { BatchLogs } from './index.js' import { Model } from './index.js' import { EpochLogs } from './logs.js' +import { ModelMetadata } from "./model.js"; -type Serialized = [D, tf.io.ModelArtifacts]; +type Serialized = [D, tf.io.ModelArtifacts, ModelMetadata?]; /** TensorFlow JavaScript model with standard training */ export class TFJS extends Model { /** Wrap the given trainable model */ constructor ( public readonly datatype: D, - private readonly model: tf.LayersModel + private readonly model: tf.LayersModel, + metadata?: ModelMetadata, ) { super() + this.metadata = metadata; if (model.loss === undefined) { throw new Error('TFJS models need to be compiled to be used') @@ -176,12 +179,14 @@ export class TFJS extends Model { static async deserialize([ datatype, artifacts, + metadata ]: Serialized): Promise> { return new this( datatype, await tf.loadLayersModel({ load: () => Promise.resolve(artifacts), }), + metadata ); } @@ -204,7 +209,7 @@ export class TFJS extends Model { includeOptimizer: true // keep model compiled }) - return [this.datatype, await ret] + return [this.datatype, await ret, this.metadata] } [Symbol.dispose](): void{ diff --git a/discojs/src/processing/index.ts b/discojs/src/processing/index.ts index 4011f40d1..44ad656c2 100644 --- a/discojs/src/processing/index.ts +++ b/discojs/src/processing/index.ts @@ -9,6 +9,7 @@ import type { Tabular, Task, Network, + ModelMetadata, } from "../index.js"; import * as processing from "./index.js"; @@ -19,6 +20,7 @@ export * from "./tabular.js"; export function preprocess( task: Task, dataset: Dataset, + metadata?: ModelMetadata, ): Dataset { switch (task.dataType) { case "image": { @@ -37,12 +39,17 @@ export function preprocess( // cast as typescript doesn't reduce generic type const d = dataset as Dataset; const { inputColumns, outputColumn } = task.trainingInformation; + const stats = metadata?.tabularStandardization; return d.map((row) => { const output = processing.extractColumn(row, outputColumn); + const inputs = stats + ? List(processing.standardizeRow(row, inputColumns, stats)) + : extractToNumbers(inputColumns, row); + return [ - extractToNumbers(inputColumns, row), + inputs, // TODO sanitization doesn't care about column distribution output !== "" ? processing.convertToNumber(output) : 0, ]; @@ -68,6 +75,7 @@ export function preprocess( export function preprocessWithoutLabel( task: Task, dataset: Dataset, + metadata?: ModelMetadata, ): Dataset { switch (task.dataType) { case "image": { @@ -85,8 +93,13 @@ export function preprocessWithoutLabel( // cast as typescript doesn't reduce generic type const d = dataset as Dataset; const { inputColumns } = task.trainingInformation; + const stats = metadata?.tabularStandardization; - return d.map((row) => extractToNumbers(inputColumns, row)); + return d.map((row) => + stats + ? List(processing.standardizeRow(row, inputColumns, stats)) + : extractToNumbers(inputColumns, row) + ); } case "text": { // cast as typescript doesn't reduce generic type diff --git a/discojs/src/processing/tabular.ts b/discojs/src/processing/tabular.ts index 685baca03..79ffa96a5 100644 --- a/discojs/src/processing/tabular.ts +++ b/discojs/src/processing/tabular.ts @@ -1,5 +1,10 @@ import { List } from "immutable"; +export type StandardizationStats = { + means: Record; + stds: Record; +}; + /** * Convert a string to a number * @@ -38,3 +43,63 @@ export function indexInList( if (ret === -1) throw new Error(`${element} not found in list`); return ret; } + +/** + * Return the mean, std value of each column + */ +export function computeStandardizationStats( + rows: Array>>, + columns: Array, +): StandardizationStats{ + const means: Record = {}; + const stds: Record = {}; + + for (const col of columns){ + const values = rows.map((row)=> { + const rawValue = extractColumn(row, col); + return convertToNumber(rawValue !== "" ? rawValue : "0"); + }); + const mean = values.reduce((a, b)=> a+b, 0) / values.length; + const variance = values.reduce((acc, val) => acc + (val-mean)**2, 0) / values.length; + + const std = Math.sqrt(variance); + + means[col] = mean; + stds[col] = std; + } + + return {means, stds}; +} + +/** + * Apply standardization for a single value + */ +export function standardizeValue( + value: number, + mean: number, + std: number, +): number{ + if (std == 0) return 0; // avoid divide by 0 + return (value - mean) / std; +} + +/** + * Apply standardization for a row + * + * standardization function is called for each row in dataset + */ +export function standardizeRow( + row: Partial>, + columns: Array, + stats: StandardizationStats, +): Array{ + return columns.map((col) => { + const rawValue = extractColumn(row, col) + // Handle cases where the dataset contains empty strings. + // This only occurs in test cases, as empty strings are not allowed in the web app. + const value = convertToNumber(rawValue !== "" ? rawValue : "0"); + const mean = stats.means[col]; + const std = stats.stds[col]; + return standardizeValue(value, mean, std); + }) +} \ No newline at end of file diff --git a/discojs/src/serialization/model.ts b/discojs/src/serialization/model.ts index 020d147af..b04f50f76 100644 --- a/discojs/src/serialization/model.ts +++ b/discojs/src/serialization/model.ts @@ -1,6 +1,6 @@ import type tf from '@tensorflow/tfjs' -import type { DataType, Model } from '../index.js' +import type { DataType, Model, ModelMetadata } from '../index.js' import { models, serialization } from '../index.js' import { GPTConfig } from '../models/index.js' @@ -41,11 +41,11 @@ export async function decode(encoded: Encoded): Promise> { const rawModel = raw[1] as unknown switch (type) { case Type.TFJS: { - if (raw.length !== 3) + if (raw.length !== 3 && raw.length !== 4) throw new Error( - "invalid TFJS model encoding: should be an array of length 3", + "invalid TFJS model encoding: should be an array of length 3 or 4", ); - const [rawDatatype, rawModel] = raw.slice(1) as unknown[]; + const [rawDatatype, rawModel, rawMetadata] = raw.slice(1) as unknown[]; let datatype; switch (rawDatatype) { @@ -63,6 +63,8 @@ export async function decode(encoded: Encoded): Promise> { datatype, // TODO totally unsafe casting rawModel as tf.io.ModelArtifacts, + // metadata for tabular task standardization + rawMetadata as ModelMetadata, ]); } case Type.GPT: { diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 0d182fc43..7f808ba83 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -155,13 +155,13 @@ export class Disco extends EventEmitter<{ > { this.#logger.success("Training started"); - const [trainingDataset, validationDataset] = - await this.#preprocessSplitAndBatch(dataset); - // the client fetches the latest weights upon connection // TODO unsafe cast this.trainer.model = (await this.#client.connect()) as Model; + const [trainingDataset, validationDataset] = + await this.#preprocessSplitAndBatch(dataset); + for await (const [round, epochs] of enumerate( this.trainer.train(trainingDataset, validationDataset), )) { @@ -213,21 +213,78 @@ export class Disco extends EventEmitter<{ > { const { batchSize, validationSplit } = this.#task.trainingInformation; - let preprocessed = processing.preprocess(this.#task, dataset); + if (validationSplit === 0){ + if (this.#task.dataType === "tabular"){ + const rows = await arrayFromAsync(dataset as Dataset); + const inputColumns = this.#task.trainingInformation.inputColumns; + + const stats = processing.computeStandardizationStats(rows, inputColumns); + this.trainer.model.metadata = { + tabularStandardization: stats, + }; + + const preprocessed = processing.preprocess( + this.#task, + dataset, + this.trainer.model.metadata, + ); + return [preprocessed.batch(batchSize).cached(), undefined]; + } + // If task datatype is not tabular + let preprocessed = processing.preprocess(this.#task, dataset); + + preprocessed = ( + this.#preprocessOnce + ? new Dataset(await arrayFromAsync(preprocessed)) + : preprocessed + ) + return [preprocessed.batch(batchSize).cached(), undefined]; + } + + // If training/validation splitting ratio is defined + const [training, validation] = dataset.split(validationSplit); + + if (this.#task.dataType == "tabular"){ + const trainingRows = await arrayFromAsync(training as Dataset); + const inputColumns = this.#task.trainingInformation.inputColumns; + const stats = processing.computeStandardizationStats(trainingRows, inputColumns); + + this.trainer.model.metadata = { + tabularStandardization: stats, + }; + + let preprocessedTraining = processing.preprocess(this.#task, training, this.trainer.model.metadata); + let preprocessedValidation = processing.preprocess(this.#task, validation, this.trainer.model.metadata); + preprocessedTraining = this.#preprocessOnce + ? new Dataset(await arrayFromAsync(preprocessedTraining)) + : preprocessedTraining; + + preprocessedValidation = this.#preprocessOnce + ? new Dataset(await arrayFromAsync(preprocessedValidation)) + : preprocessedValidation; + + return [ + preprocessedTraining.batch(batchSize).cached(), + preprocessedValidation.batch(batchSize).cached(), + ]; + } + + // if task datatype is not tabular + let preprocessedTraining = processing.preprocess(this.#task, training); + let preprocessedValidation = processing.preprocess(this.#task, validation); - preprocessed = ( - this.#preprocessOnce - ? new Dataset(await arrayFromAsync(preprocessed)) - : preprocessed - ) - if (validationSplit === 0) return [preprocessed.batch(batchSize).cached(), undefined]; + preprocessedTraining = this.#preprocessOnce + ? new Dataset(await arrayFromAsync(preprocessedTraining)) + : preprocessedTraining; - const [training, validation] = preprocessed.split(validationSplit); + preprocessedValidation = this.#preprocessOnce + ? new Dataset(await arrayFromAsync(preprocessedValidation)) + : preprocessedValidation; return [ - training.batch(batchSize).cached(), - validation.batch(batchSize).cached(), - ]; + preprocessedTraining.batch(batchSize).cached(), + preprocessedValidation.batch(batchSize).cached(), + ]; } } diff --git a/webapp/src/components/dataset_input/validate.ts b/webapp/src/components/dataset_input/validate.ts index 38237bee3..886eb2711 100644 --- a/webapp/src/components/dataset_input/validate.ts +++ b/webapp/src/components/dataset_input/validate.ts @@ -2,15 +2,30 @@ import { Range, Set } from "immutable"; import type { LabeledDataset } from "./types"; +function isNaNValue(value: string | undefined): boolean{ + if (value === undefined) + return true; + + const trimmed = value.trim(); + return trimmed === "" || trimmed.toLowerCase() === "nan"; +} + export async function tabular( wantedColumns: Set, dataset: LabeledDataset["tabular"], ): Promise { - for await (const [columns, i] of dataset - .map((row) => Set(Object.keys(row))) - .zip(Range(1, Number.POSITIVE_INFINITY))) - if (!columns.isSuperset(wantedColumns)) - throw new Error( - `row ${i} is missing columns ${wantedColumns.subtract(columns).join(", ")}`, - ); + for await (const [row, i] of dataset + .zip(Range(1, Number.POSITIVE_INFINITY))){ + const columns = Set(Object.keys(row)); + + if (!columns.isSuperset(wantedColumns)) + throw new Error( + `row ${i} is missing columns ${wantedColumns.subtract(columns).join(", ")}`, + ); + + for (const col of wantedColumns){ + if (isNaNValue(row[col])) + throw new Error(`row ${i} column "${col}" contains NaN`); + } + } }