mirror of https://github.com/immich-app/immich.git
feat(server,ml): remove image tagging (#5903)
* remove image tagging * updated lock * fixed tests, improved logging * be nice * fixed testspull/5904/head
parent
154292242f
commit
092a23fd7f
@ -1,75 +0,0 @@
|
|||||||
from io import BytesIO
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from optimum.onnxruntime import ORTModelForImageClassification
|
|
||||||
from optimum.pipelines import pipeline
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import AutoImageProcessor
|
|
||||||
|
|
||||||
from ..config import log
|
|
||||||
from ..schemas import ModelType
|
|
||||||
from .base import InferenceModel
|
|
||||||
|
|
||||||
|
|
||||||
class ImageClassifier(InferenceModel):
|
|
||||||
_model_type = ModelType.IMAGE_CLASSIFICATION
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
min_score: float = 0.9,
|
|
||||||
cache_dir: Path | str | None = None,
|
|
||||||
**model_kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self.min_score = model_kwargs.pop("minScore", min_score)
|
|
||||||
super().__init__(model_name, cache_dir, **model_kwargs)
|
|
||||||
|
|
||||||
def _download(self) -> None:
|
|
||||||
snapshot_download(
|
|
||||||
cache_dir=self.cache_dir,
|
|
||||||
repo_id=self.model_name,
|
|
||||||
allow_patterns=["*.bin", "*.json", "*.txt"],
|
|
||||||
local_dir=self.cache_dir,
|
|
||||||
local_dir_use_symlinks=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _load(self) -> None:
|
|
||||||
processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
|
|
||||||
model_path = self.cache_dir / "model.onnx"
|
|
||||||
model_kwargs = {
|
|
||||||
"cache_dir": self.cache_dir,
|
|
||||||
"provider": self.providers[0],
|
|
||||||
"provider_options": self.provider_options[0],
|
|
||||||
"session_options": self.sess_options,
|
|
||||||
}
|
|
||||||
|
|
||||||
if model_path.exists():
|
|
||||||
model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
|
|
||||||
self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
|
|
||||||
else:
|
|
||||||
log.info(
|
|
||||||
(
|
|
||||||
f"ONNX model not found in cache directory for '{self.model_name}'."
|
|
||||||
"Exporting optimized model for future use."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.sess_options.optimized_model_filepath = model_path.as_posix()
|
|
||||||
self.model = pipeline(
|
|
||||||
self.model_type.value,
|
|
||||||
self.model_name,
|
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
feature_extractor=processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _predict(self, image: Image.Image | bytes) -> list[str]:
|
|
||||||
if isinstance(image, bytes):
|
|
||||||
image = Image.open(BytesIO(image))
|
|
||||||
predictions: list[dict[str, Any]] = self.model(image)
|
|
||||||
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
|
|
||||||
|
|
||||||
return tags
|
|
||||||
|
|
||||||
def configure(self, **model_kwargs: Any) -> None:
|
|
||||||
self.min_score = model_kwargs.pop("minScore", self.min_score)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,18 +0,0 @@
|
|||||||
# openapi.model.ClassificationConfig
|
|
||||||
|
|
||||||
## Load the model package
|
|
||||||
```dart
|
|
||||||
import 'package:openapi/api.dart';
|
|
||||||
```
|
|
||||||
|
|
||||||
## Properties
|
|
||||||
Name | Type | Description | Notes
|
|
||||||
------------ | ------------- | ------------- | -------------
|
|
||||||
**enabled** | **bool** | |
|
|
||||||
**minScore** | **int** | |
|
|
||||||
**modelName** | **String** | |
|
|
||||||
**modelType** | [**ModelType**](ModelType.md) | | [optional]
|
|
||||||
|
|
||||||
[[Back to Model list]](../README.md#documentation-for-models) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to README]](../README.md)
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,131 +0,0 @@
|
|||||||
//
|
|
||||||
// AUTO-GENERATED FILE, DO NOT MODIFY!
|
|
||||||
//
|
|
||||||
// @dart=2.12
|
|
||||||
|
|
||||||
// ignore_for_file: unused_element, unused_import
|
|
||||||
// ignore_for_file: always_put_required_named_parameters_first
|
|
||||||
// ignore_for_file: constant_identifier_names
|
|
||||||
// ignore_for_file: lines_longer_than_80_chars
|
|
||||||
|
|
||||||
part of openapi.api;
|
|
||||||
|
|
||||||
class ClassificationConfig {
|
|
||||||
/// Returns a new [ClassificationConfig] instance.
|
|
||||||
ClassificationConfig({
|
|
||||||
required this.enabled,
|
|
||||||
required this.minScore,
|
|
||||||
required this.modelName,
|
|
||||||
this.modelType,
|
|
||||||
});
|
|
||||||
|
|
||||||
bool enabled;
|
|
||||||
|
|
||||||
int minScore;
|
|
||||||
|
|
||||||
String modelName;
|
|
||||||
|
|
||||||
///
|
|
||||||
/// Please note: This property should have been non-nullable! Since the specification file
|
|
||||||
/// does not include a default value (using the "default:" property), however, the generated
|
|
||||||
/// source code must fall back to having a nullable type.
|
|
||||||
/// Consider adding a "default:" property in the specification file to hide this note.
|
|
||||||
///
|
|
||||||
ModelType? modelType;
|
|
||||||
|
|
||||||
@override
|
|
||||||
bool operator ==(Object other) => identical(this, other) || other is ClassificationConfig &&
|
|
||||||
other.enabled == enabled &&
|
|
||||||
other.minScore == minScore &&
|
|
||||||
other.modelName == modelName &&
|
|
||||||
other.modelType == modelType;
|
|
||||||
|
|
||||||
@override
|
|
||||||
int get hashCode =>
|
|
||||||
// ignore: unnecessary_parenthesis
|
|
||||||
(enabled.hashCode) +
|
|
||||||
(minScore.hashCode) +
|
|
||||||
(modelName.hashCode) +
|
|
||||||
(modelType == null ? 0 : modelType!.hashCode);
|
|
||||||
|
|
||||||
@override
|
|
||||||
String toString() => 'ClassificationConfig[enabled=$enabled, minScore=$minScore, modelName=$modelName, modelType=$modelType]';
|
|
||||||
|
|
||||||
Map<String, dynamic> toJson() {
|
|
||||||
final json = <String, dynamic>{};
|
|
||||||
json[r'enabled'] = this.enabled;
|
|
||||||
json[r'minScore'] = this.minScore;
|
|
||||||
json[r'modelName'] = this.modelName;
|
|
||||||
if (this.modelType != null) {
|
|
||||||
json[r'modelType'] = this.modelType;
|
|
||||||
} else {
|
|
||||||
// json[r'modelType'] = null;
|
|
||||||
}
|
|
||||||
return json;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a new [ClassificationConfig] instance and imports its values from
|
|
||||||
/// [value] if it's a [Map], null otherwise.
|
|
||||||
// ignore: prefer_constructors_over_static_methods
|
|
||||||
static ClassificationConfig? fromJson(dynamic value) {
|
|
||||||
if (value is Map) {
|
|
||||||
final json = value.cast<String, dynamic>();
|
|
||||||
|
|
||||||
return ClassificationConfig(
|
|
||||||
enabled: mapValueOfType<bool>(json, r'enabled')!,
|
|
||||||
minScore: mapValueOfType<int>(json, r'minScore')!,
|
|
||||||
modelName: mapValueOfType<String>(json, r'modelName')!,
|
|
||||||
modelType: ModelType.fromJson(json[r'modelType']),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
static List<ClassificationConfig> listFromJson(dynamic json, {bool growable = false,}) {
|
|
||||||
final result = <ClassificationConfig>[];
|
|
||||||
if (json is List && json.isNotEmpty) {
|
|
||||||
for (final row in json) {
|
|
||||||
final value = ClassificationConfig.fromJson(row);
|
|
||||||
if (value != null) {
|
|
||||||
result.add(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result.toList(growable: growable);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Map<String, ClassificationConfig> mapFromJson(dynamic json) {
|
|
||||||
final map = <String, ClassificationConfig>{};
|
|
||||||
if (json is Map && json.isNotEmpty) {
|
|
||||||
json = json.cast<String, dynamic>(); // ignore: parameter_assignments
|
|
||||||
for (final entry in json.entries) {
|
|
||||||
final value = ClassificationConfig.fromJson(entry.value);
|
|
||||||
if (value != null) {
|
|
||||||
map[entry.key] = value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return map;
|
|
||||||
}
|
|
||||||
|
|
||||||
// maps a json object with a list of ClassificationConfig-objects as value to a dart map
|
|
||||||
static Map<String, List<ClassificationConfig>> mapListFromJson(dynamic json, {bool growable = false,}) {
|
|
||||||
final map = <String, List<ClassificationConfig>>{};
|
|
||||||
if (json is Map && json.isNotEmpty) {
|
|
||||||
// ignore: parameter_assignments
|
|
||||||
json = json.cast<String, dynamic>();
|
|
||||||
for (final entry in json.entries) {
|
|
||||||
map[entry.key] = ClassificationConfig.listFromJson(entry.value, growable: growable,);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return map;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The list of required keys that must be present in a JSON.
|
|
||||||
static const requiredKeys = <String>{
|
|
||||||
'enabled',
|
|
||||||
'minScore',
|
|
||||||
'modelName',
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
@ -1,42 +0,0 @@
|
|||||||
//
|
|
||||||
// AUTO-GENERATED FILE, DO NOT MODIFY!
|
|
||||||
//
|
|
||||||
// @dart=2.12
|
|
||||||
|
|
||||||
// ignore_for_file: unused_element, unused_import
|
|
||||||
// ignore_for_file: always_put_required_named_parameters_first
|
|
||||||
// ignore_for_file: constant_identifier_names
|
|
||||||
// ignore_for_file: lines_longer_than_80_chars
|
|
||||||
|
|
||||||
import 'package:openapi/api.dart';
|
|
||||||
import 'package:test/test.dart';
|
|
||||||
|
|
||||||
// tests for ClassificationConfig
|
|
||||||
void main() {
|
|
||||||
// final instance = ClassificationConfig();
|
|
||||||
|
|
||||||
group('test ClassificationConfig', () {
|
|
||||||
// bool enabled
|
|
||||||
test('to test the property `enabled`', () async {
|
|
||||||
// TODO
|
|
||||||
});
|
|
||||||
|
|
||||||
// int minScore
|
|
||||||
test('to test the property `minScore`', () async {
|
|
||||||
// TODO
|
|
||||||
});
|
|
||||||
|
|
||||||
// String modelName
|
|
||||||
test('to test the property `modelName`', () async {
|
|
||||||
// TODO
|
|
||||||
});
|
|
||||||
|
|
||||||
// ModelType modelType
|
|
||||||
test('to test the property `modelType`', () async {
|
|
||||||
// TODO
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
});
|
|
||||||
|
|
||||||
}
|
|
||||||
Loading…
Reference in New Issue