/**
* VSOMTraining.js - Training Procedures and Convergence Detection
*
* This module manages the training process for VSOM including learning rate
* schedules, convergence detection, and training quality metrics. It provides
* different training strategies and monitoring capabilities.
*
* Key Features:
* - Multiple learning rate schedules
* - Convergence detection algorithms
* - Training quality metrics
* - Batch and online training modes
* - Training progress monitoring
*/
import { logger } from '../../../Utils.js'
export default class VSOMTraining {
constructor(options = {}) {
this.options = {
// Learning rate schedule
initialLearningRate: options.initialLearningRate || 0.1,
finalLearningRate: options.finalLearningRate || 0.01,
learningRateSchedule: options.learningRateSchedule || 'exponential', // 'linear', 'exponential', 'inverse', 'step'
// Neighborhood radius schedule
initialRadius: options.initialRadius || 5.0,
finalRadius: options.finalRadius || 0.5,
radiusSchedule: options.radiusSchedule || 'exponential',
// Training parameters
maxIterations: options.maxIterations || 1000,
batchSize: options.batchSize || 100,
// Convergence detection
convergenceThreshold: options.convergenceThreshold || 1e-4,
convergenceWindow: options.convergenceWindow || 10,
minIterations: options.minIterations || 100,
// Quality metrics
trackQuantizationError: options.trackQuantizationError !== false,
trackTopographicError: options.trackTopographicError !== false,
qualityCheckInterval: options.qualityCheckInterval || 50,
// Monitoring
logProgress: options.logProgress !== false,
progressInterval: options.progressInterval || 100,
...options
}
// Training state
this.currentIteration = 0
this.isTraining = false
this.trainingStartTime = null
this.trainingHistory = []
// Convergence tracking
this.errorHistory = []
this.converged = false
this.convergenceIteration = null
// Quality metrics
this.qualityMetrics = {
quantizationError: [],
topographicError: [],
neighborhoodPreservation: []
}
logger.debug('VSOMTraining initialized with options:', this.options)
}
/**
* Execute complete training process
* @param {Object} vsomCore - VSOM core algorithm instance
* @param {Object} topology - VSOM topology instance
* @param {Array} trainingData - Array of training vectors
* @param {Object} callbacks - Optional training callbacks
* @returns {Object} Training results
*/
async train(vsomCore, topology, trainingData, callbacks = {}) {
logger.info(`Starting VSOM training: ${trainingData.length} samples, ${this.options.maxIterations} iterations`)
this.isTraining = true
this.trainingStartTime = Date.now()
this.currentIteration = 0
this.converged = false
this.convergenceIteration = null
// Initialize training history
this.trainingHistory = []
this.errorHistory = []
this.qualityMetrics = {
quantizationError: [],
topographicError: [],
neighborhoodPreservation: []
}
// Create neighborhood function
const neighborhoodFunction = topology.createNeighborhoodFunction('gaussian')
try {
// Training loop
for (let iteration = 0; iteration < this.options.maxIterations; iteration++) {
this.currentIteration = iteration
// Calculate current learning parameters
const learningRate = this.calculateLearningRate(iteration)
const neighborhoodRadius = this.calculateNeighborhoodRadius(iteration)
// Perform training step
const iterationResults = await this.trainingStep(
vsomCore,
topology,
trainingData,
learningRate,
neighborhoodRadius,
neighborhoodFunction
)
// Record training progress
this.recordIteration(iteration, learningRate, neighborhoodRadius, iterationResults)
// Check convergence
if (iteration >= this.options.minIterations) {
if (this.checkConvergence()) {
this.converged = true
this.convergenceIteration = iteration
logger.info(`Training converged at iteration ${iteration}`)
break
}
}
// Quality metrics
if (iteration % this.options.qualityCheckInterval === 0) {
await this.calculateQualityMetrics(vsomCore, trainingData, iteration)
}
// Progress logging
if (this.options.logProgress && iteration % this.options.progressInterval === 0) {
this.logTrainingProgress(iteration, learningRate, neighborhoodRadius, iterationResults)
}
// Execute callbacks
if (callbacks.onIteration) {
await callbacks.onIteration(iteration, iterationResults)
}
// Early stopping check
if (callbacks.shouldStop && callbacks.shouldStop(iteration, iterationResults)) {
logger.info(`Training stopped early at iteration ${iteration} by callback`)
break
}
}
// Final quality assessment
await this.calculateQualityMetrics(vsomCore, trainingData, this.currentIteration)
const trainingTime = Date.now() - this.trainingStartTime
const results = this.compileTrainingResults(trainingTime)
logger.info(`Training completed in ${trainingTime}ms after ${this.currentIteration + 1} iterations`)
if (callbacks.onComplete) {
await callbacks.onComplete(results)
}
return results
} catch (error) {
logger.error('Training failed:', error)
throw error
} finally {
this.isTraining = false
}
}
/**
* Perform single training step
* @param {Object} vsomCore - VSOM core algorithm instance
* @param {Object} topology - VSOM topology instance
* @param {Array} trainingData - Training data
* @param {number} learningRate - Current learning rate
* @param {number} neighborhoodRadius - Current neighborhood radius
* @param {Function} neighborhoodFunction - Neighborhood function
* @returns {Object} Iteration results
*/
async trainingStep(vsomCore, topology, trainingData, learningRate, neighborhoodRadius, neighborhoodFunction) {
const stepStartTime = Date.now()
// Shuffle training data for this epoch
const shuffledData = this.shuffleArray([...trainingData])
let totalQuantizationError = 0
let batchCount = 0
// Process data in batches
for (let i = 0; i < shuffledData.length; i += this.options.batchSize) {
const batch = shuffledData.slice(i, i + this.options.batchSize)
// Find BMUs for batch
const bmuIndices = vsomCore.findBestMatchingUnits(batch)
// Update weights
vsomCore.updateWeights(batch, bmuIndices, learningRate, neighborhoodRadius, neighborhoodFunction)
// Calculate batch quantization error
for (let j = 0; j < batch.length; j++) {
const distance = vsomCore.calculateDistance(batch[j], vsomCore.getNodeWeights(bmuIndices[j]))
totalQuantizationError += distance
}
batchCount++
}
const averageQuantizationError = totalQuantizationError / shuffledData.length
const stepTime = Date.now() - stepStartTime
return {
quantizationError: averageQuantizationError,
processingTime: stepTime,
batchCount: batchCount,
samplesProcessed: shuffledData.length
}
}
/**
* Calculate learning rate for current iteration
* @param {number} iteration - Current iteration
* @returns {number} Learning rate
*/
calculateLearningRate(iteration) {
const progress = iteration / this.options.maxIterations
const initial = this.options.initialLearningRate
const final = this.options.finalLearningRate
switch (this.options.learningRateSchedule) {
case 'linear':
return initial * (1 - progress) + final * progress
case 'exponential':
const decayFactor = Math.log(final / initial)
return initial * Math.exp(decayFactor * progress)
case 'inverse':
return initial / (1 + iteration * 0.01)
case 'step':
const stepSize = this.options.maxIterations / 4
const step = Math.floor(iteration / stepSize)
return initial * Math.pow(0.5, step)
default:
return initial * Math.exp(-iteration / (this.options.maxIterations / 3))
}
}
/**
* Calculate neighborhood radius for current iteration
* @param {number} iteration - Current iteration
* @returns {number} Neighborhood radius
*/
calculateNeighborhoodRadius(iteration) {
const progress = iteration / this.options.maxIterations
const initial = this.options.initialRadius
const final = this.options.finalRadius
switch (this.options.radiusSchedule) {
case 'linear':
return initial * (1 - progress) + final * progress
case 'exponential':
const decayFactor = Math.log(final / initial)
return initial * Math.exp(decayFactor * progress)
case 'inverse':
return initial / (1 + iteration * 0.02)
default:
return initial * Math.exp(-iteration / (this.options.maxIterations / 2))
}
}
/**
* Check if training has converged
* @returns {boolean} True if converged
*/
checkConvergence() {
if (this.errorHistory.length < this.options.convergenceWindow) {
return false
}
// Get recent errors
const recentErrors = this.errorHistory.slice(-this.options.convergenceWindow)
// Calculate error variance over convergence window
const mean = recentErrors.reduce((sum, error) => sum + error, 0) / recentErrors.length
const variance = recentErrors.reduce((sum, error) => sum + Math.pow(error - mean, 2), 0) / recentErrors.length
const standardDeviation = Math.sqrt(variance)
// Check if standard deviation is below threshold
return standardDeviation < this.options.convergenceThreshold
}
/**
* Calculate training quality metrics
* @param {Object} vsomCore - VSOM core algorithm instance
* @param {Array} trainingData - Training data
* @param {number} iteration - Current iteration
*/
async calculateQualityMetrics(vsomCore, trainingData, iteration) {
if (this.options.trackQuantizationError) {
const qError = vsomCore.calculateQuantizationError(trainingData)
this.qualityMetrics.quantizationError.push({
iteration: iteration,
value: qError
})
}
if (this.options.trackTopographicError) {
const tError = vsomCore.calculateTopographicError(trainingData)
this.qualityMetrics.topographicError.push({
iteration: iteration,
value: tError
})
}
}
/**
* Record iteration results
* @param {number} iteration - Current iteration
* @param {number} learningRate - Learning rate used
* @param {number} neighborhoodRadius - Neighborhood radius used
* @param {Object} results - Iteration results
*/
recordIteration(iteration, learningRate, neighborhoodRadius, results) {
this.trainingHistory.push({
iteration: iteration,
learningRate: learningRate,
neighborhoodRadius: neighborhoodRadius,
quantizationError: results.quantizationError,
processingTime: results.processingTime,
timestamp: Date.now()
})
this.errorHistory.push(results.quantizationError)
// Limit history size to prevent memory issues
const maxHistorySize = this.options.maxIterations + 100
if (this.trainingHistory.length > maxHistorySize) {
this.trainingHistory = this.trainingHistory.slice(-maxHistorySize)
}
if (this.errorHistory.length > maxHistorySize) {
this.errorHistory = this.errorHistory.slice(-maxHistorySize)
}
}
/**
* Log training progress
* @param {number} iteration - Current iteration
* @param {number} learningRate - Learning rate
* @param {number} neighborhoodRadius - Neighborhood radius
* @param {Object} results - Iteration results
*/
logTrainingProgress(iteration, learningRate, neighborhoodRadius, results) {
const progress = ((iteration + 1) / this.options.maxIterations * 100).toFixed(1)
const elapsed = Date.now() - this.trainingStartTime
const eta = elapsed / (iteration + 1) * (this.options.maxIterations - iteration - 1)
logger.info(`Training ${progress}%: iteration ${iteration + 1}/${this.options.maxIterations}, ` +
`QE: ${results.quantizationError.toFixed(6)}, ` +
`LR: ${learningRate.toFixed(4)}, ` +
`R: ${neighborhoodRadius.toFixed(2)}, ` +
`ETA: ${Math.round(eta / 1000)}s`)
}
/**
* Compile final training results
* @param {number} trainingTime - Total training time
* @returns {Object} Complete training results
*/
compileTrainingResults(trainingTime) {
return {
// Training summary
totalIterations: this.currentIteration + 1,
trainingTime: trainingTime,
converged: this.converged,
convergenceIteration: this.convergenceIteration,
// Final state
finalQuantizationError: this.errorHistory[this.errorHistory.length - 1] || null,
finalLearningRate: this.calculateLearningRate(this.currentIteration),
finalNeighborhoodRadius: this.calculateNeighborhoodRadius(this.currentIteration),
// Training history
trainingHistory: this.trainingHistory,
errorHistory: this.errorHistory,
qualityMetrics: this.qualityMetrics,
// Performance metrics
averageIterationTime: trainingTime / (this.currentIteration + 1),
iterationsPerSecond: (this.currentIteration + 1) / (trainingTime / 1000),
// Configuration used
trainingOptions: this.options
}
}
/**
* Shuffle array in place using Fisher-Yates algorithm
* @param {Array} array - Array to shuffle
* @returns {Array} Shuffled array
*/
shuffleArray(array) {
for (let i = array.length - 1; i > 0; i--) {
const j = Math.floor(Math.random() * (i + 1))
;[array[i], array[j]] = [array[j], array[i]]
}
return array
}
/**
* Stop training if currently running
*/
stopTraining() {
if (this.isTraining) {
logger.info(`Training stopped at iteration ${this.currentIteration}`)
this.isTraining = false
}
}
/**
* Get current training status
* @returns {Object} Training status information
*/
getTrainingStatus() {
return {
isTraining: this.isTraining,
currentIteration: this.currentIteration,
maxIterations: this.options.maxIterations,
progress: this.currentIteration / this.options.maxIterations,
converged: this.converged,
convergenceIteration: this.convergenceIteration,
elapsedTime: this.trainingStartTime ? Date.now() - this.trainingStartTime : 0
}
}
/**
* Get training statistics
* @returns {Object} Training statistics
*/
getStatistics() {
return {
trainingHistorySize: this.trainingHistory.length,
errorHistorySize: this.errorHistory.length,
qualityMetricsCount: Object.values(this.qualityMetrics).reduce((sum, arr) => sum + arr.length, 0),
memoryUsage: this.estimateMemoryUsage()
}
}
/**
* Estimate memory usage
* @returns {number} Estimated memory usage in bytes
*/
estimateMemoryUsage() {
const historySize = this.trainingHistory.length * 200 // Rough estimate per entry
const errorHistorySize = this.errorHistory.length * 8 // Float64
const qualityMetricsSize = Object.values(this.qualityMetrics).reduce((sum, arr) => sum + arr.length * 16, 0)
return historySize + errorHistorySize + qualityMetricsSize
}
/**
* Reset training state
*/
reset() {
this.currentIteration = 0
this.isTraining = false
this.trainingStartTime = null
this.trainingHistory = []
this.errorHistory = []
this.converged = false
this.convergenceIteration = null
this.qualityMetrics = {
quantizationError: [],
topographicError: [],
neighborhoodPreservation: []
}
logger.debug('VSOMTraining state reset')
}
}