Skip to content

Commit

Permalink
chore: move maia2 onnx models + add stateful maia hook
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinjosethomas committed Feb 5, 2025
1 parent 26a2e40 commit 31af432
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 49 deletions.
9 changes: 4 additions & 5 deletions src/hooks/useAnalysisController/useAnalysisController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ export const useAnalysisController = (
})
}

const maia = useMaiaEngine()
const { maia, status } = useMaiaEngine()
const engine = useStockfishEngine(parseStockfishEvaluation)
const [currentMove, setCurrentMove] = useState<null | [string, string]>(null)
const [currentMove, setCurrentMove] = useState<[string, string] | null>()
const [stockfishEvaluations, setStockfishEvaluations] = useState<
StockfishEvaluation[]
>([])
Expand All @@ -59,8 +59,7 @@ export const useAnalysisController = (
const board = new Chess(game.moves[controller.currentIndex].board)

;(async () => {
if (maia?.status !== 'ready' || maiaEvaluations[controller.currentIndex])
return
if (status !== 'ready' || maiaEvaluations[controller.currentIndex]) return

const { result } = await maia.batchEvaluate(
Array(9).fill(board.fen()),
Expand All @@ -87,7 +86,7 @@ export const useAnalysisController = (
return newEvaluations
})
})()
}, [controller.currentIndex, game.type, maia?.status])
}, [controller.currentIndex, game.type, status])

useEffect(() => {
if (game.type === 'tournament') return
Expand Down
File renamed without changes.
86 changes: 65 additions & 21 deletions src/utils/maia2/model.ts → src/hooks/useMaiaEngine/model.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,74 @@
import { MaiaStatus } from 'src/types'
import { InferenceSession, Tensor } from 'onnxruntime-web'

import { mirrorMove, preprocess, allPossibleMovesReversed } from './utils'

interface MaiaOptions {
model: string
type: 'rapid' | 'blitz'
setStatus: (status: MaiaStatus) => void
setProgress: (progress: number) => void
setError: (error: string) => void
}

class Maia {
public model!: InferenceSession
public type: 'rapid' | 'blitz'
public status: 'loading' | 'no-cache' | 'downloading' | 'ready'

constructor(options: { model: string; type: 'rapid' | 'blitz' }) {
this.status = 'loading'
this.type = options.type ?? 'rapid'
;(async () => {
try {
console.log('Getting cached')
const buffer = await this.getCachedModel(options.model, options.type)
await this.initializeModel(buffer)
} catch (e) {
console.log('Missing cache')
this.status = 'no-cache'
}
})()
private model!: InferenceSession
private type: 'rapid' | 'blitz'
private modelUrl: string
private options: MaiaOptions

constructor(options: MaiaOptions) {
this.type = options.type
this.modelUrl = options.model
this.options = options

this.initialize()
}

public getStatus() {
return this.status
private async initialize() {
try {
const buffer = await this.getCachedModel(this.modelUrl, this.type)
await this.initializeModel(buffer)
this.options.setStatus('ready')
} catch (e) {
this.options.setStatus('no-cache')
}
}

public async downloadModel() {
const response = await fetch(this.modelUrl)
if (!response.ok) throw new Error('Failed to fetch model')

const reader = response.body?.getReader()
const contentLength = +(response.headers.get('Content-Length') ?? 0)

if (!reader) throw new Error('No response body')

const chunks: Uint8Array[] = []
let receivedLength = 0

while (true) {
const { done, value } = await reader.read()
if (done) break

chunks.push(value)
receivedLength += value.length

this.options.setProgress((receivedLength / contentLength) * 100)
}

const buffer = new Uint8Array(receivedLength)
let position = 0
for (const chunk of chunks) {
buffer.set(chunk, position)
position += chunk.length
}

const cache = await caches.open(`MAIA2-${this.type.toUpperCase()}-MODEL`)
await cache.put(this.modelUrl, new Response(buffer.buffer))

await this.initializeModel(buffer.buffer)
this.options.setStatus('ready')
}

public async getCachedModel(
Expand Down Expand Up @@ -52,8 +97,7 @@ class Maia {

public async initializeModel(buffer: ArrayBuffer) {
this.model = await InferenceSession.create(buffer)
this.status = 'ready'
console.log('initialized')
this.options.setStatus('ready')
}

/**
Expand Down
42 changes: 31 additions & 11 deletions src/hooks/useMaiaEngine/useMaiaEngine.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,38 @@
import { useState, useMemo, useEffect } from 'react'

import Maia from 'src/utils/maia2'
import Maia from './model'
import { MaiaStatus } from 'src/types'
import { useState, useMemo } from 'react'

export const useMaiaEngine = () => {
const [maia, setMaia] = useState<Maia>()
const [status, setStatus] = useState<MaiaStatus>('loading')
const [progress, setProgress] = useState(0)
const [error, setError] = useState<string | null>(null)

useEffect(() => {
setMaia(new Maia({ model: '/maia2/maia_rapid.onnx', type: 'rapid' }))
const maia = useMemo(() => {
const model = new Maia({
model: '/maia2/maia_rapid.onnx',
type: 'rapid',
setStatus: setStatus,
setProgress: setProgress,
setError: setError,
})
return model
}, [])

// const maia = useMemo(() => {
// const model = new Maia({ model: '/maia2/maia_rapid.onnx', type: 'rapid' })
// return model
// }, [])
const downloadModel = async () => {
try {
setStatus('downloading')
await maia.downloadModel()
} catch (err) {
setError(err instanceof Error ? err.message : 'Failed to download model')
setStatus('error')
}
}

return maia
return {
maia,
status,
progress,
error,
downloadModel,
}
}
File renamed without changes.
10 changes: 1 addition & 9 deletions src/pages/analysis/[...id].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -577,15 +577,7 @@ const Analysis: React.FC<Props> = ({
content="Collection of chess training and analysis tools centered around Maia."
/>
</Head>
{maia?.status !== 'ready' ? (
<>
<div className="absolute left-0 top-0 z-50 flex h-screen w-screen flex-col bg-black">
<p className="text-white">{maia?.status}</p>
</div>
</>
) : (
<></>
)}

<GameControllerContext.Provider value={{ ...controller }}>
{analyzedGame && (isMobile ? mobileLayout : desktopLayout)}
</GameControllerContext.Provider>
Expand Down
7 changes: 7 additions & 0 deletions src/types/analysis/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,10 @@ export interface StockfishEvaluation {
cp_vec: { [key: string]: number }
cp_relative_vec: { [key: string]: number }
}

export type MaiaStatus =
| 'loading'
| 'no-cache'
| 'downloading'
| 'ready'
| 'error'
3 changes: 0 additions & 3 deletions src/utils/maia2/index.ts

This file was deleted.

0 comments on commit 31af432

Please sign in to comment.