215 lines
7.2 KiB
Dart
215 lines
7.2 KiB
Dart
import 'dart:io';
|
|
|
|
import 'package:archive/archive_io.dart';
|
|
import 'package:dio/dio.dart';
|
|
import 'package:path/path.dart' as p;
|
|
import 'package:path_provider/path_provider.dart';
|
|
import 'package:riverpod_annotation/riverpod_annotation.dart';
|
|
import 'package:trainhub_flutter/core/constants/ai_constants.dart';
|
|
import 'package:trainhub_flutter/data/services/ai_process_manager.dart';
|
|
import 'package:trainhub_flutter/injection.dart' as di;
|
|
import 'package:trainhub_flutter/presentation/settings/ai_model_settings_state.dart';
|
|
|
|
part 'ai_model_settings_controller.g.dart';
|
|
|
|
Future<String> _llamaArchiveUrl() async {
|
|
final build = AiConstants.llamaBuild;
|
|
if (Platform.isMacOS) {
|
|
final result = await Process.run('uname', ['-m']);
|
|
final arch = (result.stdout as String).trim();
|
|
final suffix = arch == 'arm64' ? 'macos-arm64' : 'macos-x64';
|
|
return 'https://github.com/ggml-org/llama.cpp/releases/download/$build/llama-$build-bin-$suffix.tar.gz';
|
|
} else if (Platform.isWindows) {
|
|
return 'https://github.com/ggml-org/llama.cpp/releases/download/$build/llama-$build-bin-win-vulkan-x64.zip';
|
|
} else if (Platform.isLinux) {
|
|
return 'https://github.com/ggml-org/llama.cpp/releases/download/$build/llama-$build-bin-ubuntu-vulkan-x64.tar.gz';
|
|
}
|
|
throw UnsupportedError('Unsupported platform: ${Platform.operatingSystem}');
|
|
}
|
|
|
|
@riverpod
|
|
class AiModelSettingsController extends _$AiModelSettingsController {
|
|
final _dio = Dio();
|
|
|
|
@override
|
|
AiModelSettingsState build() => const AiModelSettingsState();
|
|
|
|
Future<void> validateModels() async {
|
|
state = state.copyWith(
|
|
currentTask: 'Checking installed files…',
|
|
errorMessage: null,
|
|
);
|
|
try {
|
|
final dir = await getApplicationDocumentsDirectory();
|
|
final base = dir.path;
|
|
final serverBin = File(p.join(base, AiConstants.serverBinaryName));
|
|
final nomicModel = File(p.join(base, AiConstants.nomicModelFile));
|
|
final qwenModel = File(p.join(base, AiConstants.qwenModelFile));
|
|
final validated =
|
|
serverBin.existsSync() &&
|
|
nomicModel.existsSync() &&
|
|
qwenModel.existsSync();
|
|
state = state.copyWith(
|
|
areModelsValidated: validated,
|
|
currentTask: validated ? 'All files present.' : 'Files missing.',
|
|
);
|
|
} catch (e) {
|
|
state = state.copyWith(
|
|
areModelsValidated: false,
|
|
currentTask: 'Validation failed.',
|
|
errorMessage: e.toString(),
|
|
);
|
|
}
|
|
}
|
|
|
|
Future<void> downloadAll() async {
|
|
if (state.isDownloading) return;
|
|
try {
|
|
await di.getIt<AiProcessManager>().stopServers();
|
|
} catch (_) {}
|
|
state = state.copyWith(
|
|
isDownloading: true,
|
|
progress: 0.0,
|
|
areModelsValidated: false,
|
|
errorMessage: null,
|
|
);
|
|
try {
|
|
final dir = await getApplicationDocumentsDirectory();
|
|
final archiveUrl = await _llamaArchiveUrl();
|
|
final archiveExt = archiveUrl.endsWith('.zip') ? '.zip' : '.tar.gz';
|
|
final archivePath = p.join(dir.path, 'llama_binary$archiveExt');
|
|
await _downloadFile(
|
|
url: archiveUrl,
|
|
savePath: archivePath,
|
|
taskLabel: 'Downloading llama.cpp binary…',
|
|
overallStart: 0.0,
|
|
overallEnd: 0.2,
|
|
);
|
|
state = state.copyWith(
|
|
currentTask: 'Extracting llama.cpp binary…',
|
|
progress: 0.2,
|
|
);
|
|
await _extractBinary(archivePath, dir.path);
|
|
final archiveFile = File(archivePath);
|
|
if (archiveFile.existsSync()) archiveFile.deleteSync();
|
|
await _downloadFile(
|
|
url: AiConstants.nomicModelUrl,
|
|
savePath: p.join(dir.path, AiConstants.nomicModelFile),
|
|
taskLabel: 'Downloading Nomic embedding model…',
|
|
overallStart: 0.2,
|
|
overallEnd: 0.55,
|
|
);
|
|
await _downloadFile(
|
|
url: AiConstants.qwenModelUrl,
|
|
savePath: p.join(dir.path, AiConstants.qwenModelFile),
|
|
taskLabel: 'Downloading Qwen 2.5 7B model…',
|
|
overallStart: 0.55,
|
|
overallEnd: 1.0,
|
|
);
|
|
state = state.copyWith(
|
|
isDownloading: false,
|
|
progress: 1.0,
|
|
currentTask: 'Download complete.',
|
|
);
|
|
await validateModels();
|
|
} on DioException catch (e) {
|
|
state = state.copyWith(
|
|
isDownloading: false,
|
|
currentTask: 'Download failed.',
|
|
errorMessage: 'Network error: ${e.message}',
|
|
);
|
|
} catch (e) {
|
|
state = state.copyWith(
|
|
isDownloading: false,
|
|
currentTask: 'Download failed.',
|
|
errorMessage: e.toString(),
|
|
);
|
|
}
|
|
}
|
|
|
|
Future<void> _downloadFile({
|
|
required String url,
|
|
required String savePath,
|
|
required String taskLabel,
|
|
required double overallStart,
|
|
required double overallEnd,
|
|
}) async {
|
|
state = state.copyWith(currentTask: taskLabel, progress: overallStart);
|
|
await _dio.download(
|
|
url,
|
|
savePath,
|
|
onReceiveProgress: (received, total) {
|
|
if (total <= 0) return;
|
|
final fileProgress = received / total;
|
|
final overall =
|
|
overallStart + fileProgress * (overallEnd - overallStart);
|
|
state = state.copyWith(progress: overall);
|
|
},
|
|
options: Options(
|
|
followRedirects: true,
|
|
maxRedirects: 5,
|
|
receiveTimeout: const Duration(hours: 2),
|
|
),
|
|
);
|
|
}
|
|
|
|
Future<void> _extractBinary(String archivePath, String destDir) async {
|
|
final extractDir = p.join(destDir, '_llama_extract_tmp');
|
|
final extractDirObj = Directory(extractDir);
|
|
if (extractDirObj.existsSync()) extractDirObj.deleteSync(recursive: true);
|
|
extractDirObj.createSync(recursive: true);
|
|
try {
|
|
await extractFileToDisk(archivePath, extractDir);
|
|
bool foundServer = false;
|
|
final binaryName = AiConstants.serverBinaryName;
|
|
for (final entity in extractDirObj.listSync(recursive: true)) {
|
|
if (entity is File) {
|
|
final ext = p.extension(entity.path).toLowerCase();
|
|
final name = p.basename(entity.path);
|
|
if (name == binaryName ||
|
|
ext == '.dll' ||
|
|
ext == '.so' ||
|
|
ext == '.dylib') {
|
|
final destFile = p.join(destDir, name);
|
|
int retryCount = 0;
|
|
bool success = false;
|
|
while (!success && retryCount < 5) {
|
|
try {
|
|
if (File(destFile).existsSync()) {
|
|
File(destFile).deleteSync();
|
|
}
|
|
entity.copySync(destFile);
|
|
success = true;
|
|
} on FileSystemException catch (_) {
|
|
if (retryCount >= 4) {
|
|
throw Exception(
|
|
'Failed to overwrite $name. Ensure no other applications are using it.',
|
|
);
|
|
}
|
|
await Future.delayed(const Duration(milliseconds: 500));
|
|
retryCount++;
|
|
}
|
|
}
|
|
if (name == binaryName) {
|
|
foundServer = true;
|
|
if (Platform.isMacOS || Platform.isLinux) {
|
|
await Process.run('chmod', ['+x', destFile]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (!foundServer) {
|
|
throw FileSystemException(
|
|
'llama-server binary not found in archive.',
|
|
archivePath,
|
|
);
|
|
}
|
|
} finally {
|
|
if (extractDirObj.existsSync()) {
|
|
extractDirObj.deleteSync(recursive: true);
|
|
}
|
|
}
|
|
}
|
|
}
|