百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

NSFW 图片分类(图片cnn分类)

nanshan 2025-03-28 19:24 8 浏览 0 评论

NSFW指的是 不适宜工作场所("Not Safe (or Suitable) For Work;")。在本文中,我介绍如何创建一个检测NSFW图像的图像分类模型。

数据集

由于数据集的性质,我们无法从一些数据集的网站(如Kaggle等)获得所有图像。

但是我们找到了一个专门抓取这种类型图片的github库,所以我们可以直接使用。clone项目后可以运行下面的代码来创建文件夹,并将每个图像下载到其特定的文件夹中。

folders = ['drawings','hentai','neutral','porn','sexy']
urls = ['urls_drawings.txt','urls_hentai.txt','urls_neutral.txt','urls_porn.txt','urls_sexy.txt']
names = ['d','h','n','p','s']
for i,j,k in zip(folders,urls,names):
try:
#Specify the path of the folder that has to be made
folder_path = os.path.join('your directory',i)
os.mkdir(folder_path)
except:
pass
#setup the path of url text file
url_path = os.path.join('Datasets_Urls',j)
my_file = open(url_path, "r")
data = my_file.read()
#create a list with all urls
data_into_list = data.split("\n")
my_file.close()
icount = 0
for ii in data_into_list:
try:
#create a unique image names for each images
image_name = 'image'+str(icount)+str(k)+'.png'
image_path = os.path.join(folder_path,image_name)
#download it using the library
urllib.request.urlretrieve(ii, image_path)
icount+=1
except Exception as e:
pass
#this below code is done to make the count of the image same for all the data 
#you can use a big number if you are building a more complex model or if you have a good system
if icount == 2000:
break

这里的folder变量表示类的名称,urls变量用于获取URL文本文件(可以根据文本文件名更改它),name变量用于为每个图像创建唯一的名称。

上面代码将为每个类下载2000张图像,可以编辑最后一个“if”条件来更改下载图像的个数。

数据准备

我们下载的文件夹可能包含其他类型的文件,所以首先必须删除不需要的类型的文件。

image_exts = ['jpeg','.jpg','bmp','png']
path_list = ['drawings','hentai','neutral','porn','sexy']
cwd = os.getcwd()
def remove_other_images(path_list):
for ii in path_list:
data_dir = os.path.join(cwd,'DataSet',ii)
for image in os.listdir(os.path.join(data_dir)):
image_path = os.path.join(data_dir,image_class,image)
try:
img = cv2.imread(image_path)
tip = imghdr.what(image_path)
if tip not in image_exts:
print('Image not in ext list {}'.format(image_path))
os.remove(image_path)
except Exception as e:
print("Issue with image {}".format(image_path))
remove_other_images(path_list)

上面的代码删除了扩展名不是指定格式的图像。

另外图像可能包含许多重复的图像,所以我们必须从每个文件夹中删除重复的图像。

cwd = os.getcwd()
path_list = ['drawings','hentai','neutral','porn','sexy']
def remove_dup_images(path_list):
for ii in path_list:
os.chdir(os.path.join(cwd,'DataSet',ii))
filelist = os.listdir()
duplicates = []
hash_keys = dict()
for index, filename in enumerate(filelist):
if os.path.isfile(filename):
with open(filename,'rb') as f:
filehash = hashlib.md5(f.read()).hexdigest()
if filehash not in hash_keys:
hash_keys[filehash] = index
else:
duplicates.append((index,hash_keys[filehash]))

for index in duplicates:
os.remove(filelist[index[0]])
print('{} duplicates removed from {}'.format(len(duplicates),ii))
remove_dup_images(path_list)

这里我们使用hashlib.md5编码来查找每个类中的重复图像。

Md5为每个图像创建一个唯一的哈希值,如果哈希值重复(重复图像),那么我们将重复图片添加到一个列表中,稍后进行删除。

因为使用TensorFlow框架所以需要判断是否被TensorFlow支持,所以我们这里加一个判断:

import tensorflow as tf
os.chdir('{data-set} directory')
cwd = os.getcwd()
for ii in path_list:
os.chdir(os.path.join(cwd,ii))
filelist = os.listdir()
for image_file in filelist:
with open(image_file, 'rb') as f:
image_data = f.read()
# Check the file format
_, ext = os.path.splitext(image_file)
if ext.lower() not in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
print('Unsupported image format:', ext)
os.remove(os.path.join(cwd,ii,image_file)) 
else:
# Decode the image
try:
image = tf.image.decode_image(image_data)
except:
print(image_file)
print("unspported")
os.remove(os.path.join(cwd,ii,image_file))

以上就是数据准备的所有工作,在清理完数据后,我们可以拆分数据。比如分割创建一个训练、验证和测试文件夹,并手动添加文件夹中的图像,我们将80%用于训练,10%用于验证,10%用于测试。

模型

首先导入tensorflow

import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
import hashlib
from imageio import imread
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.layers import Flatten,Dense,Input
from tensorflow.keras.models import Model,Sequential
from keras import optimizers

对于图像,默认大小设置为224,224。

IMAGE_SIZE = [224,224]

可以使用ImageDataGenerator库,进行数据增强。数据增强也叫数据扩充,是为了增加数据集的大小。ImageDataGenerator根据给定的参数创建新图像,并将其用于训练(注意:当使用ImageDataGenerator时,原始数据将不用于训练)。

train_datagen = ImageDataGenerator(
rescale=1./255,
preprocessing_function=preprocess_input,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')

对于测试集也是这样:

test_datagen = ImageDataGenerator(rescale=1./255)

为了演示,我们直接使用VGG模型

vgg = VGG16(input_shape=IMAGE_SIZE+[3],weights='imagenet',include_top=False

然后冻结前面的层:

for layer in vgg.layers:
layer.trainable = False

最后我们加入自己的分类头:

x = Flatten()(vgg.output)
prediction = Dense(5,activation='softmax')(x)
model = Model(inputs=vgg.input, outputs=prediction)
model.summary()

模型是这样的:

训练

看看我们训练集:

train_set = train_datagen.flow_from_directory('DataSet/train',
target_size=(224,224),
batch_size=32,
class_mode='sparse')

验证集

val_set = train_datagen.flow_from_directory('DataSet/validation',
target_size=(224,224),
batch_size=32,
class_mode='sparse')

使用'
sparse_categorical_crossentropy '损失,这样可以将标签编码为整数而不是独热编码。

from tensorflow.keras.metrics import MeanSquaredError
from tensorflow.keras.metrics import CategoricalAccuracy
adam = optimizers.Adam()
model.compile(loss='sparse_categorical_crossentropy',
optimizer=adam,
metrics=['accuracy',MeanSquaredError(name='val_loss'),CategoricalAccuracy(name='val_accuracy')])

然后就可以训练了:

from datetime import datetime
from keras.callbacks import ModelCheckpoint
log_dir = 'vg_log'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = log_dir)
start = datetime.now()
history = model.fit_generator(train_set,
validation_data=val_set,
epochs=100,
steps_per_epoch=len(train_set)// batch_size,
validation_steps=len(val_set)//batch_size,
callbacks=[tensorboard_callback],
verbose=1)
duration = datetime.now() - start
print("Time taken for training is ",duration)

模型训练了100次。得到了80%的验证准确率。f1得分为93%

预测

下面的函数将获取一个图像列表并根据该列表进行预测。

import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
def print_classes(images,model):
classes = ['Drawing','Hentai','Neutral','Porn','Sexual']
fig, ax = plt.subplots(ncols=len(images), figsize=(20,20))
for idx,img in enumerate(images):
img = mpimg.imread(img)
resize = tf.image.resize(img,(224,224))
result = model.predict(np.expand_dims(resize/255,0))
result = np.argmax(result)
if classes[result] == 'Porn':
img = gaussian_filter(img, sigma=6)
elif classes[result] == 'Sexual':
img = gaussian_filter(img, sigma=6)
elif classes[result] == 'Hentai':
img = gaussian_filter(img, sigma=6)
ax[idx].imshow(img)
ax[idx].title.set_text(classes[result])
li = ['test1.jpeg','test2.jpeg','test3.jpeg','test4.jpeg','test5.jpeg']
print_classes(li,model)

看结果还是可以的。

最后,本文的源代码:

https://github.com/Nikhilthalappalli/NSFW-Classifier

但是我觉得源代码不重要,Alex Kim的项目才是你需要的:

https://github.com/alex000kim/nsfw_data_scraper

作者:Nikhil Thalappalli

相关推荐

ssh终端xshell日志查看命令(xshell怎么看日志)

现在我们云服务器运维较多用的是SSH工具,其中常用的包括PUTTY、XSHELL等,其实大同小异界面UI稍微不同,但是都可以进入远程连接。这里有朋友提到如何查看服务器的日志文件,这个其实和是否使用XS...

使用 Fail Ban 日志分析 SSH 攻击行为

通过分析`fail2ban`日志可以识别和应对SSH暴力破解等攻击行为。以下是详细的操作流程和关键分析方法:---###**一、Fail2ban日志位置**Fail2ban的日志路径因系统配置...

如何高效读取Linux日志文件?这些命令要熟记于心!

在Linux系统中,日志文件通常存储在/var/log目录下。比如,/var/log/syslog(或/var/log/messages,视发行版而定)记录系统整体事件,/var/log/a...

Windows服务器远程登录日志查询方法,linux查看登录日志方法

概述本文介绍Windows、Linux服务器查询系统的远程登录日志方法。根据服务器所使用的操作系统不同,有以下两种查询方法。Linux操作系统的登录日志查询通过远程连接登录Linux服务器,使用roo...

iptables防火墙如何记录日志(防火墙日志查看)

例如:记录所有ssh服务的登录的日志首先,我们需要了解如何将所有的iptables的INPUT链数据包记录到/var/log/messages中。如果你已经有一些iptables规则了,那么将记录日志...

如何安全管理SSH密钥以防止服务器被入侵

SSH密钥安全管理实施指南(2025年更新版)一、密钥生成与存储规范高强度密钥生成bashCopyCodessh-keygen-ted25519-a100#生成ED25519算法密钥(比...

在CentOS上安装nginx服务器(centos搭建代理服务器)

一、环境描述1.虚拟机配置CPU:单核内存:2GB硬盘:120GBIP:10.24.17.1082.操作系统版本:CentOS6.6x86_64安装方式:Minimal3.虚拟化环境VM...

CentOS7安全加固的一份整理规划建议

◆更新系统:及时更新CentOS7操作系统版本和安全补丁,确保系统以最新状态运行。◆关闭不必要的服务:在运行系统时,应关闭不需要的服务和端口,以减少系统暴露的攻击面。◆安装防火墙:使用iptables...

第四十七天-二叉树,centOS安装tomcat,Maven,vsftpd

学习笔记:1.Maven是Apache下的一个纯Java开发的开源项目。基于项目对象模型(缩写:POM)概念,Maven利用一个中央信息片断能管理一个项目的构建、报告和文档等步骤。Maven...

Linux远程桌面连接使用教程 Widows终端远程连接Linux服务器

一、前言为什么不是远程连接Linux服务器?因为我不会,远程连接window我就用电脑自带的“远程桌面连接”。以下所述都是在CentOS操作系统下的。服务器刚换成Linux的时候很迷茫,感觉无从下手...

CentOS 安全加固操作,保护你的操作系统

系统加固是保障系统安全的重要手段,对于维护企业数据安全、用户隐私以及系统稳定运行具有重要意义。加固后的系统更加健壮和稳定,能够有效减少因安全问题导致的系统故障和停机时间,提高系统的可用性和可靠性。通过...

Dockerfile部署Java项目(docker如何部署java项目)

1、概述本文主要会简单介绍什么是Docker,什么是Dockerfile,如何安装Docker,Dockerfile如何编写,如何通过Dockerfile安装jar包并外置yaml文件以及如何通过do...

CentOS7云主机部署Fail2ban阻断SSH暴力破解

关于Fail2banFail2ban可以监视你的系统日志,然后匹配日志的错误信息(正则式匹配)执行相应的屏蔽动作(一般情况下是调用防火墙屏蔽)例如:当有人在试探你的HTTP、SSH、SMTP、FTP密...

在CentOS7上用源码编译安装PostgreSQL

1、新建postgres用户#useraddpostgres&&passwdpostgres2、安装依赖包#yum-yinstallmakegccgcc-c++readline...

pure-ftpd 使用(ftp prompt命令)

pure-ftpd是一个免费的ftp软件,其他介绍就不多说了。我们直接开始主题安装centosyuminstallepel-releaseyuminstallpure-ftpd配置备份原配置...

取消回复欢迎 发表评论: