143 lines
4.4 KiB
Dart
143 lines
4.4 KiB
Dart
import 'dart:convert';
|
|
import 'dart:math' show sqrt;
|
|
|
|
import 'package:drift/drift.dart';
|
|
import 'package:trainhub_flutter/data/database/app_database.dart';
|
|
import 'package:trainhub_flutter/data/database/daos/knowledge_chunk_dao.dart';
|
|
import 'package:trainhub_flutter/domain/repositories/note_repository.dart';
|
|
import 'package:trainhub_flutter/data/services/embedding_service.dart';
|
|
import 'package:uuid/uuid.dart';
|
|
|
|
const _uuid = Uuid();
|
|
|
|
class NoteRepositoryImpl implements NoteRepository {
|
|
NoteRepositoryImpl(this._dao, this._embeddingService);
|
|
|
|
final KnowledgeChunkDao _dao;
|
|
final EmbeddingService _embeddingService;
|
|
|
|
// -------------------------------------------------------------------------
|
|
// Public interface
|
|
// -------------------------------------------------------------------------
|
|
|
|
@override
|
|
Future<void> addNote(String text) async {
|
|
final chunks = _chunkText(text);
|
|
if (chunks.isEmpty) return;
|
|
|
|
final sourceId = _uuid.v4();
|
|
final now = DateTime.now().toIso8601String();
|
|
|
|
for (final chunk in chunks) {
|
|
final embedding = await _embeddingService.embed(chunk);
|
|
await _dao.insertChunk(
|
|
KnowledgeChunksCompanion(
|
|
id: Value(_uuid.v4()),
|
|
sourceId: Value(sourceId),
|
|
content: Value(chunk),
|
|
embedding: Value(jsonEncode(embedding)),
|
|
createdAt: Value(now),
|
|
),
|
|
);
|
|
}
|
|
}
|
|
|
|
@override
|
|
Future<List<String>> searchSimilar(String query, {int topK = 3}) async {
|
|
final allRows = await _dao.getAllChunks();
|
|
if (allRows.isEmpty) return [];
|
|
|
|
final queryEmbedding = await _embeddingService.embed(query);
|
|
|
|
final scored = allRows.map((row) {
|
|
final emb =
|
|
(jsonDecode(row.embedding) as List<dynamic>)
|
|
.map((e) => (e as num).toDouble())
|
|
.toList();
|
|
return _Scored(
|
|
score: _cosineSimilarity(queryEmbedding, emb),
|
|
text: row.content,
|
|
);
|
|
}).toList()
|
|
..sort((a, b) => b.score.compareTo(a.score));
|
|
|
|
return scored.take(topK).map((s) => s.text).toList();
|
|
}
|
|
|
|
@override
|
|
Future<int> getChunkCount() => _dao.getCount();
|
|
|
|
@override
|
|
Future<void> clearAll() => _dao.deleteAll();
|
|
|
|
// -------------------------------------------------------------------------
|
|
// Text chunking
|
|
// -------------------------------------------------------------------------
|
|
|
|
/// Splits [text] into semantically meaningful chunks of at most [maxChars].
|
|
/// Strategy:
|
|
/// 1. Split by blank lines (paragraph boundaries).
|
|
/// 2. If a paragraph is still too long, split further by sentence.
|
|
/// 3. Accumulate sentences until the chunk would exceed [maxChars].
|
|
static List<String> _chunkText(String text, {int maxChars = 500}) {
|
|
final chunks = <String>[];
|
|
|
|
for (final paragraph in text.split(RegExp(r'\n{2,}'))) {
|
|
final p = paragraph.trim();
|
|
if (p.isEmpty) continue;
|
|
|
|
if (p.length <= maxChars) {
|
|
chunks.add(p);
|
|
continue;
|
|
}
|
|
|
|
// Split long paragraph by sentence boundaries (. ! ?)
|
|
final sentences =
|
|
p.split(RegExp(r'(?<=[.!?])\s+'));
|
|
var current = '';
|
|
|
|
for (final sentence in sentences) {
|
|
final candidate =
|
|
current.isEmpty ? sentence : '$current $sentence';
|
|
if (candidate.length <= maxChars) {
|
|
current = candidate;
|
|
} else {
|
|
if (current.isNotEmpty) chunks.add(current);
|
|
// If a single sentence is longer than maxChars, include it as-is
|
|
// rather than discarding it.
|
|
current = sentence.length > maxChars ? '' : sentence;
|
|
if (sentence.length > maxChars) chunks.add(sentence);
|
|
}
|
|
}
|
|
if (current.isNotEmpty) chunks.add(current);
|
|
}
|
|
|
|
return chunks;
|
|
}
|
|
|
|
// -------------------------------------------------------------------------
|
|
// Cosine similarity
|
|
// -------------------------------------------------------------------------
|
|
|
|
static double _cosineSimilarity(List<double> a, List<double> b) {
|
|
var dot = 0.0;
|
|
var normA = 0.0;
|
|
var normB = 0.0;
|
|
final len = a.length < b.length ? a.length : b.length;
|
|
for (var i = 0; i < len; i++) {
|
|
dot += a[i] * b[i];
|
|
normA += a[i] * a[i];
|
|
normB += b[i] * b[i];
|
|
}
|
|
if (normA == 0.0 || normB == 0.0) return 0.0;
|
|
return dot / (sqrt(normA) * sqrt(normB));
|
|
}
|
|
}
|
|
|
|
// Simple value holder used for sorting — not exported.
|
|
class _Scored {
|
|
const _Scored({required this.score, required this.text});
|
|
final double score;
|
|
final String text;
|
|
}
|