mirror of https://github.com/immich-app/immich.git
feat(server,ml): remove image tagging (#5903)
* remove image tagging * updated lock * fixed tests, improved logging * be nice * fixed testspull/5904/head
parent
154292242f
commit
092a23fd7f
@ -1,75 +0,0 @@
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from optimum.onnxruntime import ORTModelForImageClassification
|
||||
from optimum.pipelines import pipeline
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor
|
||||
|
||||
from ..config import log
|
||||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
|
||||
|
||||
class ImageClassifier(InferenceModel):
|
||||
_model_type = ModelType.IMAGE_CLASSIFICATION
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
min_score: float = 0.9,
|
||||
cache_dir: Path | str | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
self.min_score = model_kwargs.pop("minScore", min_score)
|
||||
super().__init__(model_name, cache_dir, **model_kwargs)
|
||||
|
||||
def _download(self) -> None:
|
||||
snapshot_download(
|
||||
cache_dir=self.cache_dir,
|
||||
repo_id=self.model_name,
|
||||
allow_patterns=["*.bin", "*.json", "*.txt"],
|
||||
local_dir=self.cache_dir,
|
||||
local_dir_use_symlinks=True,
|
||||
)
|
||||
|
||||
def _load(self) -> None:
|
||||
processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
|
||||
model_path = self.cache_dir / "model.onnx"
|
||||
model_kwargs = {
|
||||
"cache_dir": self.cache_dir,
|
||||
"provider": self.providers[0],
|
||||
"provider_options": self.provider_options[0],
|
||||
"session_options": self.sess_options,
|
||||
}
|
||||
|
||||
if model_path.exists():
|
||||
model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
|
||||
self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
|
||||
else:
|
||||
log.info(
|
||||
(
|
||||
f"ONNX model not found in cache directory for '{self.model_name}'."
|
||||
"Exporting optimized model for future use."
|
||||
),
|
||||
)
|
||||
self.sess_options.optimized_model_filepath = model_path.as_posix()
|
||||
self.model = pipeline(
|
||||
self.model_type.value,
|
||||
self.model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
feature_extractor=processor,
|
||||
)
|
||||
|
||||
def _predict(self, image: Image.Image | bytes) -> list[str]:
|
||||
if isinstance(image, bytes):
|
||||
image = Image.open(BytesIO(image))
|
||||
predictions: list[dict[str, Any]] = self.model(image)
|
||||
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
|
||||
|
||||
return tags
|
||||
|
||||
def configure(self, **model_kwargs: Any) -> None:
|
||||
self.min_score = model_kwargs.pop("minScore", self.min_score)
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,18 +0,0 @@
|
||||
# openapi.model.ClassificationConfig
|
||||
|
||||
## Load the model package
|
||||
```dart
|
||||
import 'package:openapi/api.dart';
|
||||
```
|
||||
|
||||
## Properties
|
||||
Name | Type | Description | Notes
|
||||
------------ | ------------- | ------------- | -------------
|
||||
**enabled** | **bool** | |
|
||||
**minScore** | **int** | |
|
||||
**modelName** | **String** | |
|
||||
**modelType** | [**ModelType**](ModelType.md) | | [optional]
|
||||
|
||||
[[Back to Model list]](../README.md#documentation-for-models) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to README]](../README.md)
|
||||
|
||||
|
||||
@ -1,131 +0,0 @@
|
||||
//
|
||||
// AUTO-GENERATED FILE, DO NOT MODIFY!
|
||||
//
|
||||
// @dart=2.12
|
||||
|
||||
// ignore_for_file: unused_element, unused_import
|
||||
// ignore_for_file: always_put_required_named_parameters_first
|
||||
// ignore_for_file: constant_identifier_names
|
||||
// ignore_for_file: lines_longer_than_80_chars
|
||||
|
||||
part of openapi.api;
|
||||
|
||||
class ClassificationConfig {
|
||||
/// Returns a new [ClassificationConfig] instance.
|
||||
ClassificationConfig({
|
||||
required this.enabled,
|
||||
required this.minScore,
|
||||
required this.modelName,
|
||||
this.modelType,
|
||||
});
|
||||
|
||||
bool enabled;
|
||||
|
||||
int minScore;
|
||||
|
||||
String modelName;
|
||||
|
||||
///
|
||||
/// Please note: This property should have been non-nullable! Since the specification file
|
||||
/// does not include a default value (using the "default:" property), however, the generated
|
||||
/// source code must fall back to having a nullable type.
|
||||
/// Consider adding a "default:" property in the specification file to hide this note.
|
||||
///
|
||||
ModelType? modelType;
|
||||
|
||||
@override
|
||||
bool operator ==(Object other) => identical(this, other) || other is ClassificationConfig &&
|
||||
other.enabled == enabled &&
|
||||
other.minScore == minScore &&
|
||||
other.modelName == modelName &&
|
||||
other.modelType == modelType;
|
||||
|
||||
@override
|
||||
int get hashCode =>
|
||||
// ignore: unnecessary_parenthesis
|
||||
(enabled.hashCode) +
|
||||
(minScore.hashCode) +
|
||||
(modelName.hashCode) +
|
||||
(modelType == null ? 0 : modelType!.hashCode);
|
||||
|
||||
@override
|
||||
String toString() => 'ClassificationConfig[enabled=$enabled, minScore=$minScore, modelName=$modelName, modelType=$modelType]';
|
||||
|
||||
Map<String, dynamic> toJson() {
|
||||
final json = <String, dynamic>{};
|
||||
json[r'enabled'] = this.enabled;
|
||||
json[r'minScore'] = this.minScore;
|
||||
json[r'modelName'] = this.modelName;
|
||||
if (this.modelType != null) {
|
||||
json[r'modelType'] = this.modelType;
|
||||
} else {
|
||||
// json[r'modelType'] = null;
|
||||
}
|
||||
return json;
|
||||
}
|
||||
|
||||
/// Returns a new [ClassificationConfig] instance and imports its values from
|
||||
/// [value] if it's a [Map], null otherwise.
|
||||
// ignore: prefer_constructors_over_static_methods
|
||||
static ClassificationConfig? fromJson(dynamic value) {
|
||||
if (value is Map) {
|
||||
final json = value.cast<String, dynamic>();
|
||||
|
||||
return ClassificationConfig(
|
||||
enabled: mapValueOfType<bool>(json, r'enabled')!,
|
||||
minScore: mapValueOfType<int>(json, r'minScore')!,
|
||||
modelName: mapValueOfType<String>(json, r'modelName')!,
|
||||
modelType: ModelType.fromJson(json[r'modelType']),
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
static List<ClassificationConfig> listFromJson(dynamic json, {bool growable = false,}) {
|
||||
final result = <ClassificationConfig>[];
|
||||
if (json is List && json.isNotEmpty) {
|
||||
for (final row in json) {
|
||||
final value = ClassificationConfig.fromJson(row);
|
||||
if (value != null) {
|
||||
result.add(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
return result.toList(growable: growable);
|
||||
}
|
||||
|
||||
static Map<String, ClassificationConfig> mapFromJson(dynamic json) {
|
||||
final map = <String, ClassificationConfig>{};
|
||||
if (json is Map && json.isNotEmpty) {
|
||||
json = json.cast<String, dynamic>(); // ignore: parameter_assignments
|
||||
for (final entry in json.entries) {
|
||||
final value = ClassificationConfig.fromJson(entry.value);
|
||||
if (value != null) {
|
||||
map[entry.key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
// maps a json object with a list of ClassificationConfig-objects as value to a dart map
|
||||
static Map<String, List<ClassificationConfig>> mapListFromJson(dynamic json, {bool growable = false,}) {
|
||||
final map = <String, List<ClassificationConfig>>{};
|
||||
if (json is Map && json.isNotEmpty) {
|
||||
// ignore: parameter_assignments
|
||||
json = json.cast<String, dynamic>();
|
||||
for (final entry in json.entries) {
|
||||
map[entry.key] = ClassificationConfig.listFromJson(entry.value, growable: growable,);
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
/// The list of required keys that must be present in a JSON.
|
||||
static const requiredKeys = <String>{
|
||||
'enabled',
|
||||
'minScore',
|
||||
'modelName',
|
||||
};
|
||||
}
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
//
|
||||
// AUTO-GENERATED FILE, DO NOT MODIFY!
|
||||
//
|
||||
// @dart=2.12
|
||||
|
||||
// ignore_for_file: unused_element, unused_import
|
||||
// ignore_for_file: always_put_required_named_parameters_first
|
||||
// ignore_for_file: constant_identifier_names
|
||||
// ignore_for_file: lines_longer_than_80_chars
|
||||
|
||||
import 'package:openapi/api.dart';
|
||||
import 'package:test/test.dart';
|
||||
|
||||
// tests for ClassificationConfig
|
||||
void main() {
|
||||
// final instance = ClassificationConfig();
|
||||
|
||||
group('test ClassificationConfig', () {
|
||||
// bool enabled
|
||||
test('to test the property `enabled`', () async {
|
||||
// TODO
|
||||
});
|
||||
|
||||
// int minScore
|
||||
test('to test the property `minScore`', () async {
|
||||
// TODO
|
||||
});
|
||||
|
||||
// String modelName
|
||||
test('to test the property `modelName`', () async {
|
||||
// TODO
|
||||
});
|
||||
|
||||
// ModelType modelType
|
||||
test('to test the property `modelType`', () async {
|
||||
// TODO
|
||||
});
|
||||
|
||||
|
||||
});
|
||||
|
||||
}
|
||||
Loading…
Reference in New Issue